diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..22f8603a --- /dev/null +++ b/.clang-format @@ -0,0 +1,25 @@ +# A clang-format style that approximates Python's PEP 7 +# Useful for IDE integration +# +# Based on Paul Ganssle's version at +# https://gist.github.com/pganssle/0e3a5f828b4d07d79447f6ced8e7e4db +BasedOnStyle: Google +AlwaysBreakAfterReturnType: All +AllowShortIfStatementsOnASingleLine: false +AlignAfterOpenBracket: Align +AlignTrailingComments: true +BreakBeforeBraces: Stroustrup +ColumnLimit: 79 +DerivePointerAlignment: false +IndentWidth: 4 +Language: Cpp +PointerAlignment: Right +ReflowComments: true +SpaceBeforeParens: ControlStatements +SpacesInParentheses: false +TabWidth: 4 +UseCRLF: false +UseTab: Never +StatementMacros: + - Py_BEGIN_ALLOW_THREADS + - Py_END_ALLOW_THREADS \ No newline at end of file diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000..5aced2f4 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,14 @@ +FROM mcr.microsoft.com/devcontainers/base:jammy + +ENV PYTHONUNBUFFERED 1 + +# [Optional] If your requirements rarely change, uncomment this section to add them to the image. +# COPY requirements.txt /tmp/pip-tmp/ +# RUN pip3 --disable-pip-version-check --no-cache-dir install -r /tmp/pip-tmp/requirements.txt \ +# && rm -rf /tmp/pip-tmp + +# [Optional] Uncomment this section to install additional OS packages. +# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ +# && apt-get -y install --no-install-recommends + +CMD ["sleep", "infinity"] diff --git a/.devcontainer/dev.env b/.devcontainer/dev.env new file mode 100644 index 00000000..996ee8d2 --- /dev/null +++ b/.devcontainer/dev.env @@ -0,0 +1,11 @@ +PGHOST=pg15 +PGPORT=5432 +PGDATABASE=test +PGUSER=test +PGPASSWORD=test + +PYGRESQL_DB=test +PYGRESQL_HOST=pg15 +PYGRESQL_PORT=5432 +PYGRESQL_USER=test +PYGRESQL_PASSWD=test diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..0333b8e6 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,64 @@ +// For format details, see https://aka.ms/devcontainer.json. For config options, see the +// README at: https://github.com/devcontainers/templates/tree/main/src/ubuntu +{ + "name": "PyGreSQL", + // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile + "dockerComposeFile": "docker-compose.yml", + "service": "dev", + "workspaceFolder": "/workspace", + "customizations": { + "vscode": { + // Set *default* container specific settings.json values on container create. + "settings": { + "terminal.integrated.profiles.linux": { + "bash": { + "path": "/bin/bash" + } + }, + "sqltools.connections": [ + { + "name": "Container database", + "driver": "PostgreSQL", + "previewLimit": 50, + "server": "pg15", + "port": 5432, + "database": "test", + "username": "test", + "password": "test" + } + ], + "python.pythonPath": "/usr/local/bin/python", + "python.analysis.typeCheckingMode": "basic", + "python.testing.unittestEnabled": true, + "editor.formatOnSave": true, + "editor.renderWhitespace": "all", + "editor.rulers": [ + 79 + ] + }, + // Add the IDs of extensions you want installed when the container is created. + "extensions": [ + "ms-azuretools.vscode-docker", + "ms-python.python", + "ms-vscode.cpptools", + "mtxr.sqltools", + "njpwerner.autodocstring", + "redhat.vscode-yaml", + "eamodio.gitlens", + "charliermarsh.ruff", + "streetsidesoftware.code-spell-checker", + "lextudio.restructuredtext" + ] + } + }, + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, + // Use 'forwardPorts' to make a list of ports inside the container available locally. + // "forwardPorts": [], + // Use 'postCreateCommand' to run commands after the container is created. + "postCreateCommand": "sudo bash /workspace/.devcontainer/provision.sh" + // Configure tool-specific properties. + // "customizations": {}, + // Uncomment to connect as root instead. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "root" +} \ No newline at end of file diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml new file mode 100644 index 00000000..541d63e9 --- /dev/null +++ b/.devcontainer/docker-compose.yml @@ -0,0 +1,80 @@ +services: + dev: + build: + context: . + dockerfile: ./Dockerfile + + env_file: dev.env + + volumes: + - ..:/workspace:cached + + command: sleep infinity + + pg10: + image: postgres:10 + restart: unless-stopped + volumes: + - postgres-data-10:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg12: + image: postgres:12 + restart: unless-stopped + volumes: + - postgres-data-12:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg14: + image: postgres:14 + restart: unless-stopped + volumes: + - postgres-data-14:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg15: + image: postgres:15 + restart: unless-stopped + volumes: + - postgres-data-15:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg16: + image: postgres:16 + restart: unless-stopped + volumes: + - postgres-data-16:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + + pg17: + image: postgres:17 + restart: unless-stopped + volumes: + - postgres-data-17:/var/lib/postgresql/data + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + +volumes: + postgres-data-10: + postgres-data-12: + postgres-data-14: + postgres-data-15: + postgres-data-16: + postgres-data-17: diff --git a/.devcontainer/provision.sh b/.devcontainer/provision.sh new file mode 100644 index 00000000..1ca7b020 --- /dev/null +++ b/.devcontainer/provision.sh @@ -0,0 +1,89 @@ +#!/usr/bin/bash + +# install development environment for PyGreSQL + +export DEBIAN_FRONTEND=noninteractive + +apt-get update +apt-get -y upgrade + +# install base utilities and configure time zone + +ln -fs /usr/share/zoneinfo/UTC /etc/localtime +apt-get install -y apt-utils software-properties-common +ap-get install -y tzdata +dpkg-reconfigure --frontend noninteractive tzdata + +apt-get install -y rpm wget zip + +# install all supported Python versions + +add-apt-repository -y ppa:deadsnakes/ppa +apt-get update + +apt-get install -y python3.7 python3.7-dev python3.7-distutils +apt-get install -y python3.8 python3.8-dev python3.8-distutils +apt-get install -y python3.9 python3.9-dev python3.9-distutils +apt-get install -y python3.10 python3.10-dev python3.10-distutils +apt-get install -y python3.11 python3.11-dev python3.11-distutils +apt-get install -y python3.12 python3.12-dev python3.12-distutils +apt-get install -y python3.13 python3.13-dev python3.13-distutils + +# install build and testing tool + +python -m ensurepip -U + +python3.7 -m pip install -U pip setuptools wheel build +python3.8 -m pip install -U pip setuptools wheel build +python3.9 -m pip install -U pip setuptools wheel build +python3.10 -m pip install -U pip setuptools wheel build +python3.11 -m pip install -U pip setuptools wheel build +python3.12 -m pip install -U pip setuptools wheel build +python3.13 -m pip install -U pip setuptools wheel build + +pip install ruff + +apt-get install -y tox clang-format +pip install -U tox + +# install PostgreSQL client tools + +apt-get install -y postgresql libpq-dev + +for pghost in pg10 pg12 pg14 pg15 pg16 pg17 +do + export PGHOST=$pghost + export PGDATABASE=postgres + export PGUSER=postgres + export PGPASSWORD=postgres + + createdb -E UTF8 -T template0 test + createdb -E SQL_ASCII -T template0 test_ascii + createdb -E LATIN1 -l C -T template0 test_latin1 + createdb -E LATIN9 -l C -T template0 test_latin9 + createdb -E ISO_8859_5 -l C -T template0 test_cyrillic + + psql -c "create user test with password 'test'" + + psql -c "grant create on database test to test" + psql -c "grant create on database test_ascii to test" + psql -c "grant create on database test_latin1 to test" + psql -c "grant create on database test_latin9 to test" + psql -c "grant create on database test_cyrillic to test" + + psql -c "grant create on schema public to test" test + psql -c "grant create on schema public to test" test_ascii + psql -c "grant create on schema public to test" test_latin1 + psql -c "grant create on schema public to test" test_latin9 + psql -c "grant create on schema public to test" test_cyrillic + + psql -c "create extension hstore" test + psql -c "create extension hstore" test_ascii + psql -c "create extension hstore" test_latin1 + psql -c "create extension hstore" test_latin9 + psql -c "create extension hstore" test_cyrillic +done + +export PGDATABASE=test +export PGUSER=test +export PGPASSWORD=test diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 00000000..d88cd64a --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,38 @@ +name: Publish PyGreSQL documentation + +on: + push: + branches: + - main + +jobs: + docs: + name: Build documentation + runs-on: ubuntu-22.04 + + steps: + - name: Check out repository + uses: actions/checkout@v4 + - name: Set up Python 3.13 + uses: actions/setup-python@v5 + with: + python-version: 3.13 + - name: Install dependencies + run: | + sudo apt install libpq-dev + python -m pip install --upgrade pip + pip install . + pip install "sphinx>=8,<9" + - name: Create docs with Sphinx + run: | + cd docs + make html + - name: Deploy docs to GitHub pages + uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_branch: gh-pages + publish_dir: docs/_build/html + cname: pygresql.org + enable_jekyll: false + force_orphan: true diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..66d79095 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,26 @@ +name: Run PyGreSQL quality checks + +on: + push: + pull_request: + +jobs: + checks: + name: Quality checks run + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + + steps: + - name: Check out repository + uses: actions/checkout@v4 + - name: Install tox + run: pip install tox + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: 3.13 + - name: Run quality checks + run: tox -e ruff,mypy,cformat,docs + timeout-minutes: 5 diff --git a/.github/workflows/release-docs.yml b/.github/workflows/release-docs.yml deleted file mode 100644 index 2b77e8db..00000000 --- a/.github/workflows/release-docs.yml +++ /dev/null @@ -1,38 +0,0 @@ -name: Release PyGreSQL documentation - -on: - push: - branches: - - master - -jobs: - build: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v1 - - name: Set up Python 3.7 - uses: actions/setup-python@v1 - with: - python-version: 3.7 - - name: Install dependencies - run: | - sudo apt install libpq-dev - python -m pip install --upgrade pip - pip install . - pip install "sphinx>=2.4,<3" - pip install "cloud_sptheme>=1.10,<2" - - name: Create docs with Sphinx - run: | - cd docs - make html - - name: Deploy docs to GitHub pages - uses: peaceiris/actions-gh-pages@v3 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_branch: gh-pages - publish_dir: docs/_build/html - cname: www.pygresql.org - enable_jekyll: false - force_orphan: true diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..920e3f3e --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,69 @@ +name: Run PyGreSQL test matrix + +# this has been shamelessly copied from Psycopg + +on: + push: + pull_request: + +jobs: + tests: + name: Unit tests run + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + include: + - { python: "3.7", postgres: "11" } + - { python: "3.8", postgres: "12" } + - { python: "3.9", postgres: "13" } + - { python: "3.10", postgres: "14" } + - { python: "3.11", postgres: "15" } + - { python: "3.12", postgres: "16" } + - { python: "3.13", postgres: "17" } + + # Opposite extremes of the supported Py/PG range, other architecture + - { python: "3.7", postgres: "17", architecture: "x86" } + - { python: "3.8", postgres: "16", architecture: "x86" } + - { python: "3.9", postgres: "15", architecture: "x86" } + - { python: "3.10", postgres: "14", architecture: "x86" } + - { python: "3.11", postgres: "13", architecture: "x86" } + - { python: "3.12", postgres: "12", architecture: "x86" } + - { python: "3.13", postgres: "11", architecture: "x86" } + + env: + PYGRESQL_DB: test + PYGRESQL_HOST: 127.0.0.1 + PYGRESQL_USER: test + PYGRESQL_PASSWD: test + + services: + postgresql: + image: postgres:${{ matrix.postgres }} + env: + POSTGRES_USER: test + POSTGRES_PASSWORD: test + ports: + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + + steps: + - name: Check out repository + uses: actions/checkout@v4 + - name: Install tox + run: pip install tox + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + - name: Run tests + env: + MATRIX_PYTHON: ${{ matrix.python }} + run: tox -e py${MATRIX_PYTHON/./} + timeout-minutes: 5 diff --git a/.gitignore b/.gitignore index f826aa80..22c5ce3c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,14 @@ *~ *.bak *.cache +*.dll *.egg-info *.log *.patch *.pid *.pstats -*.py[co] +*.py[cdo] +*.so *.swp __pycache__/ @@ -18,15 +20,16 @@ _build_doctrees/ /local/ /tests/LOCAL_*.py -Vagrantfile .coverage .tox/ .venv/ .vagrant/ +.vagrant-*/ Thumbs.db .DS_Store .idea/ .vs/ +.vscode/ diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..9712e405 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,22 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/conf.py + +# We recommend specifying your dependencies to enable reproducible builds: +# https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +python: + install: + - requirements: docs/requirements.txt diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index c9976f0d..00000000 --- a/.travis.yml +++ /dev/null @@ -1,25 +0,0 @@ -# Travis CI configuration -# see https://docs.travis-ci.com/user/languages/python - -language: python - -python: - - "2.7" - - "3.4" - - "3.5" - - "3.6" - - "3.7" - -install: - - pip install . - -script: python setup.py test - -addons: - postgresql: "10" - -services: - - postgresql - -before_script: - - psql -U postgres -c 'create database unittest' diff --git a/LICENSE.txt b/LICENSE.txt index 4ff09c11..e905706e 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -6,7 +6,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain -Further modifications copyright (c) 2009-2020 by the PyGreSQL Development Team +Further modifications copyright (c) 2009-2025 by the PyGreSQL Development Team PyGreSQL is released under the PostgreSQL License, a liberal Open Source license, similar to the BSD or MIT licenses: diff --git a/MANIFEST.in b/MANIFEST.in index 239841c7..8d4bbd33 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,15 +1,18 @@ -include *.c -include *.h -include *.py +include setup.py + +recursive-include pg *.py *.pyi py.typed +recursive-include pgdb *.py py.typed +recursive-include tests *.py + +include ext/*.c +include ext/*.h include README.rst include LICENSE.txt include tox.ini - -recursive-include tests *.py -exclude tests/LOCAL_PyGreSQL.py +include pyproject.toml include docs/Makefile include docs/make.bat @@ -20,5 +23,4 @@ exclude docs/index.rst recursive-include docs/community *.rst recursive-include docs/contents *.rst recursive-include docs/download *.rst -recursive-include docs/_static *.css_t *.ico *.png -recursive-include docs/_templates *.html +recursive-include docs/_static *.ico *.png diff --git a/README.rst b/README.rst index b3950fdc..46a09c2b 100644 --- a/README.rst +++ b/README.rst @@ -2,14 +2,24 @@ PyGreSQL - Python interface for PostgreSQL ========================================== PyGreSQL is a Python module that interfaces to a PostgreSQL database. -It embeds the PostgreSQL query library to allow easy use of the powerful -PostgreSQL features from a Python script. +It wraps the lower level C API library libpq to allow easy use of the +powerful PostgreSQL features from Python. PyGreSQL should run on most platforms where PostgreSQL and Python is running. It is based on the PyGres95 code written by Pascal Andre. -D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with version 2.0 +D'Arcy J. M. Cain renamed it to PyGreSQL starting with version 2.0 and serves as the "BDFL" of PyGreSQL. -Starting with version 5.0, PyGreSQL also supports Python 3. +Christoph Zwerschke volunteered as another maintainer and has been the main +contributor since version 3.7 of PyGreSQL. + +The following Python versions are supported: + +* PyGreSQL 4.x and earlier: Python 2 only +* PyGreSQL 5.x: Python 2 and Python 3 +* PyGreSQL 6.x and newer: Python 3 only + +The current version of PyGreSQL supports Python versions 3.7 to 3.13 +and PostgreSQL versions 10 to 17 on the server. Installation ------------ @@ -21,9 +31,13 @@ The simplest way to install PyGreSQL is to type:: For other ways of installing PyGreSQL and requirements, see the documentation. +Note that PyGreSQL also requires the libpq shared library to be +installed and accessible on the client machine. + Documentation ------------- -The documentation is available at `www.pygresql.org `_. - -At mirror of the documentation can be found at `pygresql.readthedocs.io `_. +The documentation is available at +`pygresql.github.io/ `_ and at +`pygresql.readthedocs.io `_, +where you can also find the documentation for older versions. diff --git a/docs/.gitignore b/docs/.gitignore deleted file mode 100644 index 4a579446..00000000 --- a/docs/.gitignore +++ /dev/null @@ -1 +0,0 @@ -index.rst \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile index 0a1113c9..d4bb2cbb 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -1,192 +1,20 @@ -# Makefile for Sphinx documentation +# Minimal makefile for Sphinx documentation # -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . BUILDDIR = _build -# User-friendly check for sphinx-build -ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) -$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) -endif - -# Internal variables. -PAPEROPT_a4 = -D latex_paper_size=a4 -PAPEROPT_letter = -D latex_paper_size=letter -ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . -# the i18n builder cannot share the environment and doctrees with the others -I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . - -.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest coverage gettext - +# Put it first so that "make" without argument is like "make help". help: - @echo "Please use \`make ' where is one of" - @echo " html to make standalone HTML files" - @echo " dirhtml to make HTML files named index.html in directories" - @echo " singlehtml to make a single large HTML file" - @echo " pickle to make pickle files" - @echo " json to make JSON files" - @echo " htmlhelp to make HTML files and a HTML help project" - @echo " qthelp to make HTML files and a qthelp project" - @echo " applehelp to make an Apple Help Book" - @echo " devhelp to make HTML files and a Devhelp project" - @echo " epub to make an epub" - @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" - @echo " latexpdf to make LaTeX files and run them through pdflatex" - @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" - @echo " text to make text files" - @echo " man to make manual pages" - @echo " texinfo to make Texinfo files" - @echo " info to make Texinfo files and run them through makeinfo" - @echo " gettext to make PO message catalogs" - @echo " changes to make an overview of all changed/added/deprecated items" - @echo " xml to make Docutils-native XML files" - @echo " pseudoxml to make pseudoxml-XML files for display purposes" - @echo " linkcheck to check all external links for integrity" - @echo " doctest to run all doctests embedded in the documentation (if enabled)" - @echo " coverage to run coverage check of the documentation (if enabled)" - -clean: - rm -rf $(BUILDDIR)/* - -html: - $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." - -dirhtml: - $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml - @echo - @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." - -singlehtml: - $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml - @echo - @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." - -pickle: - $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle - @echo - @echo "Build finished; now you can process the pickle files." - -json: - $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json - @echo - @echo "Build finished; now you can process the JSON files." - -htmlhelp: - $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp - @echo - @echo "Build finished; now you can run HTML Help Workshop with the" \ - ".hhp project file in $(BUILDDIR)/htmlhelp." - -qthelp: - $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp - @echo - @echo "Build finished; now you can run "qcollectiongenerator" with the" \ - ".qhcp project file in $(BUILDDIR)/qthelp, like this:" - @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/PyGreSQL.qhcp" - @echo "To view the help file:" - @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/PyGreSQL.qhc" - -applehelp: - $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp - @echo - @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." - @echo "N.B. You won't be able to view it unless you put it in" \ - "~/Library/Documentation/Help or install it in your application" \ - "bundle." - -devhelp: - $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp - @echo - @echo "Build finished." - @echo "To view the help file:" - @echo "# mkdir -p $$HOME/.local/share/devhelp/PyGreSQL" - @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/PyGreSQL" - @echo "# devhelp" - -epub: - $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub - @echo - @echo "Build finished. The epub file is in $(BUILDDIR)/epub." - -latex: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo - @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." - @echo "Run \`make' in that directory to run these through (pdf)latex" \ - "(use \`make latexpdf' here to do that automatically)." - -latexpdf: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through pdflatex..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -latexpdfja: - $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex - @echo "Running LaTeX files through platex and dvipdfmx..." - $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja - @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." - -text: - $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text - @echo - @echo "Build finished. The text files are in $(BUILDDIR)/text." - -man: - $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man - @echo - @echo "Build finished. The manual pages are in $(BUILDDIR)/man." - -texinfo: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo - @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." - @echo "Run \`make' in that directory to run these through makeinfo" \ - "(use \`make info' here to do that automatically)." - -info: - $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo - @echo "Running Texinfo files through makeinfo..." - make -C $(BUILDDIR)/texinfo info - @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." - -gettext: - $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale - @echo - @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." - -changes: - $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes - @echo - @echo "The overview file is in $(BUILDDIR)/changes." - -linkcheck: - $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck - @echo - @echo "Link check complete; look for any errors in the above output " \ - "or in $(BUILDDIR)/linkcheck/output.txt." - -doctest: - $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest - @echo "Testing of doctests in the sources finished, look at the " \ - "results in $(BUILDDIR)/doctest/output.txt." - -coverage: - $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage - @echo "Testing of coverage in the sources finished, look at the " \ - "results in $(BUILDDIR)/coverage/python.txt." + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -xml: - $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml - @echo - @echo "Build finished. The XML files are in $(BUILDDIR)/xml." +.PHONY: help Makefile -pseudoxml: - $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml - @echo - @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_static/pygresql.css_t b/docs/_static/pygresql.css_t deleted file mode 100644 index a3bc4de2..00000000 --- a/docs/_static/pygresql.css_t +++ /dev/null @@ -1,86 +0,0 @@ -{% macro experimental(keyword, value) %} - {% if value %} - -moz-{{keyword}}: {{value}}; - -webkit-{{keyword}}: {{value}}; - -o-{{keyword}}: {{value}}; - -ms-{{keyword}}: {{value}}; - {{keyword}}: {{value}}; - {% endif %} -{% endmacro %} - -{% macro border_radius(value) -%} - {{experimental("border-radius", value)}} -{% endmacro %} - -{% macro box_shadow(value) -%} - {{experimental("box-shadow", value)}} -{% endmacro %} - -.pageheader.related { - text-align: left; - padding: 10px 15px; - border: 1px solid #eeeeee; - margin-bottom: 10px; - {{border_radius("1em 1em 1em 1em")}} - {% if theme_borderless_decor | tobool %} - border-top: 0; - border-bottom: 0; - {% endif %} -} - -.pageheader.related .logo { - font-size: 36px; - font-style: italic; - letter-spacing: 5px; - margin-right: 2em; -} - -.pageheader.related .logo { - font-size: 36px; - font-style: italic; - letter-spacing: 5px; - margin-right: 2em; -} - -.pageheader.related .logo a, .pageheader.related .logo a:hover { - background: transparent; - color: {{ theme_relbarlinkcolor }}; - border: none; - text-decoration: none; - text-shadow: none; - {{box_shadow("none")}} -} - -.pageheader.related ul { - float: right; - margin: 2px 1em; -} - -.pageheader.related li { - float: left; - margin: 0 0 0 10px; -} - -.pageheader.related li a { - padding: 8px 12px; -} - -.norelbar .subtitle { - font-size: 14px; - line-height: 18px; - font-weight: bold; - letter-spacing: 4px; - text-align: right; - padding: 0 1em; - margin-top: -9px; -} - -.relbar-top .related.norelbar { - height: 22px; - border-bottom: 14px solid #eeeeee; -} - -.relbar-bottom .related.norelbar { - height: 22px; - border-top: 14px solid #eeeeee; -} diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html deleted file mode 100644 index 1cb2ddee..00000000 --- a/docs/_templates/layout.html +++ /dev/null @@ -1,58 +0,0 @@ -{%- extends "cloud/layout.html" %} - -{% set css_files = css_files + ["_static/pygresql.css"] %} - -{# - This layout adds a page header above the standard layout. - It also removes the relbars from all pages that are not part - of the core documentation in the contents/ directory, - adapting the navigation bar (breadcrumb) appropriately. -#} - -{% set is_content = pagename.startswith(('contents/', 'genindex', 'modindex', 'py-', 'search')) %} -{% if is_content %} -{% set master_doc = 'contents/index' %} -{% set parents = parents[1:] %} -{% endif %} - -{% block header %} - - - -{% endblock %} - -{% block relbar1 -%} -{%- if is_content -%} - {{ super() }} -{% else %} -
-{%- endif -%} -{%- endblock %} - -{% block relbar2 -%} -{%- if is_content -%} - {{ super() }} -{%- else -%} -
-{%- endif -%} -{%- endblock %} - -{% block content -%} -{%- if is_content -%} -{{ super() }} -{%- else -%} -
{{ super() }}
-{%- endif -%} -{%- endblock %} diff --git a/docs/about.rst b/docs/about.rst index 3e61d030..10ceaf59 100644 --- a/docs/about.rst +++ b/docs/about.rst @@ -1,4 +1,44 @@ About PyGreSQL ============== -.. include:: about.txt \ No newline at end of file +**PyGreSQL** is an *open-source* `Python `_ module +that interfaces to a `PostgreSQL `_ database. +It wraps the lower level C API library libpq to allow easy use of the +powerful PostgreSQL features from Python. + + | This software is copyright © 1995, Pascal Andre. + | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. + | Further modifications are copyright © 2009-2025 by the PyGreSQL team. + | For licensing details, see the full :doc:`copyright`. + +**PostgreSQL** is a highly scalable, SQL compliant, open source +object-relational database management system. With more than 20 years +of development history, it is quickly becoming the de facto database +for enterprise level open source solutions. +Best of all, PostgreSQL's source code is available under the most liberal +open source license: the BSD license. + +**Python** Python is an interpreted, interactive, object-oriented +programming language. It is often compared to Tcl, Perl, Scheme or Java. +Python combines remarkable power with very clear syntax. It has modules, +classes, exceptions, very high level dynamic data types, and dynamic typing. +There are interfaces to many system calls and libraries, as well as to +various windowing systems (X11, Motif, Tk, Mac, MFC). New built-in modules +are easily written in C or C++. Python is also usable as an extension +language for applications that need a programmable interface. +The Python implementation is copyrighted but freely usable and distributable, +even for commercial use. + +**PyGreSQL** is a Python module that interfaces to a PostgreSQL database. +It wraps the lower level C API library libpq to allow easy use of the +powerful PostgreSQL features from Python. + +PyGreSQL is developed and tested on a NetBSD system, but it also runs on +most other platforms where PostgreSQL and Python is running. It is based +on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). +D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with +version 2.0 and serves as the "BDFL" of PyGreSQL. + +The current version PyGreSQL |version| needs PostgreSQL 10 to 17, and Python +3.7 to 3.13. If you need to support older PostgreSQL or Python versions, +you can resort to the PyGreSQL 5.x versions that still support them. diff --git a/docs/about.txt b/docs/about.txt deleted file mode 100644 index 54e98d39..00000000 --- a/docs/about.txt +++ /dev/null @@ -1,42 +0,0 @@ -**PyGreSQL** is an *open-source* `Python `_ module -that interfaces to a `PostgreSQL `_ database. -It embeds the PostgreSQL query library to allow easy use of the powerful -PostgreSQL features from a Python script. - - | This software is copyright © 1995, Pascal Andre. - | Further modifications are copyright © 1997-2008 by D'Arcy J.M. Cain. - | Further modifications are copyright © 2009-2020 by the PyGreSQL team. - | For licensing details, see the full :doc:`copyright`. - -**PostgreSQL** is a highly scalable, SQL compliant, open source -object-relational database management system. With more than 20 years -of development history, it is quickly becoming the de facto database -for enterprise level open source solutions. -Best of all, PostgreSQL's source code is available under the most liberal -open source license: the BSD license. - -**Python** Python is an interpreted, interactive, object-oriented -programming language. It is often compared to Tcl, Perl, Scheme or Java. -Python combines remarkable power with very clear syntax. It has modules, -classes, exceptions, very high level dynamic data types, and dynamic typing. -There are interfaces to many system calls and libraries, as well as to -various windowing systems (X11, Motif, Tk, Mac, MFC). New built-in modules -are easily written in C or C++. Python is also usable as an extension -language for applications that need a programmable interface. -The Python implementation is copyrighted but freely usable and distributable, -even for commercial use. - -**PyGreSQL** is a Python module that interfaces to a PostgreSQL database. -It embeds the PostgreSQL query library to allow easy use of the powerful -PostgreSQL features from a Python script or application. - -PyGreSQL is developed and tested on a NetBSD system, but it also runs on -most other platforms where PostgreSQL and Python is running. It is based -on the PyGres95 code written by Pascal Andre (andre@chimay.via.ecp.fr). -D'Arcy (darcy@druid.net) renamed it to PyGreSQL starting with -version 2.0 and serves as the "BDFL" of PyGreSQL. - -The current version PyGreSQL 5.1.2 needs PostgreSQL 9.0 to 9.6 or 10 to 12, and -Python 2.6, 2.7 or 3.3 to 3.8. If you need to support older PostgreSQL versions -or older Python 2.x versions, you can resort to the PyGreSQL 4.x versions that -still support them. diff --git a/docs/announce.rst b/docs/announce.rst deleted file mode 100644 index 73ddce93..00000000 --- a/docs/announce.rst +++ /dev/null @@ -1,29 +0,0 @@ -====================== -PyGreSQL Announcements -====================== - ---------------------------------- -Release of PyGreSQL version 5.1.2 ---------------------------------- - -Release 5.1.2 of PyGreSQL. - -It is available at: https://pypi.org/project/PyGreSQL/. - -If you are running NetBSD, look in the packages directory under databases. -There is also a package in the FreeBSD ports collection. - -Please refer to `changelog.txt `_ -for things that have changed in this version. - -This version has been built and unit tested on: - - NetBSD - - FreeBSD - - openSUSE - - Ubuntu - - Windows 7 and 10 with both MinGW and Visual Studio - - PostgreSQL 9.0 to 9.6 and 10 to 12 (32 and 64bit) - - Python 2.6, 2.7 and 3.3 to 3.8 (32 and 64bit) - -| D'Arcy J.M. Cain -| darcy@PyGreSQL.org diff --git a/docs/community/source.rst b/docs/community/source.rst index 224985fd..497f6280 100644 --- a/docs/community/source.rst +++ b/docs/community/source.rst @@ -4,12 +4,12 @@ Access to the source repository The source code of PyGreSQL is available as a `Git `_ repository on `GitHub `_. -The current master branch of the repository can be cloned with the command:: +The current main branch of the repository can be cloned with the command:: git clone https://github.com/PyGreSQL/PyGreSQL.git -You can also download the master branch as a -`zip archive `_. +You can also download the main branch as a +`zip archive `_. Contributions can be proposed as `pull requests `_ on GitHub. diff --git a/docs/conf.py b/docs/conf.py index 7c0919a5..f25d78e7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,92 +1,34 @@ -# -*- coding: utf-8 -*- +# Configuration file for the Sphinx documentation builder. # -# PyGreSQL documentation build configuration file. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html -import sys -import os -import shlex -import shutil +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information -# Import Cloud theme (this will also automatically add the theme directory). -# Note: We add a navigation bar to the cloud them using a custom layout. -if os.environ.get('READTHEDOCS', None) == 'True': - # We cannot use our custom layout here, since RTD overrides layout.html. - use_cloud_theme = False -else: - try: - import cloud_sptheme - use_cloud_theme = True - except ImportError: - use_cloud_theme = False +project = 'PyGreSQL' +author = 'The PyGreSQL team' +copyright = '2025, ' + author -shutil.copyfile('start.txt' if use_cloud_theme else 'toc.txt', 'index.rst') +def project_version(): + with open('../pyproject.toml') as f: + for d in f: + if d.startswith("version ="): + version = d.split("=")[1].strip().strip('"') + return version + raise Exception("Cannot determine PyGreSQL version") -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -#sys.path.insert(0, os.path.abspath('.')) +version = release = project_version() -# -- General configuration ------------------------------------------------ +language = 'en' -# If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = ['sphinx.ext.autodoc'] -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] if use_cloud_theme else [] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# source_suffix = ['.rst', '.md'] -source_suffix = '.rst' - -# The encoding of source files. -#source_encoding = 'utf-8-sig' - -# The master toctree document. -master_doc = 'index' - -# General information about the project. -project = 'PyGreSQL' -author = 'The PyGreSQL team' -copyright = '2020, ' + author - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -version = '5.1' -# The full version, including alpha/beta/rc tags. -release = '5.1.2' - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = None - -# There are two options for replacing |today|: either, you set today to some -# non-false value, then it is used: -#today = '' -# Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -exclude_patterns = ['_build'] +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # List of pages which are included in other pages and therefore should # not appear in the toctree. @@ -95,220 +37,46 @@ 'community/mailinglist.rst', 'community/source.rst', 'community/issues.rst', 'community/support.rst', 'community/homes.rst'] -if use_cloud_theme: - exclude_patterns += ['about.rst'] - -# The reST default role (used for this markup: `text`) for all documents. -#default_role = None - -# If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True - -# If true, the current module name will be prepended to all description -# unit titles (such as .. function::). -#add_module_names = True -# If true, sectionauthor and moduleauthor directives will be shown in the -# output. They are ignored by default. -#show_authors = False - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -# A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] - - -# If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -html_theme = 'cloud' if use_cloud_theme else 'default' - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -if use_cloud_theme: - html_theme_options = { - 'roottarget': 'contents/index', - 'defaultcollapsed': True, - 'shaded_decor': True} -else: - html_theme_options = {} - -# Add any paths that contain custom themes here, relative to this directory. -html_theme_path = ['_themes'] - -# The name for this set of Sphinx documents. If None, it defaults to -# " v documentation". -html_title = 'PyGreSQL %s' % version -if use_cloud_theme: - html_title += ' documentation' +# ignore certain warnings +# (references to some of the Python names do not resolve correctly) +nitpicky = True +nitpick_ignore = [ + ('py:' + t, n) for t, names in { + 'attr': ('arraysize', 'error', 'sqlstate', 'DatabaseError.sqlstate'), + 'class': ('bool', 'bytes', 'callable', 'callables', 'class', + 'dict', 'float', 'function', 'int', 'iterable', + 'list', 'object', 'set', 'str', 'tuple', + 'False', 'True', 'None', + 'namedtuple', 'namedtuples', + 'decimal.Decimal', + 'bytes/str', 'list of namedtuples', 'tuple of callables', + 'first field', 'type of first field', + 'Notice', 'DATETIME'), + 'data': ('defbase', 'defhost', 'defopt', 'defpasswd', 'defport', + 'defuser'), + 'exc': ('Exception', 'IndexError', 'IOError', 'KeyError', + 'MemoryError', 'SyntaxError', 'TypeError', 'ValueError', + 'pg.InternalError', 'pg.InvalidResultError', + 'pg.MultipleResultsError', 'pg.NoResultError', + 'pg.OperationalError', 'pg.ProgrammingError'), + 'func': ('len', 'json.dumps', 'json.loads'), + 'meth': ('datetime.strptime', + 'cur.execute', + 'DB.close', 'DB.connection_handler', 'DB.get_regtypes', + 'DB.inserttable', 'DB.reopen'), + 'obj': ('False', 'True', 'None') + }.items() for n in names] + + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = 'alabaster' +html_static_path = ['_static'] -# A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +html_title = f'PyGreSQL {version}' -# The name of an image file (relative to this directory) to place at the top -# of the sidebar. html_logo = '_static/pygresql.png' - -# The name of an image file (within the static path) to use as favicon of the -# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 -# pixels large. html_favicon = '_static/favicon.ico' - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# Add any extra paths that contain custom files (such as robots.txt or -# .htaccess) here, relative to this directory. These files are copied -# directly to the root of the documentation. -#html_extra_path = [] - -# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, -# using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' - -# If true, SmartyPants will be used to convert quotes and dashes to -# typographically correct entities. -#html_use_smartypants = True - -# Custom sidebar templates, maps document names to template names. -#html_sidebars = {} - -# Additional templates that should be rendered to pages, maps page names to -# template names. -#html_additional_pages = {} - -# If false, no module index is generated. -#html_domain_indices = True - -# If false, no index is generated. -#html_use_index = True - -# If true, the index is split into individual pages for each letter. -#html_split_index = False - -# If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True - -# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True - -# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True - -# If true, an OpenSearch description file will be output, and all pages will -# contain a tag referring to it. The value of this option must be the -# base URL from which the finished HTML is served. -#html_use_opensearch = '' - -# This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None - -# Language to be used for generating the HTML full-text search index. -# Sphinx supports the following languages: -# 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' -# 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' -#html_search_language = 'en' - -# A dictionary with options for the search language support, empty by default. -# Now only 'ja' uses this config value -#html_search_options = {'type': 'default'} - -# The name of a javascript file (relative to the configuration directory) that -# implements a search results scorer. If empty, the default will be used. -#html_search_scorer = 'scorer.js' - -# Output file base name for HTML help builder. -htmlhelp_basename = 'PyGreSQLdoc' - - -# -- Options for LaTeX output --------------------------------------------- - -latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', - -# Latex figure (float) alignment -#'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'PyGreSQL.tex', 'PyGreSQL Documentation', - author, 'manual'), -] - -# The name of an image file (relative to this directory) to place at the top of -# the title page. -#latex_logo = None - -# For "manual" documents, if this is true, then toplevel headings are parts, -# not chapters. -#latex_use_parts = False - -# If true, show page references after internal links. -#latex_show_pagerefs = False - -# If true, show URL addresses after external links. -#latex_show_urls = False - -# Documents to append as an appendix to all manuals. -#latex_appendices = [] - -# If false, no module index is generated. -#latex_domain_indices = True - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'pygresql', 'PyGreSQL Documentation', [author], 1) -] - -# If true, show URL addresses after external links. -#man_show_urls = False - - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - (master_doc, 'PyGreSQL', u'PyGreSQL Documentation', - author, 'PyGreSQL', 'One line description of project.', - 'Miscellaneous'), -] - -# Documents to append as an appendix to all manuals. -#texinfo_appendices = [] - -# If false, no module index is generated. -#texinfo_domain_indices = True - -# How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' - -# If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False diff --git a/docs/contents/changelog.rst b/docs/contents/changelog.rst index d5ef4f05..ad5f7f0e 100644 --- a/docs/contents/changelog.rst +++ b/docs/contents/changelog.rst @@ -1,6 +1,121 @@ ChangeLog ========= +Version 6.1.0 (2024-12-05) +-------------------------- +- Support Python 3.13 and PostgreSQL 17. + +Version 6.0.1 (2024-04-19) +-------------------------- +- Properly adapt falsy JSON values (#86) + +Version 6.0 (2023-10-03) +------------------------ +- Tested with the recent releases of Python 3.12 and PostgreSQL 16. +- Make pyproject.toml the only source of truth for the version number. +- Please also note the changes already made in version 6.0b1. + +Version 6.0b1 (2023-09-06) +-------------------------- +- Officially support Python 3.12 and PostgreSQL 16 (tested with rc versions). +- Removed support for Python versions older than 3.7 (released June 2017) + and PostgreSQL older than version 10 (released October 2017). +- Converted the standalone modules `pg` and `pgdb` to packages with + several submodules each. The C extension module is now part of the + `pg` package and wrapped into the pure Python module `pg.core`. +- Added type hints and included a stub file for the C extension module. +- Added method `pkeys()` to the `pg.DB` object. +- Removed deprecated function `pg.pgnotify()`. +- Removed deprecated method `ntuples()` of the `pg.Query` object. +- Renamed `pgdb.Type` to `pgdb.DbType` to avoid confusion with `typing.Type`. +- `pg` and `pgdb` now use a shared row factory cache. +- The function `set_row_factory_size()` has been removed. The row cache is now + available as a `RowCache` class with methods `change_size()` and `clear()`. +- Modernized code and tools for development, testing, linting and building. + +Version 5.2.5 (2023-08-28) +-------------------------- +- This version officially supports the new Python 3.11 and PostgreSQL 15. +- Two more improvements in the `inserttable()` method of the `pg` module + (thanks to Justin Pryzby for this contribution): + + - error handling has been improved (#72) + - the method now returns the number of inserted rows (#73) +- Another improvement in the `pg` module (#83): + - generated columns can be requested with the `get_generated()` method + - generated columns are ignored by the insert, update and upsert method +- Avoid internal query and error when casting the `sql_identifier` type (#82) +- Fix issue with multiple calls of `getresult()` after `send_query()` (#80) + +Version 5.2.4 (2022-03-26) +-------------------------- +- Three more fixes in the `inserttable()` method of the `pg` module: + - `inserttable()` failed to escape carriage return (#68) + - Allow larger row sizes up to 64 KB (#69) + - Fix use after free issue in `inserttable()` (#71) +- Replace obsolete functions for copy used internally (#59). + Therefore, `getline()` now does not return `\.` at the end any more. + +Version 5.2.3 (2022-01-30) +-------------------------- +- This version officially supports the new Python 3.10 and PostgreSQL 14. +- Some improvements and fixes in the `inserttable()` method of the `pg` module: + - Sync with `PQendcopy()` when there was an error (#60) + - Allow specifying a schema in the table name (#61) + - Improved check for internal result (#62) + - Catch buffer overflows when building the copy command + - Data can now be passed as an iterable, not just list or tuple (#66) +- Some more fixes in the `pg` module: + - Fix upsert with limited number of columns (#58). + - Fix argument handling of `is/set_non_blocking()`. + - Add missing `get/set_typecasts` in list of exports. +- Fixed a reference counting issue when casting JSON columns (#57). + +Version 5.2.2 (2020-12-09) +-------------------------- +- Added a missing adapter method for UUIDs in the classic `pg` module. +- Performance optimizations for `fetchmany()` in the `pgdb` module (#51). +- Fixed a reference counting issue in the `cast_array/record` methods (#52). +- Ignore incompatible libpq.dll in Windows PATH for Python >= 3.8 (#53). + +Version 5.2.1 (2020-09-25) +-------------------------- +- This version officially supports the new Python 3.9 and PostgreSQL 13. +- The `copy_to()` and `copy_from()` methods in the pgdb module now also work + with table names containing schema qualifiers (#47). + +Version 5.2 (2020-06-21) +------------------------ +- We now require Python version 2.7 or 3.5 and newer. +- All Python code is now tested with flake8 and made PEP8 compliant. +- Changes to the classic PyGreSQL module (pg): + - New module level function `get_pqlib_version()` that gets the version + of the pqlib used by PyGreSQL (needs PostgreSQL >= 9.1 on the client). + - New query method `memsize()` that gets the memory size allocated by + the query (needs PostgreSQL >= 12 on the client). + - New query method `fieldinfo()` that gets name and type information for + one or all field(s) of the query. Contributed by Justin Pryzby (#39). + - Experimental support for asynchronous command processing. + Additional connection parameter `nowait`, and connection methods + `send_query()`, `poll()`, `set_non_blocking()`, `is_non_blocking()`. + Generously contributed by Patrick TJ McPhee (#19). + - The `types` parameter of `format_query` can now be passed as a string + that will be split on whitespace when values are passed as a sequence, + and the types can now also be specified using actual Python types + instead of type names. Suggested by Justin Pryzby (#38). + - The `inserttable()` method now accepts an optional column list that will + be passed on to the COPY command. Contributed by Justin Pryzby (#24). + - The `DBTypes` class now also includes the `typlen` attribute with + information about the size of the type (contributed by Justin Pryzby). + - Large objects on the server are not closed any more when they are + deallocated as Python objects, since this could cause several problems. + Bug report and analysis by Justin Pryzby (#30). +- Changes to the DB-API 2 module (pgdb): + - When using Python 2, errors are now derived from StandardError + instead of Exception, as required by the DB-API 2 compliance test. + - Connection arguments containing single quotes caused problems + (reported and fixed by Tyler Ramer and Jamie McAtamney). + Version 5.1.2 (2020-04-19) -------------------------- - Improved handling of build_ext options for disabling certain features. @@ -49,22 +164,21 @@ Version 5.1 (2019-05-17) and this function is not part of the official API. - Added new connection attributes `socket`, `backend_pid`, `ssl_in_use` and `ssl_attributes` (the latter need PostgreSQL >= 9.5 on the client). - - Changes to the DB-API 2 module (pgdb): - Connections now have an `autocommit` attribute which is set to `False` by default but can be set to `True` to switch to autocommit mode where no transactions are started and calling commit() is not required. Note that this is not part of the DB-API 2 standard. -Vesion 5.0.7 (2019-05-17) -------------------------- +Version 5.0.7 (2019-05-17) +-------------------------- - This version officially supports the new PostgreSQL 11. - Fixed a bug in parsing array subscript ranges (reported by Justin Pryzby). - Fixed an issue when deleting a DB wrapper object with the underlying connection already closed (bug report by Jacob Champion). -Vesion 5.0.6 (2018-07-29) -------------------------- +Version 5.0.6 (2018-07-29) +-------------------------- - This version officially supports the new Python 3.7. - Correct trove classifier for the PostgreSQL License. diff --git a/docs/contents/install.rst b/docs/contents/install.rst index 6cdee9da..23694528 100644 --- a/docs/contents/install.rst +++ b/docs/contents/install.rst @@ -7,17 +7,19 @@ General You must first install Python and PostgreSQL on your system. If you want to access remote databases only, you don't need to install the full PostgreSQL server, but only the libpq C-interface library. -If you are on Windows, make sure that the directory that contains -libpq.dll is part of your ``PATH`` environment variable. +On Windows, this library is called ``libpq.dll`` and is for instance contained +in the PostgreSQL ODBC driver (search for "psqlodbc"). On Linux, it is called +``libpq.so`` and usually provided in a package called "libpq" or "libpq5". +On Windows, you also need to make sure that the directory that contains +``libpq.dll`` is part of your ``PATH`` environment variable. The current version of PyGreSQL has been tested with Python versions -2.6, 2.7 and 3.3 to 3.8, and PostgreSQL versions 9.0 to 9.6 and 10 to 12. +3.7 to 3.13, and PostgreSQL versions 10 to 17. -PyGreSQL will be installed as three modules, a shared library called -_pg.so (on Linux) or a DLL called _pg.pyd (on Windows), and two pure -Python wrapper modules called pg.py and pgdb.py. -All three files will be installed directly into the Python site-packages -directory. To uninstall PyGreSQL, simply remove these three files. +PyGreSQL will be installed as two packages named ``pg`` (for the classic +interface) and ``pgdb`` (for the DB API 2 compliant interface). The former +also contains a shared library called ``_pg.so`` (on Linux) or a DLL called +``_pg.pyd`` (on Windows) and a stub file ``_pg.pyi`` for this library. Installing with Pip @@ -32,6 +34,9 @@ This will automatically try to find and download a distribution on the `Python Package Index `_ that matches your operating system and Python version and install it. +Note that you still need to have the libpq interface installed on your system +(see the general remarks above). + Installing from a Binary Distribution ------------------------------------- @@ -85,24 +90,29 @@ Now you should be ready to use PyGreSQL. You can also run the build step separately if you want to create a distribution to be installed on a different system or explicitly enable or disable certain -features. For instance, in order to build PyGreSQL without support for the SSL -info functions, run:: +features. For instance, in order to build PyGreSQL without support for the +memory size functions, run:: - python setup.py build_ext --no-ssl-info + python setup.py build_ext --no-memory-size By default, PyGreSQL is compiled with support for all features available in the installed PostgreSQL version, and you will get warnings for the features that are not supported in this version. You can also explicitly require a feature in order to get an error if it is not available, for instance: - python setup.py build_ext --ssl-info + python setup.py build_ext --memory-size You can find out all possible build options with:: python setup.py build_ext --help Alternatively, you can also use the corresponding C preprocessor macros like -``SSL_INFO`` directly (see the next section). +``MEMORY_SIZE`` directly (see the next section). + +Note that if you build PyGreSQL with support for newer features that are not +available in the PQLib installed on the runtime system, you may get an error +when importing PyGreSQL, since these features are missing in the shared library +which will prevent Python from loading it. Compiling Manually ~~~~~~~~~~~~~~~~~~ @@ -143,11 +153,7 @@ Stand-Alone Some options may be added to this line:: - -DDEFAULT_VARS default variables support - -DDIRECT_ACCESS direct access methods - -DLARGE_OBJECTS large object support - -DESCAPING_FUNCS support for newer escaping functions - -DSSL_INFO support SSL information + -DMEMORY_SIZE = support memory size function (PostgreSQL 12 or newer) On some systems you may need to include ``-lcrypt`` in the list of libraries to make it compile. @@ -189,11 +195,7 @@ Built-in to Python interpreter Some options may be added to this line:: - -DDEFAULT_VARS default variables support - -DDIRECT_ACCESS direct access methods - -DLARGE_OBJECTS large object support - -DESCAPING_FUNCS support for newer escaping functions - -DSSL_INFO support SSL information + -DMEMORY_SIZE = support memory size function (PostgreSQL 12 or newer) On some systems you may need to include ``-lcrypt`` in the list of libraries to make it compile. diff --git a/docs/contents/pg/adaptation.rst b/docs/contents/pg/adaptation.rst index 6ed6e779..de82cbfa 100644 --- a/docs/contents/pg/adaptation.rst +++ b/docs/contents/pg/adaptation.rst @@ -1,7 +1,7 @@ Remarks on Adaptation and Typecasting ===================================== -.. py:currentmodule:: pg +.. currentmodule:: pg Both PostgreSQL and Python have the concept of data types, but there are of course differences between the two type systems. Therefore PyGreSQL @@ -26,7 +26,7 @@ PostgreSQL Python char, bpchar, name, text, varchar str bool bool bytea bytes -int2, int4, int8, oid, serial int [#int8]_ +int2, int4, int8, oid, serial int int2vector list of int float4, float8 float numeric, money Decimal @@ -45,8 +45,6 @@ record tuple Elements of arrays and records will also be converted accordingly. - .. [#int8] int8 is converted to long in Python 2 - .. [#array] The first element of the array will always be the first element of the Python list, no matter what the lower bound of the PostgreSQL array is. The information about the start index of the array (which is @@ -233,7 +231,7 @@ our values:: ... self.price = price ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) But when we try to insert an instance of this class in the same way, we @@ -248,7 +246,7 @@ PostgreSQL by adding a "magic" method with the name ``__pg_str__``, like so:: ... ... ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) ... ... def __pg_str__(self, typ): @@ -363,7 +361,7 @@ With PostgreSQL we can easily calculate that these two circles overlap:: True However, calculating the intersection points between the two circles using the -``#`` operator does not work (at least not as of PostgreSQL version 12). +``#`` operator does not work (at least not as of PostgreSQL version 14). So let's resort to SymPy to find out. To ease importing circles from PostgreSQL to SymPy, we create and register the following typecast function:: diff --git a/docs/contents/pg/connection.rst b/docs/contents/pg/connection.rst index 8556c5d2..e4a08591 100644 --- a/docs/contents/pg/connection.rst +++ b/docs/contents/pg/connection.rst @@ -1,7 +1,7 @@ Connection -- The connection object =================================== -.. py:currentmodule:: pg +.. currentmodule:: pg .. class:: Connection @@ -13,17 +13,8 @@ significant parameters in function calls. Some methods give direct access to the connection socket. *Do not use them unless you really know what you are doing.* - If you prefer disabling them, - do not set the ``direct_access`` option in the Python setup file. - These methods are specified by the tag [DA]. - -.. note:: - - Some other methods give access to large objects - (refer to PostgreSQL user manual for more information about these). - If you want to forbid access to these from the module, - set the ``large_objects`` option in the Python setup file. - These methods are specified by the tag [LO]. + Some other methods give access to large objects. + Refer to the PostgreSQL user manual for more information about these. query -- execute a SQL command string ------------------------------------- @@ -81,6 +72,102 @@ Example:: phone = con.query("select phone from employees where name=$1", (name,)).getresult() + +send_query - executes a SQL command string asynchronously +--------------------------------------------------------- + +.. method:: Connection.send_query(command, [args]) + + Submits a command to the server without waiting for the result(s). + + :param str command: SQL command + :param args: optional parameter values + :returns: a query object, as described below + :rtype: :class:`Query` + :raises TypeError: bad argument type, or too many arguments + :raises TypeError: invalid connection + :raises ValueError: empty SQL query or lost connection + :raises pg.ProgrammingError: error in query + +This method is much the same as :meth:`Connection.query`, except that it +returns without waiting for the query to complete. The database connection +cannot be used for other operations until the query completes, but the +application can do other things, including executing queries using other +database connections. The application can call ``select()`` using the +``fileno`` obtained by the connection's :meth:`Connection.fileno` method +to determine when the query has results to return. + +This method always returns a :class:`Query` object. This object differs +from the :class:`Query` object returned by :meth:`Connection.query` in a +few ways. Most importantly, when :meth:`Connection.send_query` is used, the +application must call one of the result-returning methods such as +:meth:`Query.getresult` or :meth:`Query.dictresult` until it either raises +an exception or returns ``None``. + +Otherwise, the database connection will be left in an unusable state. + +In cases when :meth:`Connection.query` would return something other than +a :class:`Query` object, that result will be returned by calling one of +the result-returning methods on the :class:`Query` object returned by +:meth:`Connection.send_query`. There's one important difference in these +result codes: if :meth:`Connection.query` returns `None`, the result-returning +methods will return an empty string (`''`). It's still necessary to call a +result-returning method until it returns `None`. + +:meth:`Query.listfields`, :meth:`Query.fieldname` and :meth:`Query.fieldnum` +only work after a call to a result-returning method with a non-``None`` return +value. Calling ``len()`` on a :class:`Query` object returns the number of rows +of the previous result-returning method. + +If multiple semi-colon-delimited statements are passed to +:meth:`Connection.query`, only the results of the last statement are returned +in the :class:`Query` object. With :meth:`Connection.send_query`, all results +are returned. Each result set will be returned by a separate call to +:meth:`Query.getresult()` or other result-returning methods. + +.. versionadded:: 5.2 + +Examples:: + + name = input("Name? ") + query = con.send_query("select phone from employees where name=$1", + (name,)) + phone = query.getresult() + query.getresult() # to close the query + + # Run two queries in one round trip: + # (Note that you cannot use a union here + # when the result sets have different row types.) + query = con.send_query("select a,b,c from x where d=e; + "select e,f from y where g") + result_x = query.dictresult() + result_y = query.dictresult() + query.dictresult() # to close the query + + # Using select() to wait for the query to be ready: + query = con.send_query("select pg_sleep(20)") + r, w, e = select([con.fileno(), other, sockets], [], []) + if con.fileno() in r: + results = query.getresult() + query.getresult() # to close the query + + # Concurrent queries on separate connections: + con1 = connect() + con2 = connect() + s = con1.query("begin; set transaction isolation level repeatable read;" + "select pg_export_snapshot();").single() + con2.query("begin; set transaction isolation level repeatable read;" + f"set transaction snapshot '{s}'") + q1 = con1.send_query("select a,b,c from x where d=e") + q2 = con2.send_query("select e,f from y where g") + r1 = q1.getresult() + q1.getresult() + r2 = q2.getresult() + q2.getresult() + con1.query("commit") + con2.query("commit") + + query_prepared -- execute a prepared statement ---------------------------------------------- @@ -169,6 +256,56 @@ reset -- reset the connection This method resets the current database connection. +poll - completes an asynchronous connection +------------------------------------------- + +.. method:: Connection.poll() + + Complete an asynchronous :mod:`pg` connection and get its state + + :returns: state of the connection + :rtype: int + :raises TypeError: too many (any) arguments + :raises TypeError: invalid connection + :raises pg.InternalError: some error occurred during pg connection + +The database connection can be performed without any blocking calls. +This allows the application mainline to perform other operations or perhaps +connect to multiple databases concurrently. Once the connection is established, +it's no different from a connection made using blocking calls. + +The required steps are to pass the parameter ``nowait=True`` to the +:meth:`pg.connect` call, then call :meth:`Connection.poll` until it either +returns :const:`POLLING_OK` or raises an exception. To avoid blocking +in :meth:`Connection.poll`, use `select()` or `poll()` to wait for the +connection to be readable or writable, depending on the return code of the +previous call to :meth:`Connection.poll`. The initial state of the connection +is :const:`POLLING_WRITING`. The possible states are defined as constants in +the :mod:`pg` module (:const:`POLLING_OK`, :const:`POLLING_FAILED`, +:const:`POLLING_READING` and :const:`POLLING_WRITING`). + +.. versionadded:: 5.2 + +Example:: + + con = pg.connect('testdb', nowait=True) + fileno = con.fileno() + rd = [] + wt = [fileno] + rc = pg.POLLING_WRITING + while rc not in (pg.POLLING_OK, pg.POLLING_FAILED): + ra, wa, xa = select(rd, wt, [], timeout) + if not ra and not wa: + timedout() + rc = con.poll() + if rc == pg.POLLING_READING: + rd = [fileno] + wt = [] + else: + rd = [] + wt = [fileno] + + cancel -- abandon processing of current SQL command --------------------------------------------------- @@ -281,6 +418,40 @@ fileno -- get the socket used to connect to the database This method returns the underlying socket id used to connect to the database. This is useful for use in select calls, etc. +set_non_blocking - set the non-blocking status of the connection +---------------------------------------------------------------- + +.. method:: set_non_blocking(nb) + + Set the non-blocking mode of the connection + + :param bool nb: True to put the connection into non-blocking mode. + False to put it into blocking mode. + :raises TypeError: too many parameters + :raises TypeError: invalid connection + +Puts the socket connection into non-blocking mode or into blocking mode. +This affects copy commands and large object operations, but not queries. + +.. versionadded:: 5.2 + +is_non_blocking - report the blocking status of the connection +-------------------------------------------------------------- + +.. method:: is_non_blocking() + + get the non-blocking mode of the connection + + :returns: True if the connection is in non-blocking mode. + False if it is in blocking mode. + :rtype: bool + :raises TypeError: too many parameters + :raises TypeError: invalid connection + +Returns True if the connection is in non-blocking mode, False otherwise. + +.. versionadded:: 5.2 + getnotify -- get the last notify from the server ------------------------------------------------ @@ -304,25 +475,29 @@ first, otherwise :meth:`Connection.getnotify` will always return ``None``. .. versionchanged:: 4.1 Support for payload strings was added in version 4.1. -inserttable -- insert a list into a table ------------------------------------------ +inserttable -- insert an iterable into a table +---------------------------------------------- -.. method:: Connection.inserttable(table, values) +.. method:: Connection.inserttable(table, values, [columns]) - Insert a Python list into a database table + Insert a Python iterable into a database table :param str table: the table name - :param list values: list of rows values - :rtype: None + :param list values: iterable of row values, which must be lists or tuples + :param list columns: list or tuple of column names + :rtype: int :raises TypeError: invalid connection, bad argument type, or too many arguments :raises MemoryError: insert buffer could not be allocated :raises ValueError: unsupported values -This method allows to *quickly* insert large blocks of data in a table: -It inserts the whole values list into the given table. Internally, it -uses the COPY command of the PostgreSQL database. The list is a list -of tuples/lists that define the values for each inserted row. The rows -values may contain string, integer, long or double (real) values. +This method allows to *quickly* insert large blocks of data in a table. +Internally, it uses the COPY command of the PostgreSQL database. +The method takes an iterable of row values which must be tuples or lists +of the same size, containing the values for each inserted row. +These may contain string, integer, long or double (real) values. +``columns`` is an optional tuple or list of column names to be passed on +to the COPY command. +The number of rows affected is returned. .. warning:: @@ -421,8 +596,8 @@ attributes: .. versionadded:: 4.1 -putline -- write a line to the server socket [DA] -------------------------------------------------- +putline -- write a line to the server socket +-------------------------------------------- .. method:: Connection.putline(line) @@ -434,14 +609,14 @@ putline -- write a line to the server socket [DA] This method allows to directly write a string to the server socket. -getline -- get a line from server socket [DA] ---------------------------------------------- +getline -- get a line from server socket +---------------------------------------- .. method:: Connection.getline() Get a line from server socket - :returns: the line read + :returns: the line read :rtype: str :raises TypeError: invalid connection :raises TypeError: too many parameters @@ -449,8 +624,8 @@ getline -- get a line from server socket [DA] This method allows to directly read a string from the server socket. -endcopy -- synchronize client and server [DA] ---------------------------------------------- +endcopy -- synchronize client and server +---------------------------------------- .. method:: Connection.endcopy() @@ -463,8 +638,8 @@ endcopy -- synchronize client and server [DA] The use of direct access methods may desynchronize client and server. This method ensure that client and server will be synchronized. -locreate -- create a large object in the database [LO] ------------------------------------------------------- +locreate -- create a large object in the database +------------------------------------------------- .. method:: Connection.locreate(mode) @@ -478,11 +653,11 @@ locreate -- create a large object in the database [LO] This method creates a large object in the database. The mode can be defined by OR-ing the constants defined in the :mod:`pg` module (:const:`INV_READ`, -:const:`INV_WRITE` and :const:`INV_ARCHIVE`). Please refer to PostgreSQL -user manual for a description of the mode values. +and :const:`INV_WRITE`). Please refer to PostgreSQL user manual for a +description of the mode values. -getlo -- build a large object from given oid [LO] -------------------------------------------------- +getlo -- build a large object from given oid +-------------------------------------------- .. method:: Connection.getlo(oid) @@ -491,14 +666,14 @@ getlo -- build a large object from given oid [LO] :param int oid: OID of the existing large object :returns: object handling the PostgreSQL large object :rtype: :class:`LargeObject` - :raises TypeError: invalid connection, bad parameter type, or too many parameters + :raises TypeError: invalid connection, bad parameter type, or too many parameters :raises ValueError: bad OID value (0 is invalid_oid) This method allows reusing a previously created large object through the :class:`LargeObject` interface, provided the user has its OID. -loimport -- import a file to a large object [LO] ------------------------------------------------- +loimport -- import a file to a large object +------------------------------------------- .. method:: Connection.loimport(name) @@ -546,7 +721,7 @@ the connection and its status. These attributes are: .. attribute:: Connection.server_version - the backend version (int, e.g. 90305 for 9.3.5) + the backend version (int, e.g. 150400 for 15.4) .. versionadded:: 4.0 @@ -574,10 +749,10 @@ the connection and its status. These attributes are: this is True if the connection uses SSL, False if not -.. versionadded:: 5.1 (needs PostgreSQL >= 9.5) +.. versionadded:: 5.1 .. attribute:: Connection.ssl_attributes SSL-related information about the connection (dict) -.. versionadded:: 5.1 (needs PostgreSQL >= 9.5) +.. versionadded:: 5.1 diff --git a/docs/contents/pg/db_types.rst b/docs/contents/pg/db_types.rst index 3318fd06..d7333a41 100644 --- a/docs/contents/pg/db_types.rst +++ b/docs/contents/pg/db_types.rst @@ -1,7 +1,7 @@ DbTypes -- The internal cache for database types ================================================ -.. py:currentmodule:: pg +.. currentmodule:: pg .. class:: DbTypes @@ -13,14 +13,15 @@ returned by :meth:`DB.get_attnames` as dictionary values). These type names are strings which are equal to either the simple PyGreSQL names or to the more fine-grained registered PostgreSQL type names if these -have been enabled with :meth:`DB.use_regtypes`. Besides being strings, they -carry additional information about the associated PostgreSQL type in the -following attributes: +have been enabled with :meth:`DB.use_regtypes`. Type names are strings that +are augmented with additional information about the associated PostgreSQL +type that can be inspected using the following attributes: - *oid* -- the PostgreSQL type OID - *pgtype* -- the internal PostgreSQL data type name - *regtype* -- the registered PostgreSQL data type name - *simple* -- the more coarse-grained PyGreSQL type name + - *typlen* -- internal size of the type, negative if variable - *typtype* -- `b` = base type, `c` = composite type etc. - *category* -- `A` = Array, `b` =Boolean, `C` = Composite etc. - *delim* -- delimiter for array types diff --git a/docs/contents/pg/db_wrapper.rst b/docs/contents/pg/db_wrapper.rst index 2727f8fc..b9e72b69 100644 --- a/docs/contents/pg/db_wrapper.rst +++ b/docs/contents/pg/db_wrapper.rst @@ -1,7 +1,7 @@ The DB wrapper class ==================== -.. py:currentmodule:: pg +.. currentmodule:: pg .. class:: DB @@ -16,7 +16,7 @@ The preferred way to use this module is as follows:: for r in db.query( # just for example "SELECT foo, bar FROM foo_bar_table WHERE foo !~ bar" ).dictresult(): - print('%(foo)s %(bar)s' % r) + print('{foo} {bar}'.format(**r)) This class can be subclassed as in this example:: @@ -48,8 +48,7 @@ You can also initialize the DB class with an existing :mod:`pg` or :mod:`pgdb` connection. Pass this connection as a single unnamed parameter, or as a single parameter named ``db``. This allows you to use all of the methods of the DB class with a DB-API 2 compliant connection. Note that the -:meth:`Connection.close` and :meth:`Connection.reopen` methods are inoperative -in this case. +:meth:`DB.close` and :meth:`DB.reopen` methods are inoperative in this case. pkey -- return the primary key of a table ----------------------------------------- @@ -59,7 +58,7 @@ pkey -- return the primary key of a table Return the primary key of a table :param str table: name of table - :returns: Name of the field which is the primary key of the table + :returns: Name of the field that is the primary key of the table :rtype: str :raises KeyError: the table does not have a primary key @@ -68,6 +67,24 @@ returned as strings unless you set the composite flag. Composite primary keys are always represented as tuples. Note that this raises a KeyError if the table does not have a primary key. +pkeys -- return the primary keys of a table +------------------------------------------- + +.. method:: DB.pkeys(table) + + Return the primary keys of a table as a tuple + + :param str table: name of table + :returns: Names of the fields that are the primary keys of the table + :rtype: tuple + :raises KeyError: the table does not have a primary key + +This method returns the primary keys of a table as a tuple, i.e. +single primary keys are also returned as a tuple with one item. +Note that this raises a KeyError if the table does not have a primary key. + +.. versionadded:: 6.0 + get_databases -- get list of databases in the system ---------------------------------------------------- @@ -137,6 +154,20 @@ By default, only a limited number of simple types will be returned. You can get the registered types instead, if enabled by calling the :meth:`DB.use_regtypes` method. +get_generated -- get the generated columns of a table +----------------------------------------------------- + +.. method:: DB.get_generated(table) + + Get the generated columns of a table + + :param str table: name of table + :returns: an frozenset of column names + +Given the name of a table, digs out the set of generated columns. + +.. versionadded:: 5.2.5 + has_table_privilege -- check table privilege -------------------------------------------- @@ -456,11 +487,11 @@ Example:: name = input("Name? ") phone = input("Phone? ") - rows = db.query("update employees set phone=$2 where name=$1", - name, phone).getresult()[0][0] + num_rows = db.query("update employees set phone=$2 where name=$1", + name, phone) # or - rows = db.query("update employees set phone=$2 where name=$1", - (name, phone)).getresult()[0][0] + num_rows = db.query("update employees set phone=$2 where name=$1", + (name, phone)) query_formatted -- execute a formatted SQL command string --------------------------------------------------------- @@ -507,13 +538,31 @@ Example:: name = input("Name? ") phone = input("Phone? ") - rows = db.query_formatted( + num_rows = db.query_formatted( "update employees set phone=%s where name=%s", - (phone, name)).getresult()[0][0] + (phone, name)) # or - rows = db.query_formatted( + num_rows = db.query_formatted( "update employees set phone=%(phone)s where name=%(name)s", - dict(name=name, phone=phone)).getresult()[0][0] + dict(name=name, phone=phone)) + +Example with specification of types:: + + db.query_formatted( + "update orders set info=%s where id=%s", + ({'customer': 'Joe', 'product': 'beer'}, 'id': 7), + types=('json', 'int')) + # or + db.query_formatted( + "update orders set info=%s where id=%s", + ({'customer': 'Joe', 'product': 'beer'}, 'id': 7), + types=('json int')) + # or + db.query_formatted( + "update orders set info=%(info)s where id=%(id)s", + {'info': {'customer': 'Joe', 'product': 'beer'}, 'id': 7}, + types={'info': 'json', 'id': 'int'}) + query_prepared -- execute a prepared statement ---------------------------------------------- @@ -666,7 +715,7 @@ delete -- delete a row from a database table Delete a row from a database table :param str table: name of table - :param dict d: optional dictionary of values + :param dict row: optional dictionary of values :param col: optional keyword arguments for updating the dictionary :rtype: None :raises pg.ProgrammingError: table has no primary key, @@ -774,7 +823,7 @@ has only one column anyway. :param int offset: number of rows to be skipped (the OFFSET clause) :param bool scalar: whether only the first column shall be returned :returns: the content of the table as a list - :rtype: dict or OrderedDict + :rtype: dict :raises TypeError: the table name has not been specified :raises KeyError: keyname(s) are invalid or not part of the result :raises pg.ProgrammingError: no keyname(s) and table has no primary key @@ -788,10 +837,9 @@ The rows will be also named tuples unless the *scalar* option has been set to *True*. With the optional parameter *keyname* you can specify a different set of columns to be used as the keys of the dictionary. -If the Python version supports it, the dictionary will be an *OrderedDict* -using the order specified with the *order* parameter or the key column(s) -if not specified. You can set *order* to *False* if you don't care about the -ordering. In this case the returned dictionary will be an ordinary one. +The dictionary will be ordered using the order specified with the *order* +parameter or the key column(s) if not specified. You can set *order* to +*False* if you don't care about the ordering. .. versionadded:: 5.0 @@ -799,7 +847,7 @@ escape_literal/identifier/string/bytea -- escape for SQL -------------------------------------------------------- The following methods escape text or binary strings so that they can be -inserted directly into an SQL command. Except for :meth:`DB.escape_byte`, +inserted directly into an SQL command. Except for :meth:`DB.escape_bytea`, you don't need to call these methods for the strings passed as parameters to :meth:`DB.query`. You also don't need to call any of these methods when storing data using :meth:`DB.insert` and similar. @@ -851,9 +899,9 @@ properties (such as character encoding). Escape binary data for use within SQL as type ``bytea`` - :param str datastring: string containing the binary data that is to be escaped + :param bytes/str datastring: the binary data that is to be escaped :returns: the escaped string - :rtype: str + :rtype: bytes/str Similar to the module function :func:`pg.escape_bytea` with the same name, but the behavior of this method is adjusted depending on the connection @@ -866,7 +914,7 @@ unescape_bytea -- unescape data retrieved from the database Unescape ``bytea`` data that has been retrieved as text - :param datastring: the ``bytea`` data string that has been retrieved as text + :param str string: the ``bytea`` string that has been retrieved as text :returns: byte string containing the binary data :rtype: bytes diff --git a/docs/contents/pg/introduction.rst b/docs/contents/pg/introduction.rst index 6a4ca7b8..1e369e12 100644 --- a/docs/contents/pg/introduction.rst +++ b/docs/contents/pg/introduction.rst @@ -1,6 +1,8 @@ Introduction ============ +.. currentmodule:: pg + You may either choose to use the "classic" PyGreSQL interface provided by the :mod:`pg` module or else the newer DB-API 2.0 compliant interface provided by the :mod:`pgdb` module. diff --git a/docs/contents/pg/large_objects.rst b/docs/contents/pg/large_objects.rst index d195eb4c..037b2128 100644 --- a/docs/contents/pg/large_objects.rst +++ b/docs/contents/pg/large_objects.rst @@ -1,26 +1,27 @@ LargeObject -- Large Objects ============================ -.. py:currentmodule:: pg +.. currentmodule:: pg .. class:: LargeObject -Objects that are instances of the class :class:`LargeObject` are used to handle -all the requests concerning a PostgreSQL large object. These objects embed -and hide all the "recurrent" variables (object OID and connection), exactly -in the same way :class:`Connection` instances do, thus only keeping significant -parameters in function calls. The :class:`LargeObject` instance keeps a -reference to the :class:`Connection` object used for its creation, sending -requests though with its parameters. Any modification but dereferencing the +Instances of the class :class:`LargeObject` are used to handle all the +requests concerning a PostgreSQL large object. These objects embed and hide +all the recurring variables (object OID and connection), in the same way +:class:`Connection` instances do, thus only keeping significant parameters +in function calls. The :class:`LargeObject` instance keeps a reference to +the :class:`Connection` object used for its creation, sending requests +through with its parameters. Any modification other than dereferencing the :class:`Connection` object will thus affect the :class:`LargeObject` instance. Dereferencing the initial :class:`Connection` object is not a problem since Python won't deallocate it before the :class:`LargeObject` instance -dereferences it. All functions return a generic error message on call error, -whatever the exact error was. The :attr:`error` attribute of the object allows -to get the exact error message. +dereferences it. All functions return a generic error message on error. +The exact error message is provided by the object's :attr:`error` attribute. -See also the PostgreSQL programmer's guide for more information about the -large object interface. +See also the PostgreSQL documentation for more information about the +`large object interface`__. + +__ https://www.postgresql.org/docs/current/largeobjects.html open -- open a large object --------------------------- @@ -34,9 +35,10 @@ open -- open a large object :raises TypeError: invalid connection, bad parameter type, or too many parameters :raises IOError: already opened object, or open error -This method opens a large object for reading/writing, in the same way than the -Unix open() function. The mode value can be obtained by OR-ing the constants -defined in the :mod:`pg` module (:const:`INV_READ`, :const:`INV_WRITE`). +This method opens a large object for reading/writing, in a similar manner as +the Unix open() function does for files. The mode value can be obtained by +OR-ing the constants defined in the :mod:`pg` module (:const:`INV_READ`, +:const:`INV_WRITE`). close -- close a large object ----------------------------- @@ -50,7 +52,7 @@ close -- close a large object :raises TypeError: too many parameters :raises IOError: object is not opened, or close error -This method closes a previously opened large object, in the same way than +This method closes a previously opened large object, in a similar manner as the Unix close() function. read, write, tell, seek, unlink -- file-like large object handling @@ -60,7 +62,7 @@ read, write, tell, seek, unlink -- file-like large object handling Read data from large object - :param int size: maximal size of the buffer to be read + :param int size: maximum size of the buffer to be read :returns: the read buffer :rtype: bytes :raises TypeError: invalid connection, invalid object, @@ -68,19 +70,19 @@ read, write, tell, seek, unlink -- file-like large object handling :raises ValueError: if `size` is negative :raises IOError: object is not opened, or read error -This function allows to read data from a large object, starting at current -position. +This function allows reading data from a large object, starting at the +current position. .. method:: LargeObject.write(string) - Read data to large object + Write data to large object - :param bytes string: string buffer to be written + :param bytes data: buffer of bytes to be written :rtype: None :raises TypeError: invalid connection, bad parameter type, or too many parameters :raises IOError: object is not opened, or write error -This function allows to write data to a large object, starting at current +This function allows writing data to a large object, starting at the current position. .. method:: LargeObject.seek(offset, whence) @@ -95,9 +97,9 @@ position. bad parameter type, or too many parameters :raises IOError: object is not opened, or seek error -This method allows to move the position cursor in the large object. -The valid values for the whence parameter are defined as constants in the -:mod:`pg` module (:const:`SEEK_SET`, :const:`SEEK_CUR`, :const:`SEEK_END`). +This method updates the position offset in the large object. The valid values +for the whence parameter are defined as constants in the :mod:`pg` module +(:const:`SEEK_SET`, :const:`SEEK_CUR`, :const:`SEEK_END`). .. method:: LargeObject.tell() @@ -109,7 +111,7 @@ The valid values for the whence parameter are defined as constants in the :raises TypeError: too many parameters :raises IOError: object is not opened, or seek error -This method allows to get the current position in the large object. +This method returns the current position offset in the large object. .. method:: LargeObject.unlink() @@ -135,7 +137,7 @@ size -- get the large object size :raises TypeError: too many parameters :raises IOError: object is not opened, or seek/tell error -This (composite) method allows to get the size of a large object. It was +This (composite) method returns the size of a large object. It was implemented because this function is very useful for a web interfaced database. Currently, the large object needs to be opened first. @@ -152,14 +154,14 @@ export -- save a large object to a file bad parameter type, or too many parameters :raises IOError: object is not closed, or export error -This methods allows to dump the content of a large object in a very simple -way. The exported file is created on the host of the program, not the -server host. +This methods allows saving the content of a large object to a file in a +very simple way. The file is created on the host running the PyGreSQL +interface, not on the server host. Object attributes ----------------- -:class:`LargeObject` objects define a read-only set of attributes that allow -to get some information about it. These attributes are: +:class:`LargeObject` objects define a read-only set of attributes exposing +some information about it. These attributes are: .. attribute:: LargeObject.oid @@ -175,9 +177,10 @@ to get some information about it. These attributes are: .. warning:: - In multi-threaded environments, :attr:`LargeObject.error` may be modified by - another thread using the same :class:`Connection`. Remember these object - are shared, not duplicated. You should provide some locking to be able - if you want to check this. The :attr:`LargeObject.oid` attribute is very + In multi-threaded environments, :attr:`LargeObject.error` may be modified + by another thread using the same :class:`Connection`. Remember these + objects are shared, not duplicated. You should provide some locking if you + want to use this information in a program in which it's shared between + multiple threads. The :attr:`LargeObject.oid` attribute is very interesting, because it allows you to reuse the OID later, creating the :class:`LargeObject` object with a :meth:`Connection.getlo` method call. diff --git a/docs/contents/pg/module.rst b/docs/contents/pg/module.rst index 15b1824e..acf75f93 100644 --- a/docs/contents/pg/module.rst +++ b/docs/contents/pg/module.rst @@ -1,7 +1,7 @@ Module functions and constants ============================== -.. py:currentmodule:: pg +.. currentmodule:: pg The :mod:`pg` module defines a few functions that allow to connect to a database and to define "default variables" that override @@ -10,9 +10,7 @@ the environment variables used by PostgreSQL. These "default variables" were designed to allow you to handle general connection parameters without heavy code in your programs. You can prompt the user for a value, put it in the default variable, and forget it, without -having to modify your environment. The support for default variables can be -disabled by not setting the ``default_vars`` option in the Python setup file. -Methods relative to this are specified by the tag [DV]. +having to modify your environment. All variables are set to ``None`` at module initialization, specifying that standard environment variables should be used. @@ -20,7 +18,7 @@ standard environment variables should be used. connect -- Open a PostgreSQL connection --------------------------------------- -.. function:: connect([dbname], [host], [port], [opt], [user], [passwd]) +.. function:: connect([dbname], [host], [port], [opt], [user], [passwd], [nowait]) Open a :mod:`pg` connection @@ -36,6 +34,8 @@ connect -- Open a PostgreSQL connection :type user: str or None :param passwd: password for user (*None* = :data:`defpasswd`) :type passwd: str or None + :param nowait: whether the connection should happen asynchronously + :type nowait: bool :returns: If successful, the :class:`Connection` handling the connection :rtype: :class:`Connection` :raises TypeError: bad argument type, or too many arguments @@ -49,11 +49,15 @@ Python tutorial. The names of the keywords are the name of the parameters given in the syntax line. The ``opt`` parameter can be used to pass command-line options to the server. For a precise description of the parameters, please refer to the PostgreSQL user manual. +See :meth:`Connection.poll` for a description of the ``nowait`` parameter. If you want to add additional parameters not specified here, you must pass a connection string or a connection URI instead of the ``dbname`` (as in ``con3`` and ``con4`` in the following example). +.. versionchanged:: 5.2 + Support for asynchronous connections via the ``nowait`` parameter. + Example:: import pg @@ -63,8 +67,26 @@ Example:: con3 = pg.connect('host=myhost user=bob dbname=testdb connect_timeout=10') con4 = pg.connect('postgresql://bob@myhost/testdb?connect_timeout=10') -get/set_defhost -- default server host [DV] -------------------------------------------- + +get_pqlib_version -- get the version of libpq +--------------------------------------------- + +.. function:: get_pqlib_version() + + Get the version of libpq that is being used by PyGreSQL + + :returns: the version of libpq + :rtype: int + :raises TypeError: too many arguments + +The number is formed by converting the major, minor, and revision numbers of +the libpq version into two-decimal-digit numbers and appending them together. +For example, version 15.4 will be returned as 150400. + +.. versionadded:: 5.2 + +get/set_defhost -- default server host +-------------------------------------- .. function:: get_defhost(host) @@ -93,8 +115,8 @@ If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default host. -get/set_defport -- default server port [DV] -------------------------------------------- +get/set_defport -- default server port +-------------------------------------- .. function:: get_defport() @@ -121,8 +143,8 @@ This methods sets the default port value for new connections. If -1 is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default port. -get/set_defopt -- default connection options [DV] --------------------------------------------------- +get/set_defopt -- default connection options +--------------------------------------------- .. function:: get_defopt() @@ -150,8 +172,8 @@ This methods sets the default connection options value for new connections. If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default options. -get/set_defbase -- default database name [DV] ---------------------------------------------- +get/set_defbase -- default database name +---------------------------------------- .. function:: get_defbase() @@ -179,8 +201,8 @@ This method sets the default database name value for new connections. If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default host. -get/set_defuser -- default database user [DV] ---------------------------------------------- +get/set_defuser -- default database user +---------------------------------------- .. function:: get_defuser() @@ -208,8 +230,8 @@ This method sets the default database user name for new connections. If ``None`` is supplied as parameter, environment variables will be used in future connections. It returns the previous setting for default host. -get/set_defpasswd -- default database password [DV] ---------------------------------------------------- +get/set_defpasswd -- default database password +---------------------------------------------- .. function:: get_defpasswd() @@ -267,8 +289,8 @@ which takes connection properties into account. Example:: name = input("Name? ") - phone = con.query("select phone from employees where name='%s'" - % escape_string(name)).getresult() + phone = con.query("select phone from employees" + f" where name='{escape_string(name)}'").singlescalar() escape_bytea -- escape binary data for use within SQL ----------------------------------------------------- @@ -277,12 +299,13 @@ escape_bytea -- escape binary data for use within SQL escape binary data for use within SQL as type ``bytea`` - :param str datastring: string containing the binary data that is to be escaped + :param bytes/str datastring: the binary data that is to be escaped :returns: the escaped string - :rtype: str + :rtype: bytes/str :raises TypeError: bad argument type, or too many arguments Escapes binary data for use within an SQL command with the type ``bytea``. +The return value will have the same type as the given *datastring*. As with :func:`escape_string`, this is only used when inserting data directly into an SQL command string. @@ -292,8 +315,8 @@ which takes connection properties into account. Example:: picture = open('garfield.gif', 'rb').read() - con.query("update pictures set img='%s' where name='Garfield'" - % escape_bytea(picture)) + con.query(f"update pictures set img='{escape_bytea(picture)}'" + " where name='Garfield'") unescape_bytea -- unescape data that has been retrieved as text --------------------------------------------------------------- @@ -302,7 +325,7 @@ unescape_bytea -- unescape data that has been retrieved as text Unescape ``bytea`` data that has been retrieved as text - :param str datastring: the ``bytea`` data string that has been retrieved as text + :param str string: the ``bytea`` string that has been retrieved as text :returns: byte string containing the binary data :rtype: bytes :raises TypeError: bad argument type, or too many arguments @@ -326,8 +349,7 @@ get/set_decimal -- decimal type to be used for numeric values :rtype: class This function returns the Python class that is used by PyGreSQL to hold -PostgreSQL numeric values. The default class is :class:`decimal.Decimal` -if available, otherwise the :class:`float` type is used. +PostgreSQL numeric values. The default class is :class:`decimal.Decimal`. .. function:: set_decimal(cls) @@ -337,8 +359,7 @@ if available, otherwise the :class:`float` type is used. This function can be used to specify the Python class that shall be used by PyGreSQL to hold PostgreSQL numeric values. -The default class is :class:`decimal.Decimal` if available, -otherwise the :class:`float` type is used. +The default class is :class:`decimal.Decimal`. get/set_decimal_point -- decimal mark used for monetary values -------------------------------------------------------------- @@ -616,7 +637,7 @@ are not supported by default in PostgreSQL. :param str string: the string with the text representation of the array :param cast: a typecast function for the elements of the array :type cast: callable or None - :param delim: delimiter character between adjacent elements + :param bytes delim: delimiter character between adjacent elements :type str: byte string with a single character :returns: a list representing the PostgreSQL array in Python :rtype: list @@ -644,7 +665,7 @@ then a comma will be used by default. :param str string: the string with the text representation of the record :param cast: typecast function(s) for the elements of the record :type cast: callable, list or tuple of callables, or None - :param delim: delimiter character between adjacent elements + :param bytes delim: delimiter character between adjacent elements :type str: byte string with a single character :returns: a tuple representing the PostgreSQL record in Python :rtype: tuple @@ -728,6 +749,13 @@ for more information about them. These constants are: large objects access modes, used by :meth:`Connection.locreate` and :meth:`LargeObject.open` +.. data:: POLLING_OK +.. data:: POLLING_FAILED +.. data:: POLLING_READING +.. data:: POLLING_WRITING + + polling states, returned by :meth:`Connection.poll` + .. data:: SEEK_SET .. data:: SEEK_CUR .. data:: SEEK_END diff --git a/docs/contents/pg/notification.rst b/docs/contents/pg/notification.rst index a37df668..05b04a16 100644 --- a/docs/contents/pg/notification.rst +++ b/docs/contents/pg/notification.rst @@ -1,7 +1,7 @@ The Notification Handler ======================== -.. py:currentmodule:: pg +.. currentmodule:: pg PyGreSQL comes with a client-side asynchronous notification handler that was based on the ``pgnotify`` module written by Ng Pheng Siong. @@ -25,7 +25,7 @@ Instantiating the notification handler :param str stop_event: an optional different name to be used as stop event You can also create an instance of the NotificationHandler using the -:class:`DB.connection_handler` method. In this case you don't need to +:meth:`DB.connection_handler` method. In this case you don't need to pass a database connection because the :class:`DB` connection itself will be used as the datebase connection for the notification handler. @@ -116,4 +116,4 @@ or when it is closed or deleted. You can call this method instead of :meth:`NotificationHandler.unlisten` if you want to close not only the handler, but also the database connection -it was created with. \ No newline at end of file +it was created with. diff --git a/docs/contents/pg/query.rst b/docs/contents/pg/query.rst index 2d3b7abb..fcee193f 100644 --- a/docs/contents/pg/query.rst +++ b/docs/contents/pg/query.rst @@ -1,7 +1,7 @@ Query methods ============= -.. py:currentmodule:: pg +.. currentmodule:: pg .. class:: Query @@ -43,6 +43,9 @@ You can also call :func:`len` on a query to find the number of rows in the result, and access row tuples using their index directly on the :class:`Query` object. +When the :class:`Query` object was returned by :meth:`Connection.send_query`, +other return values are also possible, as documented there. + dictresult/dictiter -- get query values as dictionaries ------------------------------------------------------- @@ -81,6 +84,9 @@ fetched from the server anyway when the query is executed. If the query has duplicate field names, you will get the value for the field with the highest index in the query. +When the :class:`Query` object was returned by :meth:`Connection.send_query`, +other return values are also possible, as documented there. + .. versionadded:: 5.1 namedresult/namediter -- get query values as named tuples @@ -127,6 +133,9 @@ Column names in the database that are not valid as field names for named tuples (particularly, names starting with an underscore) are automatically renamed to valid positional names. +When the :class:`Query` object was returned by :meth:`Connection.send_query`, +other return values are also possible, as documented there. + .. versionadded:: 5.1 scalarresult/scalariter -- get query values as scalars @@ -204,7 +213,7 @@ It returns None if the result does not contain one more row. Get one row from the result of a query as named tuple :returns: next row from the query results as a named tuple - :rtype: named tuple or None + :rtype: namedtuple or None :raises TypeError: too many (any) parameters :raises MemoryError: internal memory error @@ -244,7 +253,7 @@ single/singledict/singlenamed/singlescalar -- get single result of a query :returns: single row from the query results as a tuple of fields :rtype: tuple - :raises InvalidResultError: result does not have exactly one row + :raises pg.InvalidResultError: result does not have exactly one row :raises TypeError: too many (any) parameters :raises MemoryError: internal memory error @@ -263,7 +272,7 @@ is empty and of type :exc:`pg.MultipleResultsError` if it has multiple rows. :returns: single row from the query results as a dictionary :rtype: dict - :raises InvalidResultError: result does not have exactly one row + :raises pg.InvalidResultError: result does not have exactly one row :raises TypeError: too many (any) parameters :raises MemoryError: internal memory error @@ -282,8 +291,8 @@ is empty and of type :exc:`pg.MultipleResultsError` if it has multiple rows. Get single row from the result of a query as named tuple :returns: single row from the query results as a named tuple - :rtype: named tuple - :raises InvalidResultError: result does not have exactly one row + :rtype: namedtuple + :raises pg.InvalidResultError: result does not have exactly one row :raises TypeError: too many (any) parameters :raises MemoryError: internal memory error @@ -306,7 +315,7 @@ is empty and of type :exc:`pg.MultipleResultsError` if it has multiple rows. :returns: single row from the query results as a scalar value :rtype: type of first field - :raises InvalidResultError: result does not have exactly one row + :raises pg.InvalidResultError: result does not have exactly one row :raises TypeError: too many (any) parameters :raises MemoryError: internal memory error @@ -319,19 +328,19 @@ is empty and of type :exc:`pg.MultipleResultsError` if it has multiple rows. .. versionadded:: 5.1 -listfields -- list fields names of previous query result --------------------------------------------------------- +listfields -- list field names of query result +---------------------------------------------- .. method:: Query.listfields() - List fields names of previous query result + List field names of query result :returns: field names - :rtype: list + :rtype: tuple :raises TypeError: too many parameters -This method returns the list of field names defined for the -query result. The fields are in the same order as the result values. +This method returns the tuple of field names defined for the query result. +The fields are in the same order as the result values. fieldname, fieldnum -- field name/number conversion --------------------------------------------------- @@ -365,18 +374,43 @@ build a function that converts result list strings to their correct type, using a hardcoded table definition. The number returned is the field rank in the query result. -ntuples -- return number of tuples in query object --------------------------------------------------- +fieldinfo -- detailed info about query result fields +---------------------------------------------------- + +.. method:: Query.fieldinfo([field]) + + Get information on one or all fields of the query + + :param field: a column number or name (optional) + :type field: int or str + :returns: field info tuple(s) for all fields or given field + :rtype: tuple + :raises IndexError: field does not exist + :raises TypeError: too many parameters + +If the ``field`` is specified by passing either a column number or a field +name, a four-tuple with information for the specified field of the query +result will be returned. If no ``field`` is specified, a tuple of four-tuples +for every field of the previous query result will be returned, in the same +order as they appear in the query result. + +The four-tuples contain the following information: The field name, the +internal OID number of the field type, the size in bytes of the column or a +negative value if it is of variable size, and a type-specific modifier value. + +.. versionadded:: 5.2 + +memsize -- return number of bytes allocated by query result +----------------------------------------------------------- -.. method:: Query.ntuples() +.. method:: Query.memsize() - Return number of tuples in query object + Return number of bytes allocated by query result - :returns: number of tuples in :class:`Query` + :returns: number of bytes allocated for the query result :rtype: int :raises TypeError: Too many arguments. -This method returns the number of tuples in the query result. +This method returns the number of bytes allocated for the query result. -.. deprecated:: 5.1 - You can use the normal :func:`len` function instead. +.. versionadded:: 5.2 (needs PostgreSQL >= 12) diff --git a/docs/contents/pgdb/adaptation.rst b/docs/contents/pgdb/adaptation.rst index 1295b44f..ac649a21 100644 --- a/docs/contents/pgdb/adaptation.rst +++ b/docs/contents/pgdb/adaptation.rst @@ -1,7 +1,7 @@ Remarks on Adaptation and Typecasting ===================================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb Both PostgreSQL and Python have the concept of data types, but there are of course differences between the two type systems. Therefore PyGreSQL @@ -26,7 +26,7 @@ PostgreSQL Python char, bpchar, name, text, varchar str bool bool bytea bytes -int2, int4, int8, oid, serial int [#int8]_ +int2, int4, int8, oid, serial int int2vector list of int float4, float8 float numeric, money Decimal @@ -45,8 +45,6 @@ record tuple Elements of arrays and records will also be converted accordingly. - .. [#int8] int8 is converted to long in Python 2 - .. [#array] The first element of the array will always be the first element of the Python list, no matter what the lower bound of the PostgreSQL array is. The information about the start index of the array (which is @@ -211,7 +209,7 @@ to hold our values, like this one:: ... self.price = price ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) But when we try to insert an instance of this class in the same way, we @@ -233,7 +231,7 @@ with the name ``__pg_repr__``, like this:: ... ... ... ... def __str__(self): - ... return '%s (from %s, at $%s)' % ( + ... return '{} (from {}, at ${})'.format( ... self.name, self.supplier_id, self.price) ... ... def __pg_repr__(self): diff --git a/docs/contents/pgdb/connection.rst b/docs/contents/pgdb/connection.rst index 958108b7..71492847 100644 --- a/docs/contents/pgdb/connection.rst +++ b/docs/contents/pgdb/connection.rst @@ -1,7 +1,7 @@ Connection -- The connection object =================================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb .. class:: Connection diff --git a/docs/contents/pgdb/cursor.rst b/docs/contents/pgdb/cursor.rst index a2ac63e8..72473057 100644 --- a/docs/contents/pgdb/cursor.rst +++ b/docs/contents/pgdb/cursor.rst @@ -1,7 +1,7 @@ Cursor -- The cursor object =========================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb .. class:: Cursor @@ -150,7 +150,7 @@ fetchone -- fetch next row of the query result Fetch the next row of a query result set :returns: the next row of the query result set - :rtype: named tuple or None + :rtype: namedtuple or None Fetch the next row of a query result set, returning a single named tuple, or ``None`` when no more data is available. The field names of the named @@ -176,7 +176,7 @@ fetchmany -- fetch next set of rows of the query result :param keep: if set to true, will keep the passed arraysize :tpye keep: bool :returns: the next set of rows of the query result - :rtype: list of named tuples + :rtype: list of namedtuples Fetch the next set of rows of a query result, returning a list of named tuples. An empty sequence is returned when no more rows are available. @@ -212,7 +212,7 @@ fetchall -- fetch all rows of the query result Fetch all (remaining) rows of a query result :returns: the set of all rows of the query result - :rtype: list of named tuples + :rtype: list of namedtuples Fetch all (remaining) rows of a query result, returning them as list of named tuples. The field names of the named tuple are the same as the column @@ -295,7 +295,7 @@ specified, all of them will be copied. :param str null: the textual representation of the ``NULL`` value, can also be an empty string (the default is ``'\\N'``) :param bool decode: whether decoded strings shall be returned - for non-binary formats (the default is True in Python 3) + for non-binary formats (the default is ``True``) :param list column: an optional list of column names :returns: a generator if stream is set to ``None``, otherwise the cursor @@ -340,8 +340,8 @@ be used for all result sets. If you overwrite this method, the method will be ignored. Note that named tuples are very efficient and can be easily converted to -dicts (even OrderedDicts) by calling ``row._asdict()``. If you still want -to return rows as dicts, you can create a custom cursor class like this:: +dicts by calling ``row._asdict()``. If you still want to return rows as dicts, +you can create a custom cursor class like this:: class DictCursor(pgdb.Cursor): diff --git a/docs/contents/pgdb/module.rst b/docs/contents/pgdb/module.rst index 884ac4dc..5220193c 100644 --- a/docs/contents/pgdb/module.rst +++ b/docs/contents/pgdb/module.rst @@ -1,7 +1,7 @@ Module functions and constants ============================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb The :mod:`pgdb` module defines a :func:`connect` function that allows to connect to a database, some global constants describing the capabilities diff --git a/docs/contents/pgdb/typecache.rst b/docs/contents/pgdb/typecache.rst index a8b203ab..f0861a23 100644 --- a/docs/contents/pgdb/typecache.rst +++ b/docs/contents/pgdb/typecache.rst @@ -1,7 +1,7 @@ TypeCache -- The internal cache for database types ================================================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb .. class:: TypeCache diff --git a/docs/contents/pgdb/types.rst b/docs/contents/pgdb/types.rst index 0c13ec6b..d739df32 100644 --- a/docs/contents/pgdb/types.rst +++ b/docs/contents/pgdb/types.rst @@ -1,7 +1,7 @@ Type -- Type objects and constructors ===================================== -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb .. _type_constructors: @@ -101,15 +101,15 @@ Example for using a type constructor:: Type objects ------------ -.. class:: Type +.. class:: DbType The :attr:`Cursor.description` attribute returns information about each of the result columns of a query. The *type_code* must compare equal to one -of the :class:`Type` objects defined below. Type objects can be equal to +of the :class:`DbType` objects defined below. Type objects can be equal to more than one type code (e.g. :class:`DATETIME` is equal to the type codes for ``date``, ``time`` and ``timestamp`` columns). -The pgdb module exports the following :class:`Type` objects as part of the +The pgdb module exports the following :class:`DbType` objects as part of the DB-API 2 standard: .. object:: STRING diff --git a/docs/contents/postgres/advanced.rst b/docs/contents/postgres/advanced.rst index 38c8a473..d7627312 100644 --- a/docs/contents/postgres/advanced.rst +++ b/docs/contents/postgres/advanced.rst @@ -1,7 +1,7 @@ Examples for advanced features ============================== -.. py:currentmodule:: pg +.. currentmodule:: pg In this section, we show how to use some advanced features of PostgreSQL using the classic PyGreSQL interface. @@ -27,7 +27,7 @@ all data fields from cities):: ... "'Las Vegas', 2.583E+5, 2174", ... "'Mariposa', 1200, 1953"]), ... ('capitals', [ - ... "'Sacramento',3.694E+5,30,'CA'", + ... "'Sacramento', 3.694E+5,30, 'CA'", ... "'Madison', 1.913E+5, 845, 'WI'"])] Now, let's populate the tables:: @@ -37,11 +37,11 @@ Now, let's populate the tables:: ... "'Las Vegas', 2.583E+5, 2174" ... "'Mariposa', 1200, 1953"], ... 'capitals', [ - ... "'Sacramento',3.694E+5,30,'CA'", + ... "'Sacramento', 3.694E+5,30, 'CA'", ... "'Madison', 1.913E+5, 845, 'WI'"]] >>> for table, rows in data: ... for row in rows: - ... query("INSERT INTO %s VALUES (%s)" % (table, row)) + ... query(f"INSERT INTO {table} VALUES (row)") >>> print(query("SELECT * FROM cities")) name |population|altitude -------------+----------+-------- diff --git a/docs/contents/postgres/basic.rst b/docs/contents/postgres/basic.rst index e6973442..b137351e 100644 --- a/docs/contents/postgres/basic.rst +++ b/docs/contents/postgres/basic.rst @@ -1,7 +1,7 @@ Basic examples ============== -.. py:currentmodule:: pg +.. currentmodule:: pg In this section, we demonstrate how to use some of the very basic features of PostgreSQL using the classic PyGreSQL interface. diff --git a/docs/contents/postgres/func.rst b/docs/contents/postgres/func.rst index b35e5ff7..3bfcfd98 100644 --- a/docs/contents/postgres/func.rst +++ b/docs/contents/postgres/func.rst @@ -1,7 +1,7 @@ Examples for using SQL functions ================================ -.. py:currentmodule:: pg +.. currentmodule:: pg We assume that you have already created a connection to the PostgreSQL database, as explained in the :doc:`basic`:: @@ -62,7 +62,7 @@ Before we create more sophisticated functions, let's populate an EMP table:: ... "'Bill', 4200, 36, 'shoe'", ... "'Ginger', 4800, 30, 'candy'"] >>> for emp in emps: - ... query("INSERT INTO EMP VALUES (%s)" % emp) + ... query(f"INSERT INTO EMP VALUES ({emp})") Every INSERT statement will return a '1' indicating that it has inserted one row into the EMP table. diff --git a/docs/contents/postgres/syscat.rst b/docs/contents/postgres/syscat.rst index 13740203..80718afb 100644 --- a/docs/contents/postgres/syscat.rst +++ b/docs/contents/postgres/syscat.rst @@ -1,7 +1,7 @@ Examples for using the system catalogs ====================================== -.. py:currentmodule:: pg +.. currentmodule:: pg The system catalogs are regular tables where PostgreSQL stores schema metadata, such as information about tables and columns, and internal bookkeeping diff --git a/docs/contents/tutorial.rst b/docs/contents/tutorial.rst index 0ce05430..79273c7c 100644 --- a/docs/contents/tutorial.rst +++ b/docs/contents/tutorial.rst @@ -11,7 +11,7 @@ with both flavors of the PyGreSQL interface. Please choose your flavor: First Steps with the classic PyGreSQL Interface ----------------------------------------------- -.. py:currentmodule:: pg +.. currentmodule:: pg Before doing anything else, it's necessary to create a database connection. @@ -117,13 +117,8 @@ Using the method :meth:`DB.get_as_dict`, you can easily import the whole table into a Python dictionary mapping the primary key *id* to the *name*:: >>> db.get_as_dict('fruits', scalar=True) - OrderedDict([(1, 'apple'), - (2, 'banana'), - (3, 'cherimaya'), - (4, 'durian'), - (5, 'eggfruit'), - (6, 'fig'), - (7, 'grapefruit')]) + {1: 'apple', 2: 'banana', 3: 'cherimaya', 4: 'durian', 5: 'eggfruit', + 6: 'fig', 7: 'grapefruit', 8: 'apple', 9: 'banana'} To change a single row in the database, you can use the :meth:`DB.update` method. For instance, if you want to capitalize the name 'banana':: @@ -190,7 +185,7 @@ For more advanced features and details, see the reference: :doc:`pg/index` First Steps with the DB-API 2.0 Interface ----------------------------------------- -.. py:currentmodule:: pgdb +.. currentmodule:: pgdb As with the classic interface, the first thing you need to do is to create a database connection. To do this, use the function :func:`pgdb.connect` diff --git a/docs/copyright.rst b/docs/copyright.rst index 77d9ef83..bf7d9b04 100644 --- a/docs/copyright.rst +++ b/docs/copyright.rst @@ -10,7 +10,7 @@ Copyright (c) 1995, Pascal Andre Further modifications copyright (c) 1997-2008 by D'Arcy J.M. Cain (darcy@PyGreSQL.org) -Further modifications copyright (c) 2009-2020 by the PyGreSQL team. +Further modifications copyright (c) 2009-2025 by the PyGreSQL team. Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement diff --git a/docs/download/files.rst b/docs/download/files.rst index 4f4741fd..fc3ad26f 100644 --- a/docs/download/files.rst +++ b/docs/download/files.rst @@ -3,27 +3,13 @@ Distribution files ============== = -pgmodule.c the main source file for the C extension module (_pg) -pgconn.c the connection object -pginternal.c internal functions -pglarge.c large object support -pgnotice.c the notice object -pgquery.c the query object -pgsource.c the source object +pg/ the "classic" PyGreSQL package -pgtypes.h PostgreSQL type definitions -py3c.h Python 2/3 compatibility layer for the C extension +pgdb/ a DB-SIG DB-API 2.0 compliant API wrapper for PyGreSQL -pg.py the "classic" PyGreSQL module -pgdb.py a DB-SIG DB-API 2.0 compliant API wrapper for PyGreSQL +ext/ the source files for the C extension module -setup.py the Python setup script - - To install PyGreSQL, you can run "python setup.py install". - -setup.cfg the Python setup configuration - -docs/ documentation directory +docs/ the documentation directory The documentation has been created with Sphinx. All text files are in ReST format; a HTML version of @@ -31,4 +17,12 @@ docs/ documentation directory tests/ a suite of unit tests for PyGreSQL +pyproject.toml contains project metadata and the build system requirements + +setup.py the Python setup script used for building the C extension + +LICENSE.text contains the license information for PyGreSQL + +README.rst a summary of the PyGreSQL project + ============== = diff --git a/docs/download/index.rst b/docs/download/index.rst index c4735826..88bf77b0 100644 --- a/docs/download/index.rst +++ b/docs/download/index.rst @@ -3,10 +3,8 @@ Download information .. include:: download.rst -News, Changes and Future Development ------------------------------------- - -See the :doc:`../announce` for current news. +Changes and Future Development +------------------------------ For a list of all changes in the current version |version| and in past versions, have a look at the :doc:`../contents/changelog`. diff --git a/docs/toc.txt b/docs/index.rst similarity index 58% rename from docs/toc.txt rename to docs/index.rst index 441021b4..88292059 100644 --- a/docs/toc.txt +++ b/docs/index.rst @@ -1,5 +1,3 @@ -.. PyGreSQL index page with toc (for use without cloud theme) - Welcome to PyGreSQL =================== @@ -8,7 +6,6 @@ Welcome to PyGreSQL about copyright - announce download/index contents/index - community/index \ No newline at end of file + community/index diff --git a/docs/make.bat b/docs/make.bat index b8571b60..954237b9 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -1,62 +1,16 @@ @ECHO OFF +pushd %~dp0 + REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) +set SOURCEDIR=. set BUILDDIR=_build -set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . -set I18NSPHINXOPTS=%SPHINXOPTS% . -if NOT "%PAPER%" == "" ( - set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% - set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% -) - -if "%1" == "" goto help - -if "%1" == "help" ( - :help - echo.Please use `make ^` where ^ is one of - echo. html to make standalone HTML files - echo. dirhtml to make HTML files named index.html in directories - echo. singlehtml to make a single large HTML file - echo. pickle to make pickle files - echo. json to make JSON files - echo. htmlhelp to make HTML files and a HTML help project - echo. qthelp to make HTML files and a qthelp project - echo. devhelp to make HTML files and a Devhelp project - echo. epub to make an epub - echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter - echo. text to make text files - echo. man to make manual pages - echo. texinfo to make Texinfo files - echo. gettext to make PO message catalogs - echo. changes to make an overview over all changed/added/deprecated items - echo. xml to make Docutils-native XML files - echo. pseudoxml to make pseudoxml-XML files for display purposes - echo. linkcheck to check all external links for integrity - echo. doctest to run all doctests embedded in the documentation if enabled - echo. coverage to run coverage check of the documentation if enabled - goto end -) - -if "%1" == "clean" ( - for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i - del /q /s %BUILDDIR%\* - goto end -) - - -REM Check if sphinx-build is available and fallback to Python version if any -%SPHINXBUILD% 1>NUL 2>NUL -if errorlevel 9009 goto sphinx_python -goto sphinx_ok -:sphinx_python - -set SPHINXBUILD=python -m sphinx.__init__ -%SPHINXBUILD% 2> nul +%SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx @@ -65,199 +19,17 @@ if errorlevel 9009 ( echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ + echo.https://www.sphinx-doc.org/ exit /b 1 ) -:sphinx_ok - - -if "%1" == "html" ( - %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/html. - goto end -) - -if "%1" == "dirhtml" ( - %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. - goto end -) - -if "%1" == "singlehtml" ( - %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. - goto end -) - -if "%1" == "pickle" ( - %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the pickle files. - goto end -) - -if "%1" == "json" ( - %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the JSON files. - goto end -) - -if "%1" == "htmlhelp" ( - %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run HTML Help Workshop with the ^ -.hhp project file in %BUILDDIR%/htmlhelp. - goto end -) - -if "%1" == "qthelp" ( - %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run "qcollectiongenerator" with the ^ -.qhcp project file in %BUILDDIR%/qthelp, like this: - echo.^> qcollectiongenerator %BUILDDIR%\qthelp\PyGreSQL.qhcp - echo.To view the help file: - echo.^> assistant -collectionFile %BUILDDIR%\qthelp\PyGreSQL.ghc - goto end -) - -if "%1" == "devhelp" ( - %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. - goto end -) - -if "%1" == "epub" ( - %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The epub file is in %BUILDDIR%/epub. - goto end -) - -if "%1" == "latex" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdf" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf - cd %~dp0 - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdfja" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf-ja - cd %~dp0 - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "text" ( - %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The text files are in %BUILDDIR%/text. - goto end -) - -if "%1" == "man" ( - %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The manual pages are in %BUILDDIR%/man. - goto end -) - -if "%1" == "texinfo" ( - %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. - goto end -) - -if "%1" == "gettext" ( - %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The message catalogs are in %BUILDDIR%/locale. - goto end -) - -if "%1" == "changes" ( - %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes - if errorlevel 1 exit /b 1 - echo. - echo.The overview file is in %BUILDDIR%/changes. - goto end -) - -if "%1" == "linkcheck" ( - %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck - if errorlevel 1 exit /b 1 - echo. - echo.Link check complete; look for any errors in the above output ^ -or in %BUILDDIR%/linkcheck/output.txt. - goto end -) - -if "%1" == "doctest" ( - %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest - if errorlevel 1 exit /b 1 - echo. - echo.Testing of doctests in the sources finished, look at the ^ -results in %BUILDDIR%/doctest/output.txt. - goto end -) +if "%1" == "" goto help -if "%1" == "coverage" ( - %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage - if errorlevel 1 exit /b 1 - echo. - echo.Testing of coverage in the sources finished, look at the ^ -results in %BUILDDIR%/coverage/python.txt. - goto end -) +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end -if "%1" == "xml" ( - %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The XML files are in %BUILDDIR%/xml. - goto end -) - -if "%1" == "pseudoxml" ( - %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. - goto end -) +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end +popd diff --git a/docs/requirements.txt b/docs/requirements.txt index c354e8d9..9cd8b2f5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1 +1 @@ -cloud_sptheme>=1.7.1 \ No newline at end of file +sphinx>=7,<8 diff --git a/docs/start.txt b/docs/start.txt deleted file mode 100644 index 5166896a..00000000 --- a/docs/start.txt +++ /dev/null @@ -1,15 +0,0 @@ -.. PyGreSQL index page without toc (for use with cloud theme) - -Welcome to PyGreSQL -=================== - -.. toctree:: - :hidden: - - copyright - announce - download/index - contents/index - community/index - -.. include:: about.txt \ No newline at end of file diff --git a/ext/pgconn.c b/ext/pgconn.c new file mode 100644 index 00000000..783eaffc --- /dev/null +++ b/ext/pgconn.c @@ -0,0 +1,1822 @@ +/* + * PyGreSQL - a Python interface for the PostgreSQL database. + * + * The connection object - this file is part a of the C extension module. + * + * Copyright (c) 2025 by the PyGreSQL Development Team + * + * Please see the LICENSE.TXT file for specific restrictions. + */ + +/* Deallocate connection object. */ +static void +conn_dealloc(connObject *self) +{ + if (self->cnx) { + Py_BEGIN_ALLOW_THREADS + PQfinish(self->cnx); + Py_END_ALLOW_THREADS + } + Py_XDECREF(self->cast_hook); + Py_XDECREF(self->notice_receiver); + PyObject_Del(self); +} + +/* Get connection attributes. */ +static PyObject * +conn_getattr(connObject *self, PyObject *nameobj) +{ + const char *name = PyUnicode_AsUTF8(nameobj); + + /* + * Although we could check individually, there are only a few + * attributes that don't require a live connection and unless someone + * has an urgent need, this will have to do. + */ + + /* first exception - close which returns a different error */ + if (strcmp(name, "close") && !self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* list PostgreSQL connection fields */ + + /* postmaster host */ + if (!strcmp(name, "host")) { + char *r = PQhost(self->cnx); + if (!r || r[0] == '/') /* this can return a Unix socket path */ + r = "localhost"; + return PyUnicode_FromString(r); + } + + /* postmaster port */ + if (!strcmp(name, "port")) + return PyLong_FromLong(atol(PQport(self->cnx))); + + /* selected database */ + if (!strcmp(name, "db")) + return PyUnicode_FromString(PQdb(self->cnx)); + + /* selected options */ + if (!strcmp(name, "options")) + return PyUnicode_FromString(PQoptions(self->cnx)); + + /* error (status) message */ + if (!strcmp(name, "error")) + return PyUnicode_FromString(PQerrorMessage(self->cnx)); + + /* connection status : 1 - OK, 0 - BAD */ + if (!strcmp(name, "status")) + return PyLong_FromLong(PQstatus(self->cnx) == CONNECTION_OK ? 1 : 0); + + /* provided user name */ + if (!strcmp(name, "user")) + return PyUnicode_FromString(PQuser(self->cnx)); + + /* protocol version */ + if (!strcmp(name, "protocol_version")) + return PyLong_FromLong(PQprotocolVersion(self->cnx)); + + /* backend version */ + if (!strcmp(name, "server_version")) + return PyLong_FromLong(PQserverVersion(self->cnx)); + + /* descriptor number of connection socket */ + if (!strcmp(name, "socket")) { + return PyLong_FromLong(PQsocket(self->cnx)); + } + + /* PID of backend process */ + if (!strcmp(name, "backend_pid")) { + return PyLong_FromLong(PQbackendPID(self->cnx)); + } + + /* whether the connection uses SSL */ + if (!strcmp(name, "ssl_in_use")) { + if (PQsslInUse(self->cnx)) { + Py_INCREF(Py_True); + return Py_True; + } + else { + Py_INCREF(Py_False); + return Py_False; + } + } + + /* SSL attributes */ + if (!strcmp(name, "ssl_attributes")) { + return get_ssl_attributes(self->cnx); + } + + return PyObject_GenericGetAttr((PyObject *)self, nameobj); +} + +/* Check connection validity. */ +static int +_check_cnx_obj(connObject *self) +{ + if (!self || !self->valid || !self->cnx) { + set_error_msg(OperationalError, "Connection has been closed"); + return 0; + } + return 1; +} + +/* Create source object. */ +static char conn_source__doc__[] = + "source() -- create a new source object for this connection"; + +static PyObject * +conn_source(connObject *self, PyObject *noargs) +{ + sourceObject *source_obj; + + /* checks validity */ + if (!_check_cnx_obj(self)) { + return NULL; + } + + /* allocates new query object */ + if (!(source_obj = PyObject_New(sourceObject, &sourceType))) { + return NULL; + } + + /* initializes internal parameters */ + Py_XINCREF(self); + source_obj->pgcnx = self; + source_obj->result = NULL; + source_obj->valid = 1; + source_obj->arraysize = PG_ARRAYSIZE; + + return (PyObject *)source_obj; +} + +/* For a non-query result, set the appropriate error status, + return the appropriate value, and free the result set. */ +static PyObject * +_conn_non_query_result(int status, PGresult *result, PGconn *cnx) +{ + switch (status) { + case PGRES_EMPTY_QUERY: + PyErr_SetString(PyExc_ValueError, "Empty query"); + break; + case PGRES_BAD_RESPONSE: + case PGRES_FATAL_ERROR: + case PGRES_NONFATAL_ERROR: + set_error(ProgrammingError, "Cannot execute query", cnx, result); + break; + case PGRES_COMMAND_OK: { /* INSERT, UPDATE, DELETE */ + Oid oid = PQoidValue(result); + + if (oid == InvalidOid) { /* not a single insert */ + char *ret = PQcmdTuples(result); + + if (ret[0]) { /* return number of rows affected */ + PyObject *obj = PyUnicode_FromString(ret); + PQclear(result); + return obj; + } + PQclear(result); + Py_INCREF(Py_None); + return Py_None; + } + /* for a single insert, return the oid */ + PQclear(result); + return PyLong_FromLong((long)oid); + } + case PGRES_COPY_OUT: /* no data will be received */ + case PGRES_COPY_IN: + PQclear(result); + Py_INCREF(Py_None); + return Py_None; + default: + set_error_msg(InternalError, "Unknown result status"); + } + + PQclear(result); + return NULL; /* error detected on query */ +} + +/* Base method for execution of all different kinds of queries */ +static PyObject * +_conn_query(connObject *self, PyObject *args, int prepared, int async) +{ + PyObject *query_str_obj, *param_obj = NULL; + PGresult *result; + queryObject *query_obj; + char *query; + int encoding, status, nparms = 0; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* get query args */ + if (!PyArg_ParseTuple(args, "O|O", &query_str_obj, ¶m_obj)) { + return NULL; + } + + encoding = PQclientEncoding(self->cnx); + + if (PyBytes_Check(query_str_obj)) { + query = PyBytes_AsString(query_str_obj); + query_str_obj = NULL; + } + else if (PyUnicode_Check(query_str_obj)) { + query_str_obj = get_encoded_string(query_str_obj, encoding); + if (!query_str_obj) + return NULL; /* pass the UnicodeEncodeError */ + query = PyBytes_AsString(query_str_obj); + } + else { + PyErr_SetString(PyExc_TypeError, + "Method query() expects a string as first argument"); + return NULL; + } + + /* If param_obj is passed, ensure it's a non-empty tuple. We want to treat + * an empty tuple the same as no argument since we'll get that when the + * caller passes no arguments to db.query(), and historic behaviour was + * to call PQexec() in that case, which can execute multiple commands. */ + if (param_obj) { + param_obj = PySequence_Fast( + param_obj, "Method query() expects a sequence as second argument"); + if (!param_obj) { + Py_XDECREF(query_str_obj); + return NULL; + } + nparms = (int)PySequence_Fast_GET_SIZE(param_obj); + + /* if there's a single argument and it's a list or tuple, it + * contains the positional arguments. */ + if (nparms == 1) { + PyObject *first_obj = PySequence_Fast_GET_ITEM(param_obj, 0); + if (PyList_Check(first_obj) || PyTuple_Check(first_obj)) { + Py_DECREF(param_obj); + param_obj = PySequence_Fast(first_obj, NULL); + nparms = (int)PySequence_Fast_GET_SIZE(param_obj); + } + } + } + + /* gets result */ + if (nparms) { + /* prepare arguments */ + PyObject **str, **s; + const char **parms, **p; + register int i; + + str = (PyObject **)PyMem_Malloc((size_t)nparms * sizeof(*str)); + parms = (const char **)PyMem_Malloc((size_t)nparms * sizeof(*parms)); + if (!str || !parms) { + PyMem_Free((void *)parms); + PyMem_Free(str); + Py_XDECREF(query_str_obj); + Py_XDECREF(param_obj); + return PyErr_NoMemory(); + } + + /* convert optional args to a list of strings -- this allows + * the caller to pass whatever they like, and prevents us + * from having to map types to OIDs */ + for (i = 0, s = str, p = parms; i < nparms; ++i, ++p) { + PyObject *obj = PySequence_Fast_GET_ITEM(param_obj, i); + + if (obj == Py_None) { + *p = NULL; + } + else if (PyBytes_Check(obj)) { + *p = PyBytes_AsString(obj); + } + else if (PyUnicode_Check(obj)) { + PyObject *str_obj = get_encoded_string(obj, encoding); + if (!str_obj) { + PyMem_Free((void *)parms); + while (s != str) { + s--; + Py_DECREF(*s); + } + PyMem_Free(str); + Py_XDECREF(query_str_obj); + Py_XDECREF(param_obj); + /* pass the UnicodeEncodeError */ + return NULL; + } + *s++ = str_obj; + *p = PyBytes_AsString(str_obj); + } + else { + PyObject *str_obj = PyObject_Str(obj); + if (!str_obj) { + PyMem_Free((void *)parms); + while (s != str) { + s--; + Py_DECREF(*s); + } + PyMem_Free(str); + Py_XDECREF(query_str_obj); + Py_XDECREF(param_obj); + PyErr_SetString( + PyExc_TypeError, + "Query parameter has no string representation"); + return NULL; + } + *s++ = str_obj; + *p = PyUnicode_AsUTF8(str_obj); + } + } + + Py_BEGIN_ALLOW_THREADS + if (async) { + status = + PQsendQueryParams(self->cnx, query, nparms, NULL, + (const char *const *)parms, NULL, NULL, 0); + result = NULL; + } + else { + result = prepared ? PQexecPrepared(self->cnx, query, nparms, parms, + NULL, NULL, 0) + : PQexecParams(self->cnx, query, nparms, NULL, + parms, NULL, NULL, 0); + status = result != NULL; + } + Py_END_ALLOW_THREADS + + PyMem_Free((void *)parms); + while (s != str) { + s--; + Py_DECREF(*s); + } + PyMem_Free(str); + } + else { + Py_BEGIN_ALLOW_THREADS + if (async) { + status = PQsendQuery(self->cnx, query); + result = NULL; + } + else { + result = prepared ? PQexecPrepared(self->cnx, query, 0, NULL, NULL, + NULL, 0) + : PQexec(self->cnx, query); + status = result != NULL; + } + Py_END_ALLOW_THREADS + } + + /* we don't need the query and its params any more */ + Py_XDECREF(query_str_obj); + Py_XDECREF(param_obj); + + /* checks result validity */ + if (!status) { + PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); + return NULL; + } + + /* this may have changed the datestyle, so we reset the date format + in order to force fetching it newly when next time requested */ + self->date_format = date_format; /* this is normally NULL */ + + /* checks result status */ + if (result && (status = PQresultStatus(result)) != PGRES_TUPLES_OK) + return _conn_non_query_result(status, result, self->cnx); + + if (!(query_obj = PyObject_New(queryObject, &queryType))) + return PyErr_NoMemory(); + + /* stores result and returns object */ + Py_XINCREF(self); + query_obj->pgcnx = self; + query_obj->result = result; + query_obj->async = async; + query_obj->encoding = encoding; + query_obj->current_row = 0; + if (async) { + query_obj->max_row = 0; + query_obj->num_fields = 0; + query_obj->col_types = NULL; + } + else { + query_obj->max_row = PQntuples(result); + query_obj->num_fields = PQnfields(result); + query_obj->col_types = get_col_types(result, query_obj->num_fields); + if (!query_obj->col_types) { + Py_DECREF(query_obj); + Py_DECREF(self); + return NULL; + } + } + + return (PyObject *)query_obj; +} + +/* Database query */ +static char conn_query__doc__[] = + "query(sql, [arg]) -- create a new query object for this connection\n\n" + "You must pass the SQL (string) request and you can optionally pass\n" + "a tuple with positional parameters.\n"; + +static PyObject * +conn_query(connObject *self, PyObject *args) +{ + return _conn_query(self, args, 0, 0); +} + +/* Asynchronous database query */ +static char conn_send_query__doc__[] = + "send_query(sql, [arg]) -- create a new asynchronous query for this " + "connection\n\n" + "You must pass the SQL (string) request and you can optionally pass\n" + "a tuple with positional parameters.\n"; + +static PyObject * +conn_send_query(connObject *self, PyObject *args) +{ + return _conn_query(self, args, 0, 1); +} + +/* Execute prepared statement. */ +static char conn_query_prepared__doc__[] = + "query_prepared(name, [arg]) -- execute a prepared statement\n\n" + "You must pass the name (string) of the prepared statement and you can\n" + "optionally pass a tuple with positional parameters.\n"; + +static PyObject * +conn_query_prepared(connObject *self, PyObject *args) +{ + return _conn_query(self, args, 1, 0); +} + +/* Create prepared statement. */ +static char conn_prepare__doc__[] = + "prepare(name, sql) -- create a prepared statement\n\n" + "You must pass the name (string) of the prepared statement and the\n" + "SQL (string) request for later execution.\n"; + +static PyObject * +conn_prepare(connObject *self, PyObject *args) +{ + char *name, *query; + Py_ssize_t name_length, query_length; + PGresult *result; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* reads args */ + if (!PyArg_ParseTuple(args, "s#s#", &name, &name_length, &query, + &query_length)) { + PyErr_SetString(PyExc_TypeError, + "Method prepare() takes two string arguments"); + return NULL; + } + + /* create prepared statement */ + Py_BEGIN_ALLOW_THREADS + result = PQprepare(self->cnx, name, query, 0, NULL); + Py_END_ALLOW_THREADS + if (result && PQresultStatus(result) == PGRES_COMMAND_OK) { + PQclear(result); + Py_INCREF(Py_None); + return Py_None; /* success */ + } + set_error(ProgrammingError, "Cannot create prepared statement", self->cnx, + result); + if (result) + PQclear(result); + return NULL; /* error */ +} + +/* Describe prepared statement. */ +static char conn_describe_prepared__doc__[] = + "describe_prepared(name) -- describe a prepared statement\n\n" + "You must pass the name (string) of the prepared statement.\n"; + +static PyObject * +conn_describe_prepared(connObject *self, PyObject *args) +{ + char *name; + Py_ssize_t name_length; + PGresult *result; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* reads args */ + if (!PyArg_ParseTuple(args, "s#", &name, &name_length)) { + PyErr_SetString(PyExc_TypeError, + "Method describe_prepared() takes a string argument"); + return NULL; + } + + /* describe prepared statement */ + Py_BEGIN_ALLOW_THREADS + result = PQdescribePrepared(self->cnx, name); + Py_END_ALLOW_THREADS + if (result && PQresultStatus(result) == PGRES_COMMAND_OK) { + queryObject *query_obj = PyObject_New(queryObject, &queryType); + if (!query_obj) + return PyErr_NoMemory(); + Py_XINCREF(self); + query_obj->pgcnx = self; + query_obj->result = result; + query_obj->encoding = PQclientEncoding(self->cnx); + query_obj->current_row = 0; + query_obj->max_row = PQntuples(result); + query_obj->num_fields = PQnfields(result); + query_obj->col_types = get_col_types(result, query_obj->num_fields); + return (PyObject *)query_obj; + } + set_error(ProgrammingError, "Cannot describe prepared statement", + self->cnx, result); + if (result) + PQclear(result); + return NULL; /* error */ +} + +static char conn_putline__doc__[] = + "putline(line) -- send a line directly to the backend"; + +/* Direct access function: putline. */ +static PyObject * +conn_putline(connObject *self, PyObject *args) +{ + char *line; + Py_ssize_t line_length; + int ret; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* read args */ + if (!PyArg_ParseTuple(args, "s#", &line, &line_length)) { + PyErr_SetString(PyExc_TypeError, + "Method putline() takes a string argument"); + return NULL; + } + + /* send line to backend */ + ret = PQputCopyData(self->cnx, line, (int)line_length); + if (ret != 1) { + PyErr_SetString( + PyExc_IOError, + ret == -1 + ? PQerrorMessage(self->cnx) + : "Line cannot be queued, wait for write-ready and try again"); + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + +/* Direct access function: getline. */ +static char conn_getline__doc__[] = + "getline() -- get a line directly from the backend"; + +static PyObject * +conn_getline(connObject *self, PyObject *noargs) +{ + char *line = NULL; + PyObject *str = NULL; + int ret; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* get line synchronously */ + ret = PQgetCopyData(self->cnx, &line, 0); + + /* check result */ + if (ret <= 0) { + if (line != NULL) + PQfreemem(line); + if (ret == -1) { + PQgetResult(self->cnx); + Py_INCREF(Py_None); + return Py_None; + } + PyErr_SetString( + PyExc_MemoryError, + ret == -2 + ? PQerrorMessage(self->cnx) + : "No line available, wait for read-ready and try again"); + return NULL; + } + if (line == NULL) { + Py_INCREF(Py_None); + return Py_None; + } + /* for backward compatibility, convert terminating newline to zero byte */ + if (*line) + line[strlen(line) - 1] = '\0'; + str = PyUnicode_FromString(line); + PQfreemem(line); + return str; +} + +/* Direct access function: end copy. */ +static char conn_endcopy__doc__[] = + "endcopy() -- synchronize client and server"; + +static PyObject * +conn_endcopy(connObject *self, PyObject *noargs) +{ + int ret; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* end direct copy */ + ret = PQputCopyEnd(self->cnx, NULL); + if (ret != 1) { + PyErr_SetString(PyExc_IOError, + ret == -1 ? PQerrorMessage(self->cnx) + : "Termination message cannot be queued," + " wait for write-ready and try again"); + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + +/* Direct access function: set blocking status. */ +static char conn_set_non_blocking__doc__[] = + "set_non_blocking() -- set the non-blocking status of the connection"; + +static PyObject * +conn_set_non_blocking(connObject *self, PyObject *args) +{ + int non_blocking; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + if (!PyArg_ParseTuple(args, "i", &non_blocking)) { + PyErr_SetString( + PyExc_TypeError, + "set_non_blocking() expects a boolean value as argument"); + return NULL; + } + + if (PQsetnonblocking(self->cnx, non_blocking) < 0) { + PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); + return NULL; + } + Py_INCREF(Py_None); + return Py_None; +} + +/* Direct access function: get blocking status. */ +static char conn_is_non_blocking__doc__[] = + "is_non_blocking() -- report the blocking status of the connection"; + +static PyObject * +conn_is_non_blocking(connObject *self, PyObject *noargs) +{ + int rc; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + rc = PQisnonblocking(self->cnx); + if (rc < 0) { + PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); + return NULL; + } + + return PyBool_FromLong((long)rc); +} + +/* Insert table */ +static char conn_inserttable__doc__[] = + "inserttable(table, data, [columns]) -- insert iterable into table\n\n" + "The fields in the iterable must be in the same order as in the table\n" + "or in the list or tuple of columns if one is specified.\n"; + +static PyObject * +conn_inserttable(connObject *self, PyObject *args) +{ + PGresult *result; + char *table, *buffer, *bufpt, *bufmax, *s, *t; + int encoding, ret; + size_t bufsiz; + PyObject *rows, *iter_row, *item, *columns = NULL; + Py_ssize_t i, j, m, n; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* gets arguments */ + if (!PyArg_ParseTuple(args, "sO|O", &table, &rows, &columns)) { + PyErr_SetString( + PyExc_TypeError, + "Method inserttable() expects a string and a list as arguments"); + return NULL; + } + + /* checks list type */ + if (!(iter_row = PyObject_GetIter(rows))) { + PyErr_SetString( + PyExc_TypeError, + "Method inserttable() expects an iterable as second argument"); + return NULL; + } + m = PySequence_Check(rows) ? PySequence_Size(rows) : -1; + if (!m) { + /* no rows specified, nothing to do */ + Py_DECREF(iter_row); + Py_INCREF(Py_None); + return Py_None; + } + + /* checks columns type */ + if (columns) { + if (!(PyTuple_Check(columns) || PyList_Check(columns))) { + PyErr_SetString(PyExc_TypeError, + "Method inserttable() expects a tuple or a list" + " as third argument"); + return NULL; + } + + n = PySequence_Fast_GET_SIZE(columns); + if (!n) { + /* no columns specified, nothing to do */ + Py_DECREF(iter_row); + Py_INCREF(Py_None); + return Py_None; + } + } + else { + n = -1; /* number of columns not yet known */ + } + + /* allocate buffer */ + if (!(buffer = PyMem_Malloc(MAX_BUFFER_SIZE))) { + Py_DECREF(iter_row); + return PyErr_NoMemory(); + } + + encoding = PQclientEncoding(self->cnx); + + /* starts query */ + bufpt = buffer; + bufmax = bufpt + MAX_BUFFER_SIZE; + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "copy "); + + s = table; + do { + t = strchr(s, '.'); + if (!t) + t = s + strlen(s); + table = PQescapeIdentifier(self->cnx, s, (size_t)(t - s)); + if (bufpt < bufmax) + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "%s", table); + PQfreemem(table); + s = t; + if (*s && bufpt < bufmax) + *bufpt++ = *s++; + } while (*s); + + if (columns) { + /* adds a string like f" ({','.join(columns)})" */ + if (bufpt < bufmax) + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), " ("); + for (j = 0; j < n; ++j) { + PyObject *obj = PySequence_Fast_GET_ITEM(columns, j); + Py_ssize_t slen; + char *col; + + if (PyBytes_Check(obj)) { + Py_INCREF(obj); + } + else if (PyUnicode_Check(obj)) { + obj = get_encoded_string(obj, encoding); + if (!obj) { + PyMem_Free(buffer); + Py_DECREF(iter_row); + return NULL; /* pass the UnicodeEncodeError */ + } + } + else { + PyErr_SetString( + PyExc_TypeError, + "The third argument must contain only strings"); + PyMem_Free(buffer); + Py_DECREF(iter_row); + return NULL; + } + PyBytes_AsStringAndSize(obj, &col, &slen); + col = PQescapeIdentifier(self->cnx, col, (size_t)slen); + Py_DECREF(obj); + if (bufpt < bufmax) + bufpt += snprintf(bufpt, (size_t)(bufmax - bufpt), "%s%s", col, + j == n - 1 ? ")" : ","); + PQfreemem(col); + } + } + if (bufpt < bufmax) + snprintf(bufpt, (size_t)(bufmax - bufpt), " from stdin"); + if (bufpt >= bufmax) { + PyMem_Free(buffer); + Py_DECREF(iter_row); + return PyErr_NoMemory(); + } + + Py_BEGIN_ALLOW_THREADS + result = PQexec(self->cnx, buffer); + Py_END_ALLOW_THREADS + + if (!result || PQresultStatus(result) != PGRES_COPY_IN) { + PyMem_Free(buffer); + Py_DECREF(iter_row); + PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); + return NULL; + } + + PQclear(result); + + /* feed table */ + for (i = 0; m < 0 || i < m; ++i) { + if (!(columns = PyIter_Next(iter_row))) + break; + + if (!(PyTuple_Check(columns) || PyList_Check(columns))) { + PQputCopyEnd(self->cnx, "Invalid arguments"); + PyMem_Free(buffer); + Py_DECREF(columns); + Py_DECREF(columns); + Py_DECREF(iter_row); + PyErr_SetString( + PyExc_TypeError, + "The second argument must contain tuples or lists"); + return NULL; + } + + j = PySequence_Fast_GET_SIZE(columns); + if (n < 0) { + n = j; + } + else if (j != n) { + PQputCopyEnd(self->cnx, "Invalid arguments"); + PyMem_Free(buffer); + Py_DECREF(columns); + Py_DECREF(iter_row); + PyErr_SetString( + PyExc_TypeError, + "The second arg must contain sequences of the same size"); + return NULL; + } + + /* builds insert line */ + bufpt = buffer; + bufsiz = MAX_BUFFER_SIZE - 1; + + for (j = 0; j < n; ++j) { + if (j) { + *bufpt++ = '\t'; + --bufsiz; + } + + item = PySequence_Fast_GET_ITEM(columns, j); + + /* convert item to string and append to buffer */ + if (item == Py_None) { + if (bufsiz > 2) { + *bufpt++ = '\\'; + *bufpt++ = 'N'; + bufsiz -= 2; + } + else + bufsiz = 0; + } + else if (PyBytes_Check(item)) { + const char *t = PyBytes_AsString(item); + + while (*t && bufsiz) { + switch (*t) { + case '\\': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = '\\'; + break; + case '\t': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = 't'; + break; + case '\r': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = 'r'; + break; + case '\n': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = 'n'; + break; + default: + *bufpt++ = *t; + } + ++t; + --bufsiz; + } + } + else if (PyUnicode_Check(item)) { + PyObject *s = get_encoded_string(item, encoding); + if (!s) { + PQputCopyEnd(self->cnx, "Encoding error"); + PyMem_Free(buffer); + Py_DECREF(item); + Py_DECREF(columns); + Py_DECREF(iter_row); + return NULL; /* pass the UnicodeEncodeError */ + } + else { + const char *t = PyBytes_AsString(s); + + while (*t && bufsiz) { + switch (*t) { + case '\\': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = '\\'; + break; + case '\t': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = 't'; + break; + case '\r': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = 'r'; + break; + case '\n': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = 'n'; + break; + default: + *bufpt++ = *t; + } + ++t; + --bufsiz; + } + Py_DECREF(s); + } + } + else if (PyLong_Check(item)) { + PyObject *s = PyObject_Str(item); + const char *t = PyUnicode_AsUTF8(s); + + while (*t && bufsiz) { + *bufpt++ = *t++; + --bufsiz; + } + Py_DECREF(s); + } + else { + PyObject *s = PyObject_Repr(item); + const char *t = PyUnicode_AsUTF8(s); + + while (*t && bufsiz) { + switch (*t) { + case '\\': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = '\\'; + break; + case '\t': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = 't'; + break; + case '\r': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = 'r'; + break; + case '\n': + *bufpt++ = '\\'; + if (--bufsiz) + *bufpt++ = 'n'; + break; + default: + *bufpt++ = *t; + } + ++t; + --bufsiz; + } + Py_DECREF(s); + } + + if (bufsiz <= 0) { + PQputCopyEnd(self->cnx, "Memory error"); + PyMem_Free(buffer); + Py_DECREF(columns); + Py_DECREF(iter_row); + return PyErr_NoMemory(); + } + } + + Py_DECREF(columns); + + *bufpt++ = '\n'; + + /* sends data */ + ret = PQputCopyData(self->cnx, buffer, (int)(bufpt - buffer)); + if (ret != 1) { + char *errormsg = ret == -1 ? PQerrorMessage(self->cnx) + : "Data cannot be queued"; + PyErr_SetString(PyExc_IOError, errormsg); + PQputCopyEnd(self->cnx, errormsg); + PyMem_Free(buffer); + Py_DECREF(iter_row); + return NULL; + } + } + + Py_DECREF(iter_row); + if (PyErr_Occurred()) { + PyMem_Free(buffer); + return NULL; /* pass the iteration error */ + } + + ret = PQputCopyEnd(self->cnx, NULL); + if (ret != 1) { + PyErr_SetString(PyExc_IOError, ret == -1 ? PQerrorMessage(self->cnx) + : "Data cannot be queued"); + PyMem_Free(buffer); + return NULL; + } + + PyMem_Free(buffer); + + Py_BEGIN_ALLOW_THREADS + result = PQgetResult(self->cnx); + Py_END_ALLOW_THREADS + if (PQresultStatus(result) != PGRES_COMMAND_OK) { + PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); + PQclear(result); + return NULL; + } + else { + long ntuples = atol(PQcmdTuples(result)); + PQclear(result); + return PyLong_FromLong(ntuples); + } +} + +/* Get transaction state. */ +static char conn_transaction__doc__[] = + "transaction() -- return the current transaction status"; + +static PyObject * +conn_transaction(connObject *self, PyObject *noargs) +{ + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + return PyLong_FromLong(PQtransactionStatus(self->cnx)); +} + +/* Get parameter setting. */ +static char conn_parameter__doc__[] = + "parameter(name) -- look up a current parameter setting"; + +static PyObject * +conn_parameter(connObject *self, PyObject *args) +{ + const char *name; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* get query args */ + if (!PyArg_ParseTuple(args, "s", &name)) { + PyErr_SetString(PyExc_TypeError, + "Method parameter() takes a string as argument"); + return NULL; + } + + name = PQparameterStatus(self->cnx, name); + + if (name) + return PyUnicode_FromString(name); + + /* unknown parameter, return None */ + Py_INCREF(Py_None); + return Py_None; +} + +/* Get current date format. */ +static char conn_date_format__doc__[] = + "date_format() -- return the current date format"; + +static PyObject * +conn_date_format(connObject *self, PyObject *noargs) +{ + const char *fmt; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* check if the date format is cached in the connection */ + fmt = self->date_format; + if (!fmt) { + fmt = date_style_to_format(PQparameterStatus(self->cnx, "DateStyle")); + self->date_format = fmt; /* cache the result */ + } + + return PyUnicode_FromString(fmt); +} + +/* Escape literal */ +static char conn_escape_literal__doc__[] = + "escape_literal(str) -- escape a literal constant for use within SQL"; + +static PyObject * +conn_escape_literal(connObject *self, PyObject *string) +{ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ + + if (PyBytes_Check(string)) { + PyBytes_AsStringAndSize(string, &from, &from_length); + } + else if (PyUnicode_Check(string)) { + encoding = PQclientEncoding(self->cnx); + tmp_obj = get_encoded_string(string, encoding); + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ + PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); + } + else { + PyErr_SetString( + PyExc_TypeError, + "Method escape_literal() expects a string as argument"); + return NULL; + } + + to = PQescapeLiteral(self->cnx, from, (size_t)from_length); + to_length = strlen(to); + + Py_XDECREF(tmp_obj); + + if (encoding == -1) + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); + else + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); + if (to) + PQfreemem(to); + return to_obj; +} + +/* Escape identifier */ +static char conn_escape_identifier__doc__[] = + "escape_identifier(str) -- escape an identifier for use within SQL"; + +static PyObject * +conn_escape_identifier(connObject *self, PyObject *string) +{ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ + + if (PyBytes_Check(string)) { + PyBytes_AsStringAndSize(string, &from, &from_length); + } + else if (PyUnicode_Check(string)) { + encoding = PQclientEncoding(self->cnx); + tmp_obj = get_encoded_string(string, encoding); + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ + PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); + } + else { + PyErr_SetString( + PyExc_TypeError, + "Method escape_identifier() expects a string as argument"); + return NULL; + } + + to = PQescapeIdentifier(self->cnx, from, (size_t)from_length); + to_length = strlen(to); + + Py_XDECREF(tmp_obj); + + if (encoding == -1) + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); + else + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); + if (to) + PQfreemem(to); + return to_obj; +} + +/* Escape string */ +static char conn_escape_string__doc__[] = + "escape_string(str) -- escape a string for use within SQL"; + +static PyObject * +conn_escape_string(connObject *self, PyObject *string) +{ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ + + if (PyBytes_Check(string)) { + PyBytes_AsStringAndSize(string, &from, &from_length); + } + else if (PyUnicode_Check(string)) { + encoding = PQclientEncoding(self->cnx); + tmp_obj = get_encoded_string(string, encoding); + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ + PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); + } + else { + PyErr_SetString(PyExc_TypeError, + "Method escape_string() expects a string as argument"); + return NULL; + } + + to_length = 2 * (size_t)from_length + 1; + if ((Py_ssize_t)to_length < from_length) { /* overflow */ + to_length = (size_t)from_length; + from_length = (from_length - 1) / 2; + } + to = (char *)PyMem_Malloc(to_length); + to_length = + PQescapeStringConn(self->cnx, to, from, (size_t)from_length, NULL); + + Py_XDECREF(tmp_obj); + + if (encoding == -1) + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); + else + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); + PyMem_Free(to); + return to_obj; +} + +/* Escape bytea */ +static char conn_escape_bytea__doc__[] = + "escape_bytea(data) -- escape binary data for use within SQL as type " + "bytea"; + +static PyObject * +conn_escape_bytea(connObject *self, PyObject *data) +{ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ + + if (PyBytes_Check(data)) { + PyBytes_AsStringAndSize(data, &from, &from_length); + } + else if (PyUnicode_Check(data)) { + encoding = PQclientEncoding(self->cnx); + tmp_obj = get_encoded_string(data, encoding); + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ + PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); + } + else { + PyErr_SetString(PyExc_TypeError, + "Method escape_bytea() expects a string as argument"); + return NULL; + } + + to = (char *)PQescapeByteaConn(self->cnx, (unsigned char *)from, + (size_t)from_length, &to_length); + + Py_XDECREF(tmp_obj); + + if (encoding == -1) + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length - 1); + else + to_obj = get_decoded_string(to, (Py_ssize_t)to_length - 1, encoding); + if (to) + PQfreemem(to); + return to_obj; +} + +/* Constructor for large objects (internal use only) */ +static largeObject * +large_new(connObject *pgcnx, Oid oid) +{ + largeObject *large_obj; + + if (!(large_obj = PyObject_New(largeObject, &largeType))) { + return NULL; + } + + Py_XINCREF(pgcnx); + large_obj->pgcnx = pgcnx; + large_obj->lo_fd = -1; + large_obj->lo_oid = oid; + + return large_obj; +} + +/* Create large object. */ +static char conn_locreate__doc__[] = + "locreate(mode) -- create a new large object in the database"; + +static PyObject * +conn_locreate(connObject *self, PyObject *args) +{ + int mode; + Oid lo_oid; + + /* checks validity */ + if (!_check_cnx_obj(self)) { + return NULL; + } + + /* gets arguments */ + if (!PyArg_ParseTuple(args, "i", &mode)) { + PyErr_SetString(PyExc_TypeError, + "Method locreate() takes an integer argument"); + return NULL; + } + + /* creates large object */ + lo_oid = lo_creat(self->cnx, mode); + if (lo_oid == 0) { + set_error_msg(OperationalError, "Can't create large object"); + return NULL; + } + + return (PyObject *)large_new(self, lo_oid); +} + +/* Init from already known oid. */ +static char conn_getlo__doc__[] = + "getlo(oid) -- create a large object instance for the specified oid"; + +static PyObject * +conn_getlo(connObject *self, PyObject *args) +{ + int oid; + Oid lo_oid; + + /* checks validity */ + if (!_check_cnx_obj(self)) { + return NULL; + } + + /* gets arguments */ + if (!PyArg_ParseTuple(args, "i", &oid)) { + PyErr_SetString(PyExc_TypeError, + "Method getlo() takes an integer argument"); + return NULL; + } + + lo_oid = (Oid)oid; + if (lo_oid == 0) { + PyErr_SetString(PyExc_ValueError, "The object oid can't be null"); + return NULL; + } + + /* creates object */ + return (PyObject *)large_new(self, lo_oid); +} + +/* Import unix file. */ +static char conn_loimport__doc__[] = + "loimport(name) -- create a new large object from specified file"; + +static PyObject * +conn_loimport(connObject *self, PyObject *args) +{ + char *name; + Oid lo_oid; + + /* checks validity */ + if (!_check_cnx_obj(self)) { + return NULL; + } + + /* gets arguments */ + if (!PyArg_ParseTuple(args, "s", &name)) { + PyErr_SetString(PyExc_TypeError, + "Method loimport() takes a string argument"); + return NULL; + } + + /* imports file and checks result */ + lo_oid = lo_import(self->cnx, name); + if (lo_oid == 0) { + set_error_msg(OperationalError, "Can't create large object"); + return NULL; + } + + return (PyObject *)large_new(self, lo_oid); +} + +/* Reset connection. */ +static char conn_reset__doc__[] = + "reset() -- reset connection with current parameters\n\n" + "All derived queries and large objects derived from this connection\n" + "will not be usable after this call.\n"; + +static PyObject * +conn_reset(connObject *self, PyObject *noargs) +{ + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* resets the connection */ + PQreset(self->cnx); + Py_INCREF(Py_None); + return Py_None; +} + +/* Cancel current command. */ +static char conn_cancel__doc__[] = + "cancel() -- abandon processing of the current command"; + +static PyObject * +conn_cancel(connObject *self, PyObject *noargs) +{ + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* request that the server abandon processing of the current command */ + return PyLong_FromLong((long)PQrequestCancel(self->cnx)); +} + +/* Get connection socket. */ +static char conn_fileno__doc__[] = + "fileno() -- return database connection socket file handle"; + +static PyObject * +conn_fileno(connObject *self, PyObject *noargs) +{ + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + return PyLong_FromLong((long)PQsocket(self->cnx)); +} + +/* Set external typecast callback function. */ +static char conn_set_cast_hook__doc__[] = + "set_cast_hook(func) -- set a fallback typecast function"; + +static PyObject * +conn_set_cast_hook(connObject *self, PyObject *func) +{ + PyObject *ret = NULL; + + if (func == Py_None) { + Py_XDECREF(self->cast_hook); + self->cast_hook = NULL; + Py_INCREF(Py_None); + ret = Py_None; + } + else if (PyCallable_Check(func)) { + Py_XINCREF(func); + Py_XDECREF(self->cast_hook); + self->cast_hook = func; + Py_INCREF(Py_None); + ret = Py_None; + } + else { + PyErr_SetString(PyExc_TypeError, + "Method set_cast_hook() expects" + " a callable or None as argument"); + } + + return ret; +} + +/* Get notice receiver callback function. */ +static char conn_get_cast_hook__doc__[] = + "get_cast_hook() -- get the fallback typecast function"; + +static PyObject * +conn_get_cast_hook(connObject *self, PyObject *noargs) +{ + PyObject *ret = self->cast_hook; + ; + + if (!ret) + ret = Py_None; + Py_INCREF(ret); + + return ret; +} + +/* Get asynchronous connection state. */ +static char conn_poll__doc__[] = + "poll() -- Completes an asynchronous connection"; + +static PyObject * +conn_poll(connObject *self, PyObject *noargs) +{ + int rc; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + rc = PQconnectPoll(self->cnx); + Py_END_ALLOW_THREADS + + if (rc == PGRES_POLLING_FAILED) { + set_error(InternalError, "Polling failed", self->cnx, NULL); + return NULL; + } + + return PyLong_FromLong(rc); +} + +/* Set notice receiver callback function. */ +static char conn_set_notice_receiver__doc__[] = + "set_notice_receiver(func) -- set the current notice receiver"; + +static PyObject * +conn_set_notice_receiver(connObject *self, PyObject *func) +{ + PyObject *ret = NULL; + + if (func == Py_None) { + Py_XDECREF(self->notice_receiver); + self->notice_receiver = NULL; + Py_INCREF(Py_None); + ret = Py_None; + } + else if (PyCallable_Check(func)) { + Py_XINCREF(func); + Py_XDECREF(self->notice_receiver); + self->notice_receiver = func; + PQsetNoticeReceiver(self->cnx, notice_receiver, self); + Py_INCREF(Py_None); + ret = Py_None; + } + else { + PyErr_SetString(PyExc_TypeError, + "Method set_notice_receiver() expects" + " a callable or None as argument"); + } + + return ret; +} + +/* Get notice receiver callback function. */ +static char conn_get_notice_receiver__doc__[] = + "get_notice_receiver() -- get the current notice receiver"; + +static PyObject * +conn_get_notice_receiver(connObject *self, PyObject *noargs) +{ + PyObject *ret = self->notice_receiver; + + if (!ret) + ret = Py_None; + Py_INCREF(ret); + + return ret; +} + +/* Close without deleting. */ +static char conn_close__doc__[] = + "close() -- close connection\n\n" + "All instances of the connection object and derived objects\n" + "(queries and large objects) can no longer be used after this call.\n"; + +static PyObject * +conn_close(connObject *self, PyObject *noargs) +{ + /* connection object cannot already be closed */ + if (!self->cnx) { + set_error_msg(InternalError, "Connection already closed"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + PQfinish(self->cnx); + Py_END_ALLOW_THREADS + + self->cnx = NULL; + Py_INCREF(Py_None); + return Py_None; +} + +/* Get asynchronous notify. */ +static char conn_get_notify__doc__[] = + "getnotify() -- get database notify for this connection"; + +static PyObject * +conn_get_notify(connObject *self, PyObject *noargs) +{ + PGnotify *notify; + + if (!self->cnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + /* checks for NOTIFY messages */ + PQconsumeInput(self->cnx); + + if (!(notify = PQnotifies(self->cnx))) { + Py_INCREF(Py_None); + return Py_None; + } + else { + PyObject *notify_result, *tmp; + + if (!(tmp = PyUnicode_FromString(notify->relname))) { + return NULL; + } + + if (!(notify_result = PyTuple_New(3))) { + return NULL; + } + + PyTuple_SET_ITEM(notify_result, 0, tmp); + + if (!(tmp = PyLong_FromLong(notify->be_pid))) { + Py_DECREF(notify_result); + return NULL; + } + + PyTuple_SET_ITEM(notify_result, 1, tmp); + + /* extra exists even in old versions that did not support it */ + if (!(tmp = PyUnicode_FromString(notify->extra))) { + Py_DECREF(notify_result); + return NULL; + } + + PyTuple_SET_ITEM(notify_result, 2, tmp); + + PQfreemem(notify); + + return notify_result; + } +} + +/* Get the list of connection attributes. */ +static PyObject * +conn_dir(connObject *self, PyObject *noargs) +{ + PyObject *attrs; + + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[sssssssssssss]", "host", "port", + "db", "options", "error", "status", "user", + "protocol_version", "server_version", "socket", + "backend_pid", "ssl_in_use", "ssl_attributes"); + + return attrs; +} + +/* Connection object methods */ +static struct PyMethodDef conn_methods[] = { + {"__dir__", (PyCFunction)conn_dir, METH_NOARGS, NULL}, + + {"source", (PyCFunction)conn_source, METH_NOARGS, conn_source__doc__}, + {"query", (PyCFunction)conn_query, METH_VARARGS, conn_query__doc__}, + {"send_query", (PyCFunction)conn_send_query, METH_VARARGS, + conn_send_query__doc__}, + {"query_prepared", (PyCFunction)conn_query_prepared, METH_VARARGS, + conn_query_prepared__doc__}, + {"prepare", (PyCFunction)conn_prepare, METH_VARARGS, conn_prepare__doc__}, + {"describe_prepared", (PyCFunction)conn_describe_prepared, METH_VARARGS, + conn_describe_prepared__doc__}, + {"poll", (PyCFunction)conn_poll, METH_NOARGS, conn_poll__doc__}, + {"reset", (PyCFunction)conn_reset, METH_NOARGS, conn_reset__doc__}, + {"cancel", (PyCFunction)conn_cancel, METH_NOARGS, conn_cancel__doc__}, + {"close", (PyCFunction)conn_close, METH_NOARGS, conn_close__doc__}, + {"fileno", (PyCFunction)conn_fileno, METH_NOARGS, conn_fileno__doc__}, + {"get_cast_hook", (PyCFunction)conn_get_cast_hook, METH_NOARGS, + conn_get_cast_hook__doc__}, + {"set_cast_hook", (PyCFunction)conn_set_cast_hook, METH_O, + conn_set_cast_hook__doc__}, + {"get_notice_receiver", (PyCFunction)conn_get_notice_receiver, METH_NOARGS, + conn_get_notice_receiver__doc__}, + {"set_notice_receiver", (PyCFunction)conn_set_notice_receiver, METH_O, + conn_set_notice_receiver__doc__}, + {"getnotify", (PyCFunction)conn_get_notify, METH_NOARGS, + conn_get_notify__doc__}, + {"inserttable", (PyCFunction)conn_inserttable, METH_VARARGS, + conn_inserttable__doc__}, + {"transaction", (PyCFunction)conn_transaction, METH_NOARGS, + conn_transaction__doc__}, + {"parameter", (PyCFunction)conn_parameter, METH_VARARGS, + conn_parameter__doc__}, + {"date_format", (PyCFunction)conn_date_format, METH_NOARGS, + conn_date_format__doc__}, + + {"escape_literal", (PyCFunction)conn_escape_literal, METH_O, + conn_escape_literal__doc__}, + {"escape_identifier", (PyCFunction)conn_escape_identifier, METH_O, + conn_escape_identifier__doc__}, + {"escape_string", (PyCFunction)conn_escape_string, METH_O, + conn_escape_string__doc__}, + {"escape_bytea", (PyCFunction)conn_escape_bytea, METH_O, + conn_escape_bytea__doc__}, + + {"putline", (PyCFunction)conn_putline, METH_VARARGS, conn_putline__doc__}, + {"getline", (PyCFunction)conn_getline, METH_NOARGS, conn_getline__doc__}, + {"endcopy", (PyCFunction)conn_endcopy, METH_NOARGS, conn_endcopy__doc__}, + {"set_non_blocking", (PyCFunction)conn_set_non_blocking, METH_VARARGS, + conn_set_non_blocking__doc__}, + {"is_non_blocking", (PyCFunction)conn_is_non_blocking, METH_NOARGS, + conn_is_non_blocking__doc__}, + + {"locreate", (PyCFunction)conn_locreate, METH_VARARGS, + conn_locreate__doc__}, + {"getlo", (PyCFunction)conn_getlo, METH_VARARGS, conn_getlo__doc__}, + {"loimport", (PyCFunction)conn_loimport, METH_VARARGS, + conn_loimport__doc__}, + + {NULL, NULL} /* sentinel */ +}; + +static char conn__doc__[] = "PostgreSQL connection object"; + +/* Connection type definition */ +static PyTypeObject connType = { + PyVarObject_HEAD_INIT(NULL, 0) "pg.Connection", /* tp_name */ + sizeof(connObject), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)conn_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + (getattrofunc)conn_getattr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + conn__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + conn_methods, /* tp_methods */ +}; diff --git a/pginternal.c b/ext/pginternal.c similarity index 70% rename from pginternal.c rename to ext/pginternal.c index e1d36692..25290950 100644 --- a/pginternal.c +++ b/ext/pginternal.c @@ -3,7 +3,7 @@ * * Internal functions - this file is part a of the C extension module. * - * Copyright (c) 2020 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ @@ -37,8 +37,8 @@ get_decoded_string(const char *str, Py_ssize_t size, int encoding) if (encoding == pg_encoding_ascii) return PyUnicode_DecodeASCII(str, size, "strict"); /* encoding name should be properly translated to Python here */ - return PyUnicode_Decode(str, size, - pg_encoding_to_char(encoding), "strict"); + return PyUnicode_Decode(str, size, pg_encoding_to_char(encoding), + "strict"); } static PyObject * @@ -52,7 +52,7 @@ get_encoded_string(PyObject *unicode_obj, int encoding) return PyUnicode_AsASCIIString(unicode_obj); /* encoding name should be properly translated to Python here */ return PyUnicode_AsEncodedString(unicode_obj, - pg_encoding_to_char(encoding), "strict"); + pg_encoding_to_char(encoding), "strict"); } /* Helper functions */ @@ -64,7 +64,7 @@ get_type(Oid pgtype) int t; switch (pgtype) { - /* simple types */ + /* simple types */ case INT2OID: case INT4OID: @@ -113,7 +113,7 @@ get_type(Oid pgtype) t = PYGRES_TEXT; break; - /* array types */ + /* array types */ case INT2ARRAYOID: case INT4ARRAYOID: @@ -137,8 +137,9 @@ get_type(Oid pgtype) break; case MONEYARRAYOID: - t = array_as_text ? PYGRES_TEXT : ((decimal_point ? - PYGRES_MONEY : PYGRES_TEXT) | PYGRES_ARRAY); + t = array_as_text ? PYGRES_TEXT + : ((decimal_point ? PYGRES_MONEY : PYGRES_TEXT) | + PYGRES_ARRAY); break; case BOOLARRAYOID: @@ -146,14 +147,16 @@ get_type(Oid pgtype) break; case BYTEAARRAYOID: - t = array_as_text ? PYGRES_TEXT : ((bytea_escaped ? - PYGRES_TEXT : PYGRES_BYTEA) | PYGRES_ARRAY); + t = array_as_text ? PYGRES_TEXT + : ((bytea_escaped ? PYGRES_TEXT : PYGRES_BYTEA) | + PYGRES_ARRAY); break; case JSONARRAYOID: case JSONBARRAYOID: - t = array_as_text ? PYGRES_TEXT : ((jsondecode ? - PYGRES_JSON : PYGRES_TEXT) | PYGRES_ARRAY); + t = array_as_text ? PYGRES_TEXT + : ((jsondecode ? PYGRES_JSON : PYGRES_TEXT) | + PYGRES_ARRAY); break; case BPCHARARRAYOID: @@ -178,8 +181,8 @@ get_col_types(PGresult *result, int nfields) { int *types, *t, j; - if (!(types = PyMem_Malloc(sizeof(int) * (size_t) nfields))) { - return (int*) PyErr_NoMemory(); + if (!(types = PyMem_Malloc(sizeof(int) * (size_t)nfields))) { + return (int *)PyErr_NoMemory(); } for (j = 0, t = types; j < nfields; ++j) { @@ -199,8 +202,8 @@ cast_bytea_text(char *s) size_t str_len; /* this function should not be called when bytea_escaped is set */ - tmp_str = (char *) PQunescapeBytea((unsigned char*) s, &str_len); - obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t) str_len); + tmp_str = (char *)PQunescapeBytea((unsigned char *)s, &str_len); + obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t)str_len); if (tmp_str) { PQfreemem(tmp_str); } @@ -221,16 +224,18 @@ cast_sized_text(char *s, Py_ssize_t size, int encoding, int type) case PYGRES_BYTEA: /* this type should not be passed when bytea_escaped is set */ /* we need to add a null byte */ - tmp_str = (char *) PyMem_Malloc((size_t) size + 1); + tmp_str = (char *)PyMem_Malloc((size_t)size + 1); if (!tmp_str) { return PyErr_NoMemory(); } - memcpy(tmp_str, s, (size_t) size); - s = tmp_str; *(s + size) = '\0'; - tmp_str = (char *) PQunescapeBytea((unsigned char*) s, &str_len); + memcpy(tmp_str, s, (size_t)size); + s = tmp_str; + *(s + size) = '\0'; + tmp_str = (char *)PQunescapeBytea((unsigned char *)s, &str_len); PyMem_Free(s); - if (!tmp_str) return PyErr_NoMemory(); - obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t) str_len); + if (!tmp_str) + return PyErr_NoMemory(); + obj = PyBytes_FromStringAndSize(tmp_str, (Py_ssize_t)str_len); if (tmp_str) { PQfreemem(tmp_str); } @@ -240,18 +245,17 @@ cast_sized_text(char *s, Py_ssize_t size, int encoding, int type) /* this type should only be passed when jsondecode is set */ obj = get_decoded_string(s, size, encoding); if (obj && jsondecode) { /* was able to decode */ - tmp_obj = Py_BuildValue("(O)", obj); - obj = PyObject_CallObject(jsondecode, tmp_obj); + tmp_obj = obj; + obj = PyObject_CallFunction(jsondecode, "(O)", obj); Py_DECREF(tmp_obj); } break; - default: /* PYGRES_TEXT */ -#if IS_PY3 + default: /* PYGRES_TEXT */ obj = get_decoded_string(s, size, encoding); - if (!obj) /* cannot decode */ -#endif - obj = PyBytes_FromStringAndSize(s, size); + if (!obj) { /* cannot decode */ + obj = PyBytes_FromStringAndSize(s, size); + } } return obj; @@ -289,20 +293,20 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) case PYGRES_INT: n = sizeof(buf) / sizeof(buf[0]) - 1; - if ((int) size < n) { - n = (int) size; + if ((int)size < n) { + n = (int)size; } for (i = 0, t = buf; i < n; ++i) { *t++ = *s++; } *t = '\0'; - obj = PyInt_FromString(buf, NULL, 10); + obj = PyLong_FromString(buf, NULL, 10); break; case PYGRES_LONG: n = sizeof(buf) / sizeof(buf[0]) - 1; - if ((int) size < n) { - n = (int) size; + if ((int)size < n) { + n = (int)size; } for (i = 0, t = buf; i < n; ++i) { *t++ = *s++; @@ -312,7 +316,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) break; case PYGRES_FLOAT: - tmp_obj = PyStr_FromStringAndSize(s, size); + tmp_obj = PyUnicode_FromStringAndSize(s, size); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); break; @@ -336,24 +340,24 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) obj = PyObject_CallFunction(decimal, "(s)", buf); } else { - tmp_obj = PyStr_FromString(buf); + tmp_obj = PyUnicode_FromString(buf); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); - } break; case PYGRES_DECIMAL: - tmp_obj = PyStr_FromStringAndSize(s, size); - obj = decimal ? PyObject_CallFunctionObjArgs( - decimal, tmp_obj, NULL) : PyFloat_FromString(tmp_obj); + tmp_obj = PyUnicode_FromStringAndSize(s, size); + obj = decimal + ? PyObject_CallFunctionObjArgs(decimal, tmp_obj, NULL) + : PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); break; case PYGRES_BOOL: /* convert to bool only if bool_as_text is not set */ if (bool_as_text) { - obj = PyStr_FromString(*s == 't' ? "t" : "f"); + obj = PyUnicode_FromString(*s == 't' ? "t" : "f"); } else { obj = *s == 't' ? Py_True : Py_False; @@ -363,7 +367,7 @@ cast_sized_simple(char *s, Py_ssize_t size, int type) default: /* other types should never be passed, use cast_sized_text */ - obj = PyStr_FromStringAndSize(s, size); + obj = PyUnicode_FromStringAndSize(s, size); } return obj; @@ -381,15 +385,12 @@ cast_unsized_simple(char *s, int type) switch (type) { /* this must be the PyGreSQL internal type */ case PYGRES_INT: - obj = PyInt_FromString(s, NULL, 10); - break; - case PYGRES_LONG: obj = PyLong_FromString(s, NULL, 10); break; case PYGRES_FLOAT: - tmp_obj = PyStr_FromString(s); + tmp_obj = PyUnicode_FromString(s); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); break; @@ -408,7 +409,8 @@ cast_unsized_simple(char *s, int type) buf[j++] = '-'; } } - buf[j] = '\0'; s = buf; + buf[j] = '\0'; + s = buf; /* FALLTHROUGH */ /* no break here */ case PYGRES_DECIMAL: @@ -416,7 +418,7 @@ cast_unsized_simple(char *s, int type) obj = PyObject_CallFunction(decimal, "(s)", s); } else { - tmp_obj = PyStr_FromString(s); + tmp_obj = PyUnicode_FromString(s); obj = PyFloat_FromString(tmp_obj); Py_DECREF(tmp_obj); } @@ -425,7 +427,7 @@ cast_unsized_simple(char *s, int type) case PYGRES_BOOL: /* convert to bool only if bool_as_text is not set */ if (bool_as_text) { - obj = PyStr_FromString(*s == 't' ? "t" : "f"); + obj = PyUnicode_FromString(*s == 't' ? "t" : "f"); } else { obj = *s == 't' ? Py_True : Py_False; @@ -435,18 +437,17 @@ cast_unsized_simple(char *s, int type) default: /* other types should never be passed, use cast_sized_text */ - obj = PyStr_FromString(s); + obj = PyUnicode_FromString(s); } return obj; } /* Quick case insensitive check if given sized string is null. */ -#define STR_IS_NULL(s, n) (n == 4 && \ - (s[0] == 'n' || s[0] == 'N') && \ - (s[1] == 'u' || s[1] == 'U') && \ - (s[2] == 'l' || s[2] == 'L') && \ - (s[3] == 'l' || s[3] == 'L')) +#define STR_IS_NULL(s, n) \ + (n == 4 && (s[0] == 'n' || s[0] == 'N') && \ + (s[1] == 'u' || s[1] == 'U') && (s[2] == 'l' || s[2] == 'L') && \ + (s[3] == 'l' || s[3] == 'L')) /* Cast string s with size and encoding to a Python list, using the input and output syntax for arrays. @@ -454,8 +455,8 @@ cast_unsized_simple(char *s, int type) The parameter delim specifies the delimiter for the elements, since some types do not use the default delimiter of a comma. */ static PyObject * -cast_array(char *s, Py_ssize_t size, int encoding, - int type, PyObject *cast, char delim) +cast_array(char *s, Py_ssize_t size, int encoding, int type, PyObject *cast, + char delim) { PyObject *result, *stack[MAX_ARRAY_DEPTH]; char *end = s + size, *t; @@ -463,12 +464,13 @@ cast_array(char *s, Py_ssize_t size, int encoding, if (type) { type &= ~PYGRES_ARRAY; /* get the base type */ - if (!type) type = PYGRES_TEXT; + if (!type) + type = PYGRES_TEXT; } if (!delim) { delim = ','; } - else if (delim == '{' || delim =='}' || delim=='\\') { + else if (delim == '{' || delim == '}' || delim == '\\') { PyErr_SetString(PyExc_ValueError, "Invalid array delimiter"); return NULL; } @@ -479,20 +481,28 @@ cast_array(char *s, Py_ssize_t size, int encoding, int valid; for (valid = 0; !valid;) { - if (s == end || *s++ != '[') break; + if (s == end || *s++ != '[') + break; while (s != end && *s == ' ') ++s; - if (s != end && (*s == '+' || *s == '-')) ++s; - if (s == end || *s < '0' || *s > '9') break; + if (s != end && (*s == '+' || *s == '-')) + ++s; + if (s == end || *s < '0' || *s > '9') + break; while (s != end && *s >= '0' && *s <= '9') ++s; - if (s == end || *s++ != ':') break; - if (s != end && (*s == '+' || *s == '-')) ++s; - if (s == end || *s < '0' || *s > '9') break; + if (s == end || *s++ != ':') + break; + if (s != end && (*s == '+' || *s == '-')) + ++s; + if (s == end || *s < '0' || *s > '9') + break; while (s != end && *s >= '0' && *s <= '9') ++s; - if (s == end || *s++ != ']') break; + if (s == end || *s++ != ']') + break; while (s != end && *s == ' ') ++s; ++ranges; if (s != end && *s == '=') { - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); valid = 1; } } @@ -502,7 +512,8 @@ cast_array(char *s, Py_ssize_t size, int encoding, } } for (t = s, depth = 0; t != end && (*t == '{' || *t == ' '); ++t) { - if (*t == '{') ++depth; + if (*t == '{') + ++depth; } if (!depth) { PyErr_SetString(PyExc_ValueError, @@ -520,30 +531,40 @@ cast_array(char *s, Py_ssize_t size, int encoding, } depth--; /* next level of parsing */ result = PyList_New(0); - if (!result) return NULL; - do ++s; while (s != end && *s == ' '); + if (!result) + return NULL; + do ++s; + while (s != end && *s == ' '); /* everything is set up, start parsing the array */ while (s != end) { if (*s == '}') { PyObject *subresult; - if (!level) break; /* top level array ended */ - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + if (!level) + break; /* top level array ended */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ if (*s == delim) { - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ if (*s != '{') { PyErr_SetString(PyExc_ValueError, "Subarray expected but not found"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } - else if (*s != '}') break; /* error */ + else if (*s != '}') + break; /* error */ subresult = result; result = stack[--level]; if (PyList_Append(result, subresult)) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else if (level == depth) { /* we expect elements at this level */ @@ -555,40 +576,48 @@ cast_array(char *s, Py_ssize_t size, int encoding, if (*s == '{') { PyErr_SetString(PyExc_ValueError, "Subarray found where not expected"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } if (*s == '"') { /* quoted element */ estr = ++s; while (s != end && *s != '"') { if (*s == '\\') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; escaped = 1; } ++s; } esize = s - estr; - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); } else { /* unquoted element */ estr = s; /* can contain blanks inside */ - while (s != end && *s != '"' && - *s != '{' && *s != '}' && *s != delim) - { + while (s != end && *s != '"' && *s != '{' && *s != '}' && + *s != delim) { if (*s == '\\') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; escaped = 1; } ++s; } - t = s; while (t > estr && *(t - 1) == ' ') --t; + t = s; + while (t > estr && *(t - 1) == ' ') --t; if (!(esize = t - estr)) { - s = end; break; /* error */ + s = end; + break; /* error */ } if (STR_IS_NULL(estr, esize)) /* NULL gives None */ estr = NULL; } - if (s == end) break; /* error */ + if (s == end) + break; /* error */ if (estr) { if (escaped) { char *r; @@ -596,12 +625,14 @@ cast_array(char *s, Py_ssize_t size, int encoding, /* create unescaped string */ t = estr; - estr = (char *) PyMem_Malloc((size_t) esize); + estr = (char *)PyMem_Malloc((size_t)esize); if (!estr) { - Py_DECREF(result); return PyErr_NoMemory(); + Py_DECREF(result); + return PyErr_NoMemory(); } for (i = 0, r = estr; i < esize; ++i) { - if (*t == '\\') ++t, ++i; + if (*t == '\\') + ++t, ++i; *r++ = *t++; } esize = r - estr; @@ -613,59 +644,73 @@ cast_array(char *s, Py_ssize_t size, int encoding, element = cast_sized_simple(estr, esize, type); } else { /* external casting of base type */ -#if IS_PY3 - element = encoding == pg_encoding_ascii ? NULL : - get_decoded_string(estr, esize, encoding); - if (!element) /* no decoding necessary or possible */ -#endif - element = PyBytes_FromStringAndSize(estr, esize); + element = encoding == pg_encoding_ascii + ? NULL + : get_decoded_string(estr, esize, encoding); + if (!element) { /* no decoding necessary or possible */ + element = PyBytes_FromStringAndSize(estr, esize); + } if (element && cast) { PyObject *tmp = element; - element = PyObject_CallFunctionObjArgs( - cast, element, NULL); + element = + PyObject_CallFunctionObjArgs(cast, element, NULL); Py_DECREF(tmp); } } - if (escaped) PyMem_Free(estr); + if (escaped) + PyMem_Free(estr); if (!element) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else { - Py_INCREF(Py_None); element = Py_None; + Py_INCREF(Py_None); + element = Py_None; } if (PyList_Append(result, element)) { - Py_DECREF(element); Py_DECREF(result); return NULL; + Py_DECREF(element); + Py_DECREF(result); + return NULL; } Py_DECREF(element); if (*s == delim) { - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ } - else if (*s != '}') break; /* error */ + else if (*s != '}') + break; /* error */ } else { /* we expect arrays at this level */ if (*s != '{') { PyErr_SetString(PyExc_ValueError, "Subarray must start with a left brace"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } - do ++s; while (s != end && *s == ' '); - if (s == end) break; /* error */ + do ++s; + while (s != end && *s == ' '); + if (s == end) + break; /* error */ stack[level++] = result; - if (!(result = PyList_New(0))) return NULL; + if (!(result = PyList_New(0))) + return NULL; } } if (s == end || *s != '}') { - PyErr_SetString(PyExc_ValueError, - "Unexpected end of array"); - Py_DECREF(result); return NULL; + PyErr_SetString(PyExc_ValueError, "Unexpected end of array"); + Py_DECREF(result); + return NULL; } - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); if (s != end) { PyErr_SetString(PyExc_ValueError, "Unexpected characters after end of array"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } return result; } @@ -677,8 +722,8 @@ cast_array(char *s, Py_ssize_t size, int encoding, The parameter delim can specify a delimiter for the elements, although composite types always use a comma as delimiter. */ static PyObject * -cast_record(char *s, Py_ssize_t size, int encoding, - int *type, PyObject *cast, Py_ssize_t len, char delim) +cast_record(char *s, Py_ssize_t size, int encoding, int *type, PyObject *cast, + Py_ssize_t len, char delim) { PyObject *result, *ret; char *end = s + size, *t; @@ -687,7 +732,7 @@ cast_record(char *s, Py_ssize_t size, int encoding, if (!delim) { delim = ','; } - else if (delim == '(' || delim ==')' || delim=='\\') { + else if (delim == '(' || delim == ')' || delim == '\\') { PyErr_SetString(PyExc_ValueError, "Invalid record delimiter"); return NULL; } @@ -700,14 +745,16 @@ cast_record(char *s, Py_ssize_t size, int encoding, return NULL; } result = PyList_New(0); - if (!result) return NULL; + if (!result) + return NULL; i = 0; /* everything is set up, start parsing the record */ while (++s != end) { PyObject *element; if (*s == ')' || *s == delim) { - Py_INCREF(Py_None); element = Py_None; + Py_INCREF(Py_None); + element = Py_None; } else { char *estr; @@ -716,32 +763,40 @@ cast_record(char *s, Py_ssize_t size, int encoding, estr = s; quoted = *s == '"'; - if (quoted) ++s; + if (quoted) + ++s; esize = 0; while (s != end) { if (!quoted && (*s == ')' || *s == delim)) break; if (*s == '"') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; if (!(quoted && *s == '"')) { - quoted = !quoted; continue; + quoted = !quoted; + continue; } } if (*s == '\\') { - ++s; if (s == end) break; + ++s; + if (s == end) + break; } ++s, ++esize; } - if (s == end) break; /* error */ + if (s == end) + break; /* error */ if (estr + esize != s) { char *r; escaped = 1; /* create unescaped string */ t = estr; - estr = (char *) PyMem_Malloc((size_t) esize); + estr = (char *)PyMem_Malloc((size_t)esize); if (!estr) { - Py_DECREF(result); return PyErr_NoMemory(); + Py_DECREF(result); + return PyErr_NoMemory(); } quoted = 0; r = estr; @@ -749,10 +804,12 @@ cast_record(char *s, Py_ssize_t size, int encoding, if (*t == '"') { ++t; if (!(quoted && *t == '"')) { - quoted = !quoted; continue; + quoted = !quoted; + continue; } } - if (*t == '\\') ++t; + if (*t == '\\') + ++t; *r++ = *t++; } } @@ -760,20 +817,20 @@ cast_record(char *s, Py_ssize_t size, int encoding, int etype = type[i]; if (etype & PYGRES_ARRAY) - element = cast_array( - estr, esize, encoding, etype, NULL, 0); + element = + cast_array(estr, esize, encoding, etype, NULL, 0); else if (etype & PYGRES_TEXT) element = cast_sized_text(estr, esize, encoding, etype); else element = cast_sized_simple(estr, esize, etype); } else { /* external casting of base type */ -#if IS_PY3 - element = encoding == pg_encoding_ascii ? NULL : - get_decoded_string(estr, esize, encoding); - if (!element) /* no decoding necessary or possible */ -#endif - element = PyBytes_FromStringAndSize(estr, esize); + element = encoding == pg_encoding_ascii + ? NULL + : get_decoded_string(estr, esize, encoding); + if (!element) { /* no decoding necessary or possible */ + element = PyBytes_FromStringAndSize(estr, esize); + } if (element && cast) { if (len) { PyObject *ecast = PySequence_GetItem(cast, i); @@ -787,46 +844,58 @@ cast_record(char *s, Py_ssize_t size, int encoding, } } else { - Py_DECREF(element); element = NULL; + Py_DECREF(element); + element = NULL; } } else { PyObject *tmp = element; - element = PyObject_CallFunctionObjArgs( - cast, element, NULL); + element = + PyObject_CallFunctionObjArgs(cast, element, NULL); Py_DECREF(tmp); } } } - if (escaped) PyMem_Free(estr); + if (escaped) + PyMem_Free(estr); if (!element) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } if (PyList_Append(result, element)) { - Py_DECREF(element); Py_DECREF(result); return NULL; + Py_DECREF(element); + Py_DECREF(result); + return NULL; } Py_DECREF(element); - if (len) ++i; - if (*s != delim) break; /* no next record */ + if (len) + ++i; + if (*s != delim) + break; /* no next record */ if (len && i >= len) { PyErr_SetString(PyExc_ValueError, "Too many columns"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } if (s == end || *s != ')') { PyErr_SetString(PyExc_ValueError, "Unexpected end of record"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } - do ++s; while (s != end && *s == ' '); + do ++s; + while (s != end && *s == ' '); if (s != end) { PyErr_SetString(PyExc_ValueError, "Unexpected characters after end of record"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } if (len && i < len) { PyErr_SetString(PyExc_ValueError, "Too few columns"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } ret = PyList_AsTuple(result); @@ -852,94 +921,116 @@ cast_hstore(char *s, Py_ssize_t size, int encoding) int quoted; while (s != end && *s == ' ') ++s; - if (s == end) break; + if (s == end) + break; quoted = *s == '"'; if (quoted) { key = ++s; while (s != end) { - if (*s == '"') break; + if (*s == '"') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++key_esc; } ++s; } if (s == end) { PyErr_SetString(PyExc_ValueError, "Unterminated quote"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else { key = s; while (s != end) { - if (*s == '=' || *s == ' ') break; + if (*s == '=' || *s == ' ') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++key_esc; } ++s; } if (s == key) { PyErr_SetString(PyExc_ValueError, "Missing key"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } size = s - key - key_esc; if (key_esc) { char *r = key, *t; - key = (char *) PyMem_Malloc((size_t) size); + key = (char *)PyMem_Malloc((size_t)size); if (!key) { - Py_DECREF(result); return PyErr_NoMemory(); + Py_DECREF(result); + return PyErr_NoMemory(); } t = key; while (r != s) { if (*r == '\\') { - ++r; if (r == s) break; + ++r; + if (r == s) + break; } *t++ = *r++; } } key_obj = cast_sized_text(key, size, encoding, PYGRES_TEXT); - if (key_esc) PyMem_Free(key); + if (key_esc) + PyMem_Free(key); if (!key_obj) { - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } - if (quoted) ++s; + if (quoted) + ++s; while (s != end && *s == ' ') ++s; if (s == end || *s++ != '=' || s == end || *s++ != '>') { PyErr_SetString(PyExc_ValueError, "Invalid characters after key"); - Py_DECREF(key_obj); Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(result); + return NULL; } while (s != end && *s == ' ') ++s; quoted = *s == '"'; if (quoted) { val = ++s; while (s != end) { - if (*s == '"') break; + if (*s == '"') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++val_esc; } ++s; } if (s == end) { PyErr_SetString(PyExc_ValueError, "Unterminated quote"); - Py_DECREF(result); return NULL; + Py_DECREF(result); + return NULL; } } else { val = s; while (s != end) { - if (*s == ',' || *s == ' ') break; + if (*s == ',' || *s == ' ') + break; if (*s == '\\') { - if (++s == end) break; + if (++s == end) + break; ++val_esc; } ++s; } if (s == val) { PyErr_SetString(PyExc_ValueError, "Missing value"); - Py_DECREF(key_obj); Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(result); + return NULL; } if (STR_IS_NULL(val, s - val)) val = NULL; @@ -948,46 +1039,59 @@ cast_hstore(char *s, Py_ssize_t size, int encoding) size = s - val - val_esc; if (val_esc) { char *r = val, *t; - val = (char *) PyMem_Malloc((size_t) size); + val = (char *)PyMem_Malloc((size_t)size); if (!val) { - Py_DECREF(key_obj); Py_DECREF(result); + Py_DECREF(key_obj); + Py_DECREF(result); return PyErr_NoMemory(); } t = val; while (r != s) { if (*r == '\\') { - ++r; if (r == s) break; + ++r; + if (r == s) + break; } *t++ = *r++; } } val_obj = cast_sized_text(val, size, encoding, PYGRES_TEXT); - if (val_esc) PyMem_Free(val); + if (val_esc) + PyMem_Free(val); if (!val_obj) { - Py_DECREF(key_obj); Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(result); + return NULL; } } else { - Py_INCREF(Py_None); val_obj = Py_None; + Py_INCREF(Py_None); + val_obj = Py_None; } - if (quoted) ++s; + if (quoted) + ++s; while (s != end && *s == ' ') ++s; if (s != end) { if (*s++ != ',') { PyErr_SetString(PyExc_ValueError, "Invalid characters after val"); - Py_DECREF(key_obj); Py_DECREF(val_obj); - Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(val_obj); + Py_DECREF(result); + return NULL; } while (s != end && *s == ' ') ++s; if (s == end) { PyErr_SetString(PyExc_ValueError, "Missing entry"); - Py_DECREF(key_obj); Py_DECREF(val_obj); - Py_DECREF(result); return NULL; + Py_DECREF(key_obj); + Py_DECREF(val_obj); + Py_DECREF(result); + return NULL; } } PyDict_SetItem(result, key_obj, val_obj); - Py_DECREF(key_obj); Py_DECREF(val_obj); + Py_DECREF(key_obj); + Py_DECREF(val_obj); } return result; } @@ -1060,25 +1164,24 @@ get_error_type(const char *sqlstate) /* Set database error message and sqlstate attribute. */ static void -set_error_msg_and_state(PyObject *type, - const char *msg, int encoding, const char *sqlstate) +set_error_msg_and_state(PyObject *type, const char *msg, int encoding, + const char *sqlstate) { PyObject *err_obj, *msg_obj, *sql_obj = NULL; -#if IS_PY3 if (encoding == -1) /* unknown */ msg_obj = PyUnicode_DecodeLocale(msg, NULL); else - msg_obj = get_decoded_string(msg, (Py_ssize_t) strlen(msg), encoding); + msg_obj = get_decoded_string(msg, (Py_ssize_t)strlen(msg), encoding); if (!msg_obj) /* cannot decode */ -#endif - msg_obj = PyBytes_FromString(msg); + msg_obj = PyBytes_FromString(msg); if (sqlstate) { - sql_obj = PyStr_FromStringAndSize(sqlstate, 5); + sql_obj = PyUnicode_FromStringAndSize(sqlstate, 5); } else { - Py_INCREF(Py_None); sql_obj = Py_None; + Py_INCREF(Py_None); + sql_obj = Py_None; } err_obj = PyObject_CallFunctionObjArgs(type, msg_obj, NULL); @@ -1103,7 +1206,7 @@ set_error_msg(PyObject *type, const char *msg) /* Set database error from connection and/or result. */ static void -set_error(PyObject *type, const char * msg, PGconn *cnx, PGresult *result) +set_error(PyObject *type, const char *msg, PGconn *cnx, PGresult *result) { char *sqlstate = NULL; int encoding = pg_encoding_ascii; @@ -1117,19 +1220,19 @@ set_error(PyObject *type, const char * msg, PGconn *cnx, PGresult *result) } if (result) { sqlstate = PQresultErrorField(result, PG_DIAG_SQLSTATE); - if (sqlstate) type = get_error_type(sqlstate); + if (sqlstate) + type = get_error_type(sqlstate); } set_error_msg_and_state(type, msg, encoding, sqlstate); } -#ifdef SSL_INFO - /* Get SSL attributes and values as a dictionary. */ static PyObject * -get_ssl_attributes(PGconn *cnx) { +get_ssl_attributes(PGconn *cnx) +{ PyObject *attr_dict = NULL; - const char * const *s; + const char *const *s; if (!(attr_dict = PyDict_New())) { return NULL; @@ -1139,7 +1242,7 @@ get_ssl_attributes(PGconn *cnx) { const char *val = PQsslAttribute(cnx, *s); if (val) { - PyObject * val_obj = PyStr_FromString(val); + PyObject *val_obj = PyUnicode_FromString(val); PyDict_SetItemString(attr_dict, *s, val_obj); Py_DECREF(val_obj); @@ -1152,8 +1255,6 @@ get_ssl_attributes(PGconn *cnx) { return attr_dict; } -#endif /* SSL_INFO */ - /* Format result (mostly useful for debugging). Note: This is similar to the Postgres function PQprint(). PQprint() is not used because handing over a stream from Python to @@ -1165,10 +1266,10 @@ format_result(const PGresult *res) const int n = PQnfields(res); if (n > 0) { - char * const aligns = (char *) PyMem_Malloc( - (unsigned int) n * sizeof(char)); - size_t * const sizes = (size_t *) PyMem_Malloc( - (unsigned int) n * sizeof(size_t)); + char *const aligns = + (char *)PyMem_Malloc((unsigned int)n * sizeof(char)); + size_t *const sizes = + (size_t *)PyMem_Malloc((unsigned int)n * sizeof(size_t)); if (aligns && sizes) { const int m = PQntuples(res); @@ -1178,7 +1279,7 @@ format_result(const PGresult *res) /* calculate sizes and alignments */ for (j = 0; j < n; ++j) { - const char * const s = PQfname(res, j); + const char *const s = PQfname(res, j); const int format = PQfformat(res, j); sizes[j] = s ? strlen(s) : 0; @@ -1214,9 +1315,9 @@ format_result(const PGresult *res) if (aligns[j]) { const int k = PQgetlength(res, i, j); - if (sizes[j] < (size_t) k) + if (sizes[j] < (size_t)k) /* value must fit */ - sizes[j] = (size_t) k; + sizes[j] = (size_t)k; } } } @@ -1224,23 +1325,23 @@ format_result(const PGresult *res) /* size of one row */ for (j = 0; j < n; ++j) size += sizes[j] + 1; /* times number of rows incl. heading */ - size *= (size_t) m + 2; + size *= (size_t)m + 2; /* plus size of footer */ size += 40; /* is the buffer size that needs to be allocated */ - buffer = (char *) PyMem_Malloc(size); + buffer = (char *)PyMem_Malloc(size); if (buffer) { char *p = buffer; PyObject *result; /* create the header */ for (j = 0; j < n; ++j) { - const char * const s = PQfname(res, j); + const char *const s = PQfname(res, j); const size_t k = sizes[j]; - const size_t h = (k - (size_t) strlen(s)) / 2; + const size_t h = (k - (size_t)strlen(s)) / 2; - sprintf(p, "%*s", (int) h, ""); - sprintf(p + h, "%-*s", (int) (k - h), s); + sprintf(p, "%*s", (int)h, ""); + sprintf(p + h, "%-*s", (int)(k - h), s); p += k; if (j + 1 < n) *p++ = '|'; @@ -1249,8 +1350,7 @@ format_result(const PGresult *res) for (j = 0; j < n; ++j) { size_t k = sizes[j]; - while (k--) - *p++ = '-'; + while (k--) *p++ = '-'; if (j + 1 < n) *p++ = '+'; } @@ -1262,11 +1362,11 @@ format_result(const PGresult *res) const size_t k = sizes[j]; if (align) { - sprintf(p, align == 'r' ? "%*s" : "%-*s", (int) k, + sprintf(p, align == 'r' ? "%*s" : "%-*s", (int)k, PQgetvalue(res, i, j)); } else { - sprintf(p, "%-*s", (int) k, + sprintf(p, "%-*s", (int)k, PQgetisnull(res, i, j) ? "" : ""); } p += k; @@ -1276,52 +1376,60 @@ format_result(const PGresult *res) *p++ = '\n'; } /* free memory */ - PyMem_Free(aligns); PyMem_Free(sizes); + PyMem_Free(aligns); + PyMem_Free(sizes); /* create the footer */ sprintf(p, "(%d row%s)", m, m == 1 ? "" : "s"); /* return the result */ - result = PyStr_FromString(buffer); + result = PyUnicode_FromString(buffer); PyMem_Free(buffer); return result; } else { - PyMem_Free(aligns); PyMem_Free(sizes); return PyErr_NoMemory(); + PyMem_Free(aligns); + PyMem_Free(sizes); + return PyErr_NoMemory(); } } else { - PyMem_Free(aligns); PyMem_Free(sizes); return PyErr_NoMemory(); + PyMem_Free(aligns); + PyMem_Free(sizes); + return PyErr_NoMemory(); } } else - return PyStr_FromString("(nothing selected)"); + return PyUnicode_FromString("(nothing selected)"); } /* Internal function converting a Postgres datestyles to date formats. */ static const char * date_style_to_format(const char *s) { - static const char *formats[] = - { - "%Y-%m-%d", /* 0 = ISO */ - "%m-%d-%Y", /* 1 = Postgres, MDY */ - "%d-%m-%Y", /* 2 = Postgres, DMY */ - "%m/%d/%Y", /* 3 = SQL, MDY */ - "%d/%m/%Y", /* 4 = SQL, DMY */ - "%d.%m.%Y" /* 5 = German */ + static const char *formats[] = { + "%Y-%m-%d", /* 0 = ISO */ + "%m-%d-%Y", /* 1 = Postgres, MDY */ + "%d-%m-%Y", /* 2 = Postgres, DMY */ + "%m/%d/%Y", /* 3 = SQL, MDY */ + "%d/%m/%Y", /* 4 = SQL, DMY */ + "%d.%m.%Y" /* 5 = German */ }; switch (s ? *s : 'I') { case 'P': /* Postgres */ s = strchr(s + 1, ','); - if (s) do ++s; while (*s && *s == ' '); + if (s) + do ++s; + while (*s && *s == ' '); return formats[s && *s == 'D' ? 2 : 1]; case 'S': /* SQL */ s = strchr(s + 1, ','); - if (s) do ++s; while (*s && *s == ' '); + if (s) + do ++s; + while (*s && *s == ' '); return formats[s && *s == 'D' ? 4 : 3]; case 'G': /* German */ return formats[5]; - default: /* ISO */ + default: /* ISO */ return formats[0]; /* ISO is the default */ } } @@ -1330,14 +1438,13 @@ date_style_to_format(const char *s) static const char * date_format_to_style(const char *s) { - static const char *datestyle[] = - { - "ISO, YMD", /* 0 = %Y-%m-%d */ - "Postgres, MDY", /* 1 = %m-%d-%Y */ - "Postgres, DMY", /* 2 = %d-%m-%Y */ - "SQL, MDY", /* 3 = %m/%d/%Y */ - "SQL, DMY", /* 4 = %d/%m/%Y */ - "German, DMY" /* 5 = %d.%m.%Y */ + static const char *datestyle[] = { + "ISO, YMD", /* 0 = %Y-%m-%d */ + "Postgres, MDY", /* 1 = %m-%d-%Y */ + "Postgres, DMY", /* 2 = %d-%m-%Y */ + "SQL, MDY", /* 3 = %m/%d/%Y */ + "SQL, DMY", /* 4 = %d/%m/%Y */ + "German, DMY" /* 5 = %d.%m.%Y */ }; switch (s ? s[1] : 'Y') { @@ -1367,7 +1474,7 @@ static void notice_receiver(void *arg, const PGresult *res) { PyGILState_STATE gstate = PyGILState_Ensure(); - connObject *self = (connObject*) arg; + connObject *self = (connObject *)arg; PyObject *func = self->notice_receiver; if (func) { @@ -1379,7 +1486,7 @@ notice_receiver(void *arg, const PGresult *res) } else { Py_INCREF(Py_None); - notice = (noticeObject *)(void *) Py_None; + notice = (noticeObject *)(void *)Py_None; } ret = PyObject_CallFunction(func, "(O)", notice); Py_XDECREF(ret); diff --git a/pglarge.c b/ext/pglarge.c similarity index 63% rename from pglarge.c rename to ext/pglarge.c index 4f9e0f3e..1b817b25 100644 --- a/pglarge.c +++ b/ext/pglarge.c @@ -3,7 +3,7 @@ * * Large object support - this file is part a of the C extension module. * - * Copyright (c) 2020 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ @@ -12,8 +12,12 @@ static void large_dealloc(largeObject *self) { - if (self->lo_fd >= 0 && self->pgcnx->valid) - lo_close(self->pgcnx->cnx, self->lo_fd); + /* Note: We do not try to close the large object here anymore, + since the server automatically closes it at the end of the + transaction in which it was created. So the object might already + be closed, which will then cause error messages on the server. + In other situations we might close the object too early here + if the Python object falls out of scope but is still needed. */ Py_XDECREF(self->pgcnx); PyObject_Del(self); @@ -24,10 +28,11 @@ static PyObject * large_str(largeObject *self) { char str[80]; - sprintf(str, self->lo_fd >= 0 ? - "Opened large object, oid %ld" : - "Closed large object, oid %ld", (long) self->lo_oid); - return PyStr_FromString(str); + sprintf(str, + self->lo_fd >= 0 ? "Opened large object, oid %ld" + : "Closed large object, oid %ld", + (long)self->lo_oid); + return PyUnicode_FromString(str); } /* Check validity of large object. */ @@ -63,7 +68,7 @@ _check_lo_obj(largeObject *self, int level) static PyObject * large_getattr(largeObject *self, PyObject *nameobj) { - const char *name = PyStr_AsString(nameobj); + const char *name = PyUnicode_AsUTF8(nameobj); /* list postgreSQL large object fields */ @@ -71,7 +76,7 @@ large_getattr(largeObject *self, PyObject *nameobj) if (!strcmp(name, "pgcnx")) { if (_check_lo_obj(self, 0)) { Py_INCREF(self->pgcnx); - return (PyObject *) (self->pgcnx); + return (PyObject *)(self->pgcnx); } PyErr_Clear(); Py_INCREF(Py_None); @@ -81,7 +86,7 @@ large_getattr(largeObject *self, PyObject *nameobj) /* large object oid */ if (!strcmp(name, "oid")) { if (_check_lo_obj(self, 0)) - return PyInt_FromLong(self->lo_oid); + return PyLong_FromLong((long)self->lo_oid); PyErr_Clear(); Py_INCREF(Py_None); return Py_None; @@ -89,10 +94,10 @@ large_getattr(largeObject *self, PyObject *nameobj) /* error (status) message */ if (!strcmp(name, "error")) - return PyStr_FromString(PQerrorMessage(self->pgcnx->cnx)); + return PyUnicode_FromString(PQerrorMessage(self->pgcnx->cnx)); /* seeks name in methods (fallback) */ - return PyObject_GenericGetAttr((PyObject *) self, nameobj); + return PyObject_GenericGetAttr((PyObject *)self, nameobj); } /* Get the list of large object attributes. */ @@ -101,17 +106,16 @@ large_dir(largeObject *self, PyObject *noargs) { PyObject *attrs; - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[sss]", "oid", "pgcnx", "error"); + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[sss]", "oid", "pgcnx", "error"); return attrs; } /* Open large object. */ static char large_open__doc__[] = -"open(mode) -- open access to large object with specified mode\n\n" -"The mode must be one of INV_READ, INV_WRITE (module level constants).\n"; + "open(mode) -- open access to large object with specified mode\n\n" + "The mode must be one of INV_READ, INV_WRITE (module level constants).\n"; static PyObject * large_open(largeObject *self, PyObject *args) @@ -144,7 +148,7 @@ large_open(largeObject *self, PyObject *args) /* Close large object. */ static char large_close__doc__[] = -"close() -- close access to large object data"; + "close() -- close access to large object data"; static PyObject * large_close(largeObject *self, PyObject *noargs) @@ -168,8 +172,8 @@ large_close(largeObject *self, PyObject *noargs) /* Read from large object. */ static char large_read__doc__[] = -"read(size) -- read from large object to sized string\n\n" -"Object must be opened in read mode before calling this method.\n"; + "read(size) -- read from large object to sized string\n\n" + "Object must be opened in read mode before calling this method.\n"; static PyObject * large_read(largeObject *self, PyObject *args) @@ -196,11 +200,11 @@ large_read(largeObject *self, PyObject *args) } /* allocate buffer and runs read */ - buffer = PyBytes_FromStringAndSize((char *) NULL, size); + buffer = PyBytes_FromStringAndSize((char *)NULL, size); if ((size = lo_read(self->pgcnx->cnx, self->lo_fd, - PyBytes_AS_STRING((PyBytesObject *) (buffer)), (size_t) size)) == -1) - { + PyBytes_AS_STRING((PyBytesObject *)(buffer)), + (size_t)size)) == -1) { PyErr_SetString(PyExc_IOError, "Error while reading"); Py_XDECREF(buffer); return NULL; @@ -213,8 +217,8 @@ large_read(largeObject *self, PyObject *args) /* Write to large object. */ static char large_write__doc__[] = -"write(string) -- write sized string to large object\n\n" -"Object must be opened in read mode before calling this method.\n"; + "write(string) -- write sized string to large object\n\n" + "Object must be opened in read mode before calling this method.\n"; static PyObject * large_write(largeObject *self, PyObject *args) @@ -237,8 +241,7 @@ large_write(largeObject *self, PyObject *args) /* sends query */ if ((size = lo_write(self->pgcnx->cnx, self->lo_fd, buffer, - (size_t) bufsize)) != bufsize) - { + (size_t)bufsize)) != bufsize) { PyErr_SetString(PyExc_IOError, "Buffer truncated during write"); return NULL; } @@ -250,9 +253,9 @@ large_write(largeObject *self, PyObject *args) /* Go to position in large object. */ static char large_seek__doc__[] = -"seek(offset, whence) -- move to specified position\n\n" -"Object must be opened before calling this method. The whence option\n" -"can be SEEK_SET, SEEK_CUR or SEEK_END (module level constants).\n"; + "seek(offset, whence) -- move to specified position\n\n" + "Object must be opened before calling this method. The whence option\n" + "can be SEEK_SET, SEEK_CUR or SEEK_END (module level constants).\n"; static PyObject * large_seek(largeObject *self, PyObject *args) @@ -273,21 +276,20 @@ large_seek(largeObject *self, PyObject *args) } /* sends query */ - if ((ret = lo_lseek( - self->pgcnx->cnx, self->lo_fd, offset, whence)) == -1) - { + if ((ret = lo_lseek(self->pgcnx->cnx, self->lo_fd, offset, whence)) == + -1) { PyErr_SetString(PyExc_IOError, "Error while moving cursor"); return NULL; } /* returns position */ - return PyInt_FromLong(ret); + return PyLong_FromLong(ret); } /* Get large object size. */ static char large_size__doc__[] = -"size() -- return large object size\n\n" -"The object must be opened before calling this method.\n"; + "size() -- return large object size\n\n" + "The object must be opened before calling this method.\n"; static PyObject * large_size(largeObject *self, PyObject *noargs) @@ -312,22 +314,21 @@ large_size(largeObject *self, PyObject *noargs) } /* move back to start position */ - if ((start = lo_lseek( - self->pgcnx->cnx, self->lo_fd, start, SEEK_SET)) == -1) - { + if ((start = lo_lseek(self->pgcnx->cnx, self->lo_fd, start, SEEK_SET)) == + -1) { PyErr_SetString(PyExc_IOError, "Error while moving back to first position"); return NULL; } /* returns size */ - return PyInt_FromLong(end); + return PyLong_FromLong(end); } /* Get large object cursor position. */ static char large_tell__doc__[] = -"tell() -- give current position in large object\n\n" -"The object must be opened before calling this method.\n"; + "tell() -- give current position in large object\n\n" + "The object must be opened before calling this method.\n"; static PyObject * large_tell(largeObject *self, PyObject *noargs) @@ -346,13 +347,13 @@ large_tell(largeObject *self, PyObject *noargs) } /* returns size */ - return PyInt_FromLong(start); + return PyLong_FromLong(start); } /* Export large object as unix file. */ static char large_export__doc__[] = -"export(filename) -- export large object data to specified file\n\n" -"The object must be closed when calling this method.\n"; + "export(filename) -- export large object data to specified file\n\n" + "The object must be closed when calling this method.\n"; static PyObject * large_export(largeObject *self, PyObject *args) @@ -383,8 +384,8 @@ large_export(largeObject *self, PyObject *args) /* Delete a large object. */ static char large_unlink__doc__[] = -"unlink() -- destroy large object\n\n" -"The object must be closed when calling this method.\n"; + "unlink() -- destroy large object\n\n" + "The object must be closed when calling this method.\n"; static PyObject * large_unlink(largeObject *self, PyObject *noargs) @@ -407,51 +408,49 @@ large_unlink(largeObject *self, PyObject *noargs) /* Large object methods */ static struct PyMethodDef large_methods[] = { - {"__dir__", (PyCFunction) large_dir, METH_NOARGS, NULL}, - {"open", (PyCFunction) large_open, METH_VARARGS, large_open__doc__}, - {"close", (PyCFunction) large_close, METH_NOARGS, large_close__doc__}, - {"read", (PyCFunction) large_read, METH_VARARGS, large_read__doc__}, - {"write", (PyCFunction) large_write, METH_VARARGS, large_write__doc__}, - {"seek", (PyCFunction) large_seek, METH_VARARGS, large_seek__doc__}, - {"size", (PyCFunction) large_size, METH_NOARGS, large_size__doc__}, - {"tell", (PyCFunction) large_tell, METH_NOARGS, large_tell__doc__}, - {"export",(PyCFunction) large_export, METH_VARARGS, large_export__doc__}, - {"unlink",(PyCFunction) large_unlink, METH_NOARGS, large_unlink__doc__}, - {NULL, NULL} -}; + {"__dir__", (PyCFunction)large_dir, METH_NOARGS, NULL}, + {"open", (PyCFunction)large_open, METH_VARARGS, large_open__doc__}, + {"close", (PyCFunction)large_close, METH_NOARGS, large_close__doc__}, + {"read", (PyCFunction)large_read, METH_VARARGS, large_read__doc__}, + {"write", (PyCFunction)large_write, METH_VARARGS, large_write__doc__}, + {"seek", (PyCFunction)large_seek, METH_VARARGS, large_seek__doc__}, + {"size", (PyCFunction)large_size, METH_NOARGS, large_size__doc__}, + {"tell", (PyCFunction)large_tell, METH_NOARGS, large_tell__doc__}, + {"export", (PyCFunction)large_export, METH_VARARGS, large_export__doc__}, + {"unlink", (PyCFunction)large_unlink, METH_NOARGS, large_unlink__doc__}, + {NULL, NULL}}; static char large__doc__[] = "PostgreSQL large object"; /* Large object type definition */ static PyTypeObject largeType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.LargeObject", /* tp_name */ - sizeof(largeObject), /* tp_basicsize */ - 0, /* tp_itemsize */ + PyVarObject_HEAD_INIT(NULL, 0) "pg.LargeObject", /* tp_name */ + sizeof(largeObject), /* tp_basicsize */ + 0, /* tp_itemsize */ /* methods */ - (destructor) large_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) large_str, /* tp_str */ - (getattrofunc) large_getattr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - large__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - large_methods, /* tp_methods */ + (destructor)large_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)large_str, /* tp_str */ + (getattrofunc)large_getattr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + large__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + large_methods, /* tp_methods */ }; diff --git a/pgmodule.c b/ext/pgmodule.c similarity index 57% rename from pgmodule.c rename to ext/pgmodule.c index 3a1c70be..916adda2 100644 --- a/pgmodule.c +++ b/ext/pgmodule.c @@ -3,7 +3,7 @@ * * This is the main file for the C extension module. * - * Copyright (c) 2020 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ @@ -12,20 +12,16 @@ #define PY_SSIZE_T_CLEAN #include - #include #include /* The type definitions from */ #include "pgtypes.h" -/* Macros for single-source Python 2/3 compatibility */ -#include "py3c.h" - static PyObject *Error, *Warning, *InterfaceError, *DatabaseError, - *InternalError, *OperationalError, *ProgrammingError, - *IntegrityError, *DataError, *NotSupportedError, - *InvalidResultError, *NoResultError, *MultipleResultsError; + *InternalError, *OperationalError, *ProgrammingError, *IntegrityError, + *DataError, *NotSupportedError, *InvalidResultError, *NoResultError, + *MultipleResultsError, *Connection, *Query, *LargeObject; #define _TOSTRING(x) #x #define TOSTRING(x) _TOSTRING(x) @@ -39,49 +35,48 @@ static const char *PyPgVersion = TOSTRING(PYGRESQL_VERSION); #define PG_ARRAYSIZE 1 /* Flags for object validity checks */ -#define CHECK_OPEN 1 -#define CHECK_CLOSE 2 -#define CHECK_CNX 4 +#define CHECK_OPEN 1 +#define CHECK_CLOSE 2 +#define CHECK_CNX 4 #define CHECK_RESULT 8 -#define CHECK_DQL 16 +#define CHECK_DQL 16 /* Query result types */ #define RESULT_EMPTY 1 -#define RESULT_DML 2 -#define RESULT_DDL 3 -#define RESULT_DQL 4 +#define RESULT_DML 2 +#define RESULT_DDL 3 +#define RESULT_DQL 4 /* Flags for move methods */ #define QUERY_MOVEFIRST 1 -#define QUERY_MOVELAST 2 -#define QUERY_MOVENEXT 3 -#define QUERY_MOVEPREV 4 +#define QUERY_MOVELAST 2 +#define QUERY_MOVENEXT 3 +#define QUERY_MOVEPREV 4 -#define MAX_BUFFER_SIZE 8192 /* maximum transaction size */ +#define MAX_BUFFER_SIZE 65536 /* maximum transaction size */ #define MAX_ARRAY_DEPTH 16 /* maximum allowed depth of an array */ /* MODULE GLOBAL VARIABLES */ -#ifdef DEFAULT_VARS static PyObject *pg_default_host; /* default database host */ static PyObject *pg_default_base; /* default database name */ static PyObject *pg_default_opt; /* default connection options */ static PyObject *pg_default_port; /* default connection port */ static PyObject *pg_default_user; /* default username */ static PyObject *pg_default_passwd; /* default password */ -#endif /* DEFAULT_VARS */ static PyObject *decimal = NULL, /* decimal type */ - *dictiter = NULL, /* function for getting named results */ - *namediter = NULL, /* function for getting named results */ - *namednext = NULL, /* function for getting one named result */ + *dictiter = NULL, /* function for getting dict results */ + *namediter = NULL, /* function for getting named results */ + *namednext = NULL, /* function for getting one named result */ *scalariter = NULL, /* function for getting scalar results */ - *jsondecode = NULL; /* function for decoding json strings */ + *jsondecode = + NULL; /* function for decoding json strings */ static const char *date_format = NULL; /* date format that is always assumed */ -static char decimal_point = '.'; /* decimal point used in money values */ -static int bool_as_text = 0; /* whether bool shall be returned as text */ -static int array_as_text = 0; /* whether arrays shall be returned as text */ -static int bytea_escaped = 0; /* whether bytea shall be returned escaped */ +static char decimal_point = '.'; /* decimal point used in money values */ +static int bool_as_text = 0; /* whether bool shall be returned as text */ +static int array_as_text = 0; /* whether arrays shall be returned as text */ +static int bytea_escaped = 0; /* whether bytea shall be returned escaped */ static int pg_encoding_utf8 = 0; static int pg_encoding_latin1 = 0; @@ -111,67 +106,57 @@ OBJECTS static PyTypeObject connType, sourceType, queryType, noticeType, largeType; /* Forward static declarations */ -static void notice_receiver(void *, const PGresult *); +static void +notice_receiver(void *, const PGresult *); /* Object declarations */ -typedef struct -{ - PyObject_HEAD - int valid; /* validity flag */ - PGconn *cnx; /* Postgres connection handle */ - const char *date_format; /* date format derived from datestyle */ - PyObject *cast_hook; /* external typecast method */ - PyObject *notice_receiver; /* current notice receiver */ -} connObject; +typedef struct { + PyObject_HEAD int valid; /* validity flag */ + PGconn *cnx; /* Postgres connection handle */ + const char *date_format; /* date format derived from datestyle */ + PyObject *cast_hook; /* external typecast method */ + PyObject *notice_receiver; /* current notice receiver */ +} connObject; #define is_connObject(v) (PyType(v) == &connType) -typedef struct -{ - PyObject_HEAD - int valid; /* validity flag */ +typedef struct { + PyObject_HEAD int valid; /* validity flag */ connObject *pgcnx; /* parent connection object */ - PGresult *result; /* result content */ - int encoding; /* client encoding */ - int result_type; /* result type (DDL/DML/DQL) */ - long arraysize; /* array size for fetch method */ - int current_row; /* currently selected row */ - int max_row; /* number of rows in the result */ - int num_fields; /* number of fields in each row */ -} sourceObject; + PGresult *result; /* result content */ + int encoding; /* client encoding */ + int result_type; /* result type (DDL/DML/DQL) */ + long arraysize; /* array size for fetch method */ + int current_row; /* currently selected row */ + int max_row; /* number of rows in the result */ + int num_fields; /* number of fields in each row */ +} sourceObject; #define is_sourceObject(v) (PyType(v) == &sourceType) -typedef struct -{ - PyObject_HEAD - connObject *pgcnx; /* parent connection object */ - PGresult const *res; /* an error or warning */ -} noticeObject; +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + PGresult const *res; /* an error or warning */ +} noticeObject; #define is_noticeObject(v) (PyType(v) == ¬iceType) -typedef struct -{ - PyObject_HEAD - connObject *pgcnx; /* parent connection object */ - PGresult *result; /* result content */ - int encoding; /* client encoding */ - int current_row; /* currently selected row */ - int max_row; /* number of rows in the result */ - int num_fields; /* number of fields in each row */ - int *col_types; /* PyGreSQL column types */ -} queryObject; +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + PGresult *result; /* result content */ + int async; /* flag for asynchronous queries */ + int encoding; /* client encoding */ + int current_row; /* currently selected row */ + int max_row; /* number of rows in the result */ + int num_fields; /* number of fields in each row */ + int *col_types; /* PyGreSQL column types */ +} queryObject; #define is_queryObject(v) (PyType(v) == &queryType) -#ifdef LARGE_OBJECTS -typedef struct -{ - PyObject_HEAD - connObject *pgcnx; /* parent connection object */ - Oid lo_oid; /* large object oid */ - int lo_fd; /* large object fd */ -} largeObject; +typedef struct { + PyObject_HEAD connObject *pgcnx; /* parent connection object */ + Oid lo_oid; /* large object oid */ + int lo_fd; /* large object fd */ +} largeObject; #define is_largeObject(v) (PyType(v) == &largeType) -#endif /* LARGE_OBJECTS */ /* Internal functions */ #include "pginternal.c" @@ -189,32 +174,31 @@ typedef struct #include "pgnotice.c" /* Large objects */ -#ifdef LARGE_OBJECTS #include "pglarge.c" -#endif /* MODULE FUNCTIONS */ /* Connect to a database. */ static char pg_connect__doc__[] = -"connect(dbname, host, port, opt) -- connect to a PostgreSQL database\n\n" -"The connection uses the specified parameters (optional, keywords aware).\n"; + "connect(dbname, host, port, opt, user, passwd, nowait) -- connect to a " + "PostgreSQL database\n\n" + "The connection uses the specified parameters (optional, keywords " + "aware).\n"; static PyObject * pg_connect(PyObject *self, PyObject *args, PyObject *dict) { - static const char *kwlist[] = - { - "dbname", "host", "port", "opt", "user", "passwd", NULL - }; + static const char *kwlist[] = {"dbname", "host", "port", "opt", + "user", "passwd", "nowait", NULL}; char *pghost, *pgopt, *pgdbname, *pguser, *pgpasswd; - int pgport; + int pgport = -1, nowait = 0, nkw = 0; char port_buffer[20]; + const char *keywords[sizeof(kwlist) / sizeof(*kwlist) + 1], + *values[sizeof(kwlist) / sizeof(*kwlist) + 1]; connObject *conn_obj; pghost = pgopt = pgdbname = pguser = pgpasswd = NULL; - pgport = -1; /* * parses standard arguments With the right compiler warnings, this @@ -222,20 +206,18 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) * don't declare kwlist as const char *kwlist[] then it complains when * I try to assign all those constant strings to it. */ - if (!PyArg_ParseTupleAndKeywords( - args, dict, "|zzizzz", (char**)kwlist, - &pgdbname, &pghost, &pgport, &pgopt, &pguser, &pgpasswd)) - { + if (!PyArg_ParseTupleAndKeywords(args, dict, "|zzizzzi", (char **)kwlist, + &pgdbname, &pghost, &pgport, &pgopt, + &pguser, &pgpasswd, &nowait)) { return NULL; } -#ifdef DEFAULT_VARS /* handles defaults variables (for uninitialised vars) */ if ((!pghost) && (pg_default_host != Py_None)) pghost = PyBytes_AsString(pg_default_host); if ((pgport == -1) && (pg_default_port != Py_None)) - pgport = (int) PyInt_AsLong(pg_default_port); + pgport = (int)PyLong_AsLong(pg_default_port); if ((!pgopt) && (pg_default_opt != Py_None)) pgopt = PyBytes_AsString(pg_default_opt); @@ -248,7 +230,6 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) if ((!pgpasswd) && (pg_default_passwd != Py_None)) pgpasswd = PyBytes_AsString(pg_default_passwd); -#endif /* DEFAULT_VARS */ if (!(conn_obj = PyObject_New(connObject, &connType))) { set_error_msg(InternalError, "Can't create new connection object"); @@ -261,14 +242,38 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) conn_obj->cast_hook = NULL; conn_obj->notice_receiver = NULL; + if (pghost) { + keywords[nkw] = "host"; + values[nkw++] = pghost; + } + if (pgopt) { + keywords[nkw] = "options"; + values[nkw++] = pgopt; + } + if (pgdbname) { + keywords[nkw] = "dbname"; + values[nkw++] = pgdbname; + } + if (pguser) { + keywords[nkw] = "user"; + values[nkw++] = pguser; + } + if (pgpasswd) { + keywords[nkw] = "password"; + values[nkw++] = pgpasswd; + } if (pgport != -1) { memset(port_buffer, 0, sizeof(port_buffer)); sprintf(port_buffer, "%d", pgport); + + keywords[nkw] = "port"; + values[nkw++] = port_buffer; } + keywords[nkw] = values[nkw] = NULL; Py_BEGIN_ALLOW_THREADS - conn_obj->cnx = PQsetdbLogin(pghost, pgport == -1 ? NULL : port_buffer, - pgopt, NULL, pgdbname, pguser, pgpasswd); + conn_obj->cnx = nowait ? PQconnectStartParams(keywords, values, 1) + : PQconnectdbParams(keywords, values, 1); Py_END_ALLOW_THREADS if (PQstatus(conn_obj->cnx) == CONNECTION_BAD) { @@ -277,23 +282,33 @@ pg_connect(PyObject *self, PyObject *args, PyObject *dict) return NULL; } - return (PyObject *) conn_obj; + return (PyObject *)conn_obj; +} + +/* Get version of libpq that is being used */ +static char pg_get_pqlib_version__doc__[] = + "get_pqlib_version() -- get the version of libpq that is being used"; + +static PyObject * +pg_get_pqlib_version(PyObject *self, PyObject *noargs) +{ + return PyLong_FromLong(PQlibVersion()); } /* Escape string */ static char pg_escape_string__doc__[] = -"escape_string(string) -- escape a string for use within SQL"; + "escape_string(string) -- escape a string for use within SQL"; static PyObject * pg_escape_string(PyObject *self, PyObject *string) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(string)) { PyBytes_AsStringAndSize(string, &from, &from_length); @@ -301,7 +316,8 @@ pg_escape_string(PyObject *self, PyObject *string) else if (PyUnicode_Check(string)) { encoding = pg_encoding_ascii; tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -310,38 +326,39 @@ pg_escape_string(PyObject *self, PyObject *string) return NULL; } - to_length = 2 * (size_t) from_length + 1; - if ((Py_ssize_t ) to_length < from_length) { /* overflow */ - to_length = (size_t) from_length; - from_length = (from_length - 1)/2; + to_length = 2 * (size_t)from_length + 1; + if ((Py_ssize_t)to_length < from_length) { /* overflow */ + to_length = (size_t)from_length; + from_length = (from_length - 1) / 2; } - to = (char *) PyMem_Malloc(to_length); - to_length = (size_t) PQescapeString(to, from, (size_t) from_length); + to = (char *)PyMem_Malloc(to_length); + to_length = (size_t)PQescapeString(to, from, (size_t)from_length); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length, encoding); PyMem_Free(to); return to_obj; } /* Escape bytea */ static char pg_escape_bytea__doc__[] = -"escape_bytea(data) -- escape binary data for use within SQL as type bytea"; + "escape_bytea(data) -- escape binary data for use within SQL as type " + "bytea"; static PyObject * pg_escape_bytea(PyObject *self, PyObject *data) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ + int encoding = -1; /* client encoding */ if (PyBytes_Check(data)) { PyBytes_AsStringAndSize(data, &from, &from_length); @@ -349,7 +366,8 @@ pg_escape_bytea(PyObject *self, PyObject *data) else if (PyUnicode_Check(data)) { encoding = pg_encoding_ascii; tmp_obj = get_encoded_string(data, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -358,15 +376,15 @@ pg_escape_bytea(PyObject *self, PyObject *data) return NULL; } - to = (char *) PQescapeBytea( - (unsigned char*) from, (size_t) from_length, &to_length); + to = (char *)PQescapeBytea((unsigned char *)from, (size_t)from_length, + &to_length); Py_XDECREF(tmp_obj); if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length - 1); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length - 1); else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length - 1, encoding); + to_obj = get_decoded_string(to, (Py_ssize_t)to_length - 1, encoding); if (to) PQfreemem(to); return to_obj; @@ -374,24 +392,25 @@ pg_escape_bytea(PyObject *self, PyObject *data) /* Unescape bytea */ static char pg_unescape_bytea__doc__[] = -"unescape_bytea(string) -- unescape bytea data retrieved as text"; + "unescape_bytea(string) -- unescape bytea data retrieved as text"; static PyObject * pg_unescape_bytea(PyObject *self, PyObject *data) { - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ + PyObject *tmp_obj = NULL, /* auxiliary string object */ + *to_obj; /* string object to return */ + char *from, /* our string argument as encoded string */ + *to; /* the result as encoded string */ + Py_ssize_t from_length; /* length of string */ + size_t to_length; /* length of result */ if (PyBytes_Check(data)) { PyBytes_AsStringAndSize(data, &from, &from_length); } else if (PyUnicode_Check(data)) { tmp_obj = get_encoded_string(data, pg_encoding_ascii); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); } else { @@ -401,13 +420,14 @@ pg_unescape_bytea(PyObject *self, PyObject *data) return NULL; } - to = (char *) PQunescapeBytea((unsigned char*) from, &to_length); + to = (char *)PQunescapeBytea((unsigned char *)from, &to_length); Py_XDECREF(tmp_obj); - if (!to) return PyErr_NoMemory(); + if (!to) + return PyErr_NoMemory(); - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); + to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t)to_length); PQfreemem(to); return to_obj; @@ -415,7 +435,7 @@ pg_unescape_bytea(PyObject *self, PyObject *data) /* Set fixed datestyle. */ static char pg_set_datestyle__doc__[] = -"set_datestyle(style) -- set which style is assumed"; + "set_datestyle(style) -- set which style is assumed"; static PyObject * pg_set_datestyle(PyObject *self, PyObject *args) @@ -432,27 +452,29 @@ pg_set_datestyle(PyObject *self, PyObject *args) date_format = datestyle ? date_style_to_format(datestyle) : NULL; - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } /* Get fixed datestyle. */ static char pg_get_datestyle__doc__[] = -"get_datestyle() -- get which date style is assumed"; + "get_datestyle() -- get which date style is assumed"; static PyObject * pg_get_datestyle(PyObject *self, PyObject *noargs) { if (date_format) { - return PyStr_FromString(date_format_to_style(date_format)); + return PyUnicode_FromString(date_format_to_style(date_format)); } else { - Py_INCREF(Py_None); return Py_None; + Py_INCREF(Py_None); + return Py_None; } } /* Get decimal point. */ static char pg_get_decimal_point__doc__[] = -"get_decimal_point() -- get decimal point to be used for money values"; + "get_decimal_point() -- get decimal point to be used for money values"; static PyObject * pg_get_decimal_point(PyObject *self, PyObject *noargs) @@ -461,11 +483,13 @@ pg_get_decimal_point(PyObject *self, PyObject *noargs) char s[2]; if (decimal_point) { - s[0] = decimal_point; s[1] = '\0'; - ret = PyStr_FromString(s); + s[0] = decimal_point; + s[1] = '\0'; + ret = PyUnicode_FromString(s); } else { - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } return ret; @@ -473,7 +497,7 @@ pg_get_decimal_point(PyObject *self, PyObject *noargs) /* Set decimal point. */ static char pg_set_decimal_point__doc__[] = -"set_decimal_point(char) -- set decimal point to be used for money values"; + "set_decimal_point(char) -- set decimal point to be used for money values"; static PyObject * pg_set_decimal_point(PyObject *self, PyObject *args) @@ -485,13 +509,14 @@ pg_set_decimal_point(PyObject *self, PyObject *args) if (PyArg_ParseTuple(args, "z", &s)) { if (!s) s = "\0"; - else if (*s && (*(s+1) || !strchr(".,;: '*/_`|", *s))) + else if (*s && (*(s + 1) || !strchr(".,;: '*/_`|", *s))) s = NULL; } if (s) { decimal_point = *s; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -503,7 +528,7 @@ pg_set_decimal_point(PyObject *self, PyObject *args) /* Get decimal type. */ static char pg_get_decimal__doc__[] = -"get_decimal() -- get the decimal type to be used for numeric values"; + "get_decimal() -- get the decimal type to be used for numeric values"; static PyObject * pg_get_decimal(PyObject *self, PyObject *noargs) @@ -518,7 +543,7 @@ pg_get_decimal(PyObject *self, PyObject *noargs) /* Set decimal type. */ static char pg_set_decimal__doc__[] = -"set_decimal(cls) -- set a decimal type to be used for numeric values"; + "set_decimal(cls) -- set a decimal type to be used for numeric values"; static PyObject * pg_set_decimal(PyObject *self, PyObject *cls) @@ -526,12 +551,17 @@ pg_set_decimal(PyObject *self, PyObject *cls) PyObject *ret = NULL; if (cls == Py_None) { - Py_XDECREF(decimal); decimal = NULL; - Py_INCREF(Py_None); ret = Py_None; + Py_XDECREF(decimal); + decimal = NULL; + Py_INCREF(Py_None); + ret = Py_None; } else if (PyCallable_Check(cls)) { - Py_XINCREF(cls); Py_XDECREF(decimal); decimal = cls; - Py_INCREF(Py_None); ret = Py_None; + Py_XINCREF(cls); + Py_XDECREF(decimal); + decimal = cls; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -544,7 +574,7 @@ pg_set_decimal(PyObject *self, PyObject *cls) /* Get usage of bool values. */ static char pg_get_bool__doc__[] = -"get_bool() -- check whether boolean values are converted to bool"; + "get_bool() -- check whether boolean values are converted to bool"; static PyObject * pg_get_bool(PyObject *self, PyObject *noargs) @@ -559,7 +589,7 @@ pg_get_bool(PyObject *self, PyObject *noargs) /* Set usage of bool values. */ static char pg_set_bool__doc__[] = -"set_bool(on) -- set whether boolean values should be converted to bool"; + "set_bool(on) -- set whether boolean values should be converted to bool"; static PyObject * pg_set_bool(PyObject *self, PyObject *args) @@ -570,7 +600,8 @@ pg_set_bool(PyObject *self, PyObject *args) /* gets arguments */ if (PyArg_ParseTuple(args, "i", &i)) { bool_as_text = i ? 0 : 1; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString( @@ -583,7 +614,7 @@ pg_set_bool(PyObject *self, PyObject *args) /* Get conversion of arrays to lists. */ static char pg_get_array__doc__[] = -"get_array() -- check whether arrays are converted as lists"; + "get_array() -- check whether arrays are converted as lists"; static PyObject * pg_get_array(PyObject *self, PyObject *noargs) @@ -598,18 +629,19 @@ pg_get_array(PyObject *self, PyObject *noargs) /* Set conversion of arrays to lists. */ static char pg_set_array__doc__[] = -"set_array(on) -- set whether arrays should be converted to lists"; + "set_array(on) -- set whether arrays should be converted to lists"; static PyObject * -pg_set_array(PyObject* self, PyObject* args) +pg_set_array(PyObject *self, PyObject *args) { - PyObject* ret = NULL; + PyObject *ret = NULL; int i; /* gets arguments */ if (PyArg_ParseTuple(args, "i", &i)) { array_as_text = i ? 0 : 1; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString( @@ -622,7 +654,7 @@ pg_set_array(PyObject* self, PyObject* args) /* Check whether bytea values are unescaped. */ static char pg_get_bytea_escaped__doc__[] = -"get_bytea_escaped() -- check whether bytea will be returned escaped"; + "get_bytea_escaped() -- check whether bytea will be returned escaped"; static PyObject * pg_get_bytea_escaped(PyObject *self, PyObject *noargs) @@ -637,7 +669,7 @@ pg_get_bytea_escaped(PyObject *self, PyObject *noargs) /* Set usage of bool values. */ static char pg_set_bytea_escaped__doc__[] = -"set_bytea_escaped(on) -- set whether bytea will be returned escaped"; + "set_bytea_escaped(on) -- set whether bytea will be returned escaped"; static PyObject * pg_set_bytea_escaped(PyObject *self, PyObject *args) @@ -648,7 +680,8 @@ pg_set_bytea_escaped(PyObject *self, PyObject *args) /* gets arguments */ if (PyArg_ParseTuple(args, "i", &i)) { bytea_escaped = i ? 1 : 0; - Py_INCREF(Py_None); ret = Py_None; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -662,18 +695,15 @@ pg_set_bytea_escaped(PyObject *self, PyObject *args) /* set query helper functions (not part of public API) */ static char pg_set_query_helpers__doc__[] = -"set_query_helpers(*helpers) -- set internal query helper functions"; + "set_query_helpers(*helpers) -- set internal query helper functions"; static PyObject * pg_set_query_helpers(PyObject *self, PyObject *args) { /* gets arguments */ - if (!PyArg_ParseTuple(args, "O!O!O!O!", - &PyFunction_Type, &dictiter, - &PyFunction_Type, &namediter, - &PyFunction_Type, &namednext, - &PyFunction_Type, &scalariter)) - { + if (!PyArg_ParseTuple(args, "O!O!O!O!", &PyFunction_Type, &dictiter, + &PyFunction_Type, &namediter, &PyFunction_Type, + &namednext, &PyFunction_Type, &scalariter)) { return NULL; } @@ -683,7 +713,7 @@ pg_set_query_helpers(PyObject *self, PyObject *args) /* Get json decode function. */ static char pg_get_jsondecode__doc__[] = -"get_jsondecode() -- get the function used for decoding json results"; + "get_jsondecode() -- get the function used for decoding json results"; static PyObject * pg_get_jsondecode(PyObject *self, PyObject *noargs) @@ -700,7 +730,8 @@ pg_get_jsondecode(PyObject *self, PyObject *noargs) /* Set json decode function. */ static char pg_set_jsondecode__doc__[] = -"set_jsondecode(func) -- set a function to be used for decoding json results"; + "set_jsondecode(func) -- set a function to be used for decoding json " + "results"; static PyObject * pg_set_jsondecode(PyObject *self, PyObject *func) @@ -708,12 +739,17 @@ pg_set_jsondecode(PyObject *self, PyObject *func) PyObject *ret = NULL; if (func == Py_None) { - Py_XDECREF(jsondecode); jsondecode = NULL; - Py_INCREF(Py_None); ret = Py_None; + Py_XDECREF(jsondecode); + jsondecode = NULL; + Py_INCREF(Py_None); + ret = Py_None; } else if (PyCallable_Check(func)) { - Py_XINCREF(func); Py_XDECREF(jsondecode); jsondecode = func; - Py_INCREF(Py_None); ret = Py_None; + Py_XINCREF(func); + Py_XDECREF(jsondecode); + jsondecode = func; + Py_INCREF(Py_None); + ret = Py_None; } else { PyErr_SetString(PyExc_TypeError, @@ -724,11 +760,9 @@ pg_set_jsondecode(PyObject *self, PyObject *func) return ret; } -#ifdef DEFAULT_VARS - /* Get default host. */ static char pg_get_defhost__doc__[] = -"get_defhost() -- return default database host"; + "get_defhost() -- return default database host"; static PyObject * pg_get_defhost(PyObject *self, PyObject *noargs) @@ -739,7 +773,8 @@ pg_get_defhost(PyObject *self, PyObject *noargs) /* Set default host. */ static char pg_set_defhost__doc__[] = -"set_defhost(string) -- set default database host and return previous value"; + "set_defhost(string) -- set default database host and return previous " + "value"; static PyObject * pg_set_defhost(PyObject *self, PyObject *args) @@ -759,7 +794,7 @@ pg_set_defhost(PyObject *self, PyObject *args) old = pg_default_host; if (tmp) { - pg_default_host = PyStr_FromString(tmp); + pg_default_host = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -771,7 +806,7 @@ pg_set_defhost(PyObject *self, PyObject *args) /* Get default database. */ static char pg_get_defbase__doc__[] = -"get_defbase() -- return default database name"; + "get_defbase() -- return default database name"; static PyObject * pg_get_defbase(PyObject *self, PyObject *noargs) @@ -782,7 +817,8 @@ pg_get_defbase(PyObject *self, PyObject *noargs) /* Set default database. */ static char pg_set_defbase__doc__[] = -"set_defbase(string) -- set default database name and return previous value"; + "set_defbase(string) -- set default database name and return previous " + "value"; static PyObject * pg_set_defbase(PyObject *self, PyObject *args) @@ -802,7 +838,7 @@ pg_set_defbase(PyObject *self, PyObject *args) old = pg_default_base; if (tmp) { - pg_default_base = PyStr_FromString(tmp); + pg_default_base = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -814,7 +850,7 @@ pg_set_defbase(PyObject *self, PyObject *args) /* Get default options. */ static char pg_get_defopt__doc__[] = -"get_defopt() -- return default database options"; + "get_defopt() -- return default database options"; static PyObject * pg_get_defopt(PyObject *self, PyObject *noargs) @@ -825,7 +861,7 @@ pg_get_defopt(PyObject *self, PyObject *noargs) /* Set default options. */ static char pg_set_defopt__doc__[] = -"set_defopt(string) -- set default options and return previous value"; + "set_defopt(string) -- set default options and return previous value"; static PyObject * pg_setdefopt(PyObject *self, PyObject *args) @@ -845,7 +881,7 @@ pg_setdefopt(PyObject *self, PyObject *args) old = pg_default_opt; if (tmp) { - pg_default_opt = PyStr_FromString(tmp); + pg_default_opt = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -857,7 +893,7 @@ pg_setdefopt(PyObject *self, PyObject *args) /* Get default username. */ static char pg_get_defuser__doc__[] = -"get_defuser() -- return default database username"; + "get_defuser() -- return default database username"; static PyObject * pg_get_defuser(PyObject *self, PyObject *noargs) @@ -869,7 +905,7 @@ pg_get_defuser(PyObject *self, PyObject *noargs) /* Set default username. */ static char pg_set_defuser__doc__[] = -"set_defuser(name) -- set default username and return previous value"; + "set_defuser(name) -- set default username and return previous value"; static PyObject * pg_set_defuser(PyObject *self, PyObject *args) @@ -889,7 +925,7 @@ pg_set_defuser(PyObject *self, PyObject *args) old = pg_default_user; if (tmp) { - pg_default_user = PyStr_FromString(tmp); + pg_default_user = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -901,7 +937,7 @@ pg_set_defuser(PyObject *self, PyObject *args) /* Set default password. */ static char pg_set_defpasswd__doc__[] = -"set_defpasswd(password) -- set default database password"; + "set_defpasswd(password) -- set default database password"; static PyObject * pg_set_defpasswd(PyObject *self, PyObject *args) @@ -917,7 +953,7 @@ pg_set_defpasswd(PyObject *self, PyObject *args) } if (tmp) { - pg_default_passwd = PyStr_FromString(tmp); + pg_default_passwd = PyUnicode_FromString(tmp); } else { Py_INCREF(Py_None); @@ -930,7 +966,7 @@ pg_set_defpasswd(PyObject *self, PyObject *args) /* Get default port. */ static char pg_get_defport__doc__[] = -"get_defport() -- return default database port"; + "get_defport() -- return default database port"; static PyObject * pg_get_defport(PyObject *self, PyObject *noargs) @@ -941,7 +977,7 @@ pg_get_defport(PyObject *self, PyObject *noargs) /* Set default port. */ static char pg_set_defport__doc__[] = -"set_defport(port) -- set default port and return previous value"; + "set_defport(port) -- set default port and return previous value"; static PyObject * pg_set_defport(PyObject *self, PyObject *args) @@ -961,7 +997,7 @@ pg_set_defport(PyObject *self, PyObject *args) old = pg_default_port; if (port != -1) { - pg_default_port = PyInt_FromLong(port); + pg_default_port = PyLong_FromLong(port); } else { Py_INCREF(Py_None); @@ -970,11 +1006,10 @@ pg_set_defport(PyObject *self, PyObject *args) return old; } -#endif /* DEFAULT_VARS */ /* Cast a string with a text representation of an array to a list. */ static char pg_cast_array__doc__[] = -"cast_array(string, cast=None, delim=',') -- cast a string as an array"; + "cast_array(string, cast=None, delim=',') -- cast a string as an array"; PyObject * pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) @@ -985,10 +1020,8 @@ pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) Py_ssize_t size; int encoding; - if (!PyArg_ParseTupleAndKeywords( - args, dict, "O|Oc", - (char**) kwlist, &string_obj, &cast_obj, &delim)) - { + if (!PyArg_ParseTupleAndKeywords(args, dict, "O|Oc", (char **)kwlist, + &string_obj, &cast_obj, &delim)) { return NULL; } @@ -999,7 +1032,8 @@ pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) } else if (PyUnicode_Check(string_obj)) { string_obj = PyUnicode_AsUTF8String(string_obj); - if (!string_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!string_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(string_obj, &string, &size); encoding = pg_encoding_utf8; } @@ -1010,12 +1044,10 @@ pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) return NULL; } - if (!cast_obj || cast_obj == Py_None) { - if (cast_obj) { - Py_DECREF(cast_obj); cast_obj = NULL; - } + if (cast_obj == Py_None) { + cast_obj = NULL; } - else if (!PyCallable_Check(cast_obj)) { + else if (cast_obj && !PyCallable_Check(cast_obj)) { PyErr_SetString( PyExc_TypeError, "Function cast_array() expects a callable as second argument"); @@ -1031,7 +1063,7 @@ pg_cast_array(PyObject *self, PyObject *args, PyObject *dict) /* Cast a string with a text representation of a record to a tuple. */ static char pg_cast_record__doc__[] = -"cast_record(string, cast=None, delim=',') -- cast a string as a record"; + "cast_record(string, cast=None, delim=',') -- cast a string as a record"; PyObject * pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) @@ -1042,10 +1074,8 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) Py_ssize_t size, len; int encoding; - if (!PyArg_ParseTupleAndKeywords( - args, dict, "O|Oc", - (char**) kwlist, &string_obj, &cast_obj, &delim)) - { + if (!PyArg_ParseTupleAndKeywords(args, dict, "O|Oc", (char **)kwlist, + &string_obj, &cast_obj, &delim)) { return NULL; } @@ -1056,7 +1086,8 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) } else if (PyUnicode_Check(string_obj)) { string_obj = PyUnicode_AsUTF8String(string_obj); - if (!string_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!string_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(string_obj, &string, &size); encoding = pg_encoding_utf8; } @@ -1071,12 +1102,13 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) len = 0; } else if (cast_obj == Py_None) { - Py_DECREF(cast_obj); cast_obj = NULL; len = 0; + cast_obj = NULL; + len = 0; } else if (PyTuple_Check(cast_obj) || PyList_Check(cast_obj)) { len = PySequence_Size(cast_obj); if (!len) { - Py_DECREF(cast_obj); cast_obj = NULL; + cast_obj = NULL; } } else { @@ -1095,7 +1127,7 @@ pg_cast_record(PyObject *self, PyObject *args, PyObject *dict) /* Cast a string with a text representation of an hstore to a dict. */ static char pg_cast_hstore__doc__[] = -"cast_hstore(string) -- cast a string as an hstore"; + "cast_hstore(string) -- cast a string as an hstore"; PyObject * pg_cast_hstore(PyObject *self, PyObject *string) @@ -1111,7 +1143,8 @@ pg_cast_hstore(PyObject *self, PyObject *string) } else if (PyUnicode_Check(string)) { tmp_obj = PyUnicode_AsUTF8String(string); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &s, &size); encoding = pg_encoding_utf8; } @@ -1132,52 +1165,47 @@ pg_cast_hstore(PyObject *self, PyObject *string) /* The list of functions defined in the module */ static struct PyMethodDef pg_methods[] = { - {"connect", (PyCFunction) pg_connect, - METH_VARARGS|METH_KEYWORDS, pg_connect__doc__}, - {"escape_string", (PyCFunction) pg_escape_string, - METH_O, pg_escape_string__doc__}, - {"escape_bytea", (PyCFunction) pg_escape_bytea, - METH_O, pg_escape_bytea__doc__}, - {"unescape_bytea", (PyCFunction) pg_unescape_bytea, - METH_O, pg_unescape_bytea__doc__}, - {"get_datestyle", (PyCFunction) pg_get_datestyle, - METH_NOARGS, pg_get_datestyle__doc__}, - {"set_datestyle", (PyCFunction) pg_set_datestyle, - METH_VARARGS, pg_set_datestyle__doc__}, - {"get_decimal_point", (PyCFunction) pg_get_decimal_point, - METH_NOARGS, pg_get_decimal_point__doc__}, - {"set_decimal_point", (PyCFunction) pg_set_decimal_point, - METH_VARARGS, pg_set_decimal_point__doc__}, - {"get_decimal", (PyCFunction) pg_get_decimal, - METH_NOARGS, pg_get_decimal__doc__}, - {"set_decimal", (PyCFunction) pg_set_decimal, - METH_O, pg_set_decimal__doc__}, - {"get_bool", (PyCFunction) pg_get_bool, - METH_NOARGS, pg_get_bool__doc__}, - {"set_bool", (PyCFunction) pg_set_bool, - METH_VARARGS, pg_set_bool__doc__}, - {"get_array", (PyCFunction) pg_get_array, - METH_NOARGS, pg_get_array__doc__}, - {"set_array", (PyCFunction) pg_set_array, - METH_VARARGS, pg_set_array__doc__}, - {"set_query_helpers", (PyCFunction) pg_set_query_helpers, - METH_VARARGS, pg_set_query_helpers__doc__}, - {"get_bytea_escaped", (PyCFunction) pg_get_bytea_escaped, - METH_NOARGS, pg_get_bytea_escaped__doc__}, - {"set_bytea_escaped", (PyCFunction) pg_set_bytea_escaped, - METH_VARARGS, pg_set_bytea_escaped__doc__}, - {"get_jsondecode", (PyCFunction) pg_get_jsondecode, - METH_NOARGS, pg_get_jsondecode__doc__}, - {"set_jsondecode", (PyCFunction) pg_set_jsondecode, - METH_O, pg_set_jsondecode__doc__}, - {"cast_array", (PyCFunction) pg_cast_array, - METH_VARARGS|METH_KEYWORDS, pg_cast_array__doc__}, - {"cast_record", (PyCFunction) pg_cast_record, - METH_VARARGS|METH_KEYWORDS, pg_cast_record__doc__}, - {"cast_hstore", (PyCFunction) pg_cast_hstore, - METH_O, pg_cast_hstore__doc__}, - -#ifdef DEFAULT_VARS + {"connect", (PyCFunction)pg_connect, METH_VARARGS | METH_KEYWORDS, + pg_connect__doc__}, + {"escape_string", (PyCFunction)pg_escape_string, METH_O, + pg_escape_string__doc__}, + {"escape_bytea", (PyCFunction)pg_escape_bytea, METH_O, + pg_escape_bytea__doc__}, + {"unescape_bytea", (PyCFunction)pg_unescape_bytea, METH_O, + pg_unescape_bytea__doc__}, + {"get_datestyle", (PyCFunction)pg_get_datestyle, METH_NOARGS, + pg_get_datestyle__doc__}, + {"set_datestyle", (PyCFunction)pg_set_datestyle, METH_VARARGS, + pg_set_datestyle__doc__}, + {"get_decimal_point", (PyCFunction)pg_get_decimal_point, METH_NOARGS, + pg_get_decimal_point__doc__}, + {"set_decimal_point", (PyCFunction)pg_set_decimal_point, METH_VARARGS, + pg_set_decimal_point__doc__}, + {"get_decimal", (PyCFunction)pg_get_decimal, METH_NOARGS, + pg_get_decimal__doc__}, + {"set_decimal", (PyCFunction)pg_set_decimal, METH_O, + pg_set_decimal__doc__}, + {"get_bool", (PyCFunction)pg_get_bool, METH_NOARGS, pg_get_bool__doc__}, + {"set_bool", (PyCFunction)pg_set_bool, METH_VARARGS, pg_set_bool__doc__}, + {"get_array", (PyCFunction)pg_get_array, METH_NOARGS, pg_get_array__doc__}, + {"set_array", (PyCFunction)pg_set_array, METH_VARARGS, + pg_set_array__doc__}, + {"set_query_helpers", (PyCFunction)pg_set_query_helpers, METH_VARARGS, + pg_set_query_helpers__doc__}, + {"get_bytea_escaped", (PyCFunction)pg_get_bytea_escaped, METH_NOARGS, + pg_get_bytea_escaped__doc__}, + {"set_bytea_escaped", (PyCFunction)pg_set_bytea_escaped, METH_VARARGS, + pg_set_bytea_escaped__doc__}, + {"get_jsondecode", (PyCFunction)pg_get_jsondecode, METH_NOARGS, + pg_get_jsondecode__doc__}, + {"set_jsondecode", (PyCFunction)pg_set_jsondecode, METH_O, + pg_set_jsondecode__doc__}, + {"cast_array", (PyCFunction)pg_cast_array, METH_VARARGS | METH_KEYWORDS, + pg_cast_array__doc__}, + {"cast_record", (PyCFunction)pg_cast_record, METH_VARARGS | METH_KEYWORDS, + pg_cast_record__doc__}, + {"cast_hstore", (PyCFunction)pg_cast_hstore, METH_O, + pg_cast_hstore__doc__}, {"get_defhost", pg_get_defhost, METH_NOARGS, pg_get_defhost__doc__}, {"set_defhost", pg_set_defhost, METH_VARARGS, pg_set_defhost__doc__}, {"get_defbase", pg_get_defbase, METH_NOARGS, pg_get_defbase__doc__}, @@ -1189,22 +1217,26 @@ static struct PyMethodDef pg_methods[] = { {"get_defuser", pg_get_defuser, METH_NOARGS, pg_get_defuser__doc__}, {"set_defuser", pg_set_defuser, METH_VARARGS, pg_set_defuser__doc__}, {"set_defpasswd", pg_set_defpasswd, METH_VARARGS, pg_set_defpasswd__doc__}, -#endif /* DEFAULT_VARS */ + {"get_pqlib_version", (PyCFunction)pg_get_pqlib_version, METH_NOARGS, + pg_get_pqlib_version__doc__}, {NULL, NULL} /* sentinel */ }; static char pg__doc__[] = "Python interface to PostgreSQL DB"; static struct PyModuleDef moduleDef = { - PyModuleDef_HEAD_INIT, - "_pg", /* m_name */ - pg__doc__, /* m_doc */ - -1, /* m_size */ - pg_methods /* m_methods */ + PyModuleDef_HEAD_INIT, "_pg", /* m_name */ + pg__doc__, /* m_doc */ + -1, /* m_size */ + pg_methods /* m_methods */ }; /* Initialization function for the module */ -MODULE_INIT_FUNC(_pg) +PyMODINIT_FUNC +PyInit__pg(void); + +PyMODINIT_FUNC +PyInit__pg(void) { PyObject *mod, *dict, *s; @@ -1213,29 +1245,13 @@ MODULE_INIT_FUNC(_pg) mod = PyModule_Create(&moduleDef); /* Initialize here because some Windows platforms get confused otherwise */ -#if IS_PY3 - connType.tp_base = noticeType.tp_base = - queryType.tp_base = sourceType.tp_base = &PyBaseObject_Type; -#ifdef LARGE_OBJECTS + connType.tp_base = noticeType.tp_base = queryType.tp_base = + sourceType.tp_base = &PyBaseObject_Type; largeType.tp_base = &PyBaseObject_Type; -#endif -#else - connType.ob_type = noticeType.ob_type = - queryType.ob_type = sourceType.ob_type = &PyType_Type; -#ifdef LARGE_OBJECTS - largeType.ob_type = &PyType_Type; -#endif -#endif - if (PyType_Ready(&connType) - || PyType_Ready(¬iceType) - || PyType_Ready(&queryType) - || PyType_Ready(&sourceType) -#ifdef LARGE_OBJECTS - || PyType_Ready(&largeType) -#endif - ) - { + if (PyType_Ready(&connType) || PyType_Ready(¬iceType) || + PyType_Ready(&queryType) || PyType_Ready(&sourceType) || + PyType_Ready(&largeType)) { return NULL; } @@ -1248,81 +1264,97 @@ MODULE_INIT_FUNC(_pg) Warning = PyErr_NewException("pg.Warning", PyExc_Exception, NULL); PyDict_SetItemString(dict, "Warning", Warning); - InterfaceError = PyErr_NewException( - "pg.InterfaceError", Error, NULL); + InterfaceError = PyErr_NewException("pg.InterfaceError", Error, NULL); PyDict_SetItemString(dict, "InterfaceError", InterfaceError); - DatabaseError = PyErr_NewException( - "pg.DatabaseError", Error, NULL); + DatabaseError = PyErr_NewException("pg.DatabaseError", Error, NULL); PyDict_SetItemString(dict, "DatabaseError", DatabaseError); - InternalError = PyErr_NewException( - "pg.InternalError", DatabaseError, NULL); + InternalError = + PyErr_NewException("pg.InternalError", DatabaseError, NULL); PyDict_SetItemString(dict, "InternalError", InternalError); - OperationalError = PyErr_NewException( - "pg.OperationalError", DatabaseError, NULL); + OperationalError = + PyErr_NewException("pg.OperationalError", DatabaseError, NULL); PyDict_SetItemString(dict, "OperationalError", OperationalError); - ProgrammingError = PyErr_NewException( - "pg.ProgrammingError", DatabaseError, NULL); + ProgrammingError = + PyErr_NewException("pg.ProgrammingError", DatabaseError, NULL); PyDict_SetItemString(dict, "ProgrammingError", ProgrammingError); - IntegrityError = PyErr_NewException( - "pg.IntegrityError", DatabaseError, NULL); + IntegrityError = + PyErr_NewException("pg.IntegrityError", DatabaseError, NULL); PyDict_SetItemString(dict, "IntegrityError", IntegrityError); - DataError = PyErr_NewException( - "pg.DataError", DatabaseError, NULL); + DataError = PyErr_NewException("pg.DataError", DatabaseError, NULL); PyDict_SetItemString(dict, "DataError", DataError); - NotSupportedError = PyErr_NewException( - "pg.NotSupportedError", DatabaseError, NULL); + NotSupportedError = + PyErr_NewException("pg.NotSupportedError", DatabaseError, NULL); PyDict_SetItemString(dict, "NotSupportedError", NotSupportedError); - InvalidResultError = PyErr_NewException( - "pg.InvalidResultError", DataError, NULL); + InvalidResultError = + PyErr_NewException("pg.InvalidResultError", DataError, NULL); PyDict_SetItemString(dict, "InvalidResultError", InvalidResultError); - NoResultError = PyErr_NewException( - "pg.NoResultError", InvalidResultError, NULL); + NoResultError = + PyErr_NewException("pg.NoResultError", InvalidResultError, NULL); PyDict_SetItemString(dict, "NoResultError", NoResultError); - MultipleResultsError = PyErr_NewException( - "pg.MultipleResultsError", InvalidResultError, NULL); + MultipleResultsError = PyErr_NewException("pg.MultipleResultsError", + InvalidResultError, NULL); PyDict_SetItemString(dict, "MultipleResultsError", MultipleResultsError); + /* Types */ + Connection = (PyObject *)&connType; + PyDict_SetItemString(dict, "Connection", Connection); + Query = (PyObject *)&queryType; + PyDict_SetItemString(dict, "Query", Query); + LargeObject = (PyObject *)&largeType; + PyDict_SetItemString(dict, "LargeObject", LargeObject); + /* Make the version available */ - s = PyStr_FromString(PyPgVersion); + s = PyUnicode_FromString(PyPgVersion); PyDict_SetItemString(dict, "version", s); PyDict_SetItemString(dict, "__version__", s); Py_DECREF(s); /* Result types for queries */ - PyDict_SetItemString(dict, "RESULT_EMPTY", PyInt_FromLong(RESULT_EMPTY)); - PyDict_SetItemString(dict, "RESULT_DML", PyInt_FromLong(RESULT_DML)); - PyDict_SetItemString(dict, "RESULT_DDL", PyInt_FromLong(RESULT_DDL)); - PyDict_SetItemString(dict, "RESULT_DQL", PyInt_FromLong(RESULT_DQL)); + PyDict_SetItemString(dict, "RESULT_EMPTY", PyLong_FromLong(RESULT_EMPTY)); + PyDict_SetItemString(dict, "RESULT_DML", PyLong_FromLong(RESULT_DML)); + PyDict_SetItemString(dict, "RESULT_DDL", PyLong_FromLong(RESULT_DDL)); + PyDict_SetItemString(dict, "RESULT_DQL", PyLong_FromLong(RESULT_DQL)); /* Transaction states */ - PyDict_SetItemString(dict,"TRANS_IDLE",PyInt_FromLong(PQTRANS_IDLE)); - PyDict_SetItemString(dict,"TRANS_ACTIVE",PyInt_FromLong(PQTRANS_ACTIVE)); - PyDict_SetItemString(dict,"TRANS_INTRANS",PyInt_FromLong(PQTRANS_INTRANS)); - PyDict_SetItemString(dict,"TRANS_INERROR",PyInt_FromLong(PQTRANS_INERROR)); - PyDict_SetItemString(dict,"TRANS_UNKNOWN",PyInt_FromLong(PQTRANS_UNKNOWN)); + PyDict_SetItemString(dict, "TRANS_IDLE", PyLong_FromLong(PQTRANS_IDLE)); + PyDict_SetItemString(dict, "TRANS_ACTIVE", + PyLong_FromLong(PQTRANS_ACTIVE)); + PyDict_SetItemString(dict, "TRANS_INTRANS", + PyLong_FromLong(PQTRANS_INTRANS)); + PyDict_SetItemString(dict, "TRANS_INERROR", + PyLong_FromLong(PQTRANS_INERROR)); + PyDict_SetItemString(dict, "TRANS_UNKNOWN", + PyLong_FromLong(PQTRANS_UNKNOWN)); + + /* Polling results */ + PyDict_SetItemString(dict, "POLLING_OK", + PyLong_FromLong(PGRES_POLLING_OK)); + PyDict_SetItemString(dict, "POLLING_FAILED", + PyLong_FromLong(PGRES_POLLING_FAILED)); + PyDict_SetItemString(dict, "POLLING_READING", + PyLong_FromLong(PGRES_POLLING_READING)); + PyDict_SetItemString(dict, "POLLING_WRITING", + PyLong_FromLong(PGRES_POLLING_WRITING)); -#ifdef LARGE_OBJECTS /* Create mode for large objects */ - PyDict_SetItemString(dict, "INV_READ", PyInt_FromLong(INV_READ)); - PyDict_SetItemString(dict, "INV_WRITE", PyInt_FromLong(INV_WRITE)); + PyDict_SetItemString(dict, "INV_READ", PyLong_FromLong(INV_READ)); + PyDict_SetItemString(dict, "INV_WRITE", PyLong_FromLong(INV_WRITE)); /* Position flags for lo_lseek */ - PyDict_SetItemString(dict, "SEEK_SET", PyInt_FromLong(SEEK_SET)); - PyDict_SetItemString(dict, "SEEK_CUR", PyInt_FromLong(SEEK_CUR)); - PyDict_SetItemString(dict, "SEEK_END", PyInt_FromLong(SEEK_END)); -#endif /* LARGE_OBJECTS */ + PyDict_SetItemString(dict, "SEEK_SET", PyLong_FromLong(SEEK_SET)); + PyDict_SetItemString(dict, "SEEK_CUR", PyLong_FromLong(SEEK_CUR)); + PyDict_SetItemString(dict, "SEEK_END", PyLong_FromLong(SEEK_END)); -#ifdef DEFAULT_VARS /* Prepare default values */ Py_INCREF(Py_None); pg_default_host = Py_None; @@ -1336,7 +1368,6 @@ MODULE_INIT_FUNC(_pg) pg_default_user = Py_None; Py_INCREF(Py_None); pg_default_passwd = Py_None; -#endif /* DEFAULT_VARS */ /* Store common pg encoding ids */ diff --git a/ext/pgnotice.c b/ext/pgnotice.c new file mode 100644 index 00000000..c56b249f --- /dev/null +++ b/ext/pgnotice.c @@ -0,0 +1,121 @@ +/* + * PyGreSQL - a Python interface for the PostgreSQL database. + * + * The notice object - this file is part a of the C extension module. + * + * Copyright (c) 2025 by the PyGreSQL Development Team + * + * Please see the LICENSE.TXT file for specific restrictions. + */ + +/* Get notice object attributes. */ +static PyObject * +notice_getattr(noticeObject *self, PyObject *nameobj) +{ + PGresult const *res = self->res; + const char *name = PyUnicode_AsUTF8(nameobj); + int fieldcode; + + if (!res) { + PyErr_SetString(PyExc_TypeError, "Cannot get current notice"); + return NULL; + } + + /* pg connection object */ + if (!strcmp(name, "pgcnx")) { + if (self->pgcnx && _check_cnx_obj(self->pgcnx)) { + Py_INCREF(self->pgcnx); + return (PyObject *)self->pgcnx; + } + else { + Py_INCREF(Py_None); + return Py_None; + } + } + + /* full message */ + if (!strcmp(name, "message")) { + return PyUnicode_FromString(PQresultErrorMessage(res)); + } + + /* other possible fields */ + fieldcode = 0; + if (!strcmp(name, "severity")) + fieldcode = PG_DIAG_SEVERITY; + else if (!strcmp(name, "primary")) + fieldcode = PG_DIAG_MESSAGE_PRIMARY; + else if (!strcmp(name, "detail")) + fieldcode = PG_DIAG_MESSAGE_DETAIL; + else if (!strcmp(name, "hint")) + fieldcode = PG_DIAG_MESSAGE_HINT; + if (fieldcode) { + char *s = PQresultErrorField(res, fieldcode); + if (s) { + return PyUnicode_FromString(s); + } + else { + Py_INCREF(Py_None); + return Py_None; + } + } + + return PyObject_GenericGetAttr((PyObject *)self, nameobj); +} + +/* Get the list of notice attributes. */ +static PyObject * +notice_dir(noticeObject *self, PyObject *noargs) +{ + PyObject *attrs; + + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[ssssss]", "pgcnx", "severity", + "message", "primary", "detail", "hint"); + + return attrs; +} + +/* Return notice as string in human readable form. */ +static PyObject * +notice_str(noticeObject *self) +{ + return notice_getattr(self, PyBytes_FromString("message")); +} + +/* Notice object methods */ +static struct PyMethodDef notice_methods[] = { + {"__dir__", (PyCFunction)notice_dir, METH_NOARGS, NULL}, {NULL, NULL}}; + +static char notice__doc__[] = "PostgreSQL notice object"; + +/* Notice type definition */ +static PyTypeObject noticeType = { + PyVarObject_HEAD_INIT(NULL, 0) "pg.Notice", /* tp_name */ + sizeof(noticeObject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + 0, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)notice_str, /* tp_str */ + (getattrofunc)notice_getattr, /* tp_getattro */ + PyObject_GenericSetAttr, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + notice__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + notice_methods, /* tp_methods */ +}; diff --git a/ext/pgquery.c b/ext/pgquery.c new file mode 100644 index 00000000..b87eba18 --- /dev/null +++ b/ext/pgquery.c @@ -0,0 +1,1004 @@ +/* + * PyGreSQL - a Python interface for the PostgreSQL database. + * + * The query object - this file is part a of the C extension module. + * + * Copyright (c) 2025 by the PyGreSQL Development Team + * + * Please see the LICENSE.TXT file for specific restrictions. + */ + +/* Deallocate the query object. */ +static void +query_dealloc(queryObject *self) +{ + Py_XDECREF(self->pgcnx); + if (self->col_types) { + PyMem_Free(self->col_types); + } + if (self->result) { + PQclear(self->result); + } + + PyObject_Del(self); +} + +/* Return query as string in human readable form. */ +static PyObject * +query_str(queryObject *self) +{ + return format_result(self->result); +} + +/* Return length of a query object. */ +static Py_ssize_t +query_len(PyObject *self) +{ + PyObject *tmp; + Py_ssize_t len; + + tmp = PyLong_FromLong(((queryObject *)self)->max_row); + len = PyLong_AsSsize_t(tmp); + Py_DECREF(tmp); + return len; +} + +/* Return the value in the given column of the current row. */ +static PyObject * +_query_value_in_column(queryObject *self, int column) +{ + char *s; + int type; + + if (PQgetisnull(self->result, self->current_row, column)) { + Py_INCREF(Py_None); + return Py_None; + } + + /* get the string representation of the value */ + /* note: this is always null-terminated text format */ + s = PQgetvalue(self->result, self->current_row, column); + /* get the PyGreSQL type of the column */ + type = self->col_types[column]; + + /* cast the string representation into a Python object */ + if (type & PYGRES_ARRAY) + return cast_array(s, + PQgetlength(self->result, self->current_row, column), + self->encoding, type, NULL, 0); + if (type == PYGRES_BYTEA) + return cast_bytea_text(s); + if (type == PYGRES_OTHER) + return cast_other(s, + PQgetlength(self->result, self->current_row, column), + self->encoding, PQftype(self->result, column), + self->pgcnx->cast_hook); + if (type & PYGRES_TEXT) + return cast_sized_text( + s, PQgetlength(self->result, self->current_row, column), + self->encoding, type); + return cast_unsized_simple(s, type); +} + +/* Return the current row as a tuple. */ +static PyObject * +_query_row_as_tuple(queryObject *self) +{ + PyObject *row_tuple = NULL; + int j; + + if (!(row_tuple = PyTuple_New(self->num_fields))) { + return NULL; + } + + for (j = 0; j < self->num_fields; ++j) { + PyObject *val = _query_value_in_column(self, j); + if (!val) { + Py_DECREF(row_tuple); + return NULL; + } + PyTuple_SET_ITEM(row_tuple, j, val); + } + + return row_tuple; +} + +/* Fetch the result if this is an asynchronous query and it has not yet + been fetched in this round-trip. Also mark whether the result should + be kept for this round-trip (e.g. to be used in an iterator). + If this is a normal query result, the query itself will be returned, + otherwise a result value will be returned that shall be passed on. */ +static PyObject * +_get_async_result(queryObject *self, int keep) +{ + int fetch = 0; + + if (self->async) { + if (self->async == 1) { + fetch = 1; + if (keep) { + /* mark query as fetched, do not fetch again */ + self->async = 2; + } + } + else if (!keep) { + self->async = 1; + } + } + + if (fetch) { + int status; + + if (!self->pgcnx) { + PyErr_SetString(PyExc_TypeError, "Connection is not valid"); + return NULL; + } + + Py_BEGIN_ALLOW_THREADS + if (self->result) { + PQclear(self->result); + } + self->result = PQgetResult(self->pgcnx->cnx); + Py_END_ALLOW_THREADS + if (!self->result) { + /* end of result set, return None */ + self->max_row = 0; + self->num_fields = 0; + self->col_types = NULL; + Py_INCREF(Py_None); + return Py_None; + } + + if ((status = PQresultStatus(self->result)) != PGRES_TUPLES_OK) { + PyObject *result = + _conn_non_query_result(status, self->result, self->pgcnx->cnx); + self->result = NULL; /* since this has been already cleared */ + if (!result) { + /* Raise an error. We need to call PQgetResult() to clear the + connection state. This should return NULL the first time. */ + self->result = PQgetResult(self->pgcnx->cnx); + while (self->result) { + PQclear(self->result); + self->result = PQgetResult(self->pgcnx->cnx); + Py_DECREF(self->pgcnx); + self->pgcnx = NULL; + } + } + else if (result == Py_None) { + /* It would be confusing to return None here because the + caller has to call again until we return None. We can't + just consume that final None because we don't know if there + are additional statements following this one, so we return + an empty string where query() would return None. */ + Py_DECREF(result); + result = PyUnicode_FromString(""); + } + return result; + } + + self->max_row = PQntuples(self->result); + self->num_fields = PQnfields(self->result); + self->col_types = get_col_types(self->result, self->num_fields); + if (!self->col_types) { + Py_DECREF(self); + Py_DECREF(self); + return NULL; + } + } + else if (self->async == 2 && !self->max_row && !self->num_fields && + !self->col_types) { + Py_INCREF(Py_None); + return Py_None; + } + + /* return the query object itself as sentinel for a normal query result */ + return (PyObject *)self; +} + +/* Return given item from a query object. */ +static PyObject * +query_getitem(PyObject *self, Py_ssize_t i) +{ + queryObject *q = (queryObject *)self; + PyObject *tmp; + long row; + + if ((tmp = _get_async_result(q, 0)) != (PyObject *)self) + return tmp; + + tmp = PyLong_FromSize_t((size_t)i); + row = PyLong_AsLong(tmp); + Py_DECREF(tmp); + + if (row < 0 || row >= q->max_row) { + PyErr_SetNone(PyExc_IndexError); + return NULL; + } + + q->current_row = (int)row; + return _query_row_as_tuple(q); +} + +/* __iter__() method of the queryObject: + Returns the default iterator yielding rows as tuples. */ +static PyObject * +query_iter(queryObject *self) +{ + PyObject *res; + + if ((res = _get_async_result(self, 0)) != (PyObject *)self) + return res; + + self->current_row = 0; + Py_INCREF(self); + return (PyObject *)self; +} + +/* __next__() method of the queryObject: + Returns the current row as a tuple and moves to the next one. */ +static PyObject * +query_next(queryObject *self, PyObject *noargs) +{ + PyObject *row_tuple = NULL; + + if (self->current_row >= self->max_row) { + PyErr_SetNone(PyExc_StopIteration); + return NULL; + } + + row_tuple = _query_row_as_tuple(self); + if (row_tuple) + ++self->current_row; + return row_tuple; +} + +/* Get number of bytes allocated for PGresult object */ +static char query_memsize__doc__[] = + "memsize() -- return number of bytes allocated by query result"; +static PyObject * +query_memsize(queryObject *self, PyObject *noargs) +{ +#ifdef MEMORY_SIZE + return PyLong_FromSize_t(PQresultMemorySize(self->result)); +#else + set_error_msg(NotSupportedError, "Memory size functions not supported"); + return NULL; +#endif /* MEMORY_SIZE */ +} + +/* List field names from query result. */ +static char query_listfields__doc__[] = + "listfields() -- List field names from result"; + +static PyObject * +query_listfields(queryObject *self, PyObject *noargs) +{ + int i; + char *name; + PyObject *fieldstuple, *str; + + /* builds tuple */ + fieldstuple = PyTuple_New(self->num_fields); + if (fieldstuple) { + for (i = 0; i < self->num_fields; ++i) { + name = PQfname(self->result, i); + str = PyUnicode_FromString(name); + PyTuple_SET_ITEM(fieldstuple, i, str); + } + } + return fieldstuple; +} + +/* Get field name from number in last result. */ +static char query_fieldname__doc__[] = + "fieldname(num) -- return name of field from result from its position"; + +static PyObject * +query_fieldname(queryObject *self, PyObject *args) +{ + int i; + char *name; + + /* gets args */ + if (!PyArg_ParseTuple(args, "i", &i)) { + PyErr_SetString(PyExc_TypeError, + "Method fieldname() takes an integer as argument"); + return NULL; + } + + /* checks number validity */ + if (i >= self->num_fields) { + PyErr_SetString(PyExc_ValueError, "Invalid field number"); + return NULL; + } + + /* gets fields name and builds object */ + name = PQfname(self->result, i); + return PyUnicode_FromString(name); +} + +/* Get field number from name in last result. */ +static char query_fieldnum__doc__[] = + "fieldnum(name) -- return position in query for field from its name"; + +static PyObject * +query_fieldnum(queryObject *self, PyObject *args) +{ + int num; + char *name; + + /* gets args */ + if (!PyArg_ParseTuple(args, "s", &name)) { + PyErr_SetString(PyExc_TypeError, + "Method fieldnum() takes a string as argument"); + return NULL; + } + + /* gets field number */ + if ((num = PQfnumber(self->result, name)) == -1) { + PyErr_SetString(PyExc_ValueError, "Unknown field"); + return NULL; + } + + return PyLong_FromLong(num); +} + +/* Build a tuple with info for query field with given number. */ +static PyObject * +_query_build_field_info(PGresult *res, int col_num) +{ + PyObject *info; + + info = PyTuple_New(4); + if (info) { + PyTuple_SET_ITEM(info, 0, PyUnicode_FromString(PQfname(res, col_num))); + PyTuple_SET_ITEM(info, 1, + PyLong_FromLong((long)PQftype(res, col_num))); + PyTuple_SET_ITEM(info, 2, PyLong_FromLong(PQfsize(res, col_num))); + PyTuple_SET_ITEM(info, 3, PyLong_FromLong(PQfmod(res, col_num))); + } + return info; +} + +/* Get information on one or all fields of the query result. */ +static char query_fieldinfo__doc__[] = + "fieldinfo([name]) -- return information about field(s) in query result"; + +static PyObject * +query_fieldinfo(queryObject *self, PyObject *args) +{ + PyObject *result, *field = NULL; + int num; + + /* gets args */ + if (!PyArg_ParseTuple(args, "|O", &field)) { + PyErr_SetString(PyExc_TypeError, + "Method fieldinfo() takes one optional argument only"); + return NULL; + } + + /* check optional field arg */ + if (field) { + /* gets field number */ + if (PyBytes_Check(field)) { + num = PQfnumber(self->result, PyBytes_AsString(field)); + } + else if (PyUnicode_Check(field)) { + PyObject *tmp = get_encoded_string(field, self->encoding); + if (!tmp) + return NULL; + num = PQfnumber(self->result, PyBytes_AsString(tmp)); + Py_DECREF(tmp); + } + else if (PyLong_Check(field)) { + num = (int)PyLong_AsLong(field); + } + else { + PyErr_SetString(PyExc_TypeError, + "Field should be given as column number or name"); + return NULL; + } + if (num < 0 || num >= self->num_fields) { + PyErr_SetString(PyExc_IndexError, "Unknown field"); + return NULL; + } + return _query_build_field_info(self->result, num); + } + + if (!(result = PyTuple_New(self->num_fields))) { + return NULL; + } + for (num = 0; num < self->num_fields; ++num) { + PyObject *info = _query_build_field_info(self->result, num); + if (!info) { + Py_DECREF(result); + return NULL; + } + PyTuple_SET_ITEM(result, num, info); + } + return result; +} + +/* Retrieve one row from the result as a tuple. */ +static char query_one__doc__[] = + "one() -- Get one row from the result of a query\n\n" + "Only one row from the result is returned as a tuple of fields.\n" + "This method can be called multiple times to return more rows.\n" + "It returns None if the result does not contain one more row.\n"; + +static PyObject * +query_one(queryObject *self, PyObject *noargs) +{ + PyObject *row_tuple; + + if ((row_tuple = _get_async_result(self, 0)) == (PyObject *)self) { + if (self->current_row >= self->max_row) { + Py_INCREF(Py_None); + return Py_None; + } + + row_tuple = _query_row_as_tuple(self); + if (row_tuple) + ++self->current_row; + } + + return row_tuple; +} + +/* Retrieve the single row from the result as a tuple. */ +static char query_single__doc__[] = + "single() -- Get the result of a query as single row\n\n" + "The single row from the query result is returned as a tuple of fields.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; + +static PyObject * +query_single(queryObject *self, PyObject *noargs) +{ + PyObject *row_tuple; + + if ((row_tuple = _get_async_result(self, 0)) == (PyObject *)self) { + if (self->max_row != 1) { + if (self->max_row) + set_error_msg(MultipleResultsError, "Multiple results found"); + else + set_error_msg(NoResultError, "No result found"); + return NULL; + } + + self->current_row = 0; + row_tuple = _query_row_as_tuple(self); + if (row_tuple) + ++self->current_row; + } + + return row_tuple; +} + +/* Retrieve the last query result as a list of tuples. */ +static char query_getresult__doc__[] = + "getresult() -- Get the result of a query\n\n" + "The result is returned as a list of rows, each one a tuple of fields\n" + "in the order returned by the server.\n"; + +static PyObject * +query_getresult(queryObject *self, PyObject *noargs) +{ + PyObject *result_list; + int i; + + if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { + if (!(result_list = PyList_New(self->max_row))) { + return NULL; + } + + for (i = self->current_row = 0; i < self->max_row; ++i) { + PyObject *row_tuple = query_next(self, noargs); + + if (!row_tuple) { + Py_DECREF(result_list); + return NULL; + } + PyList_SET_ITEM(result_list, i, row_tuple); + } + } + + return result_list; +} + +/* Return the current row as a dict. */ +static PyObject * +_query_row_as_dict(queryObject *self) +{ + PyObject *row_dict = NULL; + int j; + + if (!(row_dict = PyDict_New())) { + return NULL; + } + + for (j = 0; j < self->num_fields; ++j) { + PyObject *val = _query_value_in_column(self, j); + + if (!val) { + Py_DECREF(row_dict); + return NULL; + } + PyDict_SetItemString(row_dict, PQfname(self->result, j), val); + Py_DECREF(val); + } + + return row_dict; +} + +/* Return the current row as a dict and move to the next one. */ +static PyObject * +query_next_dict(queryObject *self, PyObject *noargs) +{ + PyObject *row_dict = NULL; + + if (self->current_row >= self->max_row) { + PyErr_SetNone(PyExc_StopIteration); + return NULL; + } + + row_dict = _query_row_as_dict(self); + if (row_dict) + ++self->current_row; + return row_dict; +} + +/* Retrieve one row from the result as a dictionary. */ +static char query_onedict__doc__[] = + "onedict() -- Get one row from the result of a query\n\n" + "Only one row from the result is returned as a dictionary with\n" + "the field names used as the keys.\n" + "This method can be called multiple times to return more rows.\n" + "It returns None if the result does not contain one more row.\n"; + +static PyObject * +query_onedict(queryObject *self, PyObject *noargs) +{ + PyObject *row_dict; + + if ((row_dict = _get_async_result(self, 0)) == (PyObject *)self) { + if (self->current_row >= self->max_row) { + Py_INCREF(Py_None); + return Py_None; + } + + row_dict = _query_row_as_dict(self); + if (row_dict) + ++self->current_row; + } + + return row_dict; +} + +/* Retrieve the single row from the result as a dictionary. */ +static char query_singledict__doc__[] = + "singledict() -- Get the result of a query as single row\n\n" + "The single row from the query result is returned as a dictionary with\n" + "the field names used as the keys.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; + +static PyObject * +query_singledict(queryObject *self, PyObject *noargs) +{ + PyObject *row_dict; + + if ((row_dict = _get_async_result(self, 0)) == (PyObject *)self) { + if (self->max_row != 1) { + if (self->max_row) + set_error_msg(MultipleResultsError, "Multiple results found"); + else + set_error_msg(NoResultError, "No result found"); + return NULL; + } + + self->current_row = 0; + row_dict = _query_row_as_dict(self); + if (row_dict) + ++self->current_row; + } + + return row_dict; +} + +/* Retrieve the last query result as a list of dictionaries. */ +static char query_dictresult__doc__[] = + "dictresult() -- Get the result of a query\n\n" + "The result is returned as a list of rows, each one a dictionary with\n" + "the field names used as the keys.\n"; + +static PyObject * +query_dictresult(queryObject *self, PyObject *noargs) +{ + PyObject *result_list; + int i; + + if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { + if (!(result_list = PyList_New(self->max_row))) { + return NULL; + } + + for (i = self->current_row = 0; i < self->max_row; ++i) { + PyObject *row_dict = query_next_dict(self, noargs); + + if (!row_dict) { + Py_DECREF(result_list); + return NULL; + } + PyList_SET_ITEM(result_list, i, row_dict); + } + } + + return result_list; +} + +/* Retrieve last result as iterator of dictionaries. */ +static char query_dictiter__doc__[] = + "dictiter() -- Get the result of a query\n\n" + "The result is returned as an iterator of rows, each one a a dictionary\n" + "with the field names used as the keys.\n"; + +static PyObject * +query_dictiter(queryObject *self, PyObject *noargs) +{ + PyObject *res; + + if (!dictiter) { + return query_dictresult(self, noargs); + } + + if ((res = _get_async_result(self, 1)) != (PyObject *)self) + return res; + + return PyObject_CallFunction(dictiter, "(O)", self); +} + +/* Retrieve one row from the result as a named tuple. */ +static char query_onenamed__doc__[] = + "onenamed() -- Get one row from the result of a query\n\n" + "Only one row from the result is returned as a named tuple of fields.\n" + "This method can be called multiple times to return more rows.\n" + "It returns None if the result does not contain one more row.\n"; + +static PyObject * +query_onenamed(queryObject *self, PyObject *noargs) +{ + PyObject *res; + + if (!namednext) { + return query_one(self, noargs); + } + + if ((res = _get_async_result(self, 1)) != (PyObject *)self) + return res; + + if (self->current_row >= self->max_row) { + Py_INCREF(Py_None); + return Py_None; + } + + return PyObject_CallFunction(namednext, "(O)", self); +} + +/* Retrieve the single row from the result as a tuple. */ +static char query_singlenamed__doc__[] = + "singlenamed() -- Get the result of a query as single row\n\n" + "The single row from the query result is returned as named tuple of " + "fields.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; + +static PyObject * +query_singlenamed(queryObject *self, PyObject *noargs) +{ + PyObject *res; + + if (!namednext) { + return query_single(self, noargs); + } + + if ((res = _get_async_result(self, 1)) != (PyObject *)self) + return res; + + if (self->max_row != 1) { + if (self->max_row) + set_error_msg(MultipleResultsError, "Multiple results found"); + else + set_error_msg(NoResultError, "No result found"); + return NULL; + } + + self->current_row = 0; + return PyObject_CallFunction(namednext, "(O)", self); +} + +/* Retrieve last result as list of named tuples. */ +static char query_namedresult__doc__[] = + "namedresult() -- Get the result of a query\n\n" + "The result is returned as a list of rows, each one a named tuple of " + "fields\n" + "in the order returned by the server.\n"; + +static PyObject * +query_namedresult(queryObject *self, PyObject *noargs) +{ + PyObject *res, *res_list; + + if (!namediter) { + return query_getresult(self, noargs); + } + + if ((res_list = _get_async_result(self, 1)) == (PyObject *)self) { + res = PyObject_CallFunction(namediter, "(O)", self); + if (!res) + return NULL; + if (PyList_Check(res)) + return res; + res_list = PySequence_List(res); + Py_DECREF(res); + } + + return res_list; +} + +/* Retrieve last result as iterator of named tuples. */ +static char query_namediter__doc__[] = + "namediter() -- Get the result of a query\n\n" + "The result is returned as an iterator of rows, each one a named tuple\n" + "of fields in the order returned by the server.\n"; + +static PyObject * +query_namediter(queryObject *self, PyObject *noargs) +{ + PyObject *res, *res_iter; + + if (!namediter) { + return query_iter(self); + } + + if ((res_iter = _get_async_result(self, 1)) == (PyObject *)self) { + res = PyObject_CallFunction(namediter, "(O)", self); + if (!res) + return NULL; + if (!PyList_Check(res)) + return res; + res_iter = (Py_TYPE(res)->tp_iter)((PyObject *)self); + Py_DECREF(res); + } + + return res_iter; +} + +/* Retrieve the last query result as a list of scalar values. */ +static char query_scalarresult__doc__[] = + "scalarresult() -- Get query result as scalars\n\n" + "The result is returned as a list of scalar values where the values\n" + "are the first fields of the rows in the order returned by the server.\n"; + +static PyObject * +query_scalarresult(queryObject *self, PyObject *noargs) +{ + PyObject *result_list; + + if ((result_list = _get_async_result(self, 0)) == (PyObject *)self) { + if (!self->num_fields) { + set_error_msg(ProgrammingError, "No fields in result"); + return NULL; + } + + if (!(result_list = PyList_New(self->max_row))) { + return NULL; + } + + for (self->current_row = 0; self->current_row < self->max_row; + ++self->current_row) { + PyObject *value = _query_value_in_column(self, 0); + + if (!value) { + Py_DECREF(result_list); + return NULL; + } + PyList_SET_ITEM(result_list, self->current_row, value); + } + } + + return result_list; +} + +/* Retrieve the last query result as iterator of scalar values. */ +static char query_scalariter__doc__[] = + "scalariter() -- Get query result as scalars\n\n" + "The result is returned as an iterator of scalar values where the values\n" + "are the first fields of the rows in the order returned by the server.\n"; + +static PyObject * +query_scalariter(queryObject *self, PyObject *noargs) +{ + PyObject *res; + + if (!scalariter) { + return query_scalarresult(self, noargs); + } + + if ((res = _get_async_result(self, 1)) != (PyObject *)self) + return res; + + if (!self->num_fields) { + set_error_msg(ProgrammingError, "No fields in result"); + return NULL; + } + + return PyObject_CallFunction(scalariter, "(O)", self); +} + +/* Retrieve one result as scalar value. */ +static char query_onescalar__doc__[] = + "onescalar() -- Get one scalar value from the result of a query\n\n" + "Returns the first field of the next row from the result as a scalar " + "value.\n" + "This method can be called multiple times to return more rows as " + "scalars.\n" + "It returns None if the result does not contain one more row.\n"; + +static PyObject * +query_onescalar(queryObject *self, PyObject *noargs) +{ + PyObject *value; + + if ((value = _get_async_result(self, 0)) == (PyObject *)self) { + if (!self->num_fields) { + set_error_msg(ProgrammingError, "No fields in result"); + return NULL; + } + + if (self->current_row >= self->max_row) { + Py_INCREF(Py_None); + return Py_None; + } + + value = _query_value_in_column(self, 0); + if (value) + ++self->current_row; + } + + return value; +} + +/* Retrieves the single row from the result as a tuple. */ +static char query_singlescalar__doc__[] = + "singlescalar() -- Get scalar value from single result of a query\n\n" + "Returns the first field of the next row from the result as a scalar " + "value.\n" + "This method returns the same single row when called multiple times.\n" + "It raises an InvalidResultError if the result doesn't have exactly one " + "row,\n" + "which will be of type NoResultError or MultipleResultsError " + "specifically.\n"; + +static PyObject * +query_singlescalar(queryObject *self, PyObject *noargs) +{ + PyObject *value; + + if ((value = _get_async_result(self, 0)) == (PyObject *)self) { + if (!self->num_fields) { + set_error_msg(ProgrammingError, "No fields in result"); + return NULL; + } + + if (self->max_row != 1) { + if (self->max_row) + set_error_msg(MultipleResultsError, "Multiple results found"); + else + set_error_msg(NoResultError, "No result found"); + return NULL; + } + + self->current_row = 0; + value = _query_value_in_column(self, 0); + if (value) + ++self->current_row; + } + + return value; +} + +/* Query sequence protocol methods */ +static PySequenceMethods query_sequence_methods = { + (lenfunc)query_len, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + (ssizeargfunc)query_getitem, /* sq_item */ + 0, /* sq_ass_item */ + 0, /* sq_contains */ + 0, /* sq_inplace_concat */ + 0, /* sq_inplace_repeat */ +}; + +/* Query object methods */ +static struct PyMethodDef query_methods[] = { + {"getresult", (PyCFunction)query_getresult, METH_NOARGS, + query_getresult__doc__}, + {"dictresult", (PyCFunction)query_dictresult, METH_NOARGS, + query_dictresult__doc__}, + {"dictiter", (PyCFunction)query_dictiter, METH_NOARGS, + query_dictiter__doc__}, + {"namedresult", (PyCFunction)query_namedresult, METH_NOARGS, + query_namedresult__doc__}, + {"namediter", (PyCFunction)query_namediter, METH_NOARGS, + query_namediter__doc__}, + {"one", (PyCFunction)query_one, METH_NOARGS, query_one__doc__}, + {"single", (PyCFunction)query_single, METH_NOARGS, query_single__doc__}, + {"onedict", (PyCFunction)query_onedict, METH_NOARGS, query_onedict__doc__}, + {"singledict", (PyCFunction)query_singledict, METH_NOARGS, + query_singledict__doc__}, + {"onenamed", (PyCFunction)query_onenamed, METH_NOARGS, + query_onenamed__doc__}, + {"singlenamed", (PyCFunction)query_singlenamed, METH_NOARGS, + query_singlenamed__doc__}, + {"scalarresult", (PyCFunction)query_scalarresult, METH_NOARGS, + query_scalarresult__doc__}, + {"scalariter", (PyCFunction)query_scalariter, METH_NOARGS, + query_scalariter__doc__}, + {"onescalar", (PyCFunction)query_onescalar, METH_NOARGS, + query_onescalar__doc__}, + {"singlescalar", (PyCFunction)query_singlescalar, METH_NOARGS, + query_singlescalar__doc__}, + {"fieldname", (PyCFunction)query_fieldname, METH_VARARGS, + query_fieldname__doc__}, + {"fieldnum", (PyCFunction)query_fieldnum, METH_VARARGS, + query_fieldnum__doc__}, + {"listfields", (PyCFunction)query_listfields, METH_NOARGS, + query_listfields__doc__}, + {"fieldinfo", (PyCFunction)query_fieldinfo, METH_VARARGS, + query_fieldinfo__doc__}, + {"memsize", (PyCFunction)query_memsize, METH_NOARGS, query_memsize__doc__}, + {NULL, NULL}}; + +static char query__doc__[] = "PyGreSQL query object"; + +/* Query type definition */ +static PyTypeObject queryType = { + PyVarObject_HEAD_INIT(NULL, 0) "pg.Query", /* tp_name */ + sizeof(queryObject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)query_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + &query_sequence_methods, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)query_str, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + query__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + (getiterfunc)query_iter, /* tp_iter */ + (iternextfunc)query_next, /* tp_iternext */ + query_methods, /* tp_methods */ +}; diff --git a/pgsource.c b/ext/pgsource.c similarity index 67% rename from pgsource.c rename to ext/pgsource.c index 9ab94e36..bbec2f86 100644 --- a/pgsource.c +++ b/ext/pgsource.c @@ -3,7 +3,7 @@ * * The source object - this file is part a of the C extension module. * - * Copyright (c) 2020 by the PyGreSQL Development Team + * Copyright (c) 2025 by the PyGreSQL Development Team * * Please see the LICENSE.TXT file for specific restrictions. */ @@ -28,10 +28,10 @@ source_str(sourceObject *self) return format_result(self->result); case RESULT_DDL: case RESULT_DML: - return PyStr_FromString(PQcmdStatus(self->result)); + return PyUnicode_FromString(PQcmdStatus(self->result)); case RESULT_EMPTY: default: - return PyStr_FromString("(empty PostgreSQL source object)"); + return PyUnicode_FromString("(empty PostgreSQL source object)"); } } @@ -65,13 +65,13 @@ _check_source_obj(sourceObject *self, int level) static PyObject * source_getattr(sourceObject *self, PyObject *nameobj) { - const char *name = PyStr_AsString(nameobj); + const char *name = PyUnicode_AsUTF8(nameobj); /* pg connection object */ if (!strcmp(name, "pgcnx")) { if (_check_source_obj(self, 0)) { Py_INCREF(self->pgcnx); - return (PyObject *) (self->pgcnx); + return (PyObject *)(self->pgcnx); } Py_INCREF(Py_None); return Py_None; @@ -79,22 +79,22 @@ source_getattr(sourceObject *self, PyObject *nameobj) /* arraysize */ if (!strcmp(name, "arraysize")) - return PyInt_FromLong(self->arraysize); + return PyLong_FromLong(self->arraysize); /* resulttype */ if (!strcmp(name, "resulttype")) - return PyInt_FromLong(self->result_type); + return PyLong_FromLong(self->result_type); /* ntuples */ if (!strcmp(name, "ntuples")) - return PyInt_FromLong(self->max_row); + return PyLong_FromLong(self->max_row); /* nfields */ if (!strcmp(name, "nfields")) - return PyInt_FromLong(self->num_fields); + return PyLong_FromLong(self->num_fields); /* seeks name in methods (fallback) */ - return PyObject_GenericGetAttr((PyObject *) self, nameobj); + return PyObject_GenericGetAttr((PyObject *)self, nameobj); } /* Set source object attributes. */ @@ -103,12 +103,12 @@ source_setattr(sourceObject *self, char *name, PyObject *v) { /* arraysize */ if (!strcmp(name, "arraysize")) { - if (!PyInt_Check(v)) { + if (!PyLong_Check(v)) { PyErr_SetString(PyExc_TypeError, "arraysize must be integer"); return -1; } - self->arraysize = PyInt_AsLong(v); + self->arraysize = PyLong_AsLong(v); return 0; } @@ -119,8 +119,9 @@ source_setattr(sourceObject *self, char *name, PyObject *v) /* Close object. */ static char source_close__doc__[] = -"close() -- close query object without deleting it\n\n" -"All instances of the query object can no longer be used after this call.\n"; + "close() -- close source object without deleting it\n\n" + "All instances of the source object can no longer be used after this " + "call.\n"; static PyObject * source_close(sourceObject *self, PyObject *noargs) @@ -141,15 +142,15 @@ source_close(sourceObject *self, PyObject *noargs) /* Database query. */ static char source_execute__doc__[] = -"execute(sql) -- execute a SQL statement (string)\n\n" -"On success, this call returns the number of affected rows, or None\n" -"for DQL (SELECT, ...) statements. The fetch (fetch(), fetchone()\n" -"and fetchall()) methods can be used to get result rows.\n"; + "execute(sql) -- execute a SQL statement (string)\n\n" + "On success, this call returns the number of affected rows, or None\n" + "for DQL (SELECT, ...) statements. The fetch (fetch(), fetchone()\n" + "and fetchall()) methods can be used to get result rows.\n"; static PyObject * source_execute(sourceObject *self, PyObject *sql) { - PyObject *tmp_obj = NULL; /* auxiliary string object */ + PyObject *tmp_obj = NULL; /* auxiliary string object */ char *query; int encoding; @@ -165,7 +166,8 @@ source_execute(sourceObject *self, PyObject *sql) } else if (PyUnicode_Check(sql)) { tmp_obj = get_encoded_string(sql, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ query = PyBytes_AsString(tmp_obj); } else { @@ -205,30 +207,29 @@ source_execute(sourceObject *self, PyObject *sql) /* checks result status */ switch (PQresultStatus(self->result)) { /* query succeeded */ - case PGRES_TUPLES_OK: /* DQL: returns None (DB-SIG compliant) */ + case PGRES_TUPLES_OK: /* DQL: returns None (DB-SIG compliant) */ self->result_type = RESULT_DQL; self->max_row = PQntuples(self->result); self->num_fields = PQnfields(self->result); Py_INCREF(Py_None); return Py_None; - case PGRES_COMMAND_OK: /* other requests */ + case PGRES_COMMAND_OK: /* other requests */ case PGRES_COPY_OUT: - case PGRES_COPY_IN: - { - long num_rows; - char *tmp; - - tmp = PQcmdTuples(self->result); - if (tmp[0]) { - self->result_type = RESULT_DML; - num_rows = atol(tmp); - } - else { - self->result_type = RESULT_DDL; - num_rows = -1; - } - return PyInt_FromLong(num_rows); + case PGRES_COPY_IN: { + long num_rows; + char *tmp; + + tmp = PQcmdTuples(self->result); + if (tmp[0]) { + self->result_type = RESULT_DML; + num_rows = atol(tmp); + } + else { + self->result_type = RESULT_DDL; + num_rows = -1; } + return PyLong_FromLong(num_rows); + } /* query failed */ case PGRES_EMPTY_QUERY: @@ -238,7 +239,7 @@ source_execute(sourceObject *self, PyObject *sql) case PGRES_FATAL_ERROR: case PGRES_NONFATAL_ERROR: set_error(ProgrammingError, "Cannot execute command", - self->pgcnx->cnx, self->result); + self->pgcnx->cnx, self->result); break; default: set_error_msg(InternalError, @@ -254,7 +255,7 @@ source_execute(sourceObject *self, PyObject *sql) /* Get oid status for last query (valid for INSERTs, 0 for other). */ static char source_oidstatus__doc__[] = -"oidstatus() -- return oid of last inserted row (if available)"; + "oidstatus() -- return oid of last inserted row (if available)"; static PyObject * source_oidstatus(sourceObject *self, PyObject *noargs) @@ -272,14 +273,14 @@ source_oidstatus(sourceObject *self, PyObject *noargs) return Py_None; } - return PyInt_FromLong(oid); + return PyLong_FromLong((long)oid); } /* Fetch rows from last result. */ static char source_fetch__doc__[] = -"fetch(num) -- return the next num rows from the last result in a list\n\n" -"If num parameter is omitted arraysize attribute value is used.\n" -"If size equals -1, all rows are fetched.\n"; + "fetch(num) -- return the next num rows from the last result in a list\n\n" + "If num parameter is omitted arraysize attribute value is used.\n" + "If size equals -1, all rows are fetched.\n"; static PyObject * source_fetch(sourceObject *self, PyObject *args) @@ -287,9 +288,7 @@ source_fetch(sourceObject *self, PyObject *args) PyObject *res_list; int i, k; long size; -#if IS_PY3 int encoding; -#endif /* checks validity */ if (!_check_source_obj(self, CHECK_RESULT | CHECK_DQL | CHECK_CNX)) { @@ -311,11 +310,10 @@ source_fetch(sourceObject *self, PyObject *args) } /* allocate list for result */ - if (!(res_list = PyList_New(0))) return NULL; + if (!(res_list = PyList_New(0))) + return NULL; -#if IS_PY3 encoding = self->encoding; -#endif /* builds result */ for (i = 0, k = self->current_row; i < size; ++i, ++k) { @@ -323,7 +321,8 @@ source_fetch(sourceObject *self, PyObject *args) int j; if (!(rowtuple = PyTuple_New(self->num_fields))) { - Py_DECREF(res_list); return NULL; + Py_DECREF(res_list); + return NULL; } for (j = 0; j < self->num_fields; ++j) { @@ -336,21 +335,22 @@ source_fetch(sourceObject *self, PyObject *args) else { char *s = PQgetvalue(self->result, k, j); Py_ssize_t size = PQgetlength(self->result, k, j); -#if IS_PY3 if (PQfformat(self->result, j) == 0) { /* textual format */ str = get_decoded_string(s, size, encoding); if (!str) /* cannot decode */ str = PyBytes_FromStringAndSize(s, size); } - else -#endif - str = PyBytes_FromStringAndSize(s, size); + else { + str = PyBytes_FromStringAndSize(s, size); + } } PyTuple_SET_ITEM(rowtuple, j, str); } if (PyList_Append(res_list, rowtuple)) { - Py_DECREF(rowtuple); Py_DECREF(res_list); return NULL; + Py_DECREF(rowtuple); + Py_DECREF(res_list); + return NULL; } Py_DECREF(rowtuple); } @@ -392,7 +392,7 @@ _source_move(sourceObject *self, int move) /* Move to first result row. */ static char source_movefirst__doc__[] = -"movefirst() -- move to first result row"; + "movefirst() -- move to first result row"; static PyObject * source_movefirst(sourceObject *self, PyObject *noargs) @@ -402,7 +402,7 @@ source_movefirst(sourceObject *self, PyObject *noargs) /* Move to last result row. */ static char source_movelast__doc__[] = -"movelast() -- move to last valid result row"; + "movelast() -- move to last valid result row"; static PyObject * source_movelast(sourceObject *self, PyObject *noargs) @@ -411,8 +411,7 @@ source_movelast(sourceObject *self, PyObject *noargs) } /* Move to next result row. */ -static char source_movenext__doc__[] = -"movenext() -- move to next result row"; +static char source_movenext__doc__[] = "movenext() -- move to next result row"; static PyObject * source_movenext(sourceObject *self, PyObject *noargs) @@ -422,7 +421,7 @@ source_movenext(sourceObject *self, PyObject *noargs) /* Move to previous result row. */ static char source_moveprev__doc__[] = -"moveprev() -- move to previous result row"; + "moveprev() -- move to previous result row"; static PyObject * source_moveprev(sourceObject *self, PyObject *noargs) @@ -432,17 +431,17 @@ source_moveprev(sourceObject *self, PyObject *noargs) /* Put copy data. */ static char source_putdata__doc__[] = -"putdata(buffer) -- send data to server during copy from stdin"; + "putdata(buffer) -- send data to server during copy from stdin"; static PyObject * source_putdata(sourceObject *self, PyObject *buffer) { - PyObject *tmp_obj = NULL; /* an auxiliary object */ - char *buf; /* the buffer as encoded string */ - Py_ssize_t nbytes; /* length of string */ - char *errormsg = NULL; /* error message */ - int res; /* direct result of the operation */ - PyObject *ret; /* return value */ + PyObject *tmp_obj = NULL; /* an auxiliary object */ + char *buf; /* the buffer as encoded string */ + Py_ssize_t nbytes; /* length of string */ + char *errormsg = NULL; /* error message */ + int res; /* direct result of the operation */ + PyObject *ret; /* return value */ /* checks validity */ if (!_check_source_obj(self, CHECK_CNX)) { @@ -464,9 +463,10 @@ source_putdata(sourceObject *self, PyObject *buffer) } else if (PyUnicode_Check(buffer)) { /* or pass a unicode string */ - tmp_obj = get_encoded_string( - buffer, PQclientEncoding(self->pgcnx->cnx)); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + tmp_obj = + get_encoded_string(buffer, PQclientEncoding(self->pgcnx->cnx)); + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ PyBytes_AsStringAndSize(tmp_obj, &buf, &nbytes); } else if (PyErr_GivenExceptionMatches(buffer, PyExc_BaseException)) { @@ -475,10 +475,11 @@ source_putdata(sourceObject *self, PyObject *buffer) if (PyUnicode_Check(tmp_obj)) { PyObject *obj = tmp_obj; - tmp_obj = get_encoded_string( - obj, PQclientEncoding(self->pgcnx->cnx)); + tmp_obj = + get_encoded_string(obj, PQclientEncoding(self->pgcnx->cnx)); Py_DECREF(obj); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ + if (!tmp_obj) + return NULL; /* pass the UnicodeEncodeError */ } errormsg = PyBytes_AsString(tmp_obj); buf = NULL; @@ -492,8 +493,7 @@ source_putdata(sourceObject *self, PyObject *buffer) /* checks validity */ if (!_check_source_obj(self, CHECK_CNX | CHECK_RESULT) || - PQresultStatus(self->result) != PGRES_COPY_IN) - { + PQresultStatus(self->result) != PGRES_COPY_IN) { PyErr_SetString(PyExc_IOError, "Connection is invalid or not in copy_in state"); Py_XDECREF(tmp_obj); @@ -501,7 +501,7 @@ source_putdata(sourceObject *self, PyObject *buffer) } if (buf) { - res = nbytes ? PQputCopyData(self->pgcnx->cnx, buf, (int) nbytes) : 1; + res = nbytes ? PQputCopyData(self->pgcnx->cnx, buf, (int)nbytes) : 1; } else { res = PQputCopyEnd(self->pgcnx->cnx, errormsg); @@ -518,7 +518,7 @@ source_putdata(sourceObject *self, PyObject *buffer) ret = Py_None; Py_INCREF(ret); } - else { /* copy is done */ + else { /* copy is done */ PGresult *result; /* final result of the operation */ Py_BEGIN_ALLOW_THREADS; @@ -531,10 +531,11 @@ source_putdata(sourceObject *self, PyObject *buffer) tmp = PQcmdTuples(result); num_rows = tmp[0] ? atol(tmp) : -1; - ret = PyInt_FromLong(num_rows); + ret = PyLong_FromLong(num_rows); } else { - if (!errormsg) errormsg = PQerrorMessage(self->pgcnx->cnx); + if (!errormsg) + errormsg = PQerrorMessage(self->pgcnx->cnx); PyErr_SetString(PyExc_IOError, errormsg); ret = NULL; } @@ -549,15 +550,15 @@ source_putdata(sourceObject *self, PyObject *buffer) /* Get copy data. */ static char source_getdata__doc__[] = -"getdata(decode) -- receive data to server during copy to stdout"; + "getdata(decode) -- receive data to server during copy to stdout"; static PyObject * source_getdata(sourceObject *self, PyObject *args) { - int *decode = 0; /* decode flag */ - char *buffer; /* the copied buffer as encoded byte string */ - Py_ssize_t nbytes; /* length of the byte string */ - PyObject *ret; /* return value */ + int *decode = 0; /* decode flag */ + char *buffer; /* the copied buffer as encoded byte string */ + Py_ssize_t nbytes; /* length of the byte string */ + PyObject *ret; /* return value */ /* checks validity */ if (!_check_source_obj(self, CHECK_CNX)) { @@ -575,8 +576,7 @@ source_getdata(sourceObject *self, PyObject *args) /* checks validity */ if (!_check_source_obj(self, CHECK_CNX | CHECK_RESULT) || - PQresultStatus(self->result) != PGRES_COPY_OUT) - { + PQresultStatus(self->result) != PGRES_COPY_OUT) { PyErr_SetString(PyExc_IOError, "Connection is invalid or not in copy_out state"); return NULL; @@ -589,7 +589,7 @@ source_getdata(sourceObject *self, PyObject *args) return NULL; } - if (nbytes == -1) { /* copy is done */ + if (nbytes == -1) { /* copy is done */ PGresult *result; /* final result of the operation */ Py_BEGIN_ALLOW_THREADS; @@ -602,7 +602,7 @@ source_getdata(sourceObject *self, PyObject *args) tmp = PQcmdTuples(result); num_rows = tmp[0] ? atol(tmp) : -1; - ret = PyInt_FromLong(num_rows); + ret = PyLong_FromLong(num_rows); } else { PyErr_SetString(PyExc_IOError, PQerrorMessage(self->pgcnx->cnx)); @@ -614,9 +614,9 @@ source_getdata(sourceObject *self, PyObject *args) self->result_type = RESULT_EMPTY; } else { /* a row has been returned */ - ret = decode ? get_decoded_string( - buffer, nbytes, PQclientEncoding(self->pgcnx->cnx)) : - PyBytes_FromStringAndSize(buffer, nbytes); + ret = decode ? get_decoded_string(buffer, nbytes, + PQclientEncoding(self->pgcnx->cnx)) + : PyBytes_FromStringAndSize(buffer, nbytes); PQfreemem(buffer); } @@ -634,11 +634,11 @@ _source_fieldindex(sourceObject *self, PyObject *param, const char *usage) return -1; /* gets field number */ - if (PyStr_Check(param)) { + if (PyUnicode_Check(param)) { num = PQfnumber(self->result, PyBytes_AsString(param)); } - else if (PyInt_Check(param)) { - num = (int) PyInt_AsLong(param); + else if (PyLong_Check(param)) { + num = (int)PyLong_AsLong(param); } else { PyErr_SetString(PyExc_TypeError, usage); @@ -667,22 +667,21 @@ _source_buildinfo(sourceObject *self, int num) } /* affects field information */ - PyTuple_SET_ITEM(result, 0, PyInt_FromLong(num)); + PyTuple_SET_ITEM(result, 0, PyLong_FromLong(num)); PyTuple_SET_ITEM(result, 1, - PyStr_FromString(PQfname(self->result, num))); + PyUnicode_FromString(PQfname(self->result, num))); PyTuple_SET_ITEM(result, 2, - PyInt_FromLong(PQftype(self->result, num))); - PyTuple_SET_ITEM(result, 3, - PyInt_FromLong(PQfsize(self->result, num))); - PyTuple_SET_ITEM(result, 4, - PyInt_FromLong(PQfmod(self->result, num))); + PyLong_FromLong((long)PQftype(self->result, num))); + PyTuple_SET_ITEM(result, 3, PyLong_FromLong(PQfsize(self->result, num))); + PyTuple_SET_ITEM(result, 4, PyLong_FromLong(PQfmod(self->result, num))); return result; } /* Lists fields info. */ static char source_listinfo__doc__[] = -"listinfo() -- get information for all fields (position, name, type oid)"; + "listinfo() -- get information for all fields" + " (position, name, type oid, size, type modifier)"; static PyObject * source_listInfo(sourceObject *self, PyObject *noargs) @@ -715,7 +714,7 @@ source_listInfo(sourceObject *self, PyObject *noargs) /* List fields information for last result. */ static char source_fieldinfo__doc__[] = -"fieldinfo(desc) -- get specified field info (position, name, type oid)"; + "fieldinfo(desc) -- get specified field info (position, name, type oid)"; static PyObject * source_fieldinfo(sourceObject *self, PyObject *desc) @@ -724,9 +723,9 @@ source_fieldinfo(sourceObject *self, PyObject *desc) /* checks args and validity */ if ((num = _source_fieldindex( - self, desc, - "Method fieldinfo() needs a string or integer as argument")) == -1) - { + self, desc, + "Method fieldinfo() needs a string or integer as argument")) == + -1) { return NULL; } @@ -736,7 +735,7 @@ source_fieldinfo(sourceObject *self, PyObject *desc) /* Retrieve field value. */ static char source_field__doc__[] = -"field(desc) -- return specified field value"; + "field(desc) -- return specified field value"; static PyObject * source_field(sourceObject *self, PyObject *desc) @@ -745,13 +744,12 @@ source_field(sourceObject *self, PyObject *desc) /* checks args and validity */ if ((num = _source_fieldindex( - self, desc, - "Method field() needs a string or integer as argument")) == -1) - { + self, desc, + "Method field() needs a string or integer as argument")) == -1) { return NULL; } - return PyStr_FromString( + return PyUnicode_FromString( PQgetvalue(self->result, self->current_row, num)); } @@ -761,78 +759,70 @@ source_dir(connObject *self, PyObject *noargs) { PyObject *attrs; - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[sssss]", - "pgcnx", "arraysize", "resulttype", "ntuples", "nfields"); + attrs = PyObject_Dir(PyObject_Type((PyObject *)self)); + PyObject_CallMethod(attrs, "extend", "[sssss]", "pgcnx", "arraysize", + "resulttype", "ntuples", "nfields"); return attrs; } /* Source object methods */ static PyMethodDef source_methods[] = { - {"__dir__", (PyCFunction) source_dir, METH_NOARGS, NULL}, - - {"close", (PyCFunction) source_close, - METH_NOARGS, source_close__doc__}, - {"execute", (PyCFunction) source_execute, - METH_O, source_execute__doc__}, - {"oidstatus", (PyCFunction) source_oidstatus, - METH_NOARGS, source_oidstatus__doc__}, - {"fetch", (PyCFunction) source_fetch, - METH_VARARGS, source_fetch__doc__}, - {"movefirst", (PyCFunction) source_movefirst, - METH_NOARGS, source_movefirst__doc__}, - {"movelast", (PyCFunction) source_movelast, - METH_NOARGS, source_movelast__doc__}, - {"movenext", (PyCFunction) source_movenext, - METH_NOARGS, source_movenext__doc__}, - {"moveprev", (PyCFunction) source_moveprev, - METH_NOARGS, source_moveprev__doc__}, - {"putdata", (PyCFunction) source_putdata, - METH_O, source_putdata__doc__}, - {"getdata", (PyCFunction) source_getdata, - METH_VARARGS, source_getdata__doc__}, - {"field", (PyCFunction) source_field, - METH_O, source_field__doc__}, - {"fieldinfo", (PyCFunction) source_fieldinfo, - METH_O, source_fieldinfo__doc__}, - {"listinfo", (PyCFunction) source_listInfo, - METH_NOARGS, source_listinfo__doc__}, - {NULL, NULL} -}; + {"__dir__", (PyCFunction)source_dir, METH_NOARGS, NULL}, + + {"close", (PyCFunction)source_close, METH_NOARGS, source_close__doc__}, + {"execute", (PyCFunction)source_execute, METH_O, source_execute__doc__}, + {"oidstatus", (PyCFunction)source_oidstatus, METH_NOARGS, + source_oidstatus__doc__}, + {"fetch", (PyCFunction)source_fetch, METH_VARARGS, source_fetch__doc__}, + {"movefirst", (PyCFunction)source_movefirst, METH_NOARGS, + source_movefirst__doc__}, + {"movelast", (PyCFunction)source_movelast, METH_NOARGS, + source_movelast__doc__}, + {"movenext", (PyCFunction)source_movenext, METH_NOARGS, + source_movenext__doc__}, + {"moveprev", (PyCFunction)source_moveprev, METH_NOARGS, + source_moveprev__doc__}, + {"putdata", (PyCFunction)source_putdata, METH_O, source_putdata__doc__}, + {"getdata", (PyCFunction)source_getdata, METH_VARARGS, + source_getdata__doc__}, + {"field", (PyCFunction)source_field, METH_O, source_field__doc__}, + {"fieldinfo", (PyCFunction)source_fieldinfo, METH_O, + source_fieldinfo__doc__}, + {"listinfo", (PyCFunction)source_listInfo, METH_NOARGS, + source_listinfo__doc__}, + {NULL, NULL}}; static char source__doc__[] = "PyGreSQL source object"; /* Source type definition */ static PyTypeObject sourceType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pgdb.Source", /* tp_name */ - sizeof(sourceObject), /* tp_basicsize */ - 0, /* tp_itemsize */ + PyVarObject_HEAD_INIT(NULL, 0) "pgdb.Source", /* tp_name */ + sizeof(sourceObject), /* tp_basicsize */ + 0, /* tp_itemsize */ /* methods */ - (destructor) source_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - (setattrfunc) source_setattr, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) source_str, /* tp_str */ - (getattrofunc) source_getattr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - source__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - source_methods, /* tp_methods */ + (destructor)source_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + (setattrfunc)source_setattr, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + (reprfunc)source_str, /* tp_str */ + (getattrofunc)source_getattr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + source__doc__, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + source_methods, /* tp_methods */ }; diff --git a/pgtypes.h b/ext/pgtypes.h similarity index 100% rename from pgtypes.h rename to ext/pgtypes.h diff --git a/pg.py b/pg.py deleted file mode 100644 index b0fa6674..00000000 --- a/pg.py +++ /dev/null @@ -1,2759 +0,0 @@ -#!/usr/bin/python -# -# PyGreSQL - a Python interface for the PostgreSQL database. -# -# This file contains the classic pg module. -# -# Copyright (c) 2020 by the PyGreSQL Development Team -# -# The notification handler is based on pgnotify which is -# Copyright (c) 2001 Ng Pheng Siong. All rights reserved. -# -# Please see the LICENSE.TXT file for specific restrictions. - -"""PyGreSQL classic interface. - -This pg module implements some basic database management stuff. -It includes the _pg module and builds on it, providing the higher -level wrapper class named DB with additional functionality. -This is known as the "classic" ("old style") PyGreSQL interface. -For a DB-API 2 compliant interface use the newer pgdb module. -""" - -from __future__ import print_function, division - -try: - from _pg import * -except ImportError: - import os - import sys - # see https://docs.python.org/3/whatsnew/3.8.html#ctypes - if os.name == 'nt' and sys.version_info >= (3, 8): - for path in os.environ["PATH"].split(os.pathsep): - if os.path.exists(os.path.join(path, 'libpq.dll')): - with os.add_dll_directory(os.path.abspath(path)): - from _pg import * - break - else: - raise - else: - raise - -__version__ = version - -__all__ = [ - 'DB', 'Adapter', - 'NotificationHandler', 'Typecasts', - 'Bytea', 'Hstore', 'Json', 'Literal', - 'Error', 'Warning', - 'DataError', 'DatabaseError', - 'IntegrityError', 'InterfaceError', 'InternalError', - 'InvalidResultError', 'MultipleResultsError', - 'NoResultError', 'NotSupportedError', - 'OperationalError', 'ProgrammingError', - 'INV_READ', 'INV_WRITE', - 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', - 'TRANS_ACTIVE', 'TRANS_IDLE', 'TRANS_INERROR', - 'TRANS_INTRANS', 'TRANS_UNKNOWN', - 'cast_array', 'cast_hstore', 'cast_record', - 'connect', 'escape_bytea', 'escape_string', 'unescape_bytea', - 'get_array', 'get_bool', 'get_bytea_escaped', - 'get_datestyle', 'get_decimal', 'get_decimal_point', - 'get_defbase', 'get_defhost', 'get_defopt', 'get_defport', 'get_defuser', - 'get_jsondecode', - 'set_array', 'set_bool', 'set_bytea_escaped', - 'set_datestyle', 'set_decimal', 'set_decimal_point', - 'set_defbase', 'set_defhost', 'set_defopt', - 'set_defpasswd', 'set_defport', 'set_defuser', - 'set_jsondecode', 'set_query_helpers', - 'version', '__version__'] - -import select -import warnings -import weakref - -from datetime import date, time, datetime, timedelta, tzinfo -from decimal import Decimal -from math import isnan, isinf -from collections import namedtuple -from keyword import iskeyword -from operator import itemgetter -from functools import partial -from re import compile as regex -from json import loads as jsondecode, dumps as jsonencode -from uuid import UUID - -try: # noinspection PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnresolvedReferences - basestring -except NameError: # Python >= 3.0 - basestring = (str, bytes) - -try: - from functools import lru_cache -except ImportError: # Python < 3.2 - from functools import update_wrapper - try: - from _thread import RLock - except ImportError: - class RLock: # for builds without threads - def __enter__(self): pass - - def __exit__(self, exctype, excinst, exctb): pass - - def lru_cache(maxsize=128): - """Simplified functools.lru_cache decorator for one argument.""" - - def decorator(function): - sentinel = object() - cache = {} - get = cache.get - lock = RLock() - root = [] - root_full = [root, False] - root[:] = [root, root, None, None] - - if maxsize == 0: - - def wrapper(arg): - res = function(arg) - return res - - elif maxsize is None: - - def wrapper(arg): - res = get(arg, sentinel) - if res is not sentinel: - return res - res = function(arg) - cache[arg] = res - return res - - else: - - def wrapper(arg): - with lock: - link = get(arg) - if link is not None: - root = root_full[0] - prev, next, _arg, res = link - prev[1] = next - next[0] = prev - last = root[0] - last[1] = root[0] = link - link[0] = last - link[1] = root - return res - res = function(arg) - with lock: - root, full = root_full - if arg in cache: - pass - elif full: - oldroot = root - oldroot[2] = arg - oldroot[3] = res - root = root_full[0] = oldroot[1] - oldarg = root[2] - oldres = root[3] # keep reference - root[2] = root[3] = None - del cache[oldarg] - cache[arg] = oldroot - else: - last = root[0] - link = [last, root, arg, res] - last[1] = root[0] = cache[arg] = link - if len(cache) >= maxsize: - root_full[1] = True - return res - - wrapper.__wrapped__ = function - return update_wrapper(wrapper, function) - - return decorator - - -# Auxiliary classes and functions that are independent from a DB connection: - -try: - from collections import OrderedDict -except ImportError: # Python 2.6 or 3.0 - OrderedDict = dict - - - class AttrDict(dict): - """Simple read-only ordered dictionary for storing attribute names.""" - - def __init__(self, *args, **kw): - if len(args) > 1 or kw: - raise TypeError - items = args[0] if args else [] - if isinstance(items, dict): - raise TypeError - items = list(items) - self._keys = [item[0] for item in items] - dict.__init__(self, items) - self._read_only = True - error = self._read_only_error - self.clear = self.update = error - self.pop = self.setdefault = self.popitem = error - - def __setitem__(self, key, value): - if self._read_only: - self._read_only_error() - dict.__setitem__(self, key, value) - - def __delitem__(self, key): - if self._read_only: - self._read_only_error() - dict.__delitem__(self, key) - - def __iter__(self): - return iter(self._keys) - - def keys(self): - return list(self._keys) - - def values(self): - return [self[key] for key in self] - - def items(self): - return [(key, self[key]) for key in self] - - def iterkeys(self): - return self.__iter__() - - def itervalues(self): - return iter(self.values()) - - def iteritems(self): - return iter(self.items()) - - @staticmethod - def _read_only_error(*args, **kw): - raise TypeError('This object is read-only') - -else: - - class AttrDict(OrderedDict): - """Simple read-only ordered dictionary for storing attribute names.""" - - def __init__(self, *args, **kw): - self._read_only = False - OrderedDict.__init__(self, *args, **kw) - self._read_only = True - error = self._read_only_error - self.clear = self.update = error - self.pop = self.setdefault = self.popitem = error - - def __setitem__(self, key, value): - if self._read_only: - self._read_only_error() - OrderedDict.__setitem__(self, key, value) - - def __delitem__(self, key): - if self._read_only: - self._read_only_error() - OrderedDict.__delitem__(self, key) - - @staticmethod - def _read_only_error(*args, **kw): - raise TypeError('This object is read-only') - -try: - from inspect import signature -except ImportError: # Python < 3.3 - from inspect import getargspec - - def get_args(func): - return getargspec(func).args -else: - - def get_args(func): - return list(signature(func).parameters) - -try: - from datetime import timezone -except ImportError: # Python < 3.2 - - class timezone(tzinfo): - """Simple timezone implementation.""" - - def __init__(self, offset, name=None): - self.offset = offset - if not name: - minutes = self.offset.days * 1440 + self.offset.seconds // 60 - if minutes < 0: - hours, minutes = divmod(-minutes, 60) - hours = -hours - else: - hours, minutes = divmod(minutes, 60) - name = 'UTC%+03d:%02d' % (hours, minutes) - self.name = name - - def utcoffset(self, dt): - return self.offset - - def tzname(self, dt): - return self.name - - def dst(self, dt): - return None - - timezone.utc = timezone(timedelta(0), 'UTC') - - _has_timezone = False -else: - _has_timezone = True - -# time zones used in Postgres timestamptz output -_timezones = dict(CET='+0100', EET='+0200', EST='-0500', - GMT='+0000', HST='-1000', MET='+0100', MST='-0700', - UCT='+0000', UTC='+0000', WET='+0000') - - -def _timezone_as_offset(tz): - if tz.startswith(('+', '-')): - if len(tz) < 5: - return tz + '00' - return tz.replace(':', '') - return _timezones.get(tz, '+0000') - - -def _get_timezone(tz): - tz = _timezone_as_offset(tz) - minutes = 60 * int(tz[1:3]) + int(tz[3:5]) - if tz[0] == '-': - minutes = -minutes - return timezone(timedelta(minutes=minutes), tz) - - -def _oid_key(table): - """Build oid key from a table name.""" - return 'oid(%s)' % table - - -class _SimpleTypes(dict): - """Dictionary mapping pg_type names to simple type names.""" - - _types = {'bool': 'bool', - 'bytea': 'bytea', - 'date': 'date interval time timetz timestamp timestamptz' - ' abstime reltime', # these are very old - 'float': 'float4 float8', - 'int': 'cid int2 int4 int8 oid xid', - 'hstore': 'hstore', 'json': 'json jsonb', 'uuid': 'uuid', - 'num': 'numeric', 'money': 'money', - 'text': 'bpchar char name text varchar'} - - def __init__(self): - for typ, keys in self._types.items(): - for key in keys.split(): - self[key] = typ - self['_%s' % key] = '%s[]' % typ - - # this could be a static method in Python > 2.6 - def __missing__(self, key): - return 'text' - -_simpletypes = _SimpleTypes() - - -def _quote_if_unqualified(param, name): - """Quote parameter representing a qualified name. - - Puts a quote_ident() call around the give parameter unless - the name contains a dot, in which case the name is ambiguous - (could be a qualified name or just a name with a dot in it) - and must be quoted manually by the caller. - """ - if isinstance(name, basestring) and '.' not in name: - return 'quote_ident(%s)' % (param,) - return param - - -class _ParameterList(list): - """Helper class for building typed parameter lists.""" - - def add(self, value, typ=None): - """Typecast value with known database type and build parameter list. - - If this is a literal value, it will be returned as is. Otherwise, a - placeholder will be returned and the parameter list will be augmented. - """ - value = self.adapt(value, typ) - if isinstance(value, Literal): - return value - self.append(value) - return '$%d' % len(self) - - -class Bytea(bytes): - """Wrapper class for marking Bytea values.""" - - -class Hstore(dict): - """Wrapper class for marking hstore values.""" - - _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') - - @classmethod - def _quote(cls, s): - if s is None: - return 'NULL' - if not s: - return '""' - s = s.replace('"', '\\"') - if cls._re_quote.search(s): - s = '"%s"' % s - return s - - def __str__(self): - q = self._quote - return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items()) - - -class Json: - """Wrapper class for marking Json values.""" - - def __init__(self, obj): - self.obj = obj - - -class Literal(str): - """Wrapper class for marking literal SQL values.""" - - -class Adapter: - """Class providing methods for adapting parameters to the database.""" - - _bool_true_values = frozenset('t true 1 y yes on'.split()) - - _date_literals = frozenset('current_date current_time' - ' current_timestamp localtime localtimestamp'.split()) - - _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$') - _re_record_quote = regex(r'[(,"\\]') - _re_array_escape = _re_record_escape = regex(r'(["\\])') - - def __init__(self, db): - self.db = weakref.proxy(db) - - @classmethod - def _adapt_bool(cls, v): - """Adapt a boolean parameter.""" - if isinstance(v, basestring): - if not v: - return None - v = v.lower() in cls._bool_true_values - return 't' if v else 'f' - - @classmethod - def _adapt_date(cls, v): - """Adapt a date parameter.""" - if not v: - return None - if isinstance(v, basestring) and v.lower() in cls._date_literals: - return Literal(v) - return v - - @staticmethod - def _adapt_num(v): - """Adapt a numeric parameter.""" - if not v and v != 0: - return None - return v - - _adapt_int = _adapt_float = _adapt_money = _adapt_num - - def _adapt_bytea(self, v): - """Adapt a bytea parameter.""" - return self.db.escape_bytea(v) - - def _adapt_json(self, v): - """Adapt a json parameter.""" - if not v: - return None - if isinstance(v, basestring): - return v - return self.db.encode_json(v) - - @classmethod - def _adapt_text_array(cls, v): - """Adapt a text type array parameter.""" - if isinstance(v, list): - adapt = cls._adapt_text_array - return '{%s}' % ','.join(adapt(v) for v in v) - if v is None: - return 'null' - if not v: - return '""' - v = str(v) - if cls._re_array_quote.search(v): - v = '"%s"' % cls._re_array_escape.sub(r'\\\1', v) - return v - - _adapt_date_array = _adapt_text_array - - @classmethod - def _adapt_bool_array(cls, v): - """Adapt a boolean array parameter.""" - if isinstance(v, list): - adapt = cls._adapt_bool_array - return '{%s}' % ','.join(adapt(v) for v in v) - if v is None: - return 'null' - if isinstance(v, basestring): - if not v: - return 'null' - v = v.lower() in cls._bool_true_values - return 't' if v else 'f' - - @classmethod - def _adapt_num_array(cls, v): - """Adapt a numeric array parameter.""" - if isinstance(v, list): - adapt = cls._adapt_num_array - return '{%s}' % ','.join(adapt(v) for v in v) - if not v and v != 0: - return 'null' - return str(v) - - _adapt_int_array = _adapt_float_array = _adapt_money_array = \ - _adapt_num_array - - def _adapt_bytea_array(self, v): - """Adapt a bytea array parameter.""" - if isinstance(v, list): - return b'{' + b','.join( - self._adapt_bytea_array(v) for v in v) + b'}' - if v is None: - return b'null' - return self.db.escape_bytea(v).replace(b'\\', b'\\\\') - - def _adapt_json_array(self, v): - """Adapt a json array parameter.""" - if isinstance(v, list): - adapt = self._adapt_json_array - return '{%s}' % ','.join(adapt(v) for v in v) - if not v: - return 'null' - if not isinstance(v, basestring): - v = self.db.encode_json(v) - if self._re_array_quote.search(v): - v = '"%s"' % self._re_array_escape.sub(r'\\\1', v) - return v - - def _adapt_record(self, v, typ): - """Adapt a record parameter with given type.""" - typ = self.get_attnames(typ).values() - if len(typ) != len(v): - raise TypeError('Record parameter %s has wrong size' % v) - adapt = self.adapt - value = [] - for v, t in zip(v, typ): - v = adapt(v, t) - if v is None: - v = '' - elif not v: - v = '""' - else: - if isinstance(v, bytes): - if str is not bytes: - v = v.decode('ascii') - else: - v = str(v) - if self._re_record_quote.search(v): - v = '"%s"' % self._re_record_escape.sub(r'\\\1', v) - value.append(v) - return '(%s)' % ','.join(value) - - def adapt(self, value, typ=None): - """Adapt a value with known database type.""" - if value is not None and not isinstance(value, Literal): - if typ: - simple = self.get_simple_name(typ) - else: - typ = simple = self.guess_simple_type(value) or 'text' - pg_str = getattr(value, '__pg_str__', None) - if pg_str: - value = pg_str(typ) - if simple == 'text': - pass - elif simple == 'record': - if isinstance(value, tuple): - value = self._adapt_record(value, typ) - elif simple.endswith('[]'): - if isinstance(value, list): - adapt = getattr(self, '_adapt_%s_array' % simple[:-2]) - value = adapt(value) - else: - adapt = getattr(self, '_adapt_%s' % simple) - value = adapt(value) - return value - - @staticmethod - def simple_type(name): - """Create a simple database type with given attribute names.""" - typ = DbType(name) - typ.simple = name - return typ - - @staticmethod - def get_simple_name(typ): - """Get the simple name of a database type.""" - if isinstance(typ, DbType): - return typ.simple - return _simpletypes[typ] - - @staticmethod - def get_attnames(typ): - """Get the attribute names of a composite database type.""" - if isinstance(typ, DbType): - return typ.attnames - return {} - - _frequent_simple_types = { - Bytea: 'bytea', - str: 'text', - bytes: 'text', - bool: 'bool', - int: 'int', - long: 'int', - float: 'float', - Decimal: 'num', - date: 'date', - time: 'date', - datetime: 'date', - timedelta: 'date' - } - - @classmethod - def guess_simple_type(cls, value): - """Try to guess which database type the given value has.""" - # optimize for most frequent types - try: - return cls._frequent_simple_types[type(value)] - except KeyError: - pass - if isinstance(value, Bytea): - return 'bytea' - if isinstance(value, basestring): - return 'text' - if isinstance(value, bool): - return 'bool' - if isinstance(value, (int, long)): - return 'int' - if isinstance(value, float): - return 'float' - if isinstance(value, Decimal): - return 'num' - if isinstance(value, (date, time, datetime, timedelta)): - return 'date' - if isinstance(value, list): - return '%s[]' % (cls.guess_simple_base_type(value) or 'text',) - if isinstance(value, tuple): - simple_type = cls.simple_type - guess = cls.guess_simple_type - - def get_attnames(self): - return AttrDict((str(n + 1), simple_type(guess(v))) - for n, v in enumerate(value)) - - typ = simple_type('record') - typ._get_attnames = get_attnames - return typ - - @classmethod - def guess_simple_base_type(cls, value): - """Try to guess the base type of a given array.""" - for v in value: - if isinstance(v, list): - typ = cls.guess_simple_base_type(v) - else: - typ = cls.guess_simple_type(v) - if typ: - return typ - - def adapt_inline(self, value, nested=False): - """Adapt a value that is put into the SQL and needs to be quoted.""" - if value is None: - return 'NULL' - if isinstance(value, Literal): - return value - if isinstance(value, Bytea): - value = self.db.escape_bytea(value) - if bytes is not str: # Python >= 3.0 - value = value.decode('ascii') - elif isinstance(value, Json): - if value.encode: - return value.encode() - value = self.db.encode_json(value) - elif isinstance(value, (datetime, date, time, timedelta)): - value = str(value) - if isinstance(value, basestring): - value = self.db.escape_string(value) - return "'%s'" % value - if isinstance(value, bool): - return 'true' if value else 'false' - if isinstance(value, float): - if isinf(value): - return "'-Infinity'" if value < 0 else "'Infinity'" - if isnan(value): - return "'NaN'" - return value - if isinstance(value, (int, long, Decimal)): - return value - if isinstance(value, list): - q = self.adapt_inline - s = '[%s]' if nested else 'ARRAY[%s]' - return s % ','.join(str(q(v, nested=True)) for v in value) - if isinstance(value, tuple): - q = self.adapt_inline - return '(%s)' % ','.join(str(q(v)) for v in value) - pg_repr = getattr(value, '__pg_repr__', None) - if not pg_repr: - raise InterfaceError( - 'Do not know how to adapt type %s' % type(value)) - value = pg_repr() - if isinstance(value, (tuple, list)): - value = self.adapt_inline(value) - return value - - def parameter_list(self): - """Return a parameter list for parameters with known database types. - - The list has an add(value, typ) method that will build up the - list and return either the literal value or a placeholder. - """ - params = _ParameterList() - params.adapt = self.adapt - return params - - def format_query(self, command, values=None, types=None, inline=False): - """Format a database query using the given values and types.""" - if not values: - return command, [] - if inline and types: - raise ValueError('Typed parameters must be sent separately') - params = self.parameter_list() - if isinstance(values, (list, tuple)): - if inline: - adapt = self.adapt_inline - literals = [adapt(value) for value in values] - else: - add = params.add - if types: - if (not isinstance(types, (list, tuple)) or - len(types) != len(values)): - raise TypeError('The values and types do not match') - literals = [add(value, typ) - for value, typ in zip(values, types)] - else: - literals = [add(value) for value in values] - command %= tuple(literals) - elif isinstance(values, dict): - # we want to allow extra keys in the dictionary, - # so we first must find the values actually used in the command - used_values = {} - literals = dict.fromkeys(values, '') - for key in values: - del literals[key] - try: - command % literals - except KeyError: - used_values[key] = values[key] - literals[key] = '' - values = used_values - if inline: - adapt = self.adapt_inline - literals = dict((key, adapt(value)) - for key, value in values.items()) - else: - add = params.add - if types: - if not isinstance(types, dict): - raise TypeError('The values and types do not match') - literals = dict((key, add(values[key], types.get(key))) - for key in sorted(values)) - else: - literals = dict((key, add(values[key])) - for key in sorted(values)) - command %= literals - else: - raise TypeError('The values must be passed as tuple, list or dict') - return command, params - - -def cast_bool(value): - """Cast a boolean value.""" - if not get_bool(): - return value - return value[0] == 't' - - -def cast_json(value): - """Cast a JSON value.""" - cast = get_jsondecode() - if not cast: - return value - return cast(value) - - -def cast_num(value): - """Cast a numeric value.""" - return (get_decimal() or float)(value) - - -def cast_money(value): - """Cast a money value.""" - point = get_decimal_point() - if not point: - return value - if point != '.': - value = value.replace(point, '.') - value = value.replace('(', '-') - value = ''.join(c for c in value if c.isdigit() or c in '.-') - return (get_decimal() or float)(value) - - -def cast_int2vector(value): - """Cast an int2vector value.""" - return [int(v) for v in value.split()] - - -def cast_date(value, connection): - """Cast a date value.""" - # The output format depends on the server setting DateStyle. The default - # setting ISO and the setting for German are actually unambiguous. The - # order of days and months in the other two settings is however ambiguous, - # so at least here we need to consult the setting to properly parse values. - if value == '-infinity': - return date.min - if value == 'infinity': - return date.max - value = value.split() - if value[-1] == 'BC': - return date.min - value = value[0] - if len(value) > 10: - return date.max - fmt = connection.date_format() - return datetime.strptime(value, fmt).date() - - -def cast_time(value): - """Cast a time value.""" - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - return datetime.strptime(value, fmt).time() - - -_re_timezone = regex('(.*)([+-].*)') - - -def cast_timetz(value): - """Cast a timetz value.""" - tz = _re_timezone.match(value) - if tz: - value, tz = tz.groups() - else: - tz = '+0000' - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - if _has_timezone: - value += _timezone_as_offset(tz) - fmt += '%z' - return datetime.strptime(value, fmt).timetz() - return datetime.strptime(value, fmt).timetz().replace( - tzinfo=_get_timezone(tz)) - - -def cast_timestamp(value, connection): - """Cast a timestamp value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - value = value.split() - if value[-1] == 'BC': - return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:5] - if len(value[3]) > 4: - return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] - else: - if len(value[0]) > 10: - return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - return datetime.strptime(' '.join(value), ' '.join(fmt)) - - -def cast_timestamptz(value, connection): - """Cast a timestamptz value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - value = value.split() - if value[-1] == 'BC': - return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:] - if len(value[3]) > 4: - return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] - value, tz = value[:-1], value[-1] - else: - if fmt.startswith('%Y-'): - tz = _re_timezone.match(value[1]) - if tz: - value[1], tz = tz.groups() - else: - tz = '+0000' - else: - value, tz = value[:-1], value[-1] - if len(value[0]) > 10: - return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - if _has_timezone: - value.append(_timezone_as_offset(tz)) - fmt.append('%z') - return datetime.strptime(' '.join(value), ' '.join(fmt)) - return datetime.strptime(' '.join(value), ' '.join(fmt)).replace( - tzinfo=_get_timezone(tz)) - - -_re_interval_sql_standard = regex( - '(?:([+-])?([0-9]+)-([0-9]+) ?)?' - '(?:([+-]?[0-9]+)(?!:) ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres = regex( - '(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres_verbose = regex( - '@ ?(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-]?[0-9]+) ?hours? ?)?' - '(?:([+-]?[0-9]+) ?mins? ?)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') - -_re_interval_iso_8601 = regex( - 'P(?:([+-]?[0-9]+)Y)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-]?[0-9]+)D)?' - '(?:T(?:([+-]?[0-9]+)H)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') - - -def cast_interval(value): - """Cast an interval value.""" - # The output format depends on the server setting IntervalStyle, but it's - # not necessary to consult this setting to parse it. It's faster to just - # check all possible formats, and there is no ambiguity here. - m = _re_interval_iso_8601.match(value) - if m: - m = [d or '0' for d in m.groups()] - secs_ago = m.pop(5) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m - if secs_ago: - secs = -secs - usecs = -usecs - else: - m = _re_interval_postgres_verbose.match(value) - if m: - m, ago = [d or '0' for d in m.groups()[:8]], m.group(9) - secs_ago = m.pop(5) == '-' - m = [-int(d) for d in m] if ago else [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m - if secs_ago: - secs = - secs - usecs = -usecs - else: - m = _re_interval_postgres.match(value) - if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - m = _re_interval_sql_standard.match(value) - if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - years_ago = m.pop(0) == '-' - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m - if years_ago: - years = -years - mons = -mons - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - raise ValueError('Cannot parse interval: %s' % value) - days += 365 * years + 30 * mons - return timedelta(days=days, hours=hours, minutes=mins, - seconds=secs, microseconds=usecs) - - -class Typecasts(dict): - """Dictionary mapping database types to typecast functions. - - The cast functions get passed the string representation of a value in - the database which they need to convert to a Python object. The - passed string will never be None since NULL values are already - handled before the cast function is called. - - Note that the basic types are already handled by the C extension. - They only need to be handled here as record or array components. - """ - - # the default cast functions - # (str functions are ignored but have been added for faster access) - defaults = {'char': str, 'bpchar': str, 'name': str, - 'text': str, 'varchar': str, - 'bool': cast_bool, 'bytea': unescape_bytea, - 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, - 'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json, - 'float4': float, 'float8': float, - 'numeric': cast_num, 'money': cast_money, - 'date': cast_date, 'interval': cast_interval, - 'time': cast_time, 'timetz': cast_timetz, - 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, - 'int2vector': cast_int2vector, 'uuid': UUID, - 'anyarray': cast_array, 'record': cast_record} - - connection = None # will be set in a connection specific instance - - def __missing__(self, typ): - """Create a cast function if it is not cached. - - Note that this class never raises a KeyError, - but returns None when no special cast function exists. - """ - if not isinstance(typ, str): - raise TypeError('Invalid type: %s' % typ) - cast = self.defaults.get(typ) - if cast: - # store default for faster access - cast = self._add_connection(cast) - self[typ] = cast - elif typ.startswith('_'): - base_cast = self[typ[1:]] - cast = self.create_array_cast(base_cast) - if base_cast: - self[typ] = cast - else: - attnames = self.get_attnames(typ) - if attnames: - casts = [self[v.pgtype] for v in attnames.values()] - cast = self.create_record_cast(typ, attnames, casts) - self[typ] = cast - return cast - - @staticmethod - def _needs_connection(func): - """Check if a typecast function needs a connection argument.""" - try: - args = get_args(func) - except (TypeError, ValueError): - return False - else: - return 'connection' in args[1:] - - def _add_connection(self, cast): - """Add a connection argument to the typecast function if necessary.""" - if not self.connection or not self._needs_connection(cast): - return cast - return partial(cast, connection=self.connection) - - def get(self, typ, default=None): - """Get the typecast function for the given database type.""" - return self[typ] or default - - def set(self, typ, cast): - """Set a typecast function for the specified database type(s).""" - if isinstance(typ, basestring): - typ = [typ] - if cast is None: - for t in typ: - self.pop(t, None) - self.pop('_%s' % t, None) - else: - if not callable(cast): - raise TypeError("Cast parameter must be callable") - for t in typ: - self[t] = self._add_connection(cast) - self.pop('_%s' % t, None) - - def reset(self, typ=None): - """Reset the typecasts for the specified type(s) to their defaults. - - When no type is specified, all typecasts will be reset. - """ - if typ is None: - self.clear() - else: - if isinstance(typ, basestring): - typ = [typ] - for t in typ: - self.pop(t, None) - - @classmethod - def get_default(cls, typ): - """Get the default typecast function for the given database type.""" - return cls.defaults.get(typ) - - @classmethod - def set_default(cls, typ, cast): - """Set a default typecast function for the given database type(s).""" - if isinstance(typ, basestring): - typ = [typ] - defaults = cls.defaults - if cast is None: - for t in typ: - defaults.pop(t, None) - defaults.pop('_%s' % t, None) - else: - if not callable(cast): - raise TypeError("Cast parameter must be callable") - for t in typ: - defaults[t] = cast - defaults.pop('_%s' % t, None) - - def get_attnames(self, typ): - """Return the fields for the given record type. - - This method will be replaced with the get_attnames() method of DbTypes. - """ - return {} - - def dateformat(self): - """Return the current date format. - - This method will be replaced with the dateformat() method of DbTypes. - """ - return '%Y-%m-%d' - - def create_array_cast(self, basecast): - """Create an array typecast for the given base cast.""" - cast_array = self['anyarray'] - def cast(v): - return cast_array(v, basecast) - return cast - - def create_record_cast(self, name, fields, casts): - """Create a named record typecast for the given fields and casts.""" - cast_record = self['record'] - record = namedtuple(name, fields) - def cast(v): - return record(*cast_record(v, casts)) - return cast - - -def get_typecast(typ): - """Get the global typecast function for the given database type(s).""" - return Typecasts.get_default(typ) - - -def set_typecast(typ, cast): - """Set a global typecast function for the given database type(s). - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call db.db_types.reset_typecast(). - """ - Typecasts.set_default(typ, cast) - - -class DbType(str): - """Class augmenting the simple type name with additional info. - - The following additional information is provided: - - oid: the PostgreSQL type OID - pgtype: the internal PostgreSQL data type name - regtype: the registered PostgreSQL data type name - simple: the more coarse-grained PyGreSQL type name - typtype: b = base type, c = composite type etc. - category: A = Array, b = Boolean, C = Composite etc. - delim: delimiter for array types - relid: corresponding table for composite types - attnames: attributes for composite types - """ - - @property - def attnames(self): - """Get names and types of the fields of a composite type.""" - return self._get_attnames(self) - - -class DbTypes(dict): - """Cache for PostgreSQL data types. - - This cache maps type OIDs and names to DbType objects containing - information on the associated database type. - """ - - _num_types = frozenset('int float num money' - ' int2 int4 int8 float4 float8 numeric money'.split()) - - def __init__(self, db): - """Initialize type cache for connection.""" - super(DbTypes, self).__init__() - self._db = weakref.proxy(db) - self._regtypes = False - self._typecasts = Typecasts() - self._typecasts.get_attnames = self.get_attnames - self._typecasts.connection = self._db - if db.server_version < 80400: - # older remote databases (not officially supported) - self._query_pg_type = ( - "SELECT oid, typname, typname::text::regtype," - " typtype, null as typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type" - " WHERE oid OPERATOR(pg_catalog.=) %s::regtype") - else: - self._query_pg_type = ( - "SELECT oid, typname, typname::regtype," - " typtype, typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type" - " WHERE oid OPERATOR(pg_catalog.=) %s::regtype") - - def add(self, oid, pgtype, regtype, - typtype, category, delim, relid): - """Create a PostgreSQL type name with additional info.""" - if oid in self: - return self[oid] - simple = 'record' if relid else _simpletypes[pgtype] - typ = DbType(regtype if self._regtypes else simple) - typ.oid = oid - typ.simple = simple - typ.pgtype = pgtype - typ.regtype = regtype - typ.typtype = typtype - typ.category = category - typ.delim = delim - typ.relid = relid - typ._get_attnames = self.get_attnames - return typ - - def __missing__(self, key): - """Get the type info from the database if it is not cached.""" - try: - q = self._query_pg_type % (_quote_if_unqualified('$1', key),) - res = self._db.query(q, (key,)).getresult() - except ProgrammingError: - res = None - if not res: - raise KeyError('Type %s could not be found' % key) - res = res[0] - typ = self.add(*res) - self[typ.oid] = self[typ.pgtype] = typ - return typ - - def get(self, key, default=None): - """Get the type even if it is not cached.""" - try: - return self[key] - except KeyError: - return default - - def get_attnames(self, typ): - """Get names and types of the fields of a composite type.""" - if not isinstance(typ, DbType): - typ = self.get(typ) - if not typ: - return None - if not typ.relid: - return None - return self._db.get_attnames(typ.relid, with_oid=False) - - def get_typecast(self, typ): - """Get the typecast function for the given database type.""" - return self._typecasts.get(typ) - - def set_typecast(self, typ, cast): - """Set a typecast function for the specified database type(s).""" - self._typecasts.set(typ, cast) - - def reset_typecast(self, typ=None): - """Reset the typecast function for the specified database type(s).""" - self._typecasts.reset(typ) - - def typecast(self, value, typ): - """Cast the given value according to the given database type.""" - if value is None: - # for NULL values, no typecast is necessary - return None - if not isinstance(typ, DbType): - typ = self.get(typ) - if typ: - typ = typ.pgtype - cast = self.get_typecast(typ) if typ else None - if not cast or cast is str: - # no typecast is necessary - return value - return cast(value) - - -_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$') - -# The result rows for database operations are returned as named tuples -# by default. Since creating namedtuple classes is a somewhat expensive -# operation, we cache up to 1024 of these classes by default. - -@lru_cache(maxsize=1024) -def _row_factory(names): - """Get a namedtuple factory for row results with the given names.""" - try: - try: - return namedtuple('Row', names, rename=True)._make - except TypeError: # Python 2.6 and 3.0 do not support rename - names = [v if _re_fieldname.match(v) and not iskeyword(v) - else 'column_%d' % (n,) - for n, v in enumerate(names)] - return namedtuple('Row', names)._make - except ValueError: # there is still a problem with the field names - names = ['column_%d' % (n,) for n in range(len(names))] - return namedtuple('Row', names)._make - - -def set_row_factory_size(maxsize): - """Change the size of the namedtuple factory cache. - - If maxsize is set to None, the cache can grow without bound. - """ - global _row_factory - _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) - - -# Helper functions used by the query object - -def _dictiter(q): - """Get query result as an iterator of dictionaries.""" - fields = q.listfields() - for r in q: - yield dict(zip(fields, r)) - - -def _namediter(q): - """Get query result as an iterator of named tuples.""" - row = _row_factory(q.listfields()) - for r in q: - yield row(r) - - -def _namednext(q): - """Get next row from query result as a named tuple.""" - return _row_factory(q.listfields())(next(q)) - - -def _scalariter(q): - """Get query result as an iterator of scalar values.""" - for r in q: - yield r[0] - - -class _MemoryQuery: - """Class that embodies a given query result.""" - - def __init__(self, result, fields): - """Create query from given result rows and field names.""" - self.result = result - self.fields = tuple(fields) - - def listfields(self): - """Return the stored field names of this query.""" - return self.fields - - def getresult(self): - """Return the stored result of this query.""" - return self.result - - def __iter__(self): - return iter(self.result) - - -def _db_error(msg, cls=DatabaseError): - """Return DatabaseError with empty sqlstate attribute.""" - error = cls(msg) - error.sqlstate = None - return error - - -def _int_error(msg): - """Return InternalError.""" - return _db_error(msg, InternalError) - - -def _prg_error(msg): - """Return ProgrammingError.""" - return _db_error(msg, ProgrammingError) - - -# Initialize the C module - -set_decimal(Decimal) -set_jsondecode(jsondecode) -set_query_helpers(_dictiter, _namediter, _namednext, _scalariter) - - -# The notification handler - -class NotificationHandler(object): - """A PostgreSQL client-side asynchronous notification handler.""" - - def __init__(self, db, event, callback=None, - arg_dict=None, timeout=None, stop_event=None): - """Initialize the notification handler. - - You must pass a PyGreSQL database connection, the name of an - event (notification channel) to listen for and a callback function. - - You can also specify a dictionary arg_dict that will be passed as - the single argument to the callback function, and a timeout value - in seconds (a floating point number denotes fractions of seconds). - If it is absent or None, the callers will never time out. If the - timeout is reached, the callback function will be called with a - single argument that is None. If you set the timeout to zero, - the handler will poll notifications synchronously and return. - - You can specify the name of the event that will be used to signal - the handler to stop listening as stop_event. By default, it will - be the event name prefixed with 'stop_'. - """ - self.db = db - self.event = event - self.stop_event = stop_event or 'stop_%s' % event - self.listening = False - self.callback = callback - if arg_dict is None: - arg_dict = {} - self.arg_dict = arg_dict - self.timeout = timeout - - def __del__(self): - self.unlisten() - - def close(self): - """Stop listening and close the connection.""" - if self.db: - self.unlisten() - self.db.close() - self.db = None - - def listen(self): - """Start listening for the event and the stop event.""" - if not self.listening: - self.db.query('listen "%s"' % self.event) - self.db.query('listen "%s"' % self.stop_event) - self.listening = True - - def unlisten(self): - """Stop listening for the event and the stop event.""" - if self.listening: - self.db.query('unlisten "%s"' % self.event) - self.db.query('unlisten "%s"' % self.stop_event) - self.listening = False - - def notify(self, db=None, stop=False, payload=None): - """Generate a notification. - - Optionally, you can pass a payload with the notification. - - If you set the stop flag, a stop notification will be sent that - will cause the handler to stop listening. - - Note: If the notification handler is running in another thread, you - must pass a different database connection since PyGreSQL database - connections are not thread-safe. - """ - if self.listening: - if not db: - db = self.db - q = 'notify "%s"' % (self.stop_event if stop else self.event) - if payload: - q += ", '%s'" % payload - return db.query(q) - - def __call__(self): - """Invoke the notification handler. - - The handler is a loop that listens for notifications on the event - and stop event channels. When either of these notifications are - received, its associated 'pid', 'event' and 'extra' (the payload - passed with the notification) are inserted into its arg_dict - dictionary and the callback is invoked with this dictionary as - a single argument. When the handler receives a stop event, it - stops listening to both events and return. - - In the special case that the timeout of the handler has been set - to zero, the handler will poll all events synchronously and return. - If will keep listening until it receives a stop event. - - Note: If you run this loop in another thread, don't use the same - database connection for database operations in the main thread. - """ - self.listen() - poll = self.timeout == 0 - if not poll: - rlist = [self.db.fileno()] - while self.listening: - if poll or select.select(rlist, [], [], self.timeout)[0]: - while self.listening: - notice = self.db.getnotify() - if not notice: # no more messages - break - event, pid, extra = notice - if event not in (self.event, self.stop_event): - self.unlisten() - raise _db_error( - 'Listening for "%s" and "%s", but notified of "%s"' - % (self.event, self.stop_event, event)) - if event == self.stop_event: - self.unlisten() - self.arg_dict.update(pid=pid, event=event, extra=extra) - self.callback(self.arg_dict) - if poll: - break - else: # we timed out - self.unlisten() - self.callback(None) - - -def pgnotify(*args, **kw): - """Same as NotificationHandler, under the traditional name.""" - warnings.warn("pgnotify is deprecated, use NotificationHandler instead", - DeprecationWarning, stacklevel=2) - return NotificationHandler(*args, **kw) - - -# The actual PostgreSQL database connection interface: - -class DB: - """Wrapper class for the _pg connection type.""" - - db = None # invalid fallback for underlying connection - - def __init__(self, *args, **kw): - """Create a new connection - - You can pass either the connection parameters or an existing - _pg or pgdb connection. This allows you to use the methods - of the classic pg interface with a DB-API 2 pgdb connection. - """ - if not args and len(kw) == 1: - db = kw.get('db') - elif not kw and len(args) == 1: - db = args[0] - else: - db = None - if db: - if isinstance(db, DB): - db = db.db - else: - try: - db = db._cnx - except AttributeError: - pass - if not db or not hasattr(db, 'db') or not hasattr(db, 'query'): - db = connect(*args, **kw) - self._db_args = args, kw - self._closeable = True - else: - self._db_args = db - self._closeable = False - self.db = db - self.dbname = db.db - self._regtypes = False - self._attnames = {} - self._pkeys = {} - self._privileges = {} - self.adapter = Adapter(self) - self.dbtypes = DbTypes(self) - if db.server_version < 80400: - # support older remote data bases (not officially supported) - self._query_attnames = ( - "SELECT a.attname, t.oid, t.typname, t.typname::text::regtype," - " t.typtype, null as typcategory, t.typdelim, t.typrelid" - " FROM pg_catalog.pg_attribute a" - " JOIN pg_catalog.pg_type t" - " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=) %s::regclass AND %s" - " AND NOT a.attisdropped ORDER BY a.attnum") - else: - self._query_attnames = ( - "SELECT a.attname, t.oid, t.typname, t.typname::regtype," - " t.typtype, t.typcategory, t.typdelim, t.typrelid" - " FROM pg_catalog.pg_attribute a" - " JOIN pg_catalog.pg_type t" - " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" - " WHERE a.attrelid OPERATOR(pg_catalog.=) %s::regclass AND %s" - " AND NOT a.attisdropped ORDER BY a.attnum") - db.set_cast_hook(self.dbtypes.typecast) - self.debug = None # For debugging scripts, this can be set - # * to a string format specification (e.g. in CGI set to "%s
"), - # * to a file object to write debug statements or - # * to a callable object which takes a string argument - # * to any other true value to just print debug statements - - def __getattr__(self, name): - # All undefined members are same as in underlying connection: - if self.db: - return getattr(self.db, name) - else: - raise _int_error('Connection is not valid') - - def __dir__(self): - # Custom dir function including the attributes of the connection: - attrs = set(self.__class__.__dict__) - attrs.update(self.__dict__) - attrs.update(dir(self.db)) - return sorted(attrs) - - # Context manager methods - - def __enter__(self): - """Enter the runtime context. This will start a transaction.""" - self.begin() - return self - - def __exit__(self, et, ev, tb): - """Exit the runtime context. This will end the transaction.""" - if et is None and ev is None and tb is None: - self.commit() - else: - self.rollback() - - def __del__(self): - try: - db = self.db - except AttributeError: - db = None - if db: - try: - db.set_cast_hook(None) - except TypeError: - pass # probably already closed - if self._closeable: - try: - db.close() - except InternalError: - pass # probably already closed - - # Auxiliary methods - - def _do_debug(self, *args): - """Print a debug message""" - if self.debug: - s = '\n'.join(str(arg) for arg in args) - if isinstance(self.debug, basestring): - print(self.debug % s) - elif hasattr(self.debug, 'write'): - self.debug.write(s + '\n') - elif callable(self.debug): - self.debug(s) - else: - print(s) - - def _escape_qualified_name(self, s): - """Escape a qualified name. - - Escapes the name for use as an SQL identifier, unless the - name contains a dot, in which case the name is ambiguous - (could be a qualified name or just a name with a dot in it) - and must be quoted manually by the caller. - """ - if '.' not in s: - s = self.escape_identifier(s) - return s - - @staticmethod - def _make_bool(d): - """Get boolean value corresponding to d.""" - return bool(d) if get_bool() else ('t' if d else 'f') - - def _list_params(self, params): - """Create a human readable parameter list.""" - return ', '.join('$%d=%r' % (n, v) for n, v in enumerate(params, 1)) - - # Public methods - - # escape_string and escape_bytea exist as methods, - # so we define unescape_bytea as a method as well - unescape_bytea = staticmethod(unescape_bytea) - - def decode_json(self, s): - """Decode a JSON string coming from the database.""" - return (get_jsondecode() or jsondecode)(s) - - def encode_json(self, d): - """Encode a JSON string for use within SQL.""" - return jsonencode(d) - - def close(self): - """Close the database connection.""" - # Wraps shared library function so we can track state. - db = self.db - if db: - try: - db.set_cast_hook(None) - except TypeError: - pass # probably already closed - if self._closeable: - db.close() - self.db = None - else: - raise _int_error('Connection already closed') - - def reset(self): - """Reset connection with current parameters. - - All derived queries and large objects derived from this connection - will not be usable after this call. - - """ - if self.db: - self.db.reset() - else: - raise _int_error('Connection already closed') - - def reopen(self): - """Reopen connection to the database. - - Used in case we need another connection to the same database. - Note that we can still reopen a database that we have closed. - - """ - # There is no such shared library function. - if self._closeable: - db = connect(*self._db_args[0], **self._db_args[1]) - if self.db: - self.db.set_cast_hook(None) - self.db.close() - db.set_cast_hook(self.dbtypes.typecast) - self.db = db - else: - self.db = self._db_args - - def begin(self, mode=None): - """Begin a transaction.""" - qstr = 'BEGIN' - if mode: - qstr += ' ' + mode - return self.query(qstr) - - start = begin - - def commit(self): - """Commit the current transaction.""" - return self.query('COMMIT') - - end = commit - - def rollback(self, name=None): - """Roll back the current transaction.""" - qstr = 'ROLLBACK' - if name: - qstr += ' TO ' + name - return self.query(qstr) - - abort = rollback - - def savepoint(self, name): - """Define a new savepoint within the current transaction.""" - return self.query('SAVEPOINT ' + name) - - def release(self, name): - """Destroy a previously defined savepoint.""" - return self.query('RELEASE ' + name) - - def get_parameter(self, parameter): - """Get the value of a run-time parameter. - - If the parameter is a string, the return value will also be a string - that is the current setting of the run-time parameter with that name. - - You can get several parameters at once by passing a list, set or dict. - When passing a list of parameter names, the return value will be a - corresponding list of parameter settings. When passing a set of - parameter names, a new dict will be returned, mapping these parameter - names to their settings. Finally, if you pass a dict as parameter, - its values will be set to the current parameter settings corresponding - to its keys. - - By passing the special name 'all' as the parameter, you can get a dict - of all existing configuration parameters. - """ - if isinstance(parameter, basestring): - parameter = [parameter] - values = None - elif isinstance(parameter, (list, tuple)): - values = [] - elif isinstance(parameter, (set, frozenset)): - values = {} - elif isinstance(parameter, dict): - values = parameter - else: - raise TypeError( - 'The parameter must be a string, list, set or dict') - if not parameter: - raise TypeError('No parameter has been specified') - params = {} if isinstance(values, dict) else [] - for key in parameter: - param = key.strip().lower() if isinstance( - key, basestring) else None - if not param: - raise TypeError('Invalid parameter') - if param == 'all': - q = 'SHOW ALL' - values = self.db.query(q).getresult() - values = dict(value[:2] for value in values) - break - if isinstance(values, dict): - params[param] = key - else: - params.append(param) - else: - for param in params: - q = 'SHOW %s' % (param,) - value = self.db.query(q).getresult()[0][0] - if values is None: - values = value - elif isinstance(values, list): - values.append(value) - else: - values[params[param]] = value - return values - - def set_parameter(self, parameter, value=None, local=False): - """Set the value of a run-time parameter. - - If the parameter and the value are strings, the run-time parameter - will be set to that value. If no value or None is passed as a value, - then the run-time parameter will be restored to its default value. - - You can set several parameters at once by passing a list of parameter - names, together with a single value that all parameters should be - set to or with a corresponding list of values. You can also pass - the parameters as a set if you only provide a single value. - Finally, you can pass a dict with parameter names as keys. In this - case, you should not pass a value, since the values for the parameters - will be taken from the dict. - - By passing the special name 'all' as the parameter, you can reset - all existing settable run-time parameters to their default values. - - If you set local to True, then the command takes effect for only the - current transaction. After commit() or rollback(), the session-level - setting takes effect again. Setting local to True will appear to - have no effect if it is executed outside a transaction, since the - transaction will end immediately. - """ - if isinstance(parameter, basestring): - parameter = {parameter: value} - elif isinstance(parameter, (list, tuple)): - if isinstance(value, (list, tuple)): - parameter = dict(zip(parameter, value)) - else: - parameter = dict.fromkeys(parameter, value) - elif isinstance(parameter, (set, frozenset)): - if isinstance(value, (list, tuple, set, frozenset)): - value = set(value) - if len(value) == 1: - value = value.pop() - if not(value is None or isinstance(value, basestring)): - raise ValueError('A single value must be specified' - ' when parameter is a set') - parameter = dict.fromkeys(parameter, value) - elif isinstance(parameter, dict): - if value is not None: - raise ValueError('A value must not be specified' - ' when parameter is a dictionary') - else: - raise TypeError( - 'The parameter must be a string, list, set or dict') - if not parameter: - raise TypeError('No parameter has been specified') - params = {} - for key, value in parameter.items(): - param = key.strip().lower() if isinstance( - key, basestring) else None - if not param: - raise TypeError('Invalid parameter') - if param == 'all': - if value is not None: - raise ValueError('A value must ot be specified' - " when parameter is 'all'") - params = {'all': None} - break - params[param] = value - local = ' LOCAL' if local else '' - for param, value in params.items(): - if value is None: - q = 'RESET%s %s' % (local, param) - else: - q = 'SET%s %s TO %s' % (local, param, value) - self._do_debug(q) - self.db.query(q) - - def query(self, command, *args): - """Execute a SQL command string. - - This method simply sends a SQL query to the database. If the query is - an insert statement that inserted exactly one row into a table that - has OIDs, the return value is the OID of the newly inserted row. - If the query is an update or delete statement, or an insert statement - that did not insert exactly one row in a table with OIDs, then the - number of rows affected is returned as a string. If it is a statement - that returns rows as a result (usually a select statement, but maybe - also an "insert/update ... returning" statement), this method returns - a Query object that can be accessed via getresult() or dictresult() - or simply printed. Otherwise, it returns `None`. - - The query can contain numbered parameters of the form $1 in place - of any data constant. Arguments given after the query string will - be substituted for the corresponding numbered parameter. Parameter - values can also be given as a single list or tuple argument. - """ - # Wraps shared library function for debugging. - if not self.db: - raise _int_error('Connection is not valid') - if args: - self._do_debug(command, args) - return self.db.query(command, args) - self._do_debug(command) - return self.db.query(command) - - def query_formatted(self, command, - parameters=None, types=None, inline=False): - """Execute a formatted SQL command string. - - Similar to query, but using Python format placeholders of the form - %s or %(names)s instead of PostgreSQL placeholders of the form $1. - The parameters must be passed as a tuple, list or dict. You can - also pass a corresponding tuple, list or dict of database types in - order to format the parameters properly in case there is ambiguity. - - If you set inline to True, the parameters will be sent to the database - embedded in the SQL command, otherwise they will be sent separately. - """ - return self.query(*self.adapter.format_query( - command, parameters, types, inline)) - - def query_prepared(self, name, *args): - """Execute a prepared SQL statement. - - This works like the query() method, except that instead of passing - the SQL command, you pass the name of a prepared statement. If you - pass an empty name, the unnamed statement will be executed. - """ - if not self.db: - raise _int_error('Connection is not valid') - if name is None: - name = '' - if args: - self._do_debug('EXECUTE', name, args) - return self.db.query_prepared(name, args) - self._do_debug('EXECUTE', name) - return self.db.query_prepared(name) - - def prepare(self, name, command): - """Create a prepared SQL statement. - - This creates a prepared statement for the given command with the - given name for later execution with the query_prepared() method. - - The name can be empty to create an unnamed statement, in which case - any pre-existing unnamed statement is automatically replaced; - otherwise it is an error if the statement name is already - defined in the current database session. We recommend always using - named queries, since unnamed queries have a limited lifetime and - can be automatically replaced or destroyed by various operations. - """ - if not self.db: - raise _int_error('Connection is not valid') - if name is None: - name = '' - self._do_debug('prepare', name, command) - return self.db.prepare(name, command) - - def describe_prepared(self, name=None): - """Describe a prepared SQL statement. - - This method returns a Query object describing the result columns of - the prepared statement with the given name. If you omit the name, - the unnamed statement will be described if you created one before. - """ - if name is None: - name = '' - return self.db.describe_prepared(name) - - def delete_prepared(self, name=None): - """Delete a prepared SQL statement - - This deallocates a previously prepared SQL statement with the given - name, or deallocates all prepared statements if you do not specify a - name. Note that prepared statements are also deallocated automatically - when the current session ends. - """ - q = "DEALLOCATE %s" % (name or 'ALL',) - self._do_debug(q) - return self.db.query(q) - - def pkey(self, table, composite=False, flush=False): - """Get or set the primary key of a table. - - Single primary keys are returned as strings unless you - set the composite flag. Composite primary keys are always - represented as tuples. Note that this raises a KeyError - if the table does not have a primary key. - - If flush is set then the internal cache for primary keys will - be flushed. This may be necessary after the database schema or - the search path has been changed. - """ - pkeys = self._pkeys - if flush: - pkeys.clear() - self._do_debug('The pkey cache has been flushed') - try: # cache lookup - pkey = pkeys[table] - except KeyError: # cache miss, check the database - q = ("SELECT a.attname, a.attnum, i.indkey" - " FROM pg_catalog.pg_index i" - " JOIN pg_catalog.pg_attribute a" - " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" - " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" - " AND NOT a.attisdropped" - " WHERE i.indrelid OPERATOR(pg_catalog.=) %s::regclass" - " AND i.indisprimary ORDER BY a.attnum") % ( - _quote_if_unqualified('$1', table),) - pkey = self.db.query(q, (table,)).getresult() - if not pkey: - raise KeyError('Table %s has no primary key' % table) - # we want to use the order defined in the primary key index here, - # not the order as defined by the columns in the table - if len(pkey) > 1: - indkey = pkey[0][2] - pkey = sorted(pkey, key=lambda row: indkey.index(row[1])) - pkey = tuple(row[0] for row in pkey) - else: - pkey = pkey[0][0] - pkeys[table] = pkey # cache it - if composite and not isinstance(pkey, tuple): - pkey = (pkey,) - return pkey - - def get_databases(self): - """Get list of databases in the system.""" - return [s[0] for s in - self.db.query( - 'SELECT datname FROM pg_catalog.pg_database').getresult()] - - def get_relations(self, kinds=None, system=False): - """Get list of relations in connected database of specified kinds. - - If kinds is None or empty, all kinds of relations are returned. - Otherwise kinds can be a string or sequence of type letters - specifying which kind of relations you want to list. - - Set the system flag if you want to get the system relations as well. - """ - where = [] - if kinds: - where.append("r.relkind IN (%s)" % - ','.join("'%s'" % k for k in kinds)) - if not system: - where.append("s.nspname NOT SIMILAR" - " TO 'pg/_%|information/_schema' ESCAPE '/'") - where = " WHERE %s" % ' AND '.join(where) if where else '' - q = ("SELECT pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" - " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" - " FROM pg_catalog.pg_class r" - " JOIN pg_catalog.pg_namespace s" - " ON s.oid OPERATOR(pg_catalog.=) r.relnamespace%s" - " ORDER BY s.nspname, r.relname") % where - return [r[0] for r in self.db.query(q).getresult()] - - def get_tables(self, system=False): - """Return list of tables in connected database. - - Set the system flag if you want to get the system tables as well. - """ - return self.get_relations('r', system) - - def get_attnames(self, table, with_oid=True, flush=False): - """Given the name of a table, dig out the set of attribute names. - - Returns a read-only dictionary of attribute names (the names are - the keys, the values are the names of the attributes' types) - with the column names in the proper order if you iterate over it. - - If flush is set, then the internal cache for attribute names will - be flushed. This may be necessary after the database schema or - the search path has been changed. - - By default, only a limited number of simple types will be returned. - You can get the registered types after calling use_regtypes(True). - """ - attnames = self._attnames - if flush: - attnames.clear() - self._do_debug('The attnames cache has been flushed') - try: # cache lookup - names = attnames[table] - except KeyError: # cache miss, check the database - q = "a.attnum OPERATOR(pg_catalog.>) 0" - if with_oid: - q = "(%s OR a.attname OPERATOR(pg_catalog.=) 'oid')" % q - q = self._query_attnames % (_quote_if_unqualified('$1', table), q) - names = self.db.query(q, (table,)).getresult() - types = self.dbtypes - names = ((name[0], types.add(*name[1:])) for name in names) - names = AttrDict(names) - attnames[table] = names # cache it - return names - - def use_regtypes(self, regtypes=None): - """Use registered type names instead of simplified type names.""" - if regtypes is None: - return self.dbtypes._regtypes - else: - regtypes = bool(regtypes) - if regtypes != self.dbtypes._regtypes: - self.dbtypes._regtypes = regtypes - self._attnames.clear() - self.dbtypes.clear() - return regtypes - - def has_table_privilege(self, table, privilege='select', flush=False): - """Check whether current user has specified table privilege. - - If flush is set, then the internal cache for table privileges will - be flushed. This may be necessary after privileges have been changed. - """ - privileges = self._privileges - if flush: - privileges.clear() - self._do_debug('The privileges cache has been flushed') - privilege = privilege.lower() - try: # ask cache - ret = privileges[table, privilege] - except KeyError: # cache miss, ask the database - q = "SELECT pg_catalog.has_table_privilege(%s, $2)" % ( - _quote_if_unqualified('$1', table),) - q = self.db.query(q, (table, privilege)) - ret = q.getresult()[0][0] == self._make_bool(True) - privileges[table, privilege] = ret # cache it - return ret - - def get(self, table, row, keyname=None): - """Get a row from a database table or view. - - This method is the basic mechanism to get a single row. It assumes - that the keyname specifies a unique row. It must be the name of a - single column or a tuple of column names. If the keyname is not - specified, then the primary key for the table is used. - - If row is a dictionary, then the value for the key is taken from it. - Otherwise, the row must be a single value or a tuple of values - corresponding to the passed keyname or primary key. The fetched row - from the table will be returned as a new dictionary or used to replace - the existing values when row was passed as a dictionary. - - The OID is also put into the dictionary if the table has one, but - in order to allow the caller to work with multiple tables, it is - munged as "oid(table)" using the actual name of the table. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - attnames = self.get_attnames(table) - qoid = _oid_key(table) if 'oid' in attnames else None - if keyname and isinstance(keyname, basestring): - keyname = (keyname,) - if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row: - row['oid'] = row[qoid] - if not keyname: - try: # if keyname is not specified, try using the primary key - keyname = self.pkey(table, True) - except KeyError: # the table has no primary key - # try using the oid instead - if qoid and isinstance(row, dict) and 'oid' in row: - keyname = ('oid',) - else: - raise _prg_error('Table %s has no primary key' % table) - else: # the table has a primary key - # check whether all key columns have values - if isinstance(row, dict) and not set(keyname).issubset(row): - # try using the oid instead - if qoid and 'oid' in row: - keyname = ('oid',) - else: - raise KeyError( - 'Missing value in row for specified keyname') - if not isinstance(row, dict): - if not isinstance(row, (tuple, list)): - row = [row] - if len(keyname) != len(row): - raise KeyError( - 'Differing number of items in keyname and row') - row = dict(zip(keyname, row)) - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - what = 'oid, *' if qoid else '*' - where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % ( - col(k), adapt(row[k], attnames[k])) for k in keyname) - if 'oid' in row: - if qoid: - row[qoid] = row['oid'] - del row['oid'] - q = 'SELECT %s FROM %s WHERE %s LIMIT 1' % ( - what, self._escape_qualified_name(table), where) - self._do_debug(q, params) - q = self.db.query(q, params) - res = q.dictresult() - if not res: - # make where clause in error message better readable - where = where.replace('OPERATOR(pg_catalog.=)', '=') - raise _db_error('No such record in %s\nwhere %s\nwith %s' % ( - table, where, self._list_params(params))) - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - return row - - def insert(self, table, row=None, **kw): - """Insert a row into a database table. - - This method inserts a row into a table. The name of the table must - be passed as the first parameter. The other parameters are used for - providing the data of the row that shall be inserted into the table. - If a dictionary is supplied as the second parameter, it starts with - that. Otherwise it uses a blank dictionary. Either way the dictionary - is updated from the keywords. - - The dictionary is then reloaded with the values actually inserted in - order to pick up values modified by rules, triggers, etc. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - if row is None: - row = {} - row.update(kw) - if 'oid' in row: - del row['oid'] # do not insert oid - attnames = self.get_attnames(table) - qoid = _oid_key(table) if 'oid' in attnames else None - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - names, values = [], [] - for n in attnames: - if n in row: - names.append(col(n)) - values.append(adapt(row[n], attnames[n])) - if not names: - raise _prg_error('No column found that can be inserted') - names, values = ', '.join(names), ', '.join(values) - ret = 'oid, *' if qoid else '*' - q = 'INSERT INTO %s (%s) VALUES (%s) RETURNING %s' % ( - self._escape_qualified_name(table), names, values, ret) - self._do_debug(q, params) - q = self.db.query(q, params) - res = q.dictresult() - if res: # this should always be true - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - return row - - def update(self, table, row=None, **kw): - """Update an existing row in a database table. - - Similar to insert, but updates an existing row. The update is based - on the primary key of the table or the OID value as munged by get() - or passed as keyword. The OID will take precedence if provided, so - that it is possible to update the primary key itself. - - The dictionary is then modified to reflect any changes caused by the - update due to triggers, rules, default values, etc. - """ - if table.endswith('*'): - table = table[:-1].rstrip() # need parent table name - attnames = self.get_attnames(table) - qoid = _oid_key(table) if 'oid' in attnames else None - if row is None: - row = {} - elif 'oid' in row: - del row['oid'] # only accept oid key from named args for safety - row.update(kw) - if qoid and qoid in row and 'oid' not in row: - row['oid'] = row[qoid] - if qoid and 'oid' in row: # try using the oid - keyname = ('oid',) - else: # try using the primary key - try: - keyname = self.pkey(table, True) - except KeyError: # the table has no primary key - raise _prg_error('Table %s has no primary key' % table) - # check whether all key columns have values - if not set(keyname).issubset(row): - raise KeyError('Missing value for primary key in row') - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % ( - col(k), adapt(row[k], attnames[k])) for k in keyname) - if 'oid' in row: - if qoid: - row[qoid] = row['oid'] - del row['oid'] - values = [] - keyname = set(keyname) - for n in attnames: - if n in row and n not in keyname: - values.append('%s = %s' % (col(n), adapt(row[n], attnames[n]))) - if not values: - return row - values = ', '.join(values) - ret = 'oid, *' if qoid else '*' - q = 'UPDATE %s SET %s WHERE %s RETURNING %s' % ( - self._escape_qualified_name(table), values, where, ret) - self._do_debug(q, params) - q = self.db.query(q, params) - res = q.dictresult() - if res: # may be empty when row does not exist - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - return row - - def upsert(self, table, row=None, **kw): - """Insert a row into a database table with conflict resolution - - This method inserts a row into a table, but instead of raising a - ProgrammingError exception in case a row with the same primary key - already exists, an update will be executed instead. This will be - performed as a single atomic operation on the database, so race - conditions can be avoided. - - Like the insert method, the first parameter is the name of the - table and the second parameter can be used to pass the values to - be inserted as a dictionary. - - Unlike the insert und update statement, keyword parameters are not - used to modify the dictionary, but to specify which columns shall - be updated in case of a conflict, and in which way: - - A value of False or None means the column shall not be updated, - a value of True means the column shall be updated with the value - that has been proposed for insertion, i.e. has been passed as value - in the dictionary. Columns that are not specified by keywords but - appear as keys in the dictionary are also updated like in the case - keywords had been passed with the value True. - - So if in the case of a conflict you want to update every column that - has been passed in the dictionary row, you would call upsert(table, row). - If you don't want to do anything in case of a conflict, i.e. leave - the existing row as it is, call upsert(table, row, **dict.fromkeys(row)). - - If you need more fine-grained control of what gets updated, you can - also pass strings in the keyword parameters. These strings will - be used as SQL expressions for the update columns. In these - expressions you can refer to the value that already exists in - the table by prefixing the column name with "included.", and to - the value that has been proposed for insertion by prefixing the - column name with the "excluded." - - The dictionary is modified in any case to reflect the values in - the database after the operation has completed. - - Note: The method uses the PostgreSQL "upsert" feature which is - only available since PostgreSQL 9.5. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - if row is None: - row = {} - if 'oid' in row: - del row['oid'] # do not insert oid - if 'oid' in kw: - del kw['oid'] # do not update oid - attnames = self.get_attnames(table) - qoid = _oid_key(table) if 'oid' in attnames else None - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - names, values, updates = [], [], [] - for n in attnames: - if n in row: - names.append(col(n)) - values.append(adapt(row[n], attnames[n])) - names, values = ', '.join(names), ', '.join(values) - try: - keyname = self.pkey(table, True) - except KeyError: - raise _prg_error('Table %s has no primary key' % table) - target = ', '.join(col(k) for k in keyname) - update = [] - keyname = set(keyname) - keyname.add('oid') - for n in attnames: - if n not in keyname: - value = kw.get(n, True) - if value: - if not isinstance(value, basestring): - value = 'excluded.%s' % col(n) - update.append('%s = %s' % (col(n), value)) - if not values: - return row - do = 'update set %s' % ', '.join(update) if update else 'nothing' - ret = 'oid, *' if qoid else '*' - q = ('INSERT INTO %s AS included (%s) VALUES (%s)' - ' ON CONFLICT (%s) DO %s RETURNING %s') % ( - self._escape_qualified_name(table), names, values, - target, do, ret) - self._do_debug(q, params) - try: - q = self.db.query(q, params) - except ProgrammingError: - if self.server_version < 90500: - raise _prg_error( - 'Upsert operation is not supported by PostgreSQL version') - raise # re-raise original error - res = q.dictresult() - if res: # may be empty with "do nothing" - for n, value in res[0].items(): - if qoid and n == 'oid': - n = qoid - row[n] = value - else: - self.get(table, row) - return row - - def clear(self, table, row=None): - """Clear all the attributes to values determined by the types. - - Numeric types are set to 0, Booleans are set to false, and everything - else is set to the empty string. If the row argument is present, - it is used as the row dictionary and any entries matching attribute - names are cleared with everything else left unchanged. - """ - # At some point we will need a way to get defaults from a table. - if row is None: - row = {} # empty if argument is not present - attnames = self.get_attnames(table) - for n, t in attnames.items(): - if n == 'oid': - continue - t = t.simple - if t in DbTypes._num_types: - row[n] = 0 - elif t == 'bool': - row[n] = self._make_bool(False) - else: - row[n] = '' - return row - - def delete(self, table, row=None, **kw): - """Delete an existing row in a database table. - - This method deletes the row from a table. It deletes based on the - primary key of the table or the OID value as munged by get() or - passed as keyword. The OID will take precedence if provided. - - The return value is the number of deleted rows (i.e. 0 if the row - did not exist and 1 if the row was deleted). - - Note that if the row cannot be deleted because e.g. it is still - referenced by another table, this method raises a ProgrammingError. - """ - if table.endswith('*'): # hint for descendant tables can be ignored - table = table[:-1].rstrip() - attnames = self.get_attnames(table) - qoid = _oid_key(table) if 'oid' in attnames else None - if row is None: - row = {} - elif 'oid' in row: - del row['oid'] # only accept oid key from named args for safety - row.update(kw) - if qoid and qoid in row and 'oid' not in row: - row['oid'] = row[qoid] - if qoid and 'oid' in row: # try using the oid - keyname = ('oid',) - else: # try using the primary key - try: - keyname = self.pkey(table, True) - except KeyError: # the table has no primary key - raise _prg_error('Table %s has no primary key' % table) - # check whether all key columns have values - if not set(keyname).issubset(row): - raise KeyError('Missing value for primary key in row') - params = self.adapter.parameter_list() - adapt = params.add - col = self.escape_identifier - where = ' AND '.join('%s OPERATOR(pg_catalog.=) %s' % ( - col(k), adapt(row[k], attnames[k])) for k in keyname) - if 'oid' in row: - if qoid: - row[qoid] = row['oid'] - del row['oid'] - q = 'DELETE FROM %s WHERE %s' % ( - self._escape_qualified_name(table), where) - self._do_debug(q, params) - res = self.db.query(q, params) - return int(res) - - def truncate(self, table, restart=False, cascade=False, only=False): - """Empty a table or set of tables. - - This method quickly removes all rows from the given table or set - of tables. It has the same effect as an unqualified DELETE on each - table, but since it does not actually scan the tables it is faster. - Furthermore, it reclaims disk space immediately, rather than requiring - a subsequent VACUUM operation. This is most useful on large tables. - - If restart is set to True, sequences owned by columns of the truncated - table(s) are automatically restarted. If cascade is set to True, it - also truncates all tables that have foreign-key references to any of - the named tables. If the parameter only is not set to True, all the - descendant tables (if any) will also be truncated. Optionally, a '*' - can be specified after the table name to explicitly indicate that - descendant tables are included. - """ - if isinstance(table, basestring): - only = {table: only} - table = [table] - elif isinstance(table, (list, tuple)): - if isinstance(only, (list, tuple)): - only = dict(zip(table, only)) - else: - only = dict.fromkeys(table, only) - elif isinstance(table, (set, frozenset)): - only = dict.fromkeys(table, only) - else: - raise TypeError('The table must be a string, list or set') - if not (restart is None or isinstance(restart, (bool, int))): - raise TypeError('Invalid type for the restart option') - if not (cascade is None or isinstance(cascade, (bool, int))): - raise TypeError('Invalid type for the cascade option') - tables = [] - for t in table: - u = only.get(t) - if not (u is None or isinstance(u, (bool, int))): - raise TypeError('Invalid type for the only option') - if t.endswith('*'): - if u: - raise ValueError( - 'Contradictory table name and only options') - t = t[:-1].rstrip() - t = self._escape_qualified_name(t) - if u: - t = 'ONLY %s' % t - tables.append(t) - q = ['TRUNCATE', ', '.join(tables)] - if restart: - q.append('RESTART IDENTITY') - if cascade: - q.append('CASCADE') - q = ' '.join(q) - self._do_debug(q) - return self.db.query(q) - - def get_as_list(self, table, what=None, where=None, - order=None, limit=None, offset=None, scalar=False): - """Get a table as a list. - - This gets a convenient representation of the table as a list - of named tuples in Python. You only need to pass the name of - the table (or any other SQL expression returning rows). Note that - by default this will return the full content of the table which - can be huge and overflow your memory. However, you can control - the amount of data returned using the other optional parameters. - - The parameter 'what' can restrict the query to only return a - subset of the table columns. It can be a string, list or a tuple. - The parameter 'where' can restrict the query to only return a - subset of the table rows. It can be a string, list or a tuple - of SQL expressions that all need to be fulfilled. The parameter - 'order' specifies the ordering of the rows. It can also be a - other string, list or a tuple. If no ordering is specified, - the result will be ordered by the primary key(s) or all columns - if no primary key exists. You can set 'order' to False if you - don't care about the ordering. The parameters 'limit' and 'offset' - can be integers specifying the maximum number of rows returned - and a number of rows skipped over. - - If you set the 'scalar' option to True, then instead of the - named tuples you will get the first items of these tuples. - This is useful if the result has only one column anyway. - """ - if not table: - raise TypeError('The table name is missing') - if what: - if isinstance(what, (list, tuple)): - what = ', '.join(map(str, what)) - if order is None: - order = what - else: - what = '*' - q = ['SELECT', what, 'FROM', table] - if where: - if isinstance(where, (list, tuple)): - where = ' AND '.join(map(str, where)) - q.extend(['WHERE', where]) - if order is None: - try: - order = self.pkey(table, True) - except (KeyError, ProgrammingError): - try: - order = list(self.get_attnames(table)) - except (KeyError, ProgrammingError): - pass - if order: - if isinstance(order, (list, tuple)): - order = ', '.join(map(str, order)) - q.extend(['ORDER BY', order]) - if limit: - q.append('LIMIT %d' % limit) - if offset: - q.append('OFFSET %d' % offset) - q = ' '.join(q) - self._do_debug(q) - q = self.db.query(q) - res = q.namedresult() - if res and scalar: - res = [row[0] for row in res] - return res - - def get_as_dict(self, table, keyname=None, what=None, where=None, - order=None, limit=None, offset=None, scalar=False): - """Get a table as a dictionary. - - This method is similar to get_as_list(), but returns the table - as a Python dict instead of a Python list, which can be even - more convenient. The primary key column(s) of the table will - be used as the keys of the dictionary, while the other column(s) - will be the corresponding values. The keys will be named tuples - if the table has a composite primary key. The rows will be also - named tuples unless the 'scalar' option has been set to True. - With the optional parameter 'keyname' you can specify an alternative - set of columns to be used as the keys of the dictionary. It must - be set as a string, list or a tuple. - - If the Python version supports it, the dictionary will be an - OrderedDict using the order specified with the 'order' parameter - or the key column(s) if not specified. You can set 'order' to False - if you don't care about the ordering. In this case the returned - dictionary will be an ordinary one. - """ - if not table: - raise TypeError('The table name is missing') - if not keyname: - try: - keyname = self.pkey(table, True) - except (KeyError, ProgrammingError): - raise _prg_error('Table %s has no primary key' % table) - if isinstance(keyname, basestring): - keyname = [keyname] - elif not isinstance(keyname, (list, tuple)): - raise KeyError('The keyname must be a string, list or tuple') - if what: - if isinstance(what, (list, tuple)): - what = ', '.join(map(str, what)) - if order is None: - order = what - else: - what = '*' - q = ['SELECT', what, 'FROM', table] - if where: - if isinstance(where, (list, tuple)): - where = ' AND '.join(map(str, where)) - q.extend(['WHERE', where]) - if order is None: - order = keyname - if order: - if isinstance(order, (list, tuple)): - order = ', '.join(map(str, order)) - q.extend(['ORDER BY', order]) - if limit: - q.append('LIMIT %d' % limit) - if offset: - q.append('OFFSET %d' % offset) - q = ' '.join(q) - self._do_debug(q) - q = self.db.query(q) - res = q.getresult() - cls = OrderedDict if order else dict - if not res: - return cls() - keyset = set(keyname) - fields = q.listfields() - if not keyset.issubset(fields): - raise KeyError('Missing keyname in row') - keyind, rowind = [], [] - for i, f in enumerate(fields): - (keyind if f in keyset else rowind).append(i) - keytuple = len(keyind) > 1 - getkey = itemgetter(*keyind) - keys = map(getkey, res) - if scalar: - rowind = rowind[:1] - rowtuple = False - else: - rowtuple = len(rowind) > 1 - if scalar or rowtuple: - getrow = itemgetter(*rowind) - else: - rowind = rowind[0] - getrow = lambda row: (row[rowind],) - rowtuple = True - rows = map(getrow, res) - if keytuple or rowtuple: - if keytuple: - keys = _namediter(_MemoryQuery(keys, keyname)) - if rowtuple: - fields = [f for f in fields if f not in keyset] - rows = _namediter(_MemoryQuery(rows, fields)) - return cls(zip(keys, rows)) - - def notification_handler(self, - event, callback, arg_dict=None, timeout=None, stop_event=None): - """Get notification handler that will run the given callback.""" - return NotificationHandler(self, - event, callback, arg_dict, timeout, stop_event) - - -# if run as script, print some information - -if __name__ == '__main__': - print('PyGreSQL version' + version) - print('') - print(__doc__) diff --git a/pg/__init__.py b/pg/__init__.py new file mode 100644 index 00000000..c3b7f4e9 --- /dev/null +++ b/pg/__init__.py @@ -0,0 +1,186 @@ +#!/usr/bin/python +# +# PyGreSQL - a Python interface for the PostgreSQL database. +# +# This file contains the classic pg module. +# +# Copyright (c) 2025 by the PyGreSQL Development Team +# +# The notification handler is based on pgnotify which is +# Copyright (c) 2001 Ng Pheng Siong. All rights reserved. +# +# Please see the LICENSE.TXT file for specific restrictions. + +"""PyGreSQL classic interface. + +This pg module implements some basic database management stuff. +It includes the _pg module and builds on it, providing the higher +level wrapper class named DB with additional functionality. +This is known as the "classic" ("old style") PyGreSQL interface. +For a DB-API 2 compliant interface use the newer pgdb module. +""" + +from __future__ import annotations + +from .adapt import Adapter, Bytea, Hstore, Json, Literal +from .cast import Typecasts, get_typecast, set_typecast +from .core import ( + INV_READ, + INV_WRITE, + POLLING_FAILED, + POLLING_OK, + POLLING_READING, + POLLING_WRITING, + RESULT_DDL, + RESULT_DML, + RESULT_DQL, + RESULT_EMPTY, + SEEK_CUR, + SEEK_END, + SEEK_SET, + TRANS_ACTIVE, + TRANS_IDLE, + TRANS_INERROR, + TRANS_INTRANS, + TRANS_UNKNOWN, + Connection, + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + InvalidResultError, + MultipleResultsError, + NoResultError, + NotSupportedError, + OperationalError, + ProgrammingError, + Query, + Warning, + cast_array, + cast_hstore, + cast_record, + connect, + escape_bytea, + escape_string, + get_array, + get_bool, + get_bytea_escaped, + get_datestyle, + get_decimal, + get_decimal_point, + get_defbase, + get_defhost, + get_defopt, + get_defport, + get_defuser, + get_jsondecode, + get_pqlib_version, + set_array, + set_bool, + set_bytea_escaped, + set_datestyle, + set_decimal, + set_decimal_point, + set_defbase, + set_defhost, + set_defopt, + set_defpasswd, + set_defport, + set_defuser, + set_jsondecode, + set_query_helpers, + unescape_bytea, + version, +) +from .db import DB +from .helpers import RowCache, init_core +from .notify import NotificationHandler + +__all__ = [ + 'DB', + 'INV_READ', + 'INV_WRITE', + 'POLLING_FAILED', + 'POLLING_OK', + 'POLLING_READING', + 'POLLING_WRITING', + 'RESULT_DDL', + 'RESULT_DML', + 'RESULT_DQL', + 'RESULT_EMPTY', + 'SEEK_CUR', + 'SEEK_END', + 'SEEK_SET', + 'TRANS_ACTIVE', + 'TRANS_IDLE', + 'TRANS_INERROR', + 'TRANS_INTRANS', + 'TRANS_UNKNOWN', + 'Adapter', + 'Bytea', + 'Connection', + 'DataError', + 'DatabaseError', + 'Error', + 'Hstore', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'InvalidResultError', + 'Json', + 'Literal', + 'MultipleResultsError', + 'NoResultError', + 'NotSupportedError', + 'NotificationHandler', + 'OperationalError', + 'ProgrammingError', + 'Query', + 'RowCache', + 'Typecasts', + 'Warning', + '__version__', + 'cast_array', + 'cast_hstore', + 'cast_record', + 'connect', + 'escape_bytea', + 'escape_string', + 'get_array', + 'get_bool', + 'get_bytea_escaped', + 'get_datestyle', + 'get_decimal', + 'get_decimal_point', + 'get_defbase', + 'get_defhost', + 'get_defopt', + 'get_defport', + 'get_defuser', + 'get_jsondecode', + 'get_pqlib_version', + 'get_typecast', + 'set_array', + 'set_bool', + 'set_bytea_escaped', + 'set_datestyle', + 'set_decimal', + 'set_decimal_point', + 'set_defbase', + 'set_defhost', + 'set_defopt', + 'set_defpasswd', + 'set_defport', + 'set_defuser', + 'set_jsondecode', + 'set_query_helpers', + 'set_typecast', + 'unescape_bytea', + 'version', +] + +__version__ = version + +init_core() diff --git a/pg/_pg.pyi b/pg/_pg.pyi new file mode 100644 index 00000000..b14bd5fc --- /dev/null +++ b/pg/_pg.pyi @@ -0,0 +1,638 @@ +"""Type hints for the PyGreSQL C extension.""" + +from __future__ import annotations + +from typing import Any, Callable, Iterable, Sequence, TypeVar + +try: + AnyStr = TypeVar('AnyStr', str, bytes, str | bytes) +except TypeError: # Python < 3.10 + AnyStr = Any # type: ignore +SomeNamedTuple = Any # alias for accessing arbitrary named tuples + +version: str +__version__: str + +RESULT_EMPTY: int +RESULT_DML: int +RESULT_DDL: int +RESULT_DQL: int + +TRANS_IDLE: int +TRANS_ACTIVE: int +TRANS_INTRANS: int +TRANS_INERROR: int +TRANS_UNKNOWN: int + +POLLING_OK: int +POLLING_FAILED: int +POLLING_READING: int +POLLING_WRITING: int + +INV_READ: int +INV_WRITE: int + +SEEK_SET: int +SEEK_CUR: int +SEEK_END: int + + +class Error(Exception): + """Exception that is the base class of all other error exceptions.""" + + +class Warning(Exception): # noqa: N818 + """Exception raised for important warnings.""" + + +class InterfaceError(Error): + """Exception raised for errors related to the database interface.""" + + +class DatabaseError(Error): + """Exception raised for errors that are related to the database.""" + + sqlstate: str | None + + +class InternalError(DatabaseError): + """Exception raised when the database encounters an internal error.""" + + +class OperationalError(DatabaseError): + """Exception raised for errors related to the operation of the database.""" + + +class ProgrammingError(DatabaseError): + """Exception raised for programming errors.""" + + +class IntegrityError(DatabaseError): + """Exception raised when the relational integrity is affected.""" + + +class DataError(DatabaseError): + """Exception raised for errors due to problems with the processed data.""" + + +class NotSupportedError(DatabaseError): + """Exception raised when a method or database API is not supported.""" + + +class InvalidResultError(DataError): + """Exception when a database operation produced an invalid result.""" + + +class NoResultError(InvalidResultError): + """Exception when a database operation did not produce any result.""" + + +class MultipleResultsError(InvalidResultError): + """Exception when a database operation produced multiple results.""" + + +class Source: + """Source object.""" + + arraysize: int + resulttype: int + ntuples: int + nfields: int + + def execute(self, sql: str) -> int | None: + """Execute a SQL statement.""" + ... + + def fetch(self, num: int) -> list[tuple]: + """Return the next num rows from the last result in a list.""" + ... + + def listinfo(self) -> tuple[tuple[int, str, int, int, int], ...]: + """Get information for all fields.""" + ... + + def oidstatus(self) -> int | None: + """Return oid of last inserted row (if available).""" + ... + + def putdata(self, buffer: str | bytes | BaseException | None + ) -> int | None: + """Send data to server during copy from stdin.""" + ... + + def getdata(self, decode: bool | None = None) -> str | bytes | int: + """Receive data to server during copy to stdout.""" + ... + + def close(self) -> None: + """Close query object without deleting it.""" + ... + + +class LargeObject: + """Large object.""" + + oid: int + pgcnx: Connection + error: str + + def open(self, mode: int) -> None: + """Open a large object. + + The valid values for 'mode' parameter are defined as the module level + constants INV_READ and INV_WRITE. + """ + ... + + def close(self) -> None: + """Close a large object.""" + ... + + def read(self, size: int) -> bytes: + """Read data from large object.""" + ... + + def write(self, data: bytes) -> None: + """Write data to large object.""" + ... + + def seek(self, offset: int, whence: int) -> int: + """Change current position in large object. + + The valid values for the 'whence' parameter are defined as the + module level constants SEEK_SET, SEEK_CUR and SEEK_END. + """ + ... + + def unlink(self) -> None: + """Delete large object.""" + ... + + def size(self) -> int: + """Return the large object size.""" + ... + + def export(self, filename: str) -> None: + """Export a large object to a file.""" + ... + + +class Connection: + """Connection object. + + This object handles a connection to a PostgreSQL database. + It embeds and hides all the parameters that define this connection, + thus just leaving really significant parameters in function calls. + """ + + host: str + port: int + db: str + options: str + error: str + status: int + user : str + protocol_version: int + server_version: int + socket: int + backend_pid: int + ssl_in_use: bool + ssl_attributes: dict[str, str | None] + + def source(self) -> Source: + """Create a new source object for this connection.""" + ... + + def query(self, cmd: str, args: Sequence | None = None) -> Query: + """Create a new query object for this connection. + + Note that if the command is something other than DQL, this method + can return an int, str or None instead of a Query. + """ + ... + + def send_query(self, cmd: str, args: Sequence | None = None) -> Query: + """Create a new asynchronous query object for this connection.""" + ... + + def query_prepared(self, name: str, args: Sequence | None = None) -> Query: + """Execute a prepared statement.""" + ... + + def prepare(self, name: str, cmd: str) -> None: + """Create a prepared statement.""" + ... + + def describe_prepared(self, name: str) -> Query: + """Describe a prepared statement.""" + ... + + def poll(self) -> int: + """Complete an asynchronous connection and get its state.""" + ... + + def reset(self) -> None: + """Reset the connection.""" + ... + + def cancel(self) -> None: + """Abandon processing of current SQL command.""" + ... + + def close(self) -> None: + """Close the database connection.""" + ... + + def fileno(self) -> int: + """Get the socket used to connect to the database.""" + ... + + def get_cast_hook(self) -> Callable | None: + """Get the function that handles all external typecasting.""" + ... + + def set_cast_hook(self, hook: Callable | None) -> None: + """Set a function that will handle all external typecasting.""" + ... + + def get_notice_receiver(self) -> Callable | None: + """Get the current notice receiver.""" + ... + + def set_notice_receiver(self, receiver: Callable | None) -> None: + """Set a custom notice receiver.""" + ... + + def getnotify(self) -> tuple[str, int, str] | None: + """Get the last notify from the server.""" + ... + + def inserttable(self, table: str, values: Sequence[list|tuple], + columns: list[str] | tuple[str, ...] | None = None) -> int: + """Insert a Python iterable into a database table.""" + ... + + def transaction(self) -> int: + """Get the current in-transaction status of the server. + + The status returned by this method can be TRANS_IDLE (currently idle), + TRANS_ACTIVE (a command is in progress), TRANS_INTRANS (idle, in a + valid transaction block), or TRANS_INERROR (idle, in a failed + transaction block). TRANS_UNKNOWN is reported if the connection is + bad. The status TRANS_ACTIVE is reported only when a query has been + sent to the server and not yet completed. + """ + ... + + def parameter(self, name: str) -> str | None: + """Look up a current parameter setting of the server.""" + ... + + def date_format(self) -> str: + """Look up the date format currently being used by the database.""" + ... + + def escape_literal(self, s: AnyStr) -> AnyStr: + """Escape a literal constant for use within SQL.""" + ... + + def escape_identifier(self, s: AnyStr) -> AnyStr: + """Escape an identifier for use within SQL.""" + ... + + def escape_string(self, s: AnyStr) -> AnyStr: + """Escape a string for use within SQL.""" + ... + + def escape_bytea(self, s: AnyStr) -> AnyStr: + """Escape binary data for use within SQL as type 'bytea'.""" + ... + + def putline(self, line: str) -> None: + """Write a line to the server socket.""" + ... + + def getline(self) -> str: + """Get a line from server socket.""" + ... + + def endcopy(self) -> None: + """Synchronize client and server.""" + ... + + def set_non_blocking(self, nb: bool) -> None: + """Set the non-blocking mode of the connection.""" + ... + + def is_non_blocking(self) -> bool: + """Get the non-blocking mode of the connection.""" + ... + + def locreate(self, mode: int) -> LargeObject: + """Create a large object in the database. + + The valid values for 'mode' parameter are defined as the module level + constants INV_READ and INV_WRITE. + """ + ... + + def getlo(self, oid: int) -> LargeObject: + """Build a large object from given oid.""" + ... + + def loimport(self, filename: str) -> LargeObject: + """Import a file to a large object.""" + ... + + +class Query: + """Query object. + + The Query object returned by Connection.query and DB.query can be used + as an iterable returning rows as tuples. You can also directly access + row tuples using their index, and get the number of rows with the + len() function. The Query class also provides the several methods + for accessing the results of the query. + """ + + def __len__(self) -> int: + ... + + def __getitem__(self, key: int) -> object: + ... + + def __iter__(self) -> Query: + ... + + def __next__(self) -> tuple: + ... + + def getresult(self) -> list[tuple]: + """Get query values as list of tuples.""" + ... + + def dictresult(self) -> list[dict[str, object]]: + """Get query values as list of dictionaries.""" + ... + + def dictiter(self) -> Iterable[dict[str, object]]: + """Get query values as iterable of dictionaries.""" + ... + + def namedresult(self) -> list[SomeNamedTuple]: + """Get query values as list of named tuples.""" + ... + + def namediter(self) -> Iterable[SomeNamedTuple]: + """Get query values as iterable of named tuples.""" + ... + + def one(self) -> tuple | None: + """Get one row from the result of a query as a tuple.""" + ... + + def single(self) -> tuple: + """Get single row from the result of a query as a tuple.""" + ... + + def onedict(self) -> dict[str, object] | None: + """Get one row from the result of a query as a dictionary.""" + ... + + def singledict(self) -> dict[str, object]: + """Get single row from the result of a query as a dictionary.""" + ... + + def onenamed(self) -> SomeNamedTuple | None: + """Get one row from the result of a query as named tuple.""" + ... + + def singlenamed(self) -> SomeNamedTuple: + """Get single row from the result of a query as named tuple.""" + ... + + def scalarresult(self) -> list: + """Get first fields from query result as list of scalar values.""" + + def scalariter(self) -> Iterable: + """Get first fields from query result as iterable of scalar values.""" + ... + + def onescalar(self) -> object | None: + """Get one row from the result of a query as scalar value.""" + ... + + def singlescalar(self) -> object: + """Get single row from the result of a query as scalar value.""" + ... + + def fieldname(self, num: int) -> str: + """Get field name from its number.""" + ... + + def fieldnum(self, name: str) -> int: + """Get field number from its name.""" + ... + + def listfields(self) -> tuple[str, ...]: + """List field names of query result.""" + ... + + def fieldinfo(self, column: int | str | None) -> tuple[str, int, int, int]: + """Get information on one or all fields of the query. + + The four-tuples contain the following information: + The field name, the internal OID number of the field type, + the size in bytes of the column or a negative value if it is + of variable size, and a type-specific modifier value. + """ + ... + + def memsize(self) -> int: + """Return number of bytes allocated by query result.""" + ... + + +def connect(dbname: str | None = None, + host: str | None = None, + port: int | None = None, + opt: str | None = None, + user: str | None = None, + passwd: str | None = None, + nowait: int | None = None) -> Connection: + """Connect to a PostgreSQL database.""" + ... + + +def cast_array(s: str, cast: Callable | None = None, + delim: bytes | None = None) -> list: + """Cast a string representing a PostgreSQL array to a Python list.""" + ... + + +def cast_record(s: str, + cast: Callable | list[Callable | None] | + tuple[Callable | None, ...] | None = None, + delim: bytes | None = None) -> tuple: + """Cast a string representing a PostgreSQL record to a Python tuple.""" + ... + + +def cast_hstore(s: str) -> dict[str, str | None]: + """Cast a string as a hstore.""" + ... + + +def escape_bytea(s: AnyStr) -> AnyStr: + """Escape binary data for use within SQL as type 'bytea'.""" + ... + + +def unescape_bytea(s: AnyStr) -> bytes: + """Unescape 'bytea' data that has been retrieved as text.""" + ... + + +def escape_string(s: AnyStr) -> AnyStr: + """Escape a string for use within SQL.""" + ... + + +def get_pqlib_version() -> int: + """Get the version of libpq that is being used by PyGreSQL.""" + ... + + +def get_array() -> bool: + """Check whether arrays are returned as list objects.""" + ... + + +def set_array(on: bool) -> None: + """Set whether arrays are returned as list objects.""" + ... + + +def get_bool() -> bool: + """Check whether boolean values are returned as bool objects.""" + ... + + +def set_bool(on: bool | int) -> None: + """Set whether boolean values are returned as bool objects.""" + ... + + +def get_bytea_escaped() -> bool: + """Check whether 'bytea' values are returned as escaped strings.""" + ... + + +def set_bytea_escaped(on: bool | int) -> None: + """Set whether 'bytea' values are returned as escaped strings.""" + ... + + +def get_datestyle() -> str | None: + """Get the assumed date style for typecasting.""" + ... + + +def set_datestyle(datestyle: str | None) -> None: + """Set a fixed date style that shall be assumed when typecasting.""" + ... + + +def get_decimal() -> type: + """Get the decimal type to be used for numeric values.""" + ... + + +def set_decimal(cls: type) -> None: + """Set a fixed date style that shall be assumed when typecasting.""" + ... + + +def get_decimal_point() -> str | None: + """Get the decimal mark used for monetary values.""" + ... + + +def set_decimal_point(mark: str | None) -> None: + """Specify which decimal mark is used for interpreting monetary values.""" + ... + + +def get_jsondecode() -> Callable[[str], object] | None: + """Get the function that deserializes JSON formatted strings.""" + ... + + +def set_jsondecode(decode: Callable[[str], object] | None) -> None: + """Set a function that will deserialize JSON formatted strings.""" + ... + + +def get_defbase() -> str | None: + """Get the default database name.""" + ... + + +def set_defbase(base: str | None) -> None: + """Set the default database name.""" + ... + + +def get_defhost() -> str | None: + """Get the default host.""" + ... + + +def set_defhost(host: str | None) -> None: + """Set the default host.""" + ... + + +def get_defport() -> int | None: + """Get the default host.""" + ... + + +def set_defport(port: int | None) -> None: + """Set the default port.""" + ... + + +def get_defopt() -> str | None: + """Get the default connection options.""" + ... + + +def set_defopt(opt: str | None) -> None: + """Set the default connection options.""" + ... + + +def get_defuser() -> str | None: + """Get the default database user.""" + ... + + +def set_defuser(user: str | None) -> None: + """Set the default database user.""" + ... + + +def get_defpasswd() -> str | None: + """Get the default database password.""" + ... + + +def set_defpasswd(passwd: str | None) -> None: + """Set the default database password.""" + ... + + +def set_query_helpers(*helpers: Callable) -> None: + """Set internal query helper functions.""" + ... diff --git a/pg/adapt.py b/pg/adapt.py new file mode 100644 index 00000000..97e0391c --- /dev/null +++ b/pg/adapt.py @@ -0,0 +1,686 @@ +"""Adaptation of parameters.""" + +from __future__ import annotations + +import weakref +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from json import dumps as jsonencode +from math import isinf, isnan +from re import compile as regex +from types import MappingProxyType +from typing import TYPE_CHECKING, Any, Callable, List, Mapping, Sequence +from uuid import UUID + +from .attrs import AttrDict +from .cast import Typecasts +from .core import InterfaceError, ProgrammingError +from .helpers import quote_if_unqualified + +if TYPE_CHECKING: + from .db import DB + +__all__ = [ + 'UUID', + 'Adapter', + 'Bytea', + 'DbType', + 'DbTypes', + 'Hstore', + 'Json', + 'Literal' +] + + +class Bytea(bytes): + """Wrapper class for marking Bytea values.""" + + +class Hstore(dict): + """Wrapper class for marking hstore values.""" + + _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') + + @classmethod + def _quote(cls, s: Any) -> str: + if s is None: + return 'NULL' + if not isinstance(s, str): + s = str(s) + if not s: + return '""' + s = s.replace('"', '\\"') + if cls._re_quote.search(s): + s = f'"{s}"' + return s + + def __str__(self) -> str: + """Create a printable representation of the hstore value.""" + q = self._quote + return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) + + +class Json: + """Wrapper class for marking Json values.""" + + def __init__(self, obj: Any, + encode: Callable[[Any], str] | None = None) -> None: + """Initialize the JSON object.""" + self.obj = obj + self.encode = encode or jsonencode + + def __str__(self) -> str: + """Create a printable representation of the JSON object.""" + obj = self.obj + if isinstance(obj, str): + return obj + return self.encode(obj) + + +class Literal(str): + """Wrapper class for marking literal SQL values.""" + + + +class _SimpleTypes(dict): + """Dictionary mapping pg_type names to simple type names. + + The corresponding Python types and simple names are also mapped. + """ + + _type_aliases: Mapping[str, list[str | type]] = MappingProxyType({ + 'bool': [bool], + 'bytea': [Bytea], + 'date': ['interval', 'time', 'timetz', 'timestamp', 'timestamptz', + 'abstime', 'reltime', # these are very old + 'datetime', 'timedelta', # these do not really exist + date, time, datetime, timedelta], + 'float': ['float4', 'float8', float], + 'int': ['cid', 'int2', 'int4', 'int8', 'oid', 'xid', int], + 'hstore': [Hstore], 'json': ['jsonb', Json], 'uuid': [UUID], + 'num': ['numeric', Decimal], 'money': [], + 'text': ['bpchar', 'char', 'name', 'varchar', bytes, str] + }) + + # noinspection PyMissingConstructor + def __init__(self) -> None: + """Initialize type mapping.""" + for typ, keys in self._type_aliases.items(): + keys = [typ, *keys] + for key in keys: + self[key] = typ + if isinstance(key, str): + self[f'_{key}'] = f'{typ}[]' + elif not isinstance(key, tuple): + self[List[key]] = f'{typ}[]' # type: ignore + + @staticmethod + def __missing__(key: str) -> str: + """Unmapped types are interpreted as text.""" + return 'text' + + def get_type_dict(self) -> dict[type, str]: + """Get a plain dictionary of only the types.""" + return {key: typ for key, typ in self.items() + if not isinstance(key, (str, tuple))} + + +_simpletypes = _SimpleTypes() +_simple_type_dict = _simpletypes.get_type_dict() + + +class _ParameterList(list): + """Helper class for building typed parameter lists.""" + + adapt: Callable + + def add(self, value: Any, typ:Any = None) -> str: + """Typecast value with known database type and build parameter list. + + If this is a literal value, it will be returned as is. Otherwise, a + placeholder will be returned and the parameter list will be augmented. + """ + # noinspection PyUnresolvedReferences + value = self.adapt(value, typ) + if isinstance(value, Literal): + return value + self.append(value) + return f'${len(self)}' + + + +class DbType(str): + """Class augmenting the simple type name with additional info. + + The following additional information is provided: + + oid: the PostgreSQL type OID + pgtype: the internal PostgreSQL data type name + regtype: the registered PostgreSQL data type name + simple: the more coarse-grained PyGreSQL type name + typlen: the internal size, negative if variable + typtype: b = base type, c = composite type etc. + category: A = Array, b = Boolean, C = Composite etc. + delim: delimiter for array types + relid: corresponding table for composite types + attnames: attributes for composite types + """ + + oid: int + pgtype: str + regtype: str + simple: str + typlen: int + typtype: str + category: str + delim: str + relid: int + + _get_attnames: Callable[[DbType], AttrDict] + + @property + def attnames(self) -> AttrDict: + """Get names and types of the fields of a composite type.""" + # noinspection PyUnresolvedReferences + return self._get_attnames(self) + + +class DbTypes(dict): + """Cache for PostgreSQL data types. + + This cache maps type OIDs and names to DbType objects containing + information on the associated database type. + """ + + _num_types = frozenset('int float num money int2 int4 int8' + ' float4 float8 numeric money'.split()) + + def __init__(self, db: DB) -> None: + """Initialize type cache for connection.""" + super().__init__() + self._db = weakref.proxy(db) + self._regtypes = False + self._typecasts = Typecasts() + self._typecasts.get_attnames = self.get_attnames # type: ignore + self._typecasts.connection = self._db.db + self._query_pg_type = ( + "SELECT oid, typname, oid::pg_catalog.regtype," + " typlen, typtype, typcategory, typdelim, typrelid" + " FROM pg_catalog.pg_type" + " WHERE oid OPERATOR(pg_catalog.=) {}::pg_catalog.regtype") + + def add(self, oid: int, pgtype: str, regtype: str, + typlen: int, typtype: str, category: str, delim: str, relid: int + ) -> DbType: + """Create a PostgreSQL type name with additional info.""" + if oid in self: + return self[oid] + simple = 'record' if relid else _simpletypes[pgtype] + typ = DbType(regtype if self._regtypes else simple) + typ.oid = oid + typ.simple = simple + typ.pgtype = pgtype + typ.regtype = regtype + typ.typlen = typlen + typ.typtype = typtype + typ.category = category + typ.delim = delim + typ.relid = relid + typ._get_attnames = self.get_attnames # type: ignore + return typ + + def __missing__(self, key: int | str) -> DbType: + """Get the type info from the database if it is not cached.""" + try: + cmd = self._query_pg_type.format(quote_if_unqualified('$1', key)) + res = self._db.query(cmd, (key,)).getresult() + except ProgrammingError: + res = None + if not res: + raise KeyError(f'Type {key} could not be found') + res = res[0] + typ = self.add(*res) + self[typ.oid] = self[typ.pgtype] = typ + return typ + + def get(self, key: int | str, # type: ignore + default: DbType | None = None) -> DbType | None: + """Get the type even if it is not cached.""" + try: + return self[key] + except KeyError: + return default + + def get_attnames(self, typ: Any) -> AttrDict | None: + """Get names and types of the fields of a composite type.""" + if not isinstance(typ, DbType): + typ = self.get(typ) + if not typ: + return None + if not typ.relid: + return None + return self._db.get_attnames(typ.relid, with_oid=False) + + def get_typecast(self, typ: Any) -> Callable | None: + """Get the typecast function for the given database type.""" + return self._typecasts.get(typ) + + def set_typecast(self, typ: str | Sequence[str], cast: Callable) -> None: + """Set a typecast function for the specified database type(s).""" + self._typecasts.set(typ, cast) + + def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecast function for the specified database type(s).""" + self._typecasts.reset(typ) + + def typecast(self, value: Any, typ: str) -> Any: + """Cast the given value according to the given database type.""" + if value is None: + # for NULL values, no typecast is necessary + return None + if not isinstance(typ, DbType): + db_type = self.get(typ) + if db_type: + typ = db_type.pgtype + cast = self.get_typecast(typ) if typ else None + if not cast or cast is str: + # no typecast is necessary + return value + return cast(value) + + +class Adapter: + """Class providing methods for adapting parameters to the database.""" + + _bool_true_values = frozenset('t true 1 y yes on'.split()) + + _date_literals = frozenset( + 'current_date current_time' + ' current_timestamp localtime localtimestamp'.split()) + + _re_array_quote = regex(r'[{},"\\\s]|^[Nn][Uu][Ll][Ll]$') + _re_record_quote = regex(r'[(,"\\]') + _re_array_escape = _re_record_escape = regex(r'(["\\])') + + def __init__(self, db: DB): + """Initialize the adapter object with the given connection.""" + self.db = weakref.proxy(db) + + @classmethod + def _adapt_bool(cls, v: Any) -> str | None: + """Adapt a boolean parameter.""" + if isinstance(v, str): + if not v: + return None + v = v.lower() in cls._bool_true_values + return 't' if v else 'f' + + @classmethod + def _adapt_date(cls, v: Any) -> Any: + """Adapt a date parameter.""" + if not v: + return None + if isinstance(v, str) and v.lower() in cls._date_literals: + return Literal(v) + return v + + @staticmethod + def _adapt_num(v: Any) -> Any: + """Adapt a numeric parameter.""" + if not v and v != 0: + return None + return v + + _adapt_int = _adapt_float = _adapt_money = _adapt_num + + def _adapt_bytea(self, v: Any) -> str: + """Adapt a bytea parameter.""" + return self.db.escape_bytea(v) + + def _adapt_json(self, v: Any) -> str | None: + """Adapt a json parameter.""" + if v is None: + return None + if isinstance(v, str): + return v + if isinstance(v, Json): + return str(v) + return self.db.encode_json(v) + + def _adapt_hstore(self, v: Any) -> str | None: + """Adapt a hstore parameter.""" + if not v: + return None + if isinstance(v, str): + return v + if isinstance(v, Hstore): + return str(v) + if isinstance(v, dict): + return str(Hstore(v)) + raise TypeError(f'Hstore parameter {v} has wrong type') + + def _adapt_uuid(self, v: Any) -> str | None: + """Adapt a UUID parameter.""" + if not v: + return None + if isinstance(v, str): + return v + return str(v) + + @classmethod + def _adapt_text_array(cls, v: Any) -> str: + """Adapt a text type array parameter.""" + if isinstance(v, list): + adapt = cls._adapt_text_array + return '{' + ','.join(adapt(v) for v in v) + '}' + if v is None: + return 'null' + if not v: + return '""' + v = str(v) + if cls._re_array_quote.search(v): + v = cls._re_array_escape.sub(r'\\\1', v) + v = f'"{v}"' + return v + + _adapt_date_array = _adapt_text_array + + @classmethod + def _adapt_bool_array(cls, v: Any) -> str: + """Adapt a boolean array parameter.""" + if isinstance(v, list): + adapt = cls._adapt_bool_array + return '{' + ','.join(adapt(v) for v in v) + '}' + if v is None: + return 'null' + if isinstance(v, str): + if not v: + return 'null' + v = v.lower() in cls._bool_true_values + return 't' if v else 'f' + + @classmethod + def _adapt_num_array(cls, v: Any) -> str: + """Adapt a numeric array parameter.""" + if isinstance(v, list): + adapt = cls._adapt_num_array + v = '{' + ','.join(adapt(v) for v in v) + '}' + if not v and v != 0: + return 'null' + return str(v) + + _adapt_int_array = _adapt_float_array = _adapt_money_array = \ + _adapt_num_array + + def _adapt_bytea_array(self, v: Any) -> bytes: + """Adapt a bytea array parameter.""" + if isinstance(v, list): + return b'{' + b','.join( + self._adapt_bytea_array(v) for v in v) + b'}' + if v is None: + return b'null' + return self.db.escape_bytea(v).replace(b'\\', b'\\\\') + + def _adapt_json_array(self, v: Any) -> str: + """Adapt a json array parameter.""" + if isinstance(v, list): + adapt = self._adapt_json_array + return '{' + ','.join(adapt(v) for v in v) + '}' + if not v: + return 'null' + if not isinstance(v, str): + v = self.db.encode_json(v) + if self._re_array_quote.search(v): + v = self._re_array_escape.sub(r'\\\1', v) + v = f'"{v}"' + return v + + def _adapt_record(self, v: Any, typ: Any) -> str: + """Adapt a record parameter with given type.""" + typ = self.get_attnames(typ).values() + if len(typ) != len(v): + raise TypeError(f'Record parameter {v} has wrong size') + adapt = self.adapt + value = [] + for v, t in zip(v, typ): # noqa: B020 + v = adapt(v, t) + if v is None: + v = '' + else: + if isinstance(v, bytes): + v = v.decode('ascii') + elif not isinstance(v, str): + v = str(v) + if v: + if self._re_record_quote.search(v): + v = self._re_record_escape.sub(r'\\\1', v) + v = f'"{v}"' + else: + v = '""' + value.append(v) + v = ','.join(value) + return f'({v})' + + def adapt(self, value: Any, typ: Any = None) -> str: + """Adapt a value with known database type.""" + if value is not None and not isinstance(value, Literal): + if typ: + simple = self.get_simple_name(typ) + else: + typ = simple = self.guess_simple_type(value) or 'text' + pg_str = getattr(value, '__pg_str__', None) + if pg_str: + value = pg_str(typ) + if simple == 'text': + pass + elif simple == 'record': + if isinstance(value, tuple): + value = self._adapt_record(value, typ) + elif simple.endswith('[]'): + if isinstance(value, list): + adapt = getattr(self, f'_adapt_{simple[:-2]}_array') + value = adapt(value) + else: + adapt = getattr(self, f'_adapt_{simple}') + value = adapt(value) + return value + + @staticmethod + def simple_type(name: str) -> DbType: + """Create a simple database type with given attribute names.""" + typ = DbType(name) + typ.simple = name + return typ + + @staticmethod + def get_simple_name(typ: Any) -> str: + """Get the simple name of a database type.""" + if isinstance(typ, DbType): + # noinspection PyUnresolvedReferences + return typ.simple + return _simpletypes[typ] + + @staticmethod + def get_attnames(typ: Any) -> dict[str, dict[str, str]]: + """Get the attribute names of a composite database type.""" + if isinstance(typ, DbType): + return typ.attnames + return {} + + @classmethod + def guess_simple_type(cls, value: Any) -> str | None: + """Try to guess which database type the given value has.""" + # optimize for most frequent types + try: + return _simple_type_dict[type(value)] + except KeyError: + pass + if isinstance(value, (bytes, str)): + return 'text' + if isinstance(value, bool): + return 'bool' + if isinstance(value, int): + return 'int' + if isinstance(value, float): + return 'float' + if isinstance(value, Decimal): + return 'num' + if isinstance(value, (date, time, datetime, timedelta)): + return 'date' + if isinstance(value, Bytea): + return 'bytea' + if isinstance(value, Json): + return 'json' + if isinstance(value, Hstore): + return 'hstore' + if isinstance(value, UUID): + return 'uuid' + if isinstance(value, list): + return (cls.guess_simple_base_type(value) or 'text') + '[]' + if isinstance(value, tuple): + simple_type = cls.simple_type + guess = cls.guess_simple_type + + # noinspection PyUnusedLocal + def get_attnames(self: DbType) -> AttrDict: + return AttrDict((str(n + 1), simple_type(guess(v) or 'text')) + for n, v in enumerate(value)) + + typ = simple_type('record') + typ._get_attnames = get_attnames + return typ + return None + + @classmethod + def guess_simple_base_type(cls, value: Any) -> str | None: + """Try to guess the base type of a given array.""" + for v in value: + if isinstance(v, list): + typ = cls.guess_simple_base_type(v) + else: + typ = cls.guess_simple_type(v) + if typ: + return typ + return None + + def adapt_inline(self, value: Any, nested: bool=False) -> Any: + """Adapt a value that is put into the SQL and needs to be quoted.""" + if value is None: + return 'NULL' + if isinstance(value, Literal): + return value + if isinstance(value, Bytea): + value = self.db.escape_bytea(value).decode('ascii') + elif isinstance(value, (datetime, date, time, timedelta)): + value = str(value) + if isinstance(value, (bytes, str)): + value = self.db.escape_string(value) + return f"'{value}'" + if isinstance(value, bool): + return 'true' if value else 'false' + if isinstance(value, float): + if isinf(value): + return "'-Infinity'" if value < 0 else "'Infinity'" + if isnan(value): + return "'NaN'" + return value + if isinstance(value, (int, Decimal)): + return value + if isinstance(value, list): + q = self.adapt_inline + s = '[{}]' if nested else 'ARRAY[{}]' + return s.format(','.join(str(q(v, nested=True)) for v in value)) + if isinstance(value, tuple): + q = self.adapt_inline + return '({})'.format(','.join(str(q(v)) for v in value)) + if isinstance(value, Json): + value = self.db.escape_string(str(value)) + return f"'{value}'::json" + if isinstance(value, Hstore): + value = self.db.escape_string(str(value)) + return f"'{value}'::hstore" + pg_repr = getattr(value, '__pg_repr__', None) + if not pg_repr: + raise InterfaceError( + f'Do not know how to adapt type {type(value)}') + value = pg_repr() + if isinstance(value, (tuple, list)): + value = self.adapt_inline(value) + return value + + def parameter_list(self) -> _ParameterList: + """Return a parameter list for parameters with known database types. + + The list has an add(value, typ) method that will build up the + list and return either the literal value or a placeholder. + """ + params = _ParameterList() + params.adapt = self.adapt + return params + + def format_query(self, command: str, + values: list | tuple | dict | None = None, + types: list | tuple | dict | None = None, + inline: bool=False + ) -> tuple[str, _ParameterList]: + """Format a database query using the given values and types. + + The optional types describe the values and must be passed as a list, + tuple or string (that will be split on whitespace) when values are + passed as a list or tuple, or as a dict if values are passed as a dict. + + If inline is set to True, then parameters will be passed inline + together with the query string. + """ + params = self.parameter_list() + if not values: + return command, params + if inline and types: + raise ValueError('Typed parameters must be sent separately') + if isinstance(values, (list, tuple)): + if inline: + adapt = self.adapt_inline + seq_literals = [adapt(value) for value in values] + else: + add = params.add + if types: + if isinstance(types, str): + types = types.split() + if (not isinstance(types, (list, tuple)) + or len(types) != len(values)): + raise TypeError('The values and types do not match') + seq_literals = [add(value, typ) + for value, typ in zip(values, types)] + else: + seq_literals = [add(value) for value in values] + command %= tuple(seq_literals) + elif isinstance(values, dict): + # we want to allow extra keys in the dictionary, + # so we first must find the values actually used in the command + used_values = {} + map_literals = dict.fromkeys(values, '') + for key in values: + del map_literals[key] + try: + command % map_literals + except KeyError: + used_values[key] = values[key] # pyright: ignore + map_literals[key] = '' + if inline: + adapt = self.adapt_inline + map_literals = {key: adapt(value) + for key, value in used_values.items()} + else: + add = params.add + if types: + if not isinstance(types, dict): + raise TypeError('The values and types do not match') + map_literals = {key: add(used_values[key], types.get(key)) + for key in sorted(used_values)} + else: + map_literals = {key: add(used_values[key]) + for key in sorted(used_values)} + command %= map_literals + else: + raise TypeError('The values must be passed as tuple, list or dict') + return command, params diff --git a/pg/attrs.py b/pg/attrs.py new file mode 100644 index 00000000..7a5e6c41 --- /dev/null +++ b/pg/attrs.py @@ -0,0 +1,35 @@ +"""Helpers for memorizing attributes.""" + +from typing import Any + +__all__ = ['AttrDict'] + + +class AttrDict(dict): + """Simple read-only ordered dictionary for storing attribute names.""" + + def __init__(self, *args: Any, **kw: Any) -> None: + """Initialize the dictionary.""" + self._read_only = False + super().__init__(*args, **kw) + self._read_only = True + error = self._read_only_error + self.clear = self.update = error # type: ignore + self.pop = self.setdefault = self.popitem = error # type: ignore + + def __setitem__(self, key: str, value: Any) -> None: + """Set a value.""" + if self._read_only: + self._read_only_error() + super().__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + """Delete a value.""" + if self._read_only: + self._read_only_error() + super().__delitem__(key) + + @staticmethod + def _read_only_error(*_args: Any, **_kw: Any) -> Any: + """Raise error for write operations.""" + raise TypeError('This object is read-only') diff --git a/pg/cast.py b/pg/cast.py new file mode 100644 index 00000000..98baa8f6 --- /dev/null +++ b/pg/cast.py @@ -0,0 +1,446 @@ +"""Typecasting mechanisms.""" + +from __future__ import annotations + +from collections import namedtuple +from datetime import date, datetime, timedelta +from functools import partial +from inspect import signature +from re import compile as regex +from typing import Any, Callable, ClassVar, Sequence +from uuid import UUID + +from .attrs import AttrDict +from .core import ( + Connection, + cast_array, + cast_hstore, + cast_record, + get_bool, + get_decimal, + get_decimal_point, + get_jsondecode, + unescape_bytea, +) +from .tz import timezone_as_offset + +__all__ = [ + 'Typecasts', + 'cast_bool', + 'cast_date', + 'cast_int2vector', + 'cast_interval', + 'cast_json', + 'cast_money', + 'cast_num', + 'cast_time', + 'cast_timestamp', + 'cast_timestamptz', + 'cast_timetz', + 'get_typecast', + 'set_typecast' +] + +def get_args(func: Callable) -> list: + """Get the arguments of a function.""" + return list(signature(func).parameters) + + +def cast_bool(value: str) -> Any: + """Cast a boolean value.""" + if not get_bool(): + return value + return value[0] == 't' + + +def cast_json(value: str) -> Any: + """Cast a JSON value.""" + cast = get_jsondecode() + if not cast: + return value + return cast(value) + + +def cast_num(value: str) -> Any: + """Cast a numeric value.""" + return (get_decimal() or float)(value) + + +def cast_money(value: str) -> Any: + """Cast a money value.""" + point = get_decimal_point() + if not point: + return value + if point != '.': + value = value.replace(point, '.') + value = value.replace('(', '-') + value = ''.join(c for c in value if c.isdigit() or c in '.-') + return (get_decimal() or float)(value) + + +def cast_int2vector(value: str) -> list[int]: + """Cast an int2vector value.""" + return [int(v) for v in value.split()] + + +def cast_date(value: str, connection: Connection) -> Any: + """Cast a date value.""" + # The output format depends on the server setting DateStyle. The default + # setting ISO and the setting for German are actually unambiguous. The + # order of days and months in the other two settings is however ambiguous, + # so at least here we need to consult the setting to properly parse values. + if value == '-infinity': + return date.min + if value == 'infinity': + return date.max + values = value.split() + if values[-1] == 'BC': + return date.min + value = values[0] + if len(value) > 10: + return date.max + format = connection.date_format() + return datetime.strptime(value, format).date() + + +def cast_time(value: str) -> Any: + """Cast a time value.""" + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + return datetime.strptime(value, format).time() + + +_re_timezone = regex('(.*)([+-].*)') + + +def cast_timetz(value: str) -> Any: + """Cast a timetz value.""" + m = _re_timezone.match(value) + if m: + value, tz = m.groups() + else: + tz = '+0000' + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + value += timezone_as_offset(tz) + format += '%z' + return datetime.strptime(value, format).timetz() + + +def cast_timestamp(value: str, connection: Connection) -> Any: + """Cast a timestamp value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:5] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + else: + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +def cast_timestamptz(value: str, connection: Connection) -> Any: + """Cast a timestamptz value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = connection.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + values, tz = values[:-1], values[-1] + else: + if format.startswith('%Y-'): + m = _re_timezone.match(values[1]) + if m: + values[1], tz = m.groups() + else: + tz = '+0000' + else: + values, tz = values[:-1], values[-1] + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + values.append(timezone_as_offset(tz)) + formats.append('%z') + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +_re_interval_sql_standard = regex( + '(?:([+-])?([0-9]+)-([0-9]+) ?)?' + '(?:([+-]?[0-9]+)(?!:) ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres = regex( + '(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres_verbose = regex( + '@ ?(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-]?[0-9]+) ?hours? ?)?' + '(?:([+-]?[0-9]+) ?mins? ?)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') + +_re_interval_iso_8601 = regex( + 'P(?:([+-]?[0-9]+)Y)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-]?[0-9]+)D)?' + '(?:T(?:([+-]?[0-9]+)H)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') + + +def cast_interval(value: str) -> timedelta: + """Cast an interval value.""" + # The output format depends on the server setting IntervalStyle, but it's + # not necessary to consult this setting to parse it. It's faster to just + # check all possible formats, and there is no ambiguity here. + m = _re_interval_iso_8601.match(value) + if m: + s = [v or '0' for v in m.groups()] + secs_ago = s.pop(5) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = -secs + usecs = -usecs + else: + m = _re_interval_postgres_verbose.match(value) + if m: + s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) + secs_ago = s.pop(5) == '-' + d = [-int(v) for v in s] if ago else [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = - secs + usecs = -usecs + else: + m = _re_interval_postgres.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + m = _re_interval_sql_standard.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + years_ago = s.pop(0) == '-' + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if years_ago: + years = -years + mons = -mons + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + raise ValueError(f'Cannot parse interval: {value}') + days += 365 * years + 30 * mons + return timedelta(days=days, hours=hours, minutes=mins, + seconds=secs, microseconds=usecs) + + +class Typecasts(dict): + """Dictionary mapping database types to typecast functions. + + The cast functions get passed the string representation of a value in + the database which they need to convert to a Python object. The + passed string will never be None since NULL values are already + handled before the cast function is called. + + Note that the basic types are already handled by the C extension. + They only need to be handled here as record or array components. + """ + + # the default cast functions + # (str functions are ignored but have been added for faster access) + defaults: ClassVar[dict[str, Callable]] = { + 'char': str, 'bpchar': str, 'name': str, + 'text': str, 'varchar': str, 'sql_identifier': str, + 'bool': cast_bool, 'bytea': unescape_bytea, + 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, + 'hstore': cast_hstore, 'json': cast_json, 'jsonb': cast_json, + 'float4': float, 'float8': float, + 'numeric': cast_num, 'money': cast_money, + 'date': cast_date, 'interval': cast_interval, + 'time': cast_time, 'timetz': cast_timetz, + 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, + 'int2vector': cast_int2vector, 'uuid': UUID, + 'anyarray': cast_array, 'record': cast_record} # pyright: ignore + + connection: Connection | None = None # set in connection specific instance + + def __missing__(self, typ: str) -> Callable | None: + """Create a cast function if it is not cached. + + Note that this class never raises a KeyError, + but returns None when no special cast function exists. + """ + if not isinstance(typ, str): + raise TypeError(f'Invalid type: {typ}') + cast: Callable | None = self.defaults.get(typ) + if cast: + # store default for faster access + cast = self._add_connection(cast) + self[typ] = cast + elif typ.startswith('_'): + base_cast = self[typ[1:]] + cast = self.create_array_cast(base_cast) + if base_cast: + self[typ] = cast + else: + attnames = self.get_attnames(typ) + if attnames: + casts = [self[v.pgtype] for v in attnames.values()] + cast = self.create_record_cast(typ, attnames, casts) + self[typ] = cast + return cast + + @staticmethod + def _needs_connection(func: Callable) -> bool: + """Check if a typecast function needs a connection argument.""" + try: + args = get_args(func) + except (TypeError, ValueError): + return False + return 'connection' in args[1:] + + def _add_connection(self, cast: Callable) -> Callable: + """Add a connection argument to the typecast function if necessary.""" + if not self.connection or not self._needs_connection(cast): + return cast + return partial(cast, connection=self.connection) + + def get(self, typ: str, default: Callable | None = None # type: ignore + ) -> Callable | None: + """Get the typecast function for the given database type.""" + return self[typ] or default + + def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a typecast function for the specified database type(s).""" + if isinstance(typ, str): + typ = [typ] + if cast is None: + for t in typ: + self.pop(t, None) + self.pop(f'_{t}', None) + else: + if not callable(cast): + raise TypeError("Cast parameter must be callable") + for t in typ: + self[t] = self._add_connection(cast) + self.pop(f'_{t}', None) + + def reset(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecasts for the specified type(s) to their defaults. + + When no type is specified, all typecasts will be reset. + """ + if typ is None: + self.clear() + else: + if isinstance(typ, str): + typ = [typ] + for t in typ: + self.pop(t, None) + + @classmethod + def get_default(cls, typ: str) -> Any: + """Get the default typecast function for the given database type.""" + return cls.defaults.get(typ) + + @classmethod + def set_default(cls, typ: str | Sequence[str], + cast: Callable | None) -> None: + """Set a default typecast function for the given database type(s).""" + if isinstance(typ, str): + typ = [typ] + defaults = cls.defaults + if cast is None: + for t in typ: + defaults.pop(t, None) + defaults.pop(f'_{t}', None) + else: + if not callable(cast): + raise TypeError("Cast parameter must be callable") + for t in typ: + defaults[t] = cast + defaults.pop(f'_{t}', None) + + # noinspection PyMethodMayBeStatic,PyUnusedLocal + def get_attnames(self, typ: Any) -> AttrDict: + """Return the fields for the given record type. + + This method will be replaced with the get_attnames() method of DbTypes. + """ + return AttrDict() + + # noinspection PyMethodMayBeStatic + def dateformat(self) -> str: + """Return the current date format. + + This method will be replaced with the dateformat() method of DbTypes. + """ + return '%Y-%m-%d' + + def create_array_cast(self, basecast: Callable) -> Callable: + """Create an array typecast for the given base cast.""" + cast_array = self['anyarray'] + + def cast(v: Any) -> list: + return cast_array(v, basecast) + return cast + + def create_record_cast(self, name: str, fields: AttrDict, + casts: list[Callable]) -> Callable: + """Create a named record typecast for the given fields and casts.""" + cast_record = self['record'] + record = namedtuple(name, fields) # type: ignore + + def cast(v: Any) -> record: + # noinspection PyArgumentList + return record(*cast_record(v, casts)) + return cast + + +def get_typecast(typ: str) -> Callable | None: + """Get the global typecast function for the given database type.""" + return Typecasts.get_default(typ) + + +def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a global typecast function for the given database type(s). + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call db.db_types.reset_typecast(). + """ + Typecasts.set_default(typ, cast) diff --git a/pg/core.py b/pg/core.py new file mode 100644 index 00000000..4d0c03c0 --- /dev/null +++ b/pg/core.py @@ -0,0 +1,180 @@ +"""Core functionality from extension module.""" + +try: + from ._pg import version +except ImportError as e: # noqa: F841 + import os + libpq = 'libpq.' + if os.name == 'nt': + libpq += 'dll' + import sys + paths = [path for path in os.environ["PATH"].split(os.pathsep) + if os.path.exists(os.path.join(path, libpq))] + if sys.version_info >= (3, 8): + # see https://docs.python.org/3/whatsnew/3.8.html#ctypes + add_dll_dir = os.add_dll_directory # type: ignore + for path in paths: + with add_dll_dir(os.path.abspath(path)): + try: + from ._pg import version + except ImportError: + pass + else: + del version + e = None # type: ignore + break + if paths: + libpq = 'compatible ' + libpq + else: + libpq += 'so' + if e: + raise ImportError( + "Cannot import shared library for PyGreSQL,\n" + f"probably because no {libpq} is installed.\n{e}") from e +else: + del version + +# import objects from extension module +from ._pg import ( + INV_READ, + INV_WRITE, + POLLING_FAILED, + POLLING_OK, + POLLING_READING, + POLLING_WRITING, + RESULT_DDL, + RESULT_DML, + RESULT_DQL, + RESULT_EMPTY, + SEEK_CUR, + SEEK_END, + SEEK_SET, + TRANS_ACTIVE, + TRANS_IDLE, + TRANS_INERROR, + TRANS_INTRANS, + TRANS_UNKNOWN, + Connection, + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + InvalidResultError, + LargeObject, + MultipleResultsError, + NoResultError, + NotSupportedError, + OperationalError, + ProgrammingError, + Query, + Warning, + cast_array, + cast_hstore, + cast_record, + connect, + escape_bytea, + escape_string, + get_array, + get_bool, + get_bytea_escaped, + get_datestyle, + get_decimal, + get_decimal_point, + get_defbase, + get_defhost, + get_defopt, + get_defport, + get_defuser, + get_jsondecode, + get_pqlib_version, + set_array, + set_bool, + set_bytea_escaped, + set_datestyle, + set_decimal, + set_decimal_point, + set_defbase, + set_defhost, + set_defopt, + set_defpasswd, + set_defport, + set_defuser, + set_jsondecode, + set_query_helpers, + unescape_bytea, + version, +) + +__all__ = [ + 'INV_READ', + 'INV_WRITE', + 'POLLING_FAILED', + 'POLLING_OK', + 'POLLING_READING', + 'POLLING_WRITING', + 'RESULT_DDL', + 'RESULT_DML', + 'RESULT_DQL', + 'RESULT_EMPTY', + 'SEEK_CUR', + 'SEEK_END', + 'SEEK_SET', + 'TRANS_ACTIVE', + 'TRANS_IDLE', + 'TRANS_INERROR', + 'TRANS_INTRANS', + 'TRANS_UNKNOWN', + 'Connection', + 'DataError', + 'DatabaseError', + 'Error', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'InvalidResultError', + 'LargeObject', + 'MultipleResultsError', + 'NoResultError', + 'NotSupportedError', + 'OperationalError', + 'ProgrammingError', + 'Query', + 'Warning', + 'cast_array', + 'cast_hstore', + 'cast_record', + 'connect', + 'escape_bytea', + 'escape_string', + 'get_array', + 'get_bool', + 'get_bytea_escaped', + 'get_datestyle', + 'get_decimal', + 'get_decimal_point', + 'get_defbase', + 'get_defhost', + 'get_defopt', + 'get_defport', + 'get_defuser', + 'get_jsondecode', + 'get_pqlib_version', + 'set_array', + 'set_bool', + 'set_bytea_escaped', + 'set_datestyle', + 'set_decimal', + 'set_decimal_point', + 'set_defbase', + 'set_defhost', + 'set_defopt', + 'set_defpasswd', + 'set_defport', + 'set_defuser', + 'set_jsondecode', + 'set_query_helpers', + 'unescape_bytea', + 'version', +] diff --git a/pg/db.py b/pg/db.py new file mode 100644 index 00000000..5c8beea7 --- /dev/null +++ b/pg/db.py @@ -0,0 +1,1502 @@ +"""Connection wrapper.""" + +from __future__ import annotations + +from contextlib import suppress +from json import dumps as jsonencode +from json import loads as jsondecode +from operator import itemgetter +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + Sequence, + TypeVar, + overload, +) + +from . import Connection, connect +from .adapt import Adapter, DbTypes +from .attrs import AttrDict +from .core import ( + InternalError, + LargeObject, + ProgrammingError, + Query, + get_bool, + get_jsondecode, + unescape_bytea, +) +from .error import db_error, int_error, prg_error +from .helpers import namediter, oid_key, quote_if_unqualified +from .notify import NotificationHandler + +if TYPE_CHECKING: + from pgdb.connection import Connection as DbApi2Connection + +try: + AnyStr = TypeVar('AnyStr', str, bytes, str | bytes) +except TypeError: # Python < 3.10 + AnyStr = Any # type: ignore + +__all__ = ['DB'] + + +# The actual PostgreSQL database connection interface: + +class DB: + """Wrapper class for the core connection type.""" + + dbname: str + host: str + port: int + options: str + error: str + status: int + user : str + protocol_version: int + server_version: int + socket: int + backend_pid: int + ssl_in_use: bool + ssl_attributes: dict[str, str | None] + + db: Connection | None = None # invalid fallback for underlying connection + _db_args: Any # either the connect args or the underlying connection + + @overload + def __init__(self, dbname: str | None = None, + host: str | None = None, port: int = -1, + opt: str | None = None, + user: str | None = None, passwd: str | None = None, + nowait: bool = False) -> None: + ... # create a new connection using the specified parameters + + @overload + def __init__(self, db: Connection | DB | DbApi2Connection) -> None: + ... # create a connection wrapper based on an existing connection + + def __init__(self, *args: Any, **kw: Any) -> None: + """Create a new connection. + + You can pass either the connection parameters or an existing + pg or pgdb Connection. This allows you to use the methods + of the classic pg interface with a DB-API 2 pgdb Connection. + """ + if kw: + db = kw.get('db') + if db is not None and (args or len(kw) > 1): + raise TypeError("Conflicting connection parameters") + elif len(args) == 1 and not isinstance(args[0], str): + db = args[0] + else: + db = None + if db: + if isinstance(db, DB): + db = db.db # allow db to be a wrapped Connection + else: + with suppress(AttributeError): + db = db._cnx # allow db to be a pgdb Connection + if not isinstance(db, Connection): + raise TypeError( + "The 'db' argument must be a valid database connection.") + self._db_args = db + self._closeable = False + else: + db = connect(*args, **kw) + self._db_args = args, kw + self._closeable = True + self.db = db + self.dbname = db.db + self._regtypes = False + self._attnames: dict[str, AttrDict] = {} + self._generated: dict[str, frozenset[str]] = {} + self._pkeys: dict[str, str | tuple[str, ...]] = {} + self._privileges: dict[tuple[str, str], bool] = {} + self.adapter = Adapter(self) + self.dbtypes = DbTypes(self) + self._query_attnames = ( + "SELECT a.attname," + " t.oid, t.typname, t.oid::pg_catalog.regtype," + " t.typlen, t.typtype, t.typcategory, t.typdelim, t.typrelid" + " FROM pg_catalog.pg_attribute a" + " JOIN pg_catalog.pg_type t" + " ON t.oid OPERATOR(pg_catalog.=) a.atttypid" + " WHERE a.attrelid OPERATOR(pg_catalog.=)" + " {}::pg_catalog.regclass" + " AND {} AND NOT a.attisdropped ORDER BY a.attnum") + if db.server_version < 120000: + self._query_generated = ( + "a.attidentity OPERATOR(pg_catalog.=) 'a'" + ) + else: + self._query_generated = ( + "(a.attidentity OPERATOR(pg_catalog.=) 'a' OR" + " a.attgenerated OPERATOR(pg_catalog.!=) '')" + ) + db.set_cast_hook(self.dbtypes.typecast) + # For debugging scripts, self.debug can be set + # * to a string format specification (e.g. in CGI set to "%s
"), + # * to a file object to write debug statements or + # * to a callable object which takes a string argument + # * to any other true value to just print debug statements + self.debug: Any = None + + def __getattr__(self, name: str) -> Any: + """Get the specified attribute of the connection.""" + # All undefined members are same as in underlying connection: + if self.db: + return getattr(self.db, name) + else: + raise int_error('Connection is not valid') + + def __dir__(self) -> list[str]: + """List all attributes of the connection.""" + # Custom dir function including the attributes of the connection: + attrs = set(self.__class__.__dict__) + attrs.update(self.__dict__) + attrs.update(dir(self.db)) + return sorted(attrs) + + # Context manager methods + + def __enter__(self) -> DB: + """Enter the runtime context. This will start a transaction.""" + self.begin() + return self + + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: + """Exit the runtime context. This will end the transaction.""" + if et is None and ev is None and tb is None: + self.commit() + else: + self.rollback() + + def __del__(self) -> None: + """Delete the connection.""" + try: + db = self.db + except AttributeError: + db = None + if db: + with suppress(TypeError): # when already closed + db.set_cast_hook(None) + if self._closeable: + with suppress(InternalError): # when already closed + db.close() + + # Auxiliary methods + + def _do_debug(self, *args: Any) -> None: + """Print a debug message.""" + if self.debug: + s = '\n'.join(str(arg) for arg in args) + if isinstance(self.debug, str): + print(self.debug % s) + elif hasattr(self.debug, 'write'): + # noinspection PyCallingNonCallable + self.debug.write(s + '\n') + elif callable(self.debug): + self.debug(s) + else: + print(s) + + def _escape_qualified_name(self, s: str) -> str: + """Escape a qualified name. + + Escapes the name for use as an SQL identifier, unless the + name contains a dot, in which case the name is ambiguous + (could be a qualified name or just a name with a dot in it) + and must be quoted manually by the caller. + """ + if '.' not in s: + s = self.escape_identifier(s) + return s + + @staticmethod + def _make_bool(d: Any) -> bool | str: + """Get boolean value corresponding to d.""" + return bool(d) if get_bool() else ('t' if d else 'f') + + @staticmethod + def _list_params(params: Sequence) -> str: + """Create a human readable parameter list.""" + return ', '.join(f'${n}={v!r}' for n, v in enumerate(params, 1)) + + @property + def _valid_db(self) -> Connection: + """Get underlying connection and make sure it is not closed.""" + db = self.db + if not db: + raise int_error('Connection already closed') + return db + + # Public methods + + # escape_string and escape_bytea exist as methods, + # so we define unescape_bytea as a method as well + unescape_bytea = staticmethod(unescape_bytea) + + @staticmethod + def decode_json(s: str) -> Any: + """Decode a JSON string coming from the database.""" + return (get_jsondecode() or jsondecode)(s) + + @staticmethod + def encode_json(d: Any) -> str: + """Encode a JSON string for use within SQL.""" + return jsonencode(d) + + def close(self) -> None: + """Close the database connection.""" + # Wraps shared library function so we can track state. + db = self._valid_db + with suppress(TypeError): # when already closed + db.set_cast_hook(None) + if self._closeable: + db.close() + self.db = None + + def reset(self) -> None: + """Reset connection with current parameters. + + All derived queries and large objects derived from this connection + will not be usable after this call. + """ + self._valid_db.reset() + + def reopen(self) -> None: + """Reopen connection to the database. + + Used in case we need another connection to the same database. + Note that we can still reopen a database that we have closed. + """ + # There is no such shared library function. + if self._closeable: + args, kw = self._db_args + db = connect(*args, **kw) + if self.db: + self.db.set_cast_hook(None) + self.db.close() + db.set_cast_hook(self.dbtypes.typecast) + self.db = db + else: + self.db = self._db_args + + def begin(self, mode: str | None = None) -> Query: + """Begin a transaction.""" + qstr = 'BEGIN' + if mode: + qstr += ' ' + mode + return self.query(qstr) + + start = begin + + def commit(self) -> Query: + """Commit the current transaction.""" + return self.query('COMMIT') + + end = commit + + def rollback(self, name: str | None = None) -> Query: + """Roll back the current transaction.""" + qstr = 'ROLLBACK' + if name: + qstr += ' TO ' + name + return self.query(qstr) + + abort = rollback + + def savepoint(self, name: str) -> Query: + """Define a new savepoint within the current transaction.""" + return self.query('SAVEPOINT ' + name) + + def release(self, name: str) -> Query: + """Destroy a previously defined savepoint.""" + return self.query('RELEASE ' + name) + + def get_parameter(self, + parameter: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str] | dict[str, Any] + ) -> str | list[str] | dict[str, str]: + """Get the value of a run-time parameter. + + If the parameter is a string, the return value will also be a string + that is the current setting of the run-time parameter with that name. + + You can get several parameters at once by passing a list, set or dict. + When passing a list of parameter names, the return value will be a + corresponding list of parameter settings. When passing a set of + parameter names, a new dict will be returned, mapping these parameter + names to their settings. Finally, if you pass a dict as parameter, + its values will be set to the current parameter settings corresponding + to its keys. + + By passing the special name 'all' as the parameter, you can get a dict + of all existing configuration parameters. + """ + values: Any + if isinstance(parameter, str): + parameter = [parameter] + values = None + elif isinstance(parameter, (list, tuple)): + values = [] + elif isinstance(parameter, (set, frozenset)): + values = {} + elif isinstance(parameter, dict): + values = parameter + else: + raise TypeError( + 'The parameter must be a string, list, set or dict') + if not parameter: + raise TypeError('No parameter has been specified') + query = self._valid_db.query + params: Any = {} if isinstance(values, dict) else [] + for param_key in parameter: + param = param_key.strip().lower() if isinstance( + param_key, (bytes, str)) else None + if not param: + raise TypeError('Invalid parameter') + if param == 'all': + cmd = 'SHOW ALL' + values = query(cmd).getresult() + values = {value[0]: value[1] for value in values} + break + if isinstance(params, dict): + params[param] = param_key + else: + params.append(param) + else: + for param in params: + cmd = f'SHOW {param}' + value = query(cmd).singlescalar() + if values is None: + values = value + elif isinstance(values, list): + values.append(value) + else: + values[params[param]] = value + return values + + def set_parameter(self, + parameter: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str] | dict[str, Any], + value: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str]| None = None, + local: bool = False) -> None: + """Set the value of a run-time parameter. + + If the parameter and the value are strings, the run-time parameter + will be set to that value. If no value or None is passed as a value, + then the run-time parameter will be restored to its default value. + + You can set several parameters at once by passing a list of parameter + names, together with a single value that all parameters should be + set to or with a corresponding list of values. You can also pass + the parameters as a set if you only provide a single value. + Finally, you can pass a dict with parameter names as keys. In this + case, you should not pass a value, since the values for the parameters + will be taken from the dict. + + By passing the special name 'all' as the parameter, you can reset + all existing settable run-time parameters to their default values. + + If you set local to True, then the command takes effect for only the + current transaction. After commit() or rollback(), the session-level + setting takes effect again. Setting local to True will appear to + have no effect if it is executed outside a transaction, since the + transaction will end immediately. + """ + if isinstance(parameter, str): + parameter = {parameter: value} + elif isinstance(parameter, (list, tuple)): + if isinstance(value, (list, tuple)): + parameter = dict(zip(parameter, value)) + else: + parameter = dict.fromkeys(parameter, value) + elif isinstance(parameter, (set, frozenset)): + if isinstance(value, (list, tuple, set, frozenset)): + value = set(value) + if len(value) == 1: + value = next(iter(value)) + if not (value is None or isinstance(value, str)): + raise ValueError( + 'A single value must be specified' + ' when parameter is a set') + parameter = dict.fromkeys(parameter, value) + elif isinstance(parameter, dict): + if value is not None: + raise ValueError( + 'A value must not be specified' + ' when parameter is a dictionary') + else: + raise TypeError( + 'The parameter must be a string, list, set or dict') + if not parameter: + raise TypeError('No parameter has been specified') + params: dict[str, str | None] = {} + for param, param_value in parameter.items(): + param = param.strip().lower() if isinstance(param, str) else None + if not param: + raise TypeError('Invalid parameter') + if param == 'all': + if param_value is not None: + raise ValueError( + 'A value must not be specified' + " when parameter is 'all'") + params = {'all': None} + break + params[param] = param_value + local_clause = ' LOCAL' if local else '' + for param, param_value in params.items(): + cmd = (f'RESET{local_clause} {param}' + if param_value is None else + f'SET{local_clause} {param} TO {param_value}') + self._do_debug(cmd) + self._valid_db.query(cmd) + + def query(self, command: str, *args: Any) -> Query: + """Execute a SQL command string. + + This method simply sends a SQL query to the database. If the query is + an insert statement that inserted exactly one row into a table that + has OIDs, the return value is the OID of the newly inserted row. + If the query is an update or delete statement, or an insert statement + that did not insert exactly one row in a table with OIDs, then the + number of rows affected is returned as a string. If it is a statement + that returns rows as a result (usually a select statement, but maybe + also an "insert/update ... returning" statement), this method returns + a Query object that can be accessed via getresult() or dictresult() + or simply printed. Otherwise, it returns `None`. + + The query can contain numbered parameters of the form $1 in place + of any data constant. Arguments given after the query string will + be substituted for the corresponding numbered parameter. Parameter + values can also be given as a single list or tuple argument. + """ + # Wraps shared library function for debugging. + db = self._valid_db + if args: + self._do_debug(command, args) + return db.query(command, args) + self._do_debug(command) + return db.query(command) + + def query_formatted(self, command: str, + parameters: tuple | list | dict | None = None, + types: tuple | list | dict | None = None, + inline: bool =False) -> Query: + """Execute a formatted SQL command string. + + Similar to query, but using Python format placeholders of the form + %s or %(names)s instead of PostgreSQL placeholders of the form $1. + The parameters must be passed as a tuple, list or dict. You can + also pass a corresponding tuple, list or dict of database types in + order to format the parameters properly in case there is ambiguity. + + If you set inline to True, the parameters will be sent to the database + embedded in the SQL command, otherwise they will be sent separately. + """ + return self.query(*self.adapter.format_query( + command, parameters, types, inline)) + + def query_prepared(self, name: str, *args: Any) -> Query: + """Execute a prepared SQL statement. + + This works like the query() method, except that instead of passing + the SQL command, you pass the name of a prepared statement. If you + pass an empty name, the unnamed statement will be executed. + """ + if name is None: + name = '' + db = self._valid_db + if args: + self._do_debug('EXECUTE', name, args) + return db.query_prepared(name, args) + self._do_debug('EXECUTE', name) + return db.query_prepared(name) + + def prepare(self, name: str, command: str) -> None: + """Create a prepared SQL statement. + + This creates a prepared statement for the given command with the + given name for later execution with the query_prepared() method. + + The name can be empty to create an unnamed statement, in which case + any pre-existing unnamed statement is automatically replaced; + otherwise it is an error if the statement name is already + defined in the current database session. We recommend always using + named queries, since unnamed queries have a limited lifetime and + can be automatically replaced or destroyed by various operations. + """ + if name is None: + name = '' + self._do_debug('prepare', name, command) + self._valid_db.prepare(name, command) + + def describe_prepared(self, name: str | None = None) -> Query: + """Describe a prepared SQL statement. + + This method returns a Query object describing the result columns of + the prepared statement with the given name. If you omit the name, + the unnamed statement will be described if you created one before. + """ + if name is None: + name = '' + return self._valid_db.describe_prepared(name) + + def delete_prepared(self, name: str | None = None) -> Query: + """Delete a prepared SQL statement. + + This deallocates a previously prepared SQL statement with the given + name, or deallocates all prepared statements if you do not specify a + name. Note that prepared statements are also deallocated automatically + when the current session ends. + """ + if not name: + name = 'ALL' + cmd = f"DEALLOCATE {name}" + self._do_debug(cmd) + return self._valid_db.query(cmd) + + def pkey(self, table: str, composite: bool = False, flush: bool = False + ) -> str | tuple[str, ...]: + """Get the primary key of a table. + + Single primary keys are returned as strings unless you + set the composite flag. Composite primary keys are always + represented as tuples. Note that this raises a KeyError + if the table does not have a primary key. + + If flush is set then the internal cache for primary keys will + be flushed. This may be necessary after the database schema or + the search path has been changed. + """ + pkeys = self._pkeys + if flush: + pkeys.clear() + self._do_debug('The pkey cache has been flushed') + try: # cache lookup + pkey = pkeys[table] + except KeyError as e: # cache miss, check the database + cmd = ("SELECT" # noqa: S608 + " a.attname, a.attnum, i.indkey" + " FROM pg_catalog.pg_index i" + " JOIN pg_catalog.pg_attribute a" + " ON a.attrelid OPERATOR(pg_catalog.=) i.indrelid" + " AND a.attnum OPERATOR(pg_catalog.=) ANY(i.indkey)" + " AND NOT a.attisdropped" + " WHERE i.indrelid OPERATOR(pg_catalog.=)" + " {}::pg_catalog.regclass" + " AND i.indisprimary ORDER BY a.attnum").format( + quote_if_unqualified('$1', table)) + res = self._valid_db.query(cmd, (table,)).getresult() + if not res: + raise KeyError(f'Table {table} has no primary key') from e + # we want to use the order defined in the primary key index here, + # not the order as defined by the columns in the table + if len(res) > 1: + indkey = res[0][2] + pkey = tuple(row[0] for row in sorted( + res, key=lambda row: indkey.index(row[1]))) + else: + pkey = res[0][0] + pkeys[table] = pkey # cache it + if composite and not isinstance(pkey, tuple): + pkey = (pkey,) + return pkey + + def pkeys(self, table: str) -> tuple[str, ...]: + """Get the primary key of a table as a tuple. + + Same as pkey() with 'composite' set to True. + """ + return self.pkey(table, True) # type: ignore + + def get_databases(self) -> list[str]: + """Get list of databases in the system.""" + return [r[0] for r in self._valid_db.query( + 'SELECT datname FROM pg_catalog.pg_database').getresult()] + + def get_relations(self, kinds: str | Sequence[str] | None = None, + system: bool = False) -> list[str]: + """Get list of relations in connected database of specified kinds. + + If kinds is None or empty, all kinds of relations are returned. + Otherwise, kinds can be a string or sequence of type letters + specifying which kind of relations you want to list. + + Set the system flag if you want to get the system relations as well. + """ + where_parts = [] + if kinds: + where_parts.append( + "r.relkind IN ({})".format(','.join(f"'{k}'" for k in kinds))) + if not system: + where_parts.append("s.nspname NOT SIMILAR" + " TO 'pg/_%|information/_schema' ESCAPE '/'") + where = " WHERE " + ' AND '.join(where_parts) if where_parts else '' + cmd = ("SELECT" # noqa: S608 + " pg_catalog.quote_ident(s.nspname) OPERATOR(pg_catalog.||)" + " '.' OPERATOR(pg_catalog.||) pg_catalog.quote_ident(r.relname)" + " FROM pg_catalog.pg_class r" + " JOIN pg_catalog.pg_namespace s" + f" ON s.oid OPERATOR(pg_catalog.=) r.relnamespace{where}" + " ORDER BY s.nspname, r.relname") + return [r[0] for r in self._valid_db.query(cmd).getresult()] + + def get_tables(self, system: bool = False) -> list[str]: + """Return list of tables in connected database. + + Set the system flag if you want to get the system tables as well. + """ + return self.get_relations('r', system) + + def get_attnames(self, table: str, with_oid: bool=True, flush: bool=False + ) -> AttrDict: + """Given the name of a table, dig out the set of attribute names. + + Returns a read-only dictionary of attribute names (the names are + the keys, the values are the names of the attributes' types) + with the column names in the proper order if you iterate over it. + + If flush is set, then the internal cache for attribute names will + be flushed. This may be necessary after the database schema or + the search path has been changed. + + By default, only a limited number of simple types will be returned. + You can get the registered types after calling use_regtypes(True). + """ + attnames = self._attnames + if flush: + attnames.clear() + self._do_debug('The attnames cache has been flushed') + try: # cache lookup + names = attnames[table] + except KeyError: # cache miss, check the database + cmd = "a.attnum OPERATOR(pg_catalog.>) 0" + if with_oid: + cmd = f"({cmd} OR a.attname OPERATOR(pg_catalog.=) 'oid')" + cmd = self._query_attnames.format( + quote_if_unqualified('$1', table), cmd) + res = self._valid_db.query(cmd, (table,)).getresult() + types = self.dbtypes + names = AttrDict((name[0], types.add(*name[1:])) for name in res) + attnames[table] = names # cache it + return names + + def get_generated(self, table: str, flush: bool = False) -> frozenset[str]: + """Given the name of a table, dig out the set of generated columns. + + Returns a set of column names that are generated and unalterable. + + If flush is set, then the internal cache for generated columns will + be flushed. This may be necessary after the database schema or + the search path has been changed. + """ + generated = self._generated + if flush: + generated.clear() + self._do_debug('The generated cache has been flushed') + try: # cache lookup + names = generated[table] + except KeyError: # cache miss, check the database + cmd = "a.attnum OPERATOR(pg_catalog.>) 0" + cmd = f"{cmd} AND {self._query_generated}" + cmd = self._query_attnames.format( + quote_if_unqualified('$1', table), cmd) + res = self._valid_db.query(cmd, (table,)).getresult() + names = frozenset(name[0] for name in res) + generated[table] = names # cache it + return names + + def use_regtypes(self, regtypes: bool | None = None) -> bool: + """Use registered type names instead of simplified type names.""" + if regtypes is None: + return self.dbtypes._regtypes + regtypes = bool(regtypes) + if regtypes != self.dbtypes._regtypes: + self.dbtypes._regtypes = regtypes + self._attnames.clear() + self.dbtypes.clear() + return regtypes + + def has_table_privilege(self, table: str, privilege: str = 'select', + flush: bool = False) -> bool: + """Check whether current user has specified table privilege. + + If flush is set, then the internal cache for table privileges will + be flushed. This may be necessary after privileges have been changed. + """ + privileges = self._privileges + if flush: + privileges.clear() + self._do_debug('The privileges cache has been flushed') + privilege = privilege.lower() + try: # ask cache + ret = privileges[table, privilege] + except KeyError: # cache miss, ask the database + cmd = "SELECT pg_catalog.has_table_privilege({}, $2)".format( + quote_if_unqualified('$1', table)) + query = self._valid_db.query(cmd, (table, privilege)) + ret = query.singlescalar() == self._make_bool(True) + privileges[table, privilege] = ret # cache it + return ret + + def get(self, table: str, row: Any, + keyname: str | tuple[str, ...] | None = None) -> dict[str, Any]: + """Get a row from a database table or view. + + This method is the basic mechanism to get a single row. It assumes + that the keyname specifies a unique row. It must be the name of a + single column or a tuple of column names. If the keyname is not + specified, then the primary key for the table is used. + + If row is a dictionary, then the value for the key is taken from it. + Otherwise, the row must be a single value or a tuple of values + corresponding to the passed keyname or primary key. The fetched row + from the table will be returned as a new dictionary or used to replace + the existing values when row was passed as a dictionary. + + The OID is also put into the dictionary if the table has one, but + in order to allow the caller to work with multiple tables, it is + munged as "oid(table)" using the actual name of the table. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + attnames = self.get_attnames(table) + qoid = oid_key(table) if 'oid' in attnames else None + if keyname and isinstance(keyname, str): + keyname = (keyname,) + if qoid and isinstance(row, dict) and qoid in row and 'oid' not in row: + row['oid'] = row[qoid] + if not keyname: + try: # if keyname is not specified, try using the primary key + keyname = self.pkeys(table) + except KeyError as e: # the table has no primary key + # try using the oid instead + if qoid and isinstance(row, dict) and 'oid' in row: + keyname = ('oid',) + else: + raise prg_error( + f'Table {table} has no primary key') from e + else: # the table has a primary key + # check whether all key columns have values + if isinstance(row, dict) and not set(keyname).issubset(row): + # try using the oid instead + if qoid and 'oid' in row: + keyname = ('oid',) + else: + raise KeyError( + 'Missing value in row for specified keyname') + if not isinstance(row, dict): + if not isinstance(row, (tuple, list)): + row = [row] + if len(keyname) != len(row): + raise KeyError( + 'Differing number of items in keyname and row') + row = dict(zip(keyname, row)) + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + what = 'oid, *' if qoid else '*' + where = ' AND '.join( + f'{col(k)} OPERATOR(pg_catalog.=) {adapt(row[k], attnames[k])}' + for k in keyname) + if 'oid' in row: + if qoid: + row[qoid] = row['oid'] + del row['oid'] + t = self._escape_qualified_name(table) + cmd = f'SELECT {what} FROM {t} WHERE {where} LIMIT 1' # noqa: S608s + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if not res: + # make where clause in error message better readable + where = where.replace('OPERATOR(pg_catalog.=)', '=') + raise db_error( + f'No such record in {table}\nwhere {where}\nwith ' + + self._list_params(params)) + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + return row + + def insert(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> dict[str, Any]: + """Insert a row into a database table. + + This method inserts a row into a table. The name of the table must + be passed as the first parameter. The other parameters are used for + providing the data of the row that shall be inserted into the table. + If a dictionary is supplied as the second parameter, it starts with + that. Otherwise, it uses a blank dictionary. + Either way the dictionary is updated from the keywords. + + The dictionary is then reloaded with the values actually inserted in + order to pick up values modified by rules, triggers, etc. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + if row is None: + row = {} + row.update(kw) + if 'oid' in row: + del row['oid'] # do not insert oid + attnames = self.get_attnames(table) + generated = self.get_generated(table) + qoid = oid_key(table) if 'oid' in attnames else None + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + name_list, value_list = [], [] + for n in attnames: + if n in row and n not in generated: + name_list.append(col(n)) + value_list.append(adapt(row[n], attnames[n])) + if not name_list: + raise prg_error('No column found that can be inserted') + names, values = ', '.join(name_list), ', '.join(value_list) + ret = 'oid, *' if qoid else '*' + t = self._escape_qualified_name(table) + cmd = (f'INSERT INTO {t} ({names})' # noqa: S608 + f' VALUES ({values}) RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if res: # this should always be true + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + return row + + def update(self, table: str, row: dict[str, Any] | None = None, **kw : Any + ) -> dict[str, Any]: + """Update an existing row in a database table. + + Similar to insert, but updates an existing row. The update is based + on the primary key of the table or the OID value as munged by get() + or passed as keyword. The OID will take precedence if provided, so + that it is possible to update the primary key itself. + + The dictionary is then modified to reflect any changes caused by the + update due to triggers, rules, default values, etc. + """ + if table.endswith('*'): + table = table[:-1].rstrip() # need parent table name + attnames = self.get_attnames(table) + generated = self.get_generated(table) + qoid = oid_key(table) if 'oid' in attnames else None + if row is None: + row = {} + elif 'oid' in row: + del row['oid'] # only accept oid key from named args for safety + row.update(kw) + if qoid and qoid in row and 'oid' not in row: + row['oid'] = row[qoid] + if qoid and 'oid' in row: # try using the oid + keynames: tuple[str, ...] = ('oid',) + keyset = set(keynames) + else: # try using the primary key + try: + keynames = self.pkeys(table) + except KeyError as e: # the table has no primary key + raise prg_error(f'Table {table} has no primary key') from e + keyset = set(keynames) + # check whether all key columns have values + if not keyset.issubset(row): + raise KeyError('Missing value for primary key in row') + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + where = ' AND '.join( + f'{col(k)} OPERATOR(pg_catalog.=) {adapt(row[k], attnames[k])}' + for k in keynames) + if 'oid' in row: + if qoid: + row[qoid] = row['oid'] + del row['oid'] + values_list = [] + for n in attnames: + if n in row and n not in keyset and n not in generated: + values_list.append(f'{col(n)} = {adapt(row[n], attnames[n])}') + if not values_list: + return row + values = ', '.join(values_list) + ret = 'oid, *' if qoid else '*' + t = self._escape_qualified_name(table) + cmd = (f'UPDATE {t} SET {values}' # noqa: S608 + f' WHERE {where} RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if res: # may be empty when row does not exist + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + return row + + def upsert(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> dict[str, Any]: + """Insert a row into a database table with conflict resolution. + + This method inserts a row into a table, but instead of raising a + ProgrammingError exception in case a row with the same primary key + already exists, an update will be executed instead. This will be + performed as a single atomic operation on the database, so race + conditions can be avoided. + + Like the insert method, the first parameter is the name of the + table and the second parameter can be used to pass the values to + be inserted as a dictionary. + + Unlike the insert und update statement, keyword parameters are not + used to modify the dictionary, but to specify which columns shall + be updated in case of a conflict, and in which way: + + A value of False or None means the column shall not be updated, + a value of True means the column shall be updated with the value + that has been proposed for insertion, i.e. has been passed as value + in the dictionary. Columns that are not specified by keywords but + appear as keys in the dictionary are also updated like in the case + keywords had been passed with the value True. + + So if in the case of a conflict you want to update every column + that has been passed in the dictionary row, you would call + upsert(table, row). If you don't want to do anything in case + of a conflict, i.e. leave the existing row as it is, call + upsert(table, row, **dict.fromkeys(row)). + + If you need more fine-grained control of what gets updated, you can + also pass strings in the keyword parameters. These strings will + be used as SQL expressions for the update columns. In these + expressions you can refer to the value that already exists in + the table by prefixing the column name with "included.", and to + the value that has been proposed for insertion by prefixing the + column name with the "excluded." + + The dictionary is modified in any case to reflect the values in + the database after the operation has completed. + + Note: The method uses the PostgreSQL "upsert" feature which is + only available since PostgreSQL 9.5. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + if row is None: + row = {} + if 'oid' in row: + del row['oid'] # do not insert oid + if 'oid' in kw: + del kw['oid'] # do not update oid + attnames = self.get_attnames(table) + generated = self.get_generated(table) + qoid = oid_key(table) if 'oid' in attnames else None + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + name_list, value_list = [], [] + for n in attnames: + if n in row and n not in generated: + name_list.append(col(n)) + value_list.append(adapt(row[n], attnames[n])) + names, values = ', '.join(name_list), ', '.join(value_list) + try: + keynames = self.pkeys(table) + except KeyError as e: + raise prg_error(f'Table {table} has no primary key') from e + target = ', '.join(col(k) for k in keynames) + update = [] + keyset = set(keynames) + keyset.add('oid') + for n in attnames: + if n not in keyset and n not in generated: + value = kw.get(n, n in row) + if value: + if not isinstance(value, str): + value = f'excluded.{col(n)}' + update.append(f'{col(n)} = {value}') + if not values: + return row + do = 'update set ' + ', '.join(update) if update else 'nothing' + ret = 'oid, *' if qoid else '*' + t = self._escape_qualified_name(table) + cmd = (f'INSERT INTO {t} AS included ({names})' # noqa: S608 + f' VALUES ({values})' + f' ON CONFLICT ({target}) DO {do} RETURNING {ret}') + self._do_debug(cmd, params) + query = self._valid_db.query(cmd, params) + res = query.dictresult() + if res: # may be empty with "do nothing" + for n, value in res[0].items(): + if qoid and n == 'oid': + n = qoid + row[n] = value + else: + self.get(table, row) + return row + + def clear(self, table: str, row: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Clear all the attributes to values determined by the types. + + Numeric types are set to 0, Booleans are set to false, and everything + else is set to the empty string. If the row argument is present, + it is used as the row dictionary and any entries matching attribute + names are cleared with everything else left unchanged. + """ + # At some point we will need a way to get defaults from a table. + if row is None: + row = {} # empty if argument is not present + attnames = self.get_attnames(table) + for n, t in attnames.items(): + if n == 'oid': + continue + t = t.simple + if t in DbTypes._num_types: + row[n] = 0 + elif t == 'bool': + row[n] = self._make_bool(False) + else: + row[n] = '' + return row + + def delete(self, table: str, row: dict[str, Any] | None = None, **kw: Any + ) -> int: + """Delete an existing row in a database table. + + This method deletes the row from a table. It deletes based on the + primary key of the table or the OID value as munged by get() or + passed as keyword. The OID will take precedence if provided. + + The return value is the number of deleted rows (i.e. 0 if the row + did not exist and 1 if the row was deleted). + + Note that if the row cannot be deleted because e.g. it is still + referenced by another table, this method raises a ProgrammingError. + """ + if table.endswith('*'): # hint for descendant tables can be ignored + table = table[:-1].rstrip() + attnames = self.get_attnames(table) + qoid = oid_key(table) if 'oid' in attnames else None + if row is None: + row = {} + elif 'oid' in row: + del row['oid'] # only accept oid key from named args for safety + row.update(kw) + if qoid and qoid in row and 'oid' not in row: + row['oid'] = row[qoid] + if qoid and 'oid' in row: # try using the oid + keynames: tuple[str, ...] = ('oid',) + else: # try using the primary key + try: + keynames = self.pkeys(table) + except KeyError as e: # the table has no primary key + raise prg_error(f'Table {table} has no primary key') from e + # check whether all key columns have values + if not set(keynames).issubset(row): + raise KeyError('Missing value for primary key in row') + params = self.adapter.parameter_list() + adapt = params.add + col = self.escape_identifier + where = ' AND '.join( + f'{col(k)} OPERATOR(pg_catalog.=) {adapt(row[k], attnames[k])}' + for k in keynames) + if 'oid' in row: + if qoid: + row[qoid] = row['oid'] + del row['oid'] + t = self._escape_qualified_name(table) + cmd = f'DELETE FROM {t} WHERE {where}' # noqa: S608 + self._do_debug(cmd, params) + res = self._valid_db.query(cmd, params) + return int(res) # type: ignore + + def truncate(self, table: str | list[str] | tuple[str, ...] | + set[str] | frozenset[str], restart: bool = False, + cascade: bool = False, only: bool = False) -> Query: + """Empty a table or set of tables. + + This method quickly removes all rows from the given table or set + of tables. It has the same effect as an unqualified DELETE on each + table, but since it does not actually scan the tables it is faster. + Furthermore, it reclaims disk space immediately, rather than requiring + a subsequent VACUUM operation. This is most useful on large tables. + + If restart is set to True, sequences owned by columns of the truncated + table(s) are automatically restarted. If cascade is set to True, it + also truncates all tables that have foreign-key references to any of + the named tables. If the parameter 'only' is not set to True, all the + descendant tables (if any) will also be truncated. Optionally, a '*' + can be specified after the table name to explicitly indicate that + descendant tables are included. + """ + if isinstance(table, str): + table_only = {table: only} + table = [table] + elif isinstance(table, (list, tuple)): + if isinstance(only, (list, tuple)): + table_only = dict(zip(table, only)) + else: + table_only = dict.fromkeys(table, only) + elif isinstance(table, (set, frozenset)): + table_only = dict.fromkeys(table, only) + else: + raise TypeError('The table must be a string, list or set') + if not (restart is None or isinstance(restart, (bool, int))): + raise TypeError('Invalid type for the restart option') + if not (cascade is None or isinstance(cascade, (bool, int))): + raise TypeError('Invalid type for the cascade option') + tables = [] + for t in table: + u = table_only.get(t) + if not (u is None or isinstance(u, (bool, int))): + raise TypeError('Invalid type for the only option') + if t.endswith('*'): + if u: + raise ValueError( + 'Contradictory table name and only options') + t = t[:-1].rstrip() + t = self._escape_qualified_name(t) + if u: + t = f'ONLY {t}' + tables.append(t) + cmd_parts = ['TRUNCATE', ', '.join(tables)] + if restart: + cmd_parts.append('RESTART IDENTITY') + if cascade: + cmd_parts.append('CASCADE') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + return self._valid_db.query(cmd) + + def get_as_list( + self, table: str, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | bool | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> list: + """Get a table as a list. + + This gets a convenient representation of the table as a list + of named tuples in Python. You only need to pass the name of + the table (or any other SQL expression returning rows). Note that + by default this will return the full content of the table which + can be huge and overflow your memory. However, you can control + the amount of data returned using the other optional parameters. + + The parameter 'what' can restrict the query to only return a + subset of the table columns. It can be a string, list or a tuple. + + The parameter 'where' can restrict the query to only return a + subset of the table rows. It can be a string, list or a tuple + of SQL expressions that all need to be fulfilled. + + The parameter 'order' specifies the ordering of the rows. It can + also be a string, list or a tuple. If no ordering is specified, + the result will be ordered by the primary key(s) or all columns if + no primary key exists. You can set 'order' to False if you don't + care about the ordering. The parameters 'limit' and 'offset' can be + integers specifying the maximum number of rows returned and a number + of rows skipped over. + + If you set the 'scalar' option to True, then instead of the + named tuples you will get the first items of these tuples. + This is useful if the result has only one column anyway. + """ + if not table: + raise TypeError('The table name is missing') + if what: + if isinstance(what, (list, tuple)): + what = ', '.join(map(str, what)) + if order is None: + order = what + else: + what = '*' + cmd_parts = ['SELECT', what, 'FROM', table] + if where: + if isinstance(where, (list, tuple)): + where = ' AND '.join(map(str, where)) + cmd_parts.extend(['WHERE', where]) + if order is None or order is True: + try: + order = self.pkeys(table) + except (KeyError, ProgrammingError): + with suppress(KeyError, ProgrammingError): + order = list(self.get_attnames(table)) + if order and not isinstance(order, bool): + if isinstance(order, (list, tuple)): + order = ', '.join(map(str, order)) + cmd_parts.extend(['ORDER BY', order]) + if limit: + cmd_parts.append(f'LIMIT {limit}') + if offset: + cmd_parts.append(f'OFFSET {offset}') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + query = self._valid_db.query(cmd) + res = query.namedresult() + if res and scalar: + res = [row[0] for row in res] + return res + + def get_as_dict( + self, table: str, + keyname: str | list[str] | tuple[str, ...] | None = None, + what: str | list[str] | tuple[str, ...] | None = None, + where: str | list[str] | tuple[str, ...] | None = None, + order: str | list[str] | tuple[str, ...] | bool | None = None, + limit: int | None = None, offset: int | None = None, + scalar: bool = False) -> dict: + """Get a table as a dictionary. + + This method is similar to get_as_list(), but returns the table + as a Python dict instead of a Python list, which can be even + more convenient. The primary key column(s) of the table will + be used as the keys of the dictionary, while the other column(s) + will be the corresponding values. The keys will be named tuples + if the table has a composite primary key. The rows will be also + named tuples unless the 'scalar' option has been set to True. + With the optional parameter 'keyname' you can specify an alternative + set of columns to be used as the keys of the dictionary. It must + be set as a string, list or a tuple. + + The dictionary will be ordered using the order specified with the + 'order' parameter or the key column(s) if not specified. You can + set 'order' to False if you don't care about the ordering. + """ + if not table: + raise TypeError('The table name is missing') + if not keyname: + try: + keyname = self.pkeys(table) + except (KeyError, ProgrammingError) as e: + raise prg_error(f'Table {table} has no primary key') from e + if isinstance(keyname, str): + keynames: list[str] | tuple[str, ...] = (keyname,) + elif isinstance(keyname, (list, tuple)): + keynames = keyname + else: + raise KeyError('The keyname must be a string, list or tuple') + if what: + if isinstance(what, (list, tuple)): + what = ', '.join(map(str, what)) + if order is None: + order = what + else: + what = '*' + cmd_parts = ['SELECT', what, 'FROM', table] + if where: + if isinstance(where, (list, tuple)): + where = ' AND '.join(map(str, where)) + cmd_parts.extend(['WHERE', where]) + if order is None or order is True: + order = keyname + if order and not isinstance(order, bool): + if isinstance(order, (list, tuple)): + order = ', '.join(map(str, order)) + cmd_parts.extend(['ORDER BY', order]) + if limit: + cmd_parts.append(f'LIMIT {limit}') + if offset: + cmd_parts.append(f'OFFSET {offset}') + cmd = ' '.join(cmd_parts) + self._do_debug(cmd) + query = self._valid_db.query(cmd) + res = query.getresult() + if not res: + return {} + keyset = set(keynames) + fields = query.listfields() + if not keyset.issubset(fields): + raise KeyError('Missing keyname in row') + key_index: list[int] = [] + row_index: list[int] = [] + for i, f in enumerate(fields): + (key_index if f in keyset else row_index).append(i) + key_tuple = len(key_index) > 1 + get_key = itemgetter(*key_index) + keys = map(get_key, res) + if scalar: + row_index = row_index[:1] + row_is_tuple = False + else: + row_is_tuple = len(row_index) > 1 + if scalar or row_is_tuple: + get_row: Callable[[tuple], tuple] = itemgetter( # pyright: ignore + *row_index) + else: + frst_index = row_index[0] + + def get_row(row : tuple) -> tuple: + return row[frst_index], # tuple with one item + + row_is_tuple = True + rows = map(get_row, res) + if key_tuple or row_is_tuple: + if key_tuple: + keys = namediter(_MemoryQuery(keys, keynames)) # type: ignore + if row_is_tuple: + fields = tuple(f for f in fields if f not in keyset) + rows = namediter(_MemoryQuery(rows, fields)) # type: ignore + # noinspection PyArgumentList + return dict(zip(keys, rows)) + + def notification_handler(self, event: str, callback: Callable, + arg_dict: dict | None = None, + timeout: int | float | None = None, + stop_event: str | None = None + ) -> NotificationHandler: + """Get notification handler that will run the given callback.""" + return NotificationHandler(self, event, callback, + arg_dict, timeout, stop_event) + + # immediately wrapped methods + + def send_query(self, cmd: str, args: Sequence | None = None) -> Query: + """Create a new asynchronous query object for this connection.""" + if args is None: + return self._valid_db.send_query(cmd) + return self._valid_db.send_query(cmd, args) + + def poll(self) -> int: + """Complete an asynchronous connection and get its state.""" + return self._valid_db.poll() + + def cancel(self) -> None: + """Abandon processing of current SQL command.""" + self._valid_db.cancel() + + def fileno(self) -> int: + """Get the socket used to connect to the database.""" + return self._valid_db.fileno() + + def get_cast_hook(self) -> Callable | None: + """Get the function that handles all external typecasting.""" + return self._valid_db.get_cast_hook() + + def set_cast_hook(self, hook: Callable | None) -> None: + """Set a function that will handle all external typecasting.""" + self._valid_db.set_cast_hook(hook) + + def get_notice_receiver(self) -> Callable | None: + """Get the current notice receiver.""" + return self._valid_db.get_notice_receiver() + + def set_notice_receiver(self, receiver: Callable | None) -> None: + """Set a custom notice receiver.""" + self._valid_db.set_notice_receiver(receiver) + + def getnotify(self) -> tuple[str, int, str] | None: + """Get the last notify from the server.""" + return self._valid_db.getnotify() + + def inserttable(self, table: str, values: Sequence[list|tuple], + columns: list[str] | tuple[str, ...] | None = None) -> int: + """Insert a Python iterable into a database table.""" + if columns is None: + return self._valid_db.inserttable(table, values) + return self._valid_db.inserttable(table, values, columns) + + def transaction(self) -> int: + """Get the current in-transaction status of the server. + + The status returned by this method can be TRANS_IDLE (currently idle), + TRANS_ACTIVE (a command is in progress), TRANS_INTRANS (idle, in a + valid transaction block), or TRANS_INERROR (idle, in a failed + transaction block). TRANS_UNKNOWN is reported if the connection is + bad. The status TRANS_ACTIVE is reported only when a query has been + sent to the server and not yet completed. + """ + return self._valid_db.transaction() + + def parameter(self, name: str) -> str | None: + """Look up a current parameter setting of the server.""" + return self._valid_db.parameter(name) + + + def date_format(self) -> str: + """Look up the date format currently being used by the database.""" + return self._valid_db.date_format() + + def escape_literal(self, s: AnyStr) -> AnyStr: + """Escape a literal constant for use within SQL.""" + return self._valid_db.escape_literal(s) + + def escape_identifier(self, s: AnyStr) -> AnyStr: + """Escape an identifier for use within SQL.""" + return self._valid_db.escape_identifier(s) + + def escape_string(self, s: AnyStr) -> AnyStr: + """Escape a string for use within SQL.""" + return self._valid_db.escape_string(s) + + def escape_bytea(self, s: AnyStr) -> AnyStr: + """Escape binary data for use within SQL as type 'bytea'.""" + return self._valid_db.escape_bytea(s) + + def putline(self, line: str) -> None: + """Write a line to the server socket.""" + self._valid_db.putline(line) + + def getline(self) -> str: + """Get a line from server socket.""" + return self._valid_db.getline() + + def endcopy(self) -> None: + """Synchronize client and server.""" + self._valid_db.endcopy() + + def set_non_blocking(self, nb: bool) -> None: + """Set the non-blocking mode of the connection.""" + self._valid_db.set_non_blocking(nb) + + def is_non_blocking(self) -> bool: + """Get the non-blocking mode of the connection.""" + return self._valid_db.is_non_blocking() + + def locreate(self, mode: int) -> LargeObject: + """Create a large object in the database. + + The valid values for 'mode' parameter are defined as the module level + constants INV_READ and INV_WRITE. + """ + return self._valid_db.locreate(mode) + + def getlo(self, oid: int) -> LargeObject: + """Build a large object from given oid.""" + return self._valid_db.getlo(oid) + + def loimport(self, filename: str) -> LargeObject: + """Import a file to a large object.""" + return self._valid_db.loimport(filename) + + +class _MemoryQuery: + """Class that embodies a given query result.""" + + result: Any + fields: tuple[str, ...] + + def __init__(self, result: Any, fields: Sequence[str]) -> None: + """Create query from given result rows and field names.""" + self.result = result + self.fields = tuple(fields) + + def listfields(self) -> tuple[str, ...]: + """Return the stored field names of this query.""" + return self.fields + + def getresult(self) -> Any: + """Return the stored result of this query.""" + return self.result + + def __iter__(self) -> Iterator[Any]: + return iter(self.result) \ No newline at end of file diff --git a/pg/error.py b/pg/error.py new file mode 100644 index 00000000..f4b9fd0f --- /dev/null +++ b/pg/error.py @@ -0,0 +1,59 @@ +"""Error helpers.""" + +from __future__ import annotations + +from typing import TypeVar + +from .core import ( + DatabaseError, + Error, + InterfaceError, + InternalError, + OperationalError, + ProgrammingError, +) + +__all__ = [ + 'db_error', + 'error', + 'if_error', + 'int_error', + 'op_error', + 'prg_error' +] + +# Error messages + +E = TypeVar('E', bound=Error) + +def error(msg: str, cls: type[E]) -> E: + """Return specified error object with empty sqlstate attribute.""" + error = cls(msg) + if isinstance(error, DatabaseError): + error.sqlstate = None + return error + + +def db_error(msg: str) -> DatabaseError: + """Return DatabaseError.""" + return error(msg, DatabaseError) + + +def int_error(msg: str) -> InternalError: + """Return InternalError.""" + return error(msg, InternalError) + + +def prg_error(msg: str) -> ProgrammingError: + """Return ProgrammingError.""" + return error(msg, ProgrammingError) + + +def if_error(msg: str) -> InterfaceError: + """Return InterfaceError.""" + return error(msg, InterfaceError) + + +def op_error(msg: str) -> OperationalError: + """Return OperationalError.""" + return error(msg, OperationalError) diff --git a/pg/helpers.py b/pg/helpers.py new file mode 100644 index 00000000..9d176740 --- /dev/null +++ b/pg/helpers.py @@ -0,0 +1,124 @@ +"""Helper functions.""" + +from __future__ import annotations + +from collections import namedtuple +from decimal import Decimal +from functools import lru_cache +from json import loads as jsondecode +from typing import Any, Callable, Generator, NamedTuple, Sequence + +from .core import Query, set_decimal, set_jsondecode, set_query_helpers + +SomeNamedTuple = Any # alias for accessing arbitrary named tuples + +__all__ = [ + 'QuoteDict', + 'RowCache', + 'dictiter', + 'namediter', + 'namednext', + 'oid_key', + 'quote_if_unqualified', + 'scalariter' +] + + +# Small helper functions + +def quote_if_unqualified(param: str, name: int | str) -> str: + """Quote parameter representing a qualified name. + + Puts a quote_ident() call around the given parameter unless + the name contains a dot, in which case the name is ambiguous + (could be a qualified name or just a name with a dot in it) + and must be quoted manually by the caller. + """ + if isinstance(name, str) and '.' not in name: + return f'quote_ident({param})' + return param + +def oid_key(table: str) -> str: + """Build oid key from a table name.""" + return f'oid({table})' + +class QuoteDict(dict): + """Dictionary with auto quoting of its items. + + The quote attribute must be set to the desired quote function. + """ + + quote: Callable[[str], str] + + def __getitem__(self, key: str) -> str: + """Get a quoted value.""" + return self.quote(super().__getitem__(key)) + + +class RowCache: + """Global cache for the named tuples used for table rows. + + The result rows for database operations are returned as named tuples + by default. Since creating namedtuple classes is a somewhat expensive + operation, we cache up to 1024 of these classes by default. + """ + + @staticmethod + @lru_cache(maxsize=1024) + def row_factory(names: Sequence[str]) -> Callable[[Sequence], NamedTuple]: + """Get a namedtuple factory for row results with the given names.""" + try: + return namedtuple('Row', names, rename=True)._make # type: ignore + except ValueError: # there is still a problem with the field names + names = [f'column_{n}' for n in range(len(names))] + return namedtuple('Row', names)._make # type: ignore + + @classmethod + def clear(cls) -> None: + """Clear the namedtuple factory cache.""" + cls.row_factory.cache_clear() + + @classmethod + def change_size(cls, maxsize: int | None) -> None: + """Change the size of the namedtuple factory cache. + + If maxsize is set to None, the cache can grow without bound. + """ + row_factory = cls.row_factory.__wrapped__ + cls.row_factory = lru_cache(maxsize)(row_factory) # type: ignore + + +# Helper functions used by the query object + +def dictiter(q: Query) -> Generator[dict[str, Any], None, None]: + """Get query result as an iterator of dictionaries.""" + fields: tuple[str, ...] = q.listfields() + for r in q: + yield dict(zip(fields, r)) + + +def namediter(q: Query) -> Generator[SomeNamedTuple, None, None]: + """Get query result as an iterator of named tuples.""" + row = RowCache.row_factory(q.listfields()) + for r in q: + yield row(r) + + +def namednext(q: Query) -> SomeNamedTuple: + """Get next row from query result as a named tuple.""" + return RowCache.row_factory(q.listfields())(next(q)) + + +def scalariter(q: Query) -> Generator[Any, None, None]: + """Get query result as an iterator of scalar values.""" + for r in q: + yield r[0] + + +# Initialization + +def init_core() -> None: + """Initialize the C extension module.""" + set_decimal(Decimal) + set_jsondecode(jsondecode) + set_query_helpers(dictiter, namediter, namednext, scalariter) diff --git a/pg/notify.py b/pg/notify.py new file mode 100644 index 00000000..e273c521 --- /dev/null +++ b/pg/notify.py @@ -0,0 +1,149 @@ +"""Handling of notifications.""" + +from __future__ import annotations + +import select +from typing import TYPE_CHECKING, Callable + +from .core import Query +from .error import db_error + +if TYPE_CHECKING: + from .db import DB + +__all__ = ['NotificationHandler'] + +# The notification handler + +class NotificationHandler: + """A PostgreSQL client-side asynchronous notification handler.""" + + def __init__(self, db: DB, event: str, callback: Callable, + arg_dict: dict | None = None, + timeout: int | float | None = None, + stop_event: str | None = None): + """Initialize the notification handler. + + You must pass a PyGreSQL database connection, the name of an + event (notification channel) to listen for and a callback function. + + You can also specify a dictionary arg_dict that will be passed as + the single argument to the callback function, and a timeout value + in seconds (a floating point number denotes fractions of seconds). + If it is absent or None, the callers will never time out. If the + timeout is reached, the callback function will be called with a + single argument that is None. If you set the timeout to zero, + the handler will poll notifications synchronously and return. + + You can specify the name of the event that will be used to signal + the handler to stop listening as stop_event. By default, it will + be the event name prefixed with 'stop_'. + """ + self.db: DB | None = db + self.event = event + self.stop_event = stop_event or f'stop_{event}' + self.listening = False + self.callback = callback + if arg_dict is None: + arg_dict = {} + self.arg_dict = arg_dict + self.timeout = timeout + + def __del__(self) -> None: + """Delete the notification handler.""" + self.unlisten() + + def close(self) -> None: + """Stop listening and close the connection.""" + if self.db: + self.unlisten() + self.db.close() + self.db = None + + def listen(self) -> None: + """Start listening for the event and the stop event.""" + db = self.db + if db and not self.listening: + db.query(f'listen "{self.event}"') + db.query(f'listen "{self.stop_event}"') + self.listening = True + + def unlisten(self) -> None: + """Stop listening for the event and the stop event.""" + db = self.db + if db and self.listening: + db.query(f'unlisten "{self.event}"') + db.query(f'unlisten "{self.stop_event}"') + self.listening = False + + def notify(self, db: DB | None = None, stop: bool = False, + payload: str | None = None) -> Query | None: + """Generate a notification. + + Optionally, you can pass a payload with the notification. + + If you set the stop flag, a stop notification will be sent that + will cause the handler to stop listening. + + Note: If the notification handler is running in another thread, you + must pass a different database connection since PyGreSQL database + connections are not thread-safe. + """ + if not self.listening: + return None + if not db: + db = self.db + if not db: + return None + event = self.stop_event if stop else self.event + cmd = f'notify "{event}"' + if payload: + cmd += f", '{payload}'" + return db.query(cmd) + + def __call__(self) -> None: + """Invoke the notification handler. + + The handler is a loop that listens for notifications on the event + and stop event channels. When either of these notifications are + received, its associated 'pid', 'event' and 'extra' (the payload + passed with the notification) are inserted into its arg_dict + dictionary and the callback is invoked with this dictionary as + a single argument. When the handler receives a stop event, it + stops listening to both events and return. + + In the special case that the timeout of the handler has been set + to zero, the handler will poll all events synchronously and return. + If will keep listening until it receives a stop event. + + Note: If you run this loop in another thread, don't use the same + database connection for database operations in the main thread. + """ + if not self.db: + return + self.listen() + poll = self.timeout == 0 + rlist = [] if poll else [self.db.fileno()] + while self.db and self.listening: + # noinspection PyUnboundLocalVariable + if poll or select.select(rlist, [], [], self.timeout)[0]: + while self.db and self.listening: + notice = self.db.getnotify() + if not notice: # no more messages + break + event, pid, extra = notice + if event not in (self.event, self.stop_event): + self.unlisten() + raise db_error( + f'Listening for "{self.event}"' + f' and "{self.stop_event}",' + f' but notified of "{event}"') + if event == self.stop_event: + self.unlisten() + self.arg_dict.update(pid=pid, event=event, extra=extra) + self.callback(self.arg_dict) + if poll: + break + else: # we timed out + self.unlisten() + self.callback(None) \ No newline at end of file diff --git a/pg/py.typed b/pg/py.typed new file mode 100644 index 00000000..ea6e1ace --- /dev/null +++ b/pg/py.typed @@ -0,0 +1,4 @@ +# Marker file for PEP 561. + +# The pg package use inline types, +# except for the _pg extension module which uses a stub file. diff --git a/pg/tz.py b/pg/tz.py new file mode 100644 index 00000000..7f22e049 --- /dev/null +++ b/pg/tz.py @@ -0,0 +1,21 @@ +"""Timezone helpers.""" + +from __future__ import annotations + +__all__ = ['timezone_as_offset'] + +# time zones used in Postgres timestamptz output +_timezone_offsets: dict[str, str] = { + 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', + 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', + 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' +} + + +def timezone_as_offset(tz: str) -> str: + """Convert timezone abbreviation to offset.""" + if tz.startswith(('+', '-')): + if len(tz) < 5: + return tz + '00' + return tz.replace(':', '') + return _timezone_offsets.get(tz, '+0000') \ No newline at end of file diff --git a/pgconn.c b/pgconn.c deleted file mode 100644 index e16fd68e..00000000 --- a/pgconn.c +++ /dev/null @@ -1,1517 +0,0 @@ -/* - * PyGreSQL - a Python interface for the PostgreSQL database. - * - * The connection object - this file is part a of the C extension module. - * - * Copyright (c) 2020 by the PyGreSQL Development Team - * - * Please see the LICENSE.TXT file for specific restrictions. - */ - -/* Deallocate connection object. */ -static void -conn_dealloc(connObject *self) -{ - if (self->cnx) { - Py_BEGIN_ALLOW_THREADS - PQfinish(self->cnx); - Py_END_ALLOW_THREADS - } - Py_XDECREF(self->cast_hook); - Py_XDECREF(self->notice_receiver); - PyObject_Del(self); -} - -/* Get connection attributes. */ -static PyObject * -conn_getattr(connObject *self, PyObject *nameobj) -{ - const char *name = PyStr_AsString(nameobj); - - /* - * Although we could check individually, there are only a few - * attributes that don't require a live connection and unless someone - * has an urgent need, this will have to do. - */ - - /* first exception - close which returns a different error */ - if (strcmp(name, "close") && !self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* list PostgreSQL connection fields */ - - /* postmaster host */ - if (!strcmp(name, "host")) { - char *r = PQhost(self->cnx); - if (!r || r[0] == '/') /* Pg >= 9.6 can return a Unix socket path */ - r = "localhost"; - return PyStr_FromString(r); - } - - /* postmaster port */ - if (!strcmp(name, "port")) - return PyInt_FromLong(atol(PQport(self->cnx))); - - /* selected database */ - if (!strcmp(name, "db")) - return PyStr_FromString(PQdb(self->cnx)); - - /* selected options */ - if (!strcmp(name, "options")) - return PyStr_FromString(PQoptions(self->cnx)); - - /* error (status) message */ - if (!strcmp(name, "error")) - return PyStr_FromString(PQerrorMessage(self->cnx)); - - /* connection status : 1 - OK, 0 - BAD */ - if (!strcmp(name, "status")) - return PyInt_FromLong(PQstatus(self->cnx) == CONNECTION_OK ? 1 : 0); - - /* provided user name */ - if (!strcmp(name, "user")) - return PyStr_FromString(PQuser(self->cnx)); - - /* protocol version */ - if (!strcmp(name, "protocol_version")) - return PyInt_FromLong(PQprotocolVersion(self->cnx)); - - /* backend version */ - if (!strcmp(name, "server_version")) - return PyInt_FromLong(PQserverVersion(self->cnx)); - - /* descriptor number of connection socket */ - if (!strcmp(name, "socket")) { - return PyInt_FromLong(PQsocket(self->cnx)); - } - - /* PID of backend process */ - if (!strcmp(name, "backend_pid")) { - return PyInt_FromLong(PQbackendPID(self->cnx)); - } - - /* whether the connection uses SSL */ - if (!strcmp(name, "ssl_in_use")) { -#ifdef SSL_INFO - if (PQsslInUse(self->cnx)) { - Py_INCREF(Py_True); return Py_True; - } - else { - Py_INCREF(Py_False); return Py_False; - } -#else - set_error_msg(NotSupportedError, "SSL info functions not supported"); - return NULL; -#endif - } - - /* SSL attributes */ - if (!strcmp(name, "ssl_attributes")) { -#ifdef SSL_INFO - return get_ssl_attributes(self->cnx); -#else - set_error_msg(NotSupportedError, "SSL info functions not supported"); - return NULL; -#endif - } - - return PyObject_GenericGetAttr((PyObject *) self, nameobj); -} - -/* Check connection validity. */ -static int -_check_cnx_obj(connObject *self) -{ - if (!self || !self->valid || !self->cnx) { - set_error_msg(OperationalError, "Connection has been closed"); - return 0; - } - return 1; -} - -/* Create source object. */ -static char conn_source__doc__[] = -"source() -- create a new source object for this connection"; - -static PyObject * -conn_source(connObject *self, PyObject *noargs) -{ - sourceObject *source_obj; - - /* checks validity */ - if (!_check_cnx_obj(self)) { - return NULL; - } - - /* allocates new query object */ - if (!(source_obj = PyObject_New(sourceObject, &sourceType))) { - return NULL; - } - - /* initializes internal parameters */ - Py_XINCREF(self); - source_obj->pgcnx = self; - source_obj->result = NULL; - source_obj->valid = 1; - source_obj->arraysize = PG_ARRAYSIZE; - - return (PyObject *) source_obj; -} - -/* Base method for execution of both unprepared and prepared queries */ -static PyObject * -_conn_query(connObject *self, PyObject *args, int prepared) -{ - PyObject *query_str_obj, *param_obj = NULL; - PGresult* result; - queryObject* query_obj; - char *query; - int encoding, status, nparms = 0; - - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* get query args */ - if (!PyArg_ParseTuple(args, "O|O", &query_str_obj, ¶m_obj)) { - return NULL; - } - - encoding = PQclientEncoding(self->cnx); - - if (PyBytes_Check(query_str_obj)) { - query = PyBytes_AsString(query_str_obj); - query_str_obj = NULL; - } - else if (PyUnicode_Check(query_str_obj)) { - query_str_obj = get_encoded_string(query_str_obj, encoding); - if (!query_str_obj) return NULL; /* pass the UnicodeEncodeError */ - query = PyBytes_AsString(query_str_obj); - } - else { - PyErr_SetString(PyExc_TypeError, - "Method query() expects a string as first argument"); - return NULL; - } - - /* If param_obj is passed, ensure it's a non-empty tuple. We want to treat - * an empty tuple the same as no argument since we'll get that when the - * caller passes no arguments to db.query(), and historic behaviour was - * to call PQexec() in that case, which can execute multiple commands. */ - if (param_obj) { - param_obj = PySequence_Fast( - param_obj, "Method query() expects a sequence as second argument"); - if (!param_obj) { - Py_XDECREF(query_str_obj); - return NULL; - } - nparms = (int) PySequence_Fast_GET_SIZE(param_obj); - - /* if there's a single argument and it's a list or tuple, it - * contains the positional arguments. */ - if (nparms == 1) { - PyObject *first_obj = PySequence_Fast_GET_ITEM(param_obj, 0); - if (PyList_Check(first_obj) || PyTuple_Check(first_obj)) { - Py_DECREF(param_obj); - param_obj = PySequence_Fast(first_obj, NULL); - nparms = (int) PySequence_Fast_GET_SIZE(param_obj); - } - } - } - - /* gets result */ - if (nparms) { - /* prepare arguments */ - PyObject **str, **s; - const char **parms, **p; - register int i; - - str = (PyObject **) PyMem_Malloc((size_t) nparms * sizeof(*str)); - parms = (const char **) PyMem_Malloc((size_t) nparms * sizeof(*parms)); - if (!str || !parms) { - PyMem_Free((void *) parms); PyMem_Free(str); - Py_XDECREF(query_str_obj); Py_XDECREF(param_obj); - return PyErr_NoMemory(); - } - - /* convert optional args to a list of strings -- this allows - * the caller to pass whatever they like, and prevents us - * from having to map types to OIDs */ - for (i = 0, s = str, p = parms; i < nparms; ++i, ++p) { - PyObject *obj = PySequence_Fast_GET_ITEM(param_obj, i); - - if (obj == Py_None) { - *p = NULL; - } - else if (PyBytes_Check(obj)) { - *p = PyBytes_AsString(obj); - } - else if (PyUnicode_Check(obj)) { - PyObject *str_obj = get_encoded_string(obj, encoding); - if (!str_obj) { - PyMem_Free((void *) parms); - while (s != str) { s--; Py_DECREF(*s); } - PyMem_Free(str); - Py_XDECREF(query_str_obj); - Py_XDECREF(param_obj); - /* pass the UnicodeEncodeError */ - return NULL; - } - *s++ = str_obj; - *p = PyBytes_AsString(str_obj); - } - else { - PyObject *str_obj = PyObject_Str(obj); - if (!str_obj) { - PyMem_Free((void *) parms); - while (s != str) { s--; Py_DECREF(*s); } - PyMem_Free(str); - Py_XDECREF(query_str_obj); - Py_XDECREF(param_obj); - PyErr_SetString( - PyExc_TypeError, - "Query parameter has no string representation"); - return NULL; - } - *s++ = str_obj; - *p = PyStr_AsString(str_obj); - } - } - - Py_BEGIN_ALLOW_THREADS - result = prepared ? - PQexecPrepared(self->cnx, query, nparms, - parms, NULL, NULL, 0) : - PQexecParams(self->cnx, query, nparms, - NULL, parms, NULL, NULL, 0); - Py_END_ALLOW_THREADS - - PyMem_Free((void *) parms); - while (s != str) { s--; Py_DECREF(*s); } - PyMem_Free(str); - } - else { - Py_BEGIN_ALLOW_THREADS - result = prepared ? - PQexecPrepared(self->cnx, query, 0, - NULL, NULL, NULL, 0) : - PQexec(self->cnx, query); - Py_END_ALLOW_THREADS - } - - /* we don't need the query and its params any more */ - Py_XDECREF(query_str_obj); - Py_XDECREF(param_obj); - - /* checks result validity */ - if (!result) { - PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); - return NULL; - } - - /* this may have changed the datestyle, so we reset the date format - in order to force fetching it newly when next time requested */ - self->date_format = date_format; /* this is normally NULL */ - - /* checks result status */ - if ((status = PQresultStatus(result)) != PGRES_TUPLES_OK) { - switch (status) { - case PGRES_EMPTY_QUERY: - PyErr_SetString(PyExc_ValueError, "Empty query"); - break; - case PGRES_BAD_RESPONSE: - case PGRES_FATAL_ERROR: - case PGRES_NONFATAL_ERROR: - set_error(ProgrammingError, "Cannot execute query", - self->cnx, result); - break; - case PGRES_COMMAND_OK: - { /* INSERT, UPDATE, DELETE */ - Oid oid = PQoidValue(result); - - if (oid == InvalidOid) { /* not a single insert */ - char *ret = PQcmdTuples(result); - - if (ret[0]) { /* return number of rows affected */ - PyObject *obj = PyStr_FromString(ret); - PQclear(result); - return obj; - } - PQclear(result); - Py_INCREF(Py_None); - return Py_None; - } - /* for a single insert, return the oid */ - PQclear(result); - return PyInt_FromLong(oid); - } - case PGRES_COPY_OUT: /* no data will be received */ - case PGRES_COPY_IN: - PQclear(result); - Py_INCREF(Py_None); - return Py_None; - default: - set_error_msg(InternalError, "Unknown result status"); - } - - PQclear(result); - return NULL; /* error detected on query */ - } - - if (!(query_obj = PyObject_New(queryObject, &queryType))) - return PyErr_NoMemory(); - - /* stores result and returns object */ - Py_XINCREF(self); - query_obj->pgcnx = self; - query_obj->result = result; - query_obj->encoding = encoding; - query_obj->current_row = 0; - query_obj->max_row = PQntuples(result); - query_obj->num_fields = PQnfields(result); - query_obj->col_types = get_col_types(result, query_obj->num_fields); - if (!query_obj->col_types) { - Py_DECREF(query_obj); - Py_DECREF(self); - return NULL; - } - - return (PyObject *) query_obj; -} - -/* Database query */ -static char conn_query__doc__[] = -"query(sql, [arg]) -- create a new query object for this connection\n\n" -"You must pass the SQL (string) request and you can optionally pass\n" -"a tuple with positional parameters.\n"; - -static PyObject * -conn_query(connObject *self, PyObject *args) -{ - return _conn_query(self, args, 0); -} - -/* Execute prepared statement. */ -static char conn_query_prepared__doc__[] = -"query_prepared(name, [arg]) -- execute a prepared statement\n\n" -"You must pass the name (string) of the prepared statement and you can\n" -"optionally pass a tuple with positional parameters.\n"; - -static PyObject * -conn_query_prepared(connObject *self, PyObject *args) -{ - return _conn_query(self, args, 1); -} - -/* Create prepared statement. */ -static char conn_prepare__doc__[] = -"prepare(name, sql) -- create a prepared statement\n\n" -"You must pass the name (string) of the prepared statement and the\n" -"SQL (string) request for later execution.\n"; - -static PyObject * -conn_prepare(connObject *self, PyObject *args) -{ - char *name, *query; - Py_ssize_t name_length, query_length; - PGresult *result; - - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* reads args */ - if (!PyArg_ParseTuple(args, "s#s#", - &name, &name_length, &query, &query_length)) - { - PyErr_SetString(PyExc_TypeError, - "Method prepare() takes two string arguments"); - return NULL; - } - - /* create prepared statement */ - Py_BEGIN_ALLOW_THREADS - result = PQprepare(self->cnx, name, query, 0, NULL); - Py_END_ALLOW_THREADS - if (result && PQresultStatus(result) == PGRES_COMMAND_OK) { - PQclear(result); - Py_INCREF(Py_None); - return Py_None; /* success */ - } - set_error(ProgrammingError, "Cannot create prepared statement", - self->cnx, result); - if (result) - PQclear(result); - return NULL; /* error */ -} - -/* Describe prepared statement. */ -static char conn_describe_prepared__doc__[] = -"describe_prepared(name) -- describe a prepared statement\n\n" -"You must pass the name (string) of the prepared statement.\n"; - -static PyObject * -conn_describe_prepared(connObject *self, PyObject *args) -{ - char *name; - Py_ssize_t name_length; - PGresult *result; - - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* reads args */ - if (!PyArg_ParseTuple(args, "s#", &name, &name_length)) { - PyErr_SetString(PyExc_TypeError, - "Method prepare() takes a string argument"); - return NULL; - } - - /* describe prepared statement */ - Py_BEGIN_ALLOW_THREADS - result = PQdescribePrepared(self->cnx, name); - Py_END_ALLOW_THREADS - if (result && PQresultStatus(result) == PGRES_COMMAND_OK) { - queryObject *query_obj = PyObject_New(queryObject, &queryType); - if (!query_obj) - return PyErr_NoMemory(); - Py_XINCREF(self); - query_obj->pgcnx = self; - query_obj->result = result; - query_obj->encoding = PQclientEncoding(self->cnx); - query_obj->current_row = 0; - query_obj->max_row = PQntuples(result); - query_obj->num_fields = PQnfields(result); - query_obj->col_types = get_col_types(result, query_obj->num_fields); - return (PyObject *) query_obj; - } - set_error(ProgrammingError, "Cannot describe prepared statement", - self->cnx, result); - if (result) - PQclear(result); - return NULL; /* error */ -} - -#ifdef DIRECT_ACCESS -static char conn_putline__doc__[] = -"putline(line) -- send a line directly to the backend"; - -/* Direct access function: putline. */ -static PyObject * -conn_putline(connObject *self, PyObject *args) -{ - char *line; - Py_ssize_t line_length; - - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* reads args */ - if (!PyArg_ParseTuple(args, "s#", &line, &line_length)) { - PyErr_SetString(PyExc_TypeError, - "Method putline() takes a string argument"); - return NULL; - } - - /* sends line to backend */ - if (PQputline(self->cnx, line)) { - PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); - return NULL; - } - Py_INCREF(Py_None); - return Py_None; -} - -/* Direct access function: getline. */ -static char conn_getline__doc__[] = -"getline() -- get a line directly from the backend"; - -static PyObject * -conn_getline(connObject *self, PyObject *noargs) -{ - char line[MAX_BUFFER_SIZE]; - PyObject *str = NULL; /* GCC */ - - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* gets line */ - switch (PQgetline(self->cnx, line, MAX_BUFFER_SIZE)) { - case 0: - str = PyStr_FromString(line); - break; - case 1: - PyErr_SetString(PyExc_MemoryError, "Buffer overflow"); - str = NULL; - break; - case EOF: - Py_INCREF(Py_None); - str = Py_None; - break; - } - - return str; -} - -/* Direct access function: end copy. */ -static char conn_endcopy__doc__[] = -"endcopy() -- synchronize client and server"; - -static PyObject * -conn_endcopy(connObject *self, PyObject *noargs) -{ - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* ends direct copy */ - if (PQendcopy(self->cnx)) { - PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); - return NULL; - } - Py_INCREF(Py_None); - return Py_None; -} -#endif /* DIRECT_ACCESS */ - - -/* Insert table */ -static char conn_inserttable__doc__[] = -"inserttable(table, data) -- insert list into table\n\n" -"The fields in the list must be in the same order as in the table.\n"; - -static PyObject * -conn_inserttable(connObject *self, PyObject *args) -{ - PGresult *result; - char *table, *buffer, *bufpt; - int encoding; - size_t bufsiz; - PyObject *list, *sublist, *item; - PyObject *(*getitem) (PyObject *, Py_ssize_t); - PyObject *(*getsubitem) (PyObject *, Py_ssize_t); - Py_ssize_t i, j, m, n; - - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* gets arguments */ - if (!PyArg_ParseTuple(args, "sO:filter", &table, &list)) { - PyErr_SetString( - PyExc_TypeError, - "Method inserttable() expects a string and a list as arguments"); - return NULL; - } - - /* checks list type */ - if (PyList_Check(list)) { - m = PyList_Size(list); - getitem = PyList_GetItem; - } - else if (PyTuple_Check(list)) { - m = PyTuple_Size(list); - getitem = PyTuple_GetItem; - } - else { - PyErr_SetString( - PyExc_TypeError, - "Method inserttable() expects a list or a tuple" - " as second argument"); - return NULL; - } - - /* allocate buffer */ - if (!(buffer = PyMem_Malloc(MAX_BUFFER_SIZE))) - return PyErr_NoMemory(); - - /* starts query */ - sprintf(buffer, "copy %s from stdin", table); - - Py_BEGIN_ALLOW_THREADS - result = PQexec(self->cnx, buffer); - Py_END_ALLOW_THREADS - - if (!result) { - PyMem_Free(buffer); - PyErr_SetString(PyExc_ValueError, PQerrorMessage(self->cnx)); - return NULL; - } - - encoding = PQclientEncoding(self->cnx); - - PQclear(result); - - n = 0; /* not strictly necessary but avoids warning */ - - /* feed table */ - for (i = 0; i < m; ++i) { - sublist = getitem(list, i); - if (PyTuple_Check(sublist)) { - j = PyTuple_Size(sublist); - getsubitem = PyTuple_GetItem; - } - else if (PyList_Check(sublist)) { - j = PyList_Size(sublist); - getsubitem = PyList_GetItem; - } - else { - PyErr_SetString( - PyExc_TypeError, - "The second argument must contain a tuple or a list"); - return NULL; - } - if (i) { - if (j != n) { - PyMem_Free(buffer); - PyErr_SetString( - PyExc_TypeError, - "Arrays contained in second arg must have same size"); - return NULL; - } - } - else { - n = j; /* never used before this assignment */ - } - - /* builds insert line */ - bufpt = buffer; - bufsiz = MAX_BUFFER_SIZE - 1; - - for (j = 0; j < n; ++j) { - if (j) { - *bufpt++ = '\t'; --bufsiz; - } - - item = getsubitem(sublist, j); - - /* convert item to string and append to buffer */ - if (item == Py_None) { - if (bufsiz > 2) { - *bufpt++ = '\\'; *bufpt++ = 'N'; - bufsiz -= 2; - } - else - bufsiz = 0; - } - else if (PyBytes_Check(item)) { - const char* t = PyBytes_AsString(item); - - while (*t && bufsiz) { - if (*t == '\\' || *t == '\t' || *t == '\n') { - *bufpt++ = '\\'; --bufsiz; - if (!bufsiz) break; - } - *bufpt++ = *t++; --bufsiz; - } - } - else if (PyUnicode_Check(item)) { - PyObject *s = get_encoded_string(item, encoding); - if (!s) { - PyMem_Free(buffer); - return NULL; /* pass the UnicodeEncodeError */ - } - else { - const char* t = PyBytes_AsString(s); - - while (*t && bufsiz) { - if (*t == '\\' || *t == '\t' || *t == '\n') { - *bufpt++ = '\\'; --bufsiz; - if (!bufsiz) break; - } - *bufpt++ = *t++; --bufsiz; - } - Py_DECREF(s); - } - } - else if (PyInt_Check(item) || PyLong_Check(item)) { - PyObject* s = PyObject_Str(item); - const char* t = PyStr_AsString(s); - - while (*t && bufsiz) { - *bufpt++ = *t++; --bufsiz; - } - Py_DECREF(s); - } - else { - PyObject* s = PyObject_Repr(item); - const char* t = PyStr_AsString(s); - - while (*t && bufsiz) { - if (*t == '\\' || *t == '\t' || *t == '\n') { - *bufpt++ = '\\'; --bufsiz; - if (!bufsiz) break; - } - *bufpt++ = *t++; --bufsiz; - } - Py_DECREF(s); - } - - if (bufsiz <= 0) { - PyMem_Free(buffer); return PyErr_NoMemory(); - } - - } - - *bufpt++ = '\n'; *bufpt = '\0'; - - /* sends data */ - if (PQputline(self->cnx, buffer)) { - PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); - PQendcopy(self->cnx); - PyMem_Free(buffer); - return NULL; - } - } - - /* ends query */ - if (PQputline(self->cnx, "\\.\n")) { - PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); - PQendcopy(self->cnx); - PyMem_Free(buffer); - return NULL; - } - - if (PQendcopy(self->cnx)) { - PyErr_SetString(PyExc_IOError, PQerrorMessage(self->cnx)); - PyMem_Free(buffer); - return NULL; - } - - PyMem_Free(buffer); - - /* no error : returns nothing */ - Py_INCREF(Py_None); - return Py_None; -} - -/* Get transaction state. */ -static char conn_transaction__doc__[] = -"transaction() -- return the current transaction status"; - -static PyObject * -conn_transaction(connObject *self, PyObject *noargs) -{ - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - return PyInt_FromLong(PQtransactionStatus(self->cnx)); -} - -/* Get parameter setting. */ -static char conn_parameter__doc__[] = -"parameter(name) -- look up a current parameter setting"; - -static PyObject * -conn_parameter(connObject *self, PyObject *args) -{ - const char *name; - - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* get query args */ - if (!PyArg_ParseTuple(args, "s", &name)) { - PyErr_SetString(PyExc_TypeError, - "Method parameter() takes a string as argument"); - return NULL; - } - - name = PQparameterStatus(self->cnx, name); - - if (name) - return PyStr_FromString(name); - - /* unknown parameter, return None */ - Py_INCREF(Py_None); - return Py_None; -} - -/* Get current date format. */ -static char conn_date_format__doc__[] = -"date_format() -- return the current date format"; - -static PyObject * -conn_date_format(connObject *self, PyObject *noargs) -{ - const char *fmt; - - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* check if the date format is cached in the connection */ - fmt = self->date_format; - if (!fmt) { - fmt = date_style_to_format(PQparameterStatus(self->cnx, "DateStyle")); - self->date_format = fmt; /* cache the result */ - } - - return PyStr_FromString(fmt); -} - -#ifdef ESCAPING_FUNCS - -/* Escape literal */ -static char conn_escape_literal__doc__[] = -"escape_literal(str) -- escape a literal constant for use within SQL"; - -static PyObject * -conn_escape_literal(connObject *self, PyObject *string) -{ - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ - - if (PyBytes_Check(string)) { - PyBytes_AsStringAndSize(string, &from, &from_length); - } - else if (PyUnicode_Check(string)) { - encoding = PQclientEncoding(self->cnx); - tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ - PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); - } - else { - PyErr_SetString( - PyExc_TypeError, - "Method escape_literal() expects a string as argument"); - return NULL; - } - - to = PQescapeLiteral(self->cnx, from, (size_t) from_length); - to_length = strlen(to); - - Py_XDECREF(tmp_obj); - - if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); - else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); - if (to) - PQfreemem(to); - return to_obj; -} - -/* Escape identifier */ -static char conn_escape_identifier__doc__[] = -"escape_identifier(str) -- escape an identifier for use within SQL"; - -static PyObject * -conn_escape_identifier(connObject *self, PyObject *string) -{ - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ - - if (PyBytes_Check(string)) { - PyBytes_AsStringAndSize(string, &from, &from_length); - } - else if (PyUnicode_Check(string)) { - encoding = PQclientEncoding(self->cnx); - tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ - PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); - } - else { - PyErr_SetString( - PyExc_TypeError, - "Method escape_identifier() expects a string as argument"); - return NULL; - } - - to = PQescapeIdentifier(self->cnx, from, (size_t) from_length); - to_length = strlen(to); - - Py_XDECREF(tmp_obj); - - if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); - else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); - if (to) - PQfreemem(to); - return to_obj; -} - -#endif /* ESCAPING_FUNCS */ - -/* Escape string */ -static char conn_escape_string__doc__[] = -"escape_string(str) -- escape a string for use within SQL"; - -static PyObject * -conn_escape_string(connObject *self, PyObject *string) -{ - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ - - if (PyBytes_Check(string)) { - PyBytes_AsStringAndSize(string, &from, &from_length); - } - else if (PyUnicode_Check(string)) { - encoding = PQclientEncoding(self->cnx); - tmp_obj = get_encoded_string(string, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ - PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); - } - else { - PyErr_SetString( - PyExc_TypeError, - "Method escape_string() expects a string as argument"); - return NULL; - } - - to_length = 2 * (size_t) from_length + 1; - if ((Py_ssize_t) to_length < from_length) { /* overflow */ - to_length = (size_t) from_length; - from_length = (from_length - 1)/2; - } - to = (char *) PyMem_Malloc(to_length); - to_length = PQescapeStringConn(self->cnx, - to, from, (size_t) from_length, NULL); - - Py_XDECREF(tmp_obj); - - if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length); - else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length, encoding); - PyMem_Free(to); - return to_obj; -} - -/* Escape bytea */ -static char conn_escape_bytea__doc__[] = -"escape_bytea(data) -- escape binary data for use within SQL as type bytea"; - -static PyObject * -conn_escape_bytea(connObject *self, PyObject *data) -{ - PyObject *tmp_obj = NULL, /* auxiliary string object */ - *to_obj; /* string object to return */ - char *from, /* our string argument as encoded string */ - *to; /* the result as encoded string */ - Py_ssize_t from_length; /* length of string */ - size_t to_length; /* length of result */ - int encoding = -1; /* client encoding */ - - if (PyBytes_Check(data)) { - PyBytes_AsStringAndSize(data, &from, &from_length); - } - else if (PyUnicode_Check(data)) { - encoding = PQclientEncoding(self->cnx); - tmp_obj = get_encoded_string(data, encoding); - if (!tmp_obj) return NULL; /* pass the UnicodeEncodeError */ - PyBytes_AsStringAndSize(tmp_obj, &from, &from_length); - } - else { - PyErr_SetString( - PyExc_TypeError, - "Method escape_bytea() expects a string as argument"); - return NULL; - } - - to = (char *) PQescapeByteaConn(self->cnx, - (unsigned char *) from, (size_t) from_length, &to_length); - - Py_XDECREF(tmp_obj); - - if (encoding == -1) - to_obj = PyBytes_FromStringAndSize(to, (Py_ssize_t) to_length - 1); - else - to_obj = get_decoded_string(to, (Py_ssize_t) to_length - 1, encoding); - if (to) - PQfreemem(to); - return to_obj; -} - -#ifdef LARGE_OBJECTS - -/* Constructor for large objects (internal use only) */ -static largeObject * -large_new(connObject *pgcnx, Oid oid) -{ - largeObject *large_obj; - - if (!(large_obj = PyObject_New(largeObject, &largeType))) { - return NULL; - } - - Py_XINCREF(pgcnx); - large_obj->pgcnx = pgcnx; - large_obj->lo_fd = -1; - large_obj->lo_oid = oid; - - return large_obj; -} - -/* Create large object. */ -static char conn_locreate__doc__[] = -"locreate(mode) -- create a new large object in the database"; - -static PyObject * -conn_locreate(connObject *self, PyObject *args) -{ - int mode; - Oid lo_oid; - - /* checks validity */ - if (!_check_cnx_obj(self)) { - return NULL; - } - - /* gets arguments */ - if (!PyArg_ParseTuple(args, "i", &mode)) { - PyErr_SetString(PyExc_TypeError, - "Method locreate() takes an integer argument"); - return NULL; - } - - /* creates large object */ - lo_oid = lo_creat(self->cnx, mode); - if (lo_oid == 0) { - set_error_msg(OperationalError, "Can't create large object"); - return NULL; - } - - return (PyObject *) large_new(self, lo_oid); -} - -/* Init from already known oid. */ -static char conn_getlo__doc__[] = -"getlo(oid) -- create a large object instance for the specified oid"; - -static PyObject * -conn_getlo(connObject *self, PyObject *args) -{ - int oid; - Oid lo_oid; - - /* checks validity */ - if (!_check_cnx_obj(self)) { - return NULL; - } - - /* gets arguments */ - if (!PyArg_ParseTuple(args, "i", &oid)) { - PyErr_SetString(PyExc_TypeError, - "Method getlo() takes an integer argument"); - return NULL; - } - - lo_oid = (Oid) oid; - if (lo_oid == 0) { - PyErr_SetString(PyExc_ValueError, "The object oid can't be null"); - return NULL; - } - - /* creates object */ - return (PyObject *) large_new(self, lo_oid); -} - -/* Import unix file. */ -static char conn_loimport__doc__[] = -"loimport(name) -- create a new large object from specified file"; - -static PyObject * -conn_loimport(connObject *self, PyObject *args) -{ - char *name; - Oid lo_oid; - - /* checks validity */ - if (!_check_cnx_obj(self)) { - return NULL; - } - - /* gets arguments */ - if (!PyArg_ParseTuple(args, "s", &name)) { - PyErr_SetString(PyExc_TypeError, - "Method loimport() takes a string argument"); - return NULL; - } - - /* imports file and checks result */ - lo_oid = lo_import(self->cnx, name); - if (lo_oid == 0) { - set_error_msg(OperationalError, "Can't create large object"); - return NULL; - } - - return (PyObject *) large_new(self, lo_oid); -} - -#endif /* LARGE_OBJECTS */ - -/* Reset connection. */ -static char conn_reset__doc__[] = -"reset() -- reset connection with current parameters\n\n" -"All derived queries and large objects derived from this connection\n" -"will not be usable after this call.\n"; - -static PyObject * -conn_reset(connObject *self, PyObject *noargs) -{ - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* resets the connection */ - PQreset(self->cnx); - Py_INCREF(Py_None); - return Py_None; -} - -/* Cancel current command. */ -static char conn_cancel__doc__[] = -"cancel() -- abandon processing of the current command"; - -static PyObject * -conn_cancel(connObject *self, PyObject *noargs) -{ - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* request that the server abandon processing of the current command */ - return PyInt_FromLong((long) PQrequestCancel(self->cnx)); -} - -/* Get connection socket. */ -static char conn_fileno__doc__[] = -"fileno() -- return database connection socket file handle"; - -static PyObject * -conn_fileno(connObject *self, PyObject *noargs) -{ - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - return PyInt_FromLong((long) PQsocket(self->cnx)); -} - -/* Set external typecast callback function. */ -static char conn_set_cast_hook__doc__[] = -"set_cast_hook(func) -- set a fallback typecast function"; - -static PyObject * -conn_set_cast_hook(connObject *self, PyObject *func) -{ - PyObject *ret = NULL; - - if (func == Py_None) { - Py_XDECREF(self->cast_hook); - self->cast_hook = NULL; - Py_INCREF(Py_None); ret = Py_None; - } - else if (PyCallable_Check(func)) { - Py_XINCREF(func); Py_XDECREF(self->cast_hook); - self->cast_hook = func; - Py_INCREF(Py_None); ret = Py_None; - } - else { - PyErr_SetString(PyExc_TypeError, - "Method set_cast_hook() expects" - " a callable or None as argument"); - } - - return ret; -} - -/* Get notice receiver callback function. */ -static char conn_get_cast_hook__doc__[] = -"get_cast_hook() -- get the fallback typecast function"; - -static PyObject * -conn_get_cast_hook(connObject *self, PyObject *noargs) -{ - PyObject *ret = self->cast_hook;; - - if (!ret) - ret = Py_None; - Py_INCREF(ret); - - return ret; -} - -/* Set notice receiver callback function. */ -static char conn_set_notice_receiver__doc__[] = -"set_notice_receiver(func) -- set the current notice receiver"; - -static PyObject * -conn_set_notice_receiver(connObject *self, PyObject *func) -{ - PyObject *ret = NULL; - - if (func == Py_None) { - Py_XDECREF(self->notice_receiver); - self->notice_receiver = NULL; - Py_INCREF(Py_None); ret = Py_None; - } - else if (PyCallable_Check(func)) { - Py_XINCREF(func); Py_XDECREF(self->notice_receiver); - self->notice_receiver = func; - PQsetNoticeReceiver(self->cnx, notice_receiver, self); - Py_INCREF(Py_None); ret = Py_None; - } - else { - PyErr_SetString(PyExc_TypeError, - "Method set_notice_receiver() expects" - " a callable or None as argument"); - } - - return ret; -} - -/* Get notice receiver callback function. */ -static char conn_get_notice_receiver__doc__[] = -"get_notice_receiver() -- get the current notice receiver"; - -static PyObject * -conn_get_notice_receiver(connObject *self, PyObject *noargs) -{ - PyObject *ret = self->notice_receiver; - - if (!ret) - ret = Py_None; - Py_INCREF(ret); - - return ret; -} - -/* Close without deleting. */ -static char conn_close__doc__[] = -"close() -- close connection\n\n" -"All instances of the connection object and derived objects\n" -"(queries and large objects) can no longer be used after this call.\n"; - -static PyObject * -conn_close(connObject *self, PyObject *noargs) -{ - /* connection object cannot already be closed */ - if (!self->cnx) { - set_error_msg(InternalError, "Connection already closed"); - return NULL; - } - - Py_BEGIN_ALLOW_THREADS - PQfinish(self->cnx); - Py_END_ALLOW_THREADS - - self->cnx = NULL; - Py_INCREF(Py_None); - return Py_None; -} - -/* Get asynchronous notify. */ -static char conn_get_notify__doc__[] = -"getnotify() -- get database notify for this connection"; - -static PyObject * -conn_get_notify(connObject *self, PyObject *noargs) -{ - PGnotify *notify; - - if (!self->cnx) { - PyErr_SetString(PyExc_TypeError, "Connection is not valid"); - return NULL; - } - - /* checks for NOTIFY messages */ - PQconsumeInput(self->cnx); - - if (!(notify = PQnotifies(self->cnx))) { - Py_INCREF(Py_None); - return Py_None; - } - else { - PyObject *notify_result, *tmp; - - if (!(tmp = PyStr_FromString(notify->relname))) { - return NULL; - } - - if (!(notify_result = PyTuple_New(3))) { - return NULL; - } - - PyTuple_SET_ITEM(notify_result, 0, tmp); - - if (!(tmp = PyInt_FromLong(notify->be_pid))) { - Py_DECREF(notify_result); - return NULL; - } - - PyTuple_SET_ITEM(notify_result, 1, tmp); - - /* extra exists even in old versions that did not support it */ - if (!(tmp = PyStr_FromString(notify->extra))) { - Py_DECREF(notify_result); - return NULL; - } - - PyTuple_SET_ITEM(notify_result, 2, tmp); - - PQfreemem(notify); - - return notify_result; - } -} - -/* Get the list of connection attributes. */ -static PyObject * -conn_dir(connObject *self, PyObject *noargs) -{ - PyObject *attrs; - - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[sssssssssssss]", - "host", "port", "db", "options", "error", "status", "user", - "protocol_version", "server_version", "socket", "backend_pid", - "ssl_in_use", "ssl_attributes"); - - return attrs; -} - -/* Connection object methods */ -static struct PyMethodDef conn_methods[] = { - {"__dir__", (PyCFunction) conn_dir, METH_NOARGS, NULL}, - - {"source", (PyCFunction) conn_source, - METH_NOARGS, conn_source__doc__}, - {"query", (PyCFunction) conn_query, - METH_VARARGS, conn_query__doc__}, - {"query_prepared", (PyCFunction) conn_query_prepared, - METH_VARARGS, conn_query_prepared__doc__}, - {"prepare", (PyCFunction) conn_prepare, - METH_VARARGS, conn_prepare__doc__}, - {"describe_prepared", (PyCFunction) conn_describe_prepared, - METH_VARARGS, conn_describe_prepared__doc__}, - {"reset", (PyCFunction) conn_reset, - METH_NOARGS, conn_reset__doc__}, - {"cancel", (PyCFunction) conn_cancel, - METH_NOARGS, conn_cancel__doc__}, - {"close", (PyCFunction) conn_close, - METH_NOARGS, conn_close__doc__}, - {"fileno", (PyCFunction) conn_fileno, - METH_NOARGS, conn_fileno__doc__}, - {"get_cast_hook", (PyCFunction) conn_get_cast_hook, - METH_NOARGS, conn_get_cast_hook__doc__}, - {"set_cast_hook", (PyCFunction) conn_set_cast_hook, - METH_O, conn_set_cast_hook__doc__}, - {"get_notice_receiver", (PyCFunction) conn_get_notice_receiver, - METH_NOARGS, conn_get_notice_receiver__doc__}, - {"set_notice_receiver", (PyCFunction) conn_set_notice_receiver, - METH_O, conn_set_notice_receiver__doc__}, - {"getnotify", (PyCFunction) conn_get_notify, - METH_NOARGS, conn_get_notify__doc__}, - {"inserttable", (PyCFunction) conn_inserttable, - METH_VARARGS, conn_inserttable__doc__}, - {"transaction", (PyCFunction) conn_transaction, - METH_NOARGS, conn_transaction__doc__}, - {"parameter", (PyCFunction) conn_parameter, - METH_VARARGS, conn_parameter__doc__}, - {"date_format", (PyCFunction) conn_date_format, - METH_NOARGS, conn_date_format__doc__}, - -#ifdef ESCAPING_FUNCS - {"escape_literal", (PyCFunction) conn_escape_literal, - METH_O, conn_escape_literal__doc__}, - {"escape_identifier", (PyCFunction) conn_escape_identifier, - METH_O, conn_escape_identifier__doc__}, -#endif /* ESCAPING_FUNCS */ - {"escape_string", (PyCFunction) conn_escape_string, - METH_O, conn_escape_string__doc__}, - {"escape_bytea", (PyCFunction) conn_escape_bytea, - METH_O, conn_escape_bytea__doc__}, - -#ifdef DIRECT_ACCESS - {"putline", (PyCFunction) conn_putline, - METH_VARARGS, conn_putline__doc__}, - {"getline", (PyCFunction) conn_getline, - METH_NOARGS, conn_getline__doc__}, - {"endcopy", (PyCFunction) conn_endcopy, - METH_NOARGS, conn_endcopy__doc__}, -#endif /* DIRECT_ACCESS */ - -#ifdef LARGE_OBJECTS - {"locreate", (PyCFunction) conn_locreate, - METH_VARARGS, conn_locreate__doc__}, - {"getlo", (PyCFunction) conn_getlo, - METH_VARARGS, conn_getlo__doc__}, - {"loimport", (PyCFunction) conn_loimport, - METH_VARARGS, conn_loimport__doc__}, -#endif /* LARGE_OBJECTS */ - - {NULL, NULL} /* sentinel */ -}; - -static char conn__doc__[] = "PostgreSQL connection object"; - -/* Connection type definition */ -static PyTypeObject connType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.Connection", /* tp_name */ - sizeof(connObject), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor) conn_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_reserved */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - (getattrofunc) conn_getattr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - conn__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - conn_methods, /* tp_methods */ -}; diff --git a/pgdb.py b/pgdb.py deleted file mode 100644 index 021ba444..00000000 --- a/pgdb.py +++ /dev/null @@ -1,1879 +0,0 @@ -#!/usr/bin/python -# -# PyGreSQL - a Python interface for the PostgreSQL database. -# -# This file contains the DB-API 2 compatible pgdb module. -# -# Copyright (c) 2020 by the PyGreSQL Development Team -# -# Please see the LICENSE.TXT file for specific restrictions. - -"""pgdb - DB-API 2.0 compliant module for PyGreSQL. - -(c) 1999, Pascal Andre . -See package documentation for further information on copyright. - -Inline documentation is sparse. -See DB-API 2.0 specification for usage information: -http://www.python.org/peps/pep-0249.html - -Basic usage: - - pgdb.connect(connect_string) # open a connection - # connect_string = 'host:database:user:password:opt' - # All parts are optional. You may also pass host through - # password as keyword arguments. To pass a port, - # pass it in the host keyword parameter: - connection = pgdb.connect(host='localhost:5432') - - cursor = connection.cursor() # open a cursor - - cursor.execute(query[, params]) - # Execute a query, binding params (a dictionary) if they are - # passed. The binding syntax is the same as the % operator - # for dictionaries, and no quoting is done. - - cursor.executemany(query, list of params) - # Execute a query many times, binding each param dictionary - # from the list. - - cursor.fetchone() # fetch one row, [value, value, ...] - - cursor.fetchall() # fetch all rows, [[value, value, ...], ...] - - cursor.fetchmany([size]) - # returns size or cursor.arraysize number of rows, - # [[value, value, ...], ...] from result set. - # Default cursor.arraysize is 1. - - cursor.description # returns information about the columns - # [(column_name, type_name, display_size, - # internal_size, precision, scale, null_ok), ...] - # Note that display_size, precision, scale and null_ok - # are not implemented. - - cursor.rowcount # number of rows available in the result set - # Available after a call to execute. - - connection.commit() # commit transaction - - connection.rollback() # or rollback transaction - - cursor.close() # close the cursor - - connection.close() # close the connection -""" - -from __future__ import print_function, division - -try: - from _pg import * -except ImportError: - import os - import sys - # see https://docs.python.org/3/whatsnew/3.8.html#ctypes - if os.name == 'nt' and sys.version_info >= (3, 8): - for path in os.environ["PATH"].split(os.pathsep): - if os.path.exists(os.path.join(path, 'libpq.dll')): - with os.add_dll_directory(os.path.abspath(path)): - from _pg import * - break - else: - raise - else: - raise - -__version__ = version - -__all__ = [ - 'Connection', 'Cursor', - 'Date', 'Time', 'Timestamp', - 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks', - 'Binary', 'Interval', 'Uuid', - 'Hstore', 'Json', 'Literal', 'Type', - 'STRING', 'BINARY', 'NUMBER', 'DATETIME', 'ROWID', 'BOOL', - 'SMALLINT', 'INTEGER', 'LONG', 'FLOAT', 'NUMERIC', 'MONEY', - 'DATE', 'TIME', 'TIMESTAMP', 'INTERVAL', - 'UUID', 'HSTORE', 'JSON', 'ARRAY', 'RECORD', - 'Error', 'Warning', - 'InterfaceError', 'DatabaseError', 'DataError', 'OperationalError', - 'IntegrityError', 'ProgrammingError', 'NotSupportedError', - 'apilevel', 'connect', 'paramstyle', 'threadsafety', - 'get_typecast', 'set_typecast', 'reset_typecast', - 'version', '__version__'] - -from datetime import date, time, datetime, timedelta, tzinfo -from time import localtime -from decimal import Decimal -from uuid import UUID as Uuid -from math import isnan, isinf -try: - from collections.abc import Iterable -except ImportError: # Python < 3.3 - from collections import Iterable -from collections import namedtuple -from keyword import iskeyword -from functools import partial -from re import compile as regex -from json import loads as jsondecode, dumps as jsonencode - -try: # noinspection PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - -try: # noinspection PyUnresolvedReferences - basestring -except NameError: # Python >= 3.0 - basestring = (str, bytes) - -try: - from functools import lru_cache -except ImportError: # Python < 3.2 - from functools import update_wrapper - try: - from _thread import RLock - except ImportError: - class RLock: # for builds without threads - def __enter__(self): pass - - def __exit__(self, exctype, excinst, exctb): pass - - def lru_cache(maxsize=128): - """Simplified functools.lru_cache decorator for one argument.""" - - def decorator(function): - sentinel = object() - cache = {} - get = cache.get - lock = RLock() - root = [] - root_full = [root, False] - root[:] = [root, root, None, None] - - if maxsize == 0: - - def wrapper(arg): - res = function(arg) - return res - - elif maxsize is None: - - def wrapper(arg): - res = get(arg, sentinel) - if res is not sentinel: - return res - res = function(arg) - cache[arg] = res - return res - - else: - - def wrapper(arg): - with lock: - link = get(arg) - if link is not None: - root = root_full[0] - prev, next, _arg, res = link - prev[1] = next - next[0] = prev - last = root[0] - last[1] = root[0] = link - link[0] = last - link[1] = root - return res - res = function(arg) - with lock: - root, full = root_full - if arg in cache: - pass - elif full: - oldroot = root - oldroot[2] = arg - oldroot[3] = res - root = root_full[0] = oldroot[1] - oldarg = root[2] - oldres = root[3] # keep reference - root[2] = root[3] = None - del cache[oldarg] - cache[arg] = oldroot - else: - last = root[0] - link = [last, root, arg, res] - last[1] = root[0] = cache[arg] = link - if len(cache) >= maxsize: - root_full[1] = True - return res - - wrapper.__wrapped__ = function - return update_wrapper(wrapper, function) - - return decorator - - -### Module Constants - -# compliant with DB API 2.0 -apilevel = '2.0' - -# module may be shared, but not connections -threadsafety = 1 - -# this module use extended python format codes -paramstyle = 'pyformat' - -# shortcut methods have been excluded from DB API 2 and -# are not recommended by the DB SIG, but they can be handy -shortcutmethods = 1 - - -### Internal Type Handling - -try: - from inspect import signature -except ImportError: # Python < 3.3 - from inspect import getargspec - - def get_args(func): - return getargspec(func).args -else: - - def get_args(func): - return list(signature(func).parameters) - -try: - from datetime import timezone -except ImportError: # Python < 3.2 - - class timezone(tzinfo): - """Simple timezone implementation.""" - - def __init__(self, offset, name=None): - self.offset = offset - if not name: - minutes = self.offset.days * 1440 + self.offset.seconds // 60 - if minutes < 0: - hours, minutes = divmod(-minutes, 60) - hours = -hours - else: - hours, minutes = divmod(minutes, 60) - name = 'UTC%+03d:%02d' % (hours, minutes) - self.name = name - - def utcoffset(self, dt): - return self.offset - - def tzname(self, dt): - return self.name - - def dst(self, dt): - return None - - timezone.utc = timezone(timedelta(0), 'UTC') - - _has_timezone = False -else: - _has_timezone = True - -# time zones used in Postgres timestamptz output -_timezones = dict(CET='+0100', EET='+0200', EST='-0500', - GMT='+0000', HST='-1000', MET='+0100', MST='-0700', - UCT='+0000', UTC='+0000', WET='+0000') - - -def _timezone_as_offset(tz): - if tz.startswith(('+', '-')): - if len(tz) < 5: - return tz + '00' - return tz.replace(':', '') - return _timezones.get(tz, '+0000') - - -def _get_timezone(tz): - tz = _timezone_as_offset(tz) - minutes = 60 * int(tz[1:3]) + int(tz[3:5]) - if tz[0] == '-': - minutes = -minutes - return timezone(timedelta(minutes=minutes), tz) - - -def decimal_type(decimal_type=None): - """Get or set global type to be used for decimal values. - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call con.type_cache.reset_typecast(). - """ - global Decimal - if decimal_type is not None: - Decimal = decimal_type - set_typecast('numeric', decimal_type) - return Decimal - - -def cast_bool(value): - """Cast boolean value in database format to bool.""" - if value: - return value[0] in ('t', 'T') - - -def cast_money(value): - """Cast money value in database format to Decimal.""" - if value: - value = value.replace('(', '-') - return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) - - -def cast_int2vector(value): - """Cast an int2vector value.""" - return [int(v) for v in value.split()] - - -def cast_date(value, connection): - """Cast a date value.""" - # The output format depends on the server setting DateStyle. The default - # setting ISO and the setting for German are actually unambiguous. The - # order of days and months in the other two settings is however ambiguous, - # so at least here we need to consult the setting to properly parse values. - if value == '-infinity': - return date.min - if value == 'infinity': - return date.max - value = value.split() - if value[-1] == 'BC': - return date.min - value = value[0] - if len(value) > 10: - return date.max - fmt = connection.date_format() - return datetime.strptime(value, fmt).date() - - -def cast_time(value): - """Cast a time value.""" - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - return datetime.strptime(value, fmt).time() - - -_re_timezone = regex('(.*)([+-].*)') - - -def cast_timetz(value): - """Cast a timetz value.""" - tz = _re_timezone.match(value) - if tz: - value, tz = tz.groups() - else: - tz = '+0000' - fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' - if _has_timezone: - value += _timezone_as_offset(tz) - fmt += '%z' - return datetime.strptime(value, fmt).timetz() - return datetime.strptime(value, fmt).timetz().replace( - tzinfo=_get_timezone(tz)) - - -def cast_timestamp(value, connection): - """Cast a timestamp value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - value = value.split() - if value[-1] == 'BC': - return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:5] - if len(value[3]) > 4: - return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] - else: - if len(value[0]) > 10: - return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - return datetime.strptime(' '.join(value), ' '.join(fmt)) - - -def cast_timestamptz(value, connection): - """Cast a timestamptz value.""" - if value == '-infinity': - return datetime.min - if value == 'infinity': - return datetime.max - value = value.split() - if value[-1] == 'BC': - return datetime.min - fmt = connection.date_format() - if fmt.endswith('-%Y') and len(value) > 2: - value = value[1:] - if len(value[3]) > 4: - return datetime.max - fmt = ['%d %b' if fmt.startswith('%d') else '%b %d', - '%H:%M:%S.%f' if len(value[2]) > 8 else '%H:%M:%S', '%Y'] - value, tz = value[:-1], value[-1] - else: - if fmt.startswith('%Y-'): - tz = _re_timezone.match(value[1]) - if tz: - value[1], tz = tz.groups() - else: - tz = '+0000' - else: - value, tz = value[:-1], value[-1] - if len(value[0]) > 10: - return datetime.max - fmt = [fmt, '%H:%M:%S.%f' if len(value[1]) > 8 else '%H:%M:%S'] - if _has_timezone: - value.append(_timezone_as_offset(tz)) - fmt.append('%z') - return datetime.strptime(' '.join(value), ' '.join(fmt)) - return datetime.strptime(' '.join(value), ' '.join(fmt)).replace( - tzinfo=_get_timezone(tz)) - - -_re_interval_sql_standard = regex( - '(?:([+-])?([0-9]+)-([0-9]+) ?)?' - '(?:([+-]?[0-9]+)(?!:) ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres = regex( - '(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') - -_re_interval_postgres_verbose = regex( - '@ ?(?:([+-]?[0-9]+) ?years? ?)?' - '(?:([+-]?[0-9]+) ?mons? ?)?' - '(?:([+-]?[0-9]+) ?days? ?)?' - '(?:([+-]?[0-9]+) ?hours? ?)?' - '(?:([+-]?[0-9]+) ?mins? ?)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') - -_re_interval_iso_8601 = regex( - 'P(?:([+-]?[0-9]+)Y)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-]?[0-9]+)D)?' - '(?:T(?:([+-]?[0-9]+)H)?' - '(?:([+-]?[0-9]+)M)?' - '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') - - -def cast_interval(value): - """Cast an interval value.""" - # The output format depends on the server setting IntervalStyle, but it's - # not necessary to consult this setting to parse it. It's faster to just - # check all possible formats, and there is no ambiguity here. - m = _re_interval_iso_8601.match(value) - if m: - m = [d or '0' for d in m.groups()] - secs_ago = m.pop(5) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m - if secs_ago: - secs = -secs - usecs = -usecs - else: - m = _re_interval_postgres_verbose.match(value) - if m: - m, ago = [d or '0' for d in m.groups()[:8]], m.group(9) - secs_ago = m.pop(5) == '-' - m = [-int(d) for d in m] if ago else [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m - if secs_ago: - secs = - secs - usecs = -usecs - else: - m = _re_interval_postgres.match(value) - if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - m = _re_interval_sql_standard.match(value) - if m and any(m.groups()): - m = [d or '0' for d in m.groups()] - years_ago = m.pop(0) == '-' - hours_ago = m.pop(3) == '-' - m = [int(d) for d in m] - years, mons, days, hours, mins, secs, usecs = m - if years_ago: - years = -years - mons = -mons - if hours_ago: - hours = -hours - mins = -mins - secs = -secs - usecs = -usecs - else: - raise ValueError('Cannot parse interval: %s' % value) - days += 365 * years + 30 * mons - return timedelta(days=days, hours=hours, minutes=mins, - seconds=secs, microseconds=usecs) - - -class Typecasts(dict): - """Dictionary mapping database types to typecast functions. - - The cast functions get passed the string representation of a value in - the database which they need to convert to a Python object. The - passed string will never be None since NULL values are already - handled before the cast function is called. - """ - - # the default cast functions - # (str functions are ignored but have been added for faster access) - defaults = {'char': str, 'bpchar': str, 'name': str, - 'text': str, 'varchar': str, - 'bool': cast_bool, 'bytea': unescape_bytea, - 'int2': int, 'int4': int, 'serial': int, 'int8': long, 'oid': int, - 'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode, - 'float4': float, 'float8': float, - 'numeric': Decimal, 'money': cast_money, - 'date': cast_date, 'interval': cast_interval, - 'time': cast_time, 'timetz': cast_timetz, - 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, - 'int2vector': cast_int2vector, 'uuid': Uuid, - 'anyarray': cast_array, 'record': cast_record} - - connection = None # will be set in local connection specific instances - - def __missing__(self, typ): - """Create a cast function if it is not cached. - - Note that this class never raises a KeyError, - but returns None when no special cast function exists. - """ - if not isinstance(typ, str): - raise TypeError('Invalid type: %s' % typ) - cast = self.defaults.get(typ) - if cast: - # store default for faster access - cast = self._add_connection(cast) - self[typ] = cast - elif typ.startswith('_'): - # create array cast - base_cast = self[typ[1:]] - cast = self.create_array_cast(base_cast) - if base_cast: - # store only if base type exists - self[typ] = cast - return cast - - @staticmethod - def _needs_connection(func): - """Check if a typecast function needs a connection argument.""" - try: - args = get_args(func) - except (TypeError, ValueError): - return False - else: - return 'connection' in args[1:] - - def _add_connection(self, cast): - """Add a connection argument to the typecast function if necessary.""" - if not self.connection or not self._needs_connection(cast): - return cast - return partial(cast, connection=self.connection) - - def get(self, typ, default=None): - """Get the typecast function for the given database type.""" - return self[typ] or default - - def set(self, typ, cast): - """Set a typecast function for the specified database type(s).""" - if isinstance(typ, basestring): - typ = [typ] - if cast is None: - for t in typ: - self.pop(t, None) - self.pop('_%s' % t, None) - else: - if not callable(cast): - raise TypeError("Cast parameter must be callable") - for t in typ: - self[t] = self._add_connection(cast) - self.pop('_%s' % t, None) - - def reset(self, typ=None): - """Reset the typecasts for the specified type(s) to their defaults. - - When no type is specified, all typecasts will be reset. - """ - defaults = self.defaults - if typ is None: - self.clear() - self.update(defaults) - else: - if isinstance(typ, basestring): - typ = [typ] - for t in typ: - cast = defaults.get(t) - if cast: - self[t] = self._add_connection(cast) - t = '_%s' % t - cast = defaults.get(t) - if cast: - self[t] = self._add_connection(cast) - else: - self.pop(t, None) - else: - self.pop(t, None) - self.pop('_%s' % t, None) - - def create_array_cast(self, basecast): - """Create an array typecast for the given base cast.""" - cast_array = self['anyarray'] - def cast(v): - return cast_array(v, basecast) - return cast - - def create_record_cast(self, name, fields, casts): - """Create a named record typecast for the given fields and casts.""" - cast_record = self['record'] - record = namedtuple(name, fields) - def cast(v): - return record(*cast_record(v, casts)) - return cast - - -_typecasts = Typecasts() # this is the global typecast dictionary - - -def get_typecast(typ): - """Get the global typecast function for the given database type(s).""" - return _typecasts.get(typ) - - -def set_typecast(typ, cast): - """Set a global typecast function for the given database type(s). - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call con.type_cache.reset_typecast(). - """ - _typecasts.set(typ, cast) - - -def reset_typecast(typ=None): - """Reset the global typecasts for the given type(s) to their default. - - When no type is specified, all typecasts will be reset. - - Note that connections cache cast functions. To be sure a global change - is picked up by a running connection, call con.type_cache.reset_typecast(). - """ - _typecasts.reset(typ) - - -class LocalTypecasts(Typecasts): - """Map typecasts, including local composite types, to cast functions.""" - - defaults = _typecasts - - connection = None # will be set in a connection specific instance - - def __missing__(self, typ): - """Create a cast function if it is not cached.""" - if typ.startswith('_'): - base_cast = self[typ[1:]] - cast = self.create_array_cast(base_cast) - if base_cast: - self[typ] = cast - else: - cast = self.defaults.get(typ) - if cast: - cast = self._add_connection(cast) - self[typ] = cast - else: - fields = self.get_fields(typ) - if fields: - casts = [self[field.type] for field in fields] - fields = [field.name for field in fields] - cast = self.create_record_cast(typ, fields, casts) - self[typ] = cast - return cast - - def get_fields(self, typ): - """Return the fields for the given record type. - - This method will be replaced with a method that looks up the fields - using the type cache of the connection. - """ - return [] - - -class TypeCode(str): - """Class representing the type_code used by the DB-API 2.0. - - TypeCode objects are strings equal to the PostgreSQL type name, - but carry some additional information. - """ - - @classmethod - def create(cls, oid, name, len, type, category, delim, relid): - """Create a type code for a PostgreSQL data type.""" - self = cls(name) - self.oid = oid - self.len = len - self.type = type - self.category = category - self.delim = delim - self.relid = relid - return self - -FieldInfo = namedtuple('FieldInfo', ['name', 'type']) - - -class TypeCache(dict): - """Cache for database types. - - This cache maps type OIDs and names to TypeCode strings containing - important information on the associated database type. - """ - - def __init__(self, cnx): - """Initialize type cache for connection.""" - super(TypeCache, self).__init__() - self._escape_string = cnx.escape_string - self._src = cnx.source() - self._typecasts = LocalTypecasts() - self._typecasts.get_fields = self.get_fields - self._typecasts.connection = cnx - if cnx.server_version < 80400: - # older remote databases (not officially supported) - self._query_pg_type = ("SELECT oid, typname," - " typlen, typtype, null as typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) %s") - else: - self._query_pg_type = ("SELECT oid, typname," - " typlen, typtype, typcategory, typdelim, typrelid" - " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) %s") - - def __missing__(self, key): - """Get the type info from the database if it is not cached.""" - if isinstance(key, int): - oid = key - else: - if '.' not in key and '"' not in key: - key = '"%s"' % (key,) - oid = "'%s'::regtype" % (self._escape_string(key),) - try: - self._src.execute(self._query_pg_type % (oid,)) - except ProgrammingError: - res = None - else: - res = self._src.fetch(1) - if not res: - raise KeyError('Type %s could not be found' % (key,)) - res = res[0] - type_code = TypeCode.create(int(res[0]), res[1], - int(res[2]), res[3], res[4], res[5], int(res[6])) - self[type_code.oid] = self[str(type_code)] = type_code - return type_code - - def get(self, key, default=None): - """Get the type even if it is not cached.""" - try: - return self[key] - except KeyError: - return default - - def get_fields(self, typ): - """Get the names and types of the fields of composite types.""" - if not isinstance(typ, TypeCode): - typ = self.get(typ) - if not typ: - return None - if not typ.relid: - return None # this type is not composite - self._src.execute("SELECT attname, atttypid" - " FROM pg_catalog.pg_attribute" - " WHERE attrelid OPERATOR(pg_catalog.=) %s" - " AND attnum OPERATOR(pg_catalog.>) 0" - " AND NOT attisdropped ORDER BY attnum" % (typ.relid,)) - return [FieldInfo(name, self.get(int(oid))) - for name, oid in self._src.fetch(-1)] - - def get_typecast(self, typ): - """Get the typecast function for the given database type.""" - return self._typecasts.get(typ) - - def set_typecast(self, typ, cast): - """Set a typecast function for the specified database type(s).""" - self._typecasts.set(typ, cast) - - def reset_typecast(self, typ=None): - """Reset the typecast function for the specified database type(s).""" - self._typecasts.reset(typ) - - def typecast(self, value, typ): - """Cast the given value according to the given database type.""" - if value is None: - # for NULL values, no typecast is necessary - return None - cast = self.get_typecast(typ) - if not cast or cast is str: - # no typecast is necessary - return value - return cast(value) - - -class _quotedict(dict): - """Dictionary with auto quoting of its items. - - The quote attribute must be set to the desired quote function. - """ - - def __getitem__(self, key): - return self.quote(super(_quotedict, self).__getitem__(key)) - - -### Error Messages - -def _db_error(msg, cls=DatabaseError): - """Return DatabaseError with empty sqlstate attribute.""" - error = cls(msg) - error.sqlstate = None - return error - - -def _op_error(msg): - """Return OperationalError.""" - return _db_error(msg, OperationalError) - - -### Row Tuples - -_re_fieldname = regex('^[A-Za-z][_a-zA-Z0-9]*$') - -# The result rows for database operations are returned as named tuples -# by default. Since creating namedtuple classes is a somewhat expensive -# operation, we cache up to 1024 of these classes by default. - -@lru_cache(maxsize=1024) -def _row_factory(names): - """Get a namedtuple factory for row results with the given names.""" - try: - try: - return namedtuple('Row', names, rename=True)._make - except TypeError: # Python 2.6 and 3.0 do not support rename - names = [v if _re_fieldname.match(v) and not iskeyword(v) - else 'column_%d' % (n,) - for n, v in enumerate(names)] - return namedtuple('Row', names)._make - except ValueError: # there is still a problem with the field names - names = ['column_%d' % (n,) for n in range(len(names))] - return namedtuple('Row', names)._make - - -def set_row_factory_size(maxsize): - """Change the size of the namedtuple factory cache. - - If maxsize is set to None, the cache can grow without bound. - """ - global _row_factory - _row_factory = lru_cache(maxsize)(_row_factory.__wrapped__) - - -### Cursor Object - -class Cursor(object): - """Cursor object.""" - - def __init__(self, dbcnx): - """Create a cursor object for the database connection.""" - self.connection = self._dbcnx = dbcnx - self._cnx = dbcnx._cnx - self.type_cache = dbcnx.type_cache - self._src = self._cnx.source() - # the official attribute for describing the result columns - self._description = None - if self.row_factory is Cursor.row_factory: - # the row factory needs to be determined dynamically - self.row_factory = None - else: - self.build_row_factory = None - self.rowcount = -1 - self.arraysize = 1 - self.lastrowid = None - - def __iter__(self): - """Make cursor compatible to the iteration protocol.""" - return self - - def __enter__(self): - """Enter the runtime context for the cursor object.""" - return self - - def __exit__(self, et, ev, tb): - """Exit the runtime context for the cursor object.""" - self.close() - - def _quote(self, value): - """Quote value depending on its type.""" - if value is None: - return 'NULL' - if isinstance(value, (Hstore, Json)): - value = str(value) - if isinstance(value, basestring): - if isinstance(value, Binary): - value = self._cnx.escape_bytea(value) - if bytes is not str: # Python >= 3.0 - value = value.decode('ascii') - else: - value = self._cnx.escape_string(value) - return "'%s'" % (value,) - if isinstance(value, float): - if isinf(value): - return "'-Infinity'" if value < 0 else "'Infinity'" - if isnan(value): - return "'NaN'" - return value - if isinstance(value, (int, long, Decimal, Literal)): - return value - if isinstance(value, datetime): - if value.tzinfo: - return "'%s'::timestamptz" % (value,) - return "'%s'::timestamp" % (value,) - if isinstance(value, date): - return "'%s'::date" % (value,) - if isinstance(value, time): - if value.tzinfo: - return "'%s'::timetz" % (value,) - return "'%s'::time" % value - if isinstance(value, timedelta): - return "'%s'::interval" % (value,) - if isinstance(value, Uuid): - return "'%s'::uuid" % (value,) - if isinstance(value, list): - # Quote value as an ARRAY constructor. This is better than using - # an array literal because it carries the information that this is - # an array and not a string. One issue with this syntax is that - # you need to add an explicit typecast when passing empty arrays. - # The ARRAY keyword is actually only necessary at the top level. - if not value: # exception for empty array - return "'{}'" - q = self._quote - try: - return 'ARRAY[%s]' % (','.join(str(q(v)) for v in value),) - except UnicodeEncodeError: # Python 2 with non-ascii values - return u'ARRAY[%s]' % (','.join(unicode(q(v)) for v in value),) - if isinstance(value, tuple): - # Quote as a ROW constructor. This is better than using a record - # literal because it carries the information that this is a record - # and not a string. We don't use the keyword ROW in order to make - # this usable with the IN syntax as well. It is only necessary - # when the records has a single column which is not really useful. - q = self._quote - try: - return '(%s)' % (','.join(str(q(v)) for v in value),) - except UnicodeEncodeError: # Python 2 with non-ascii values - return u'(%s)' % (','.join(unicode(q(v)) for v in value),) - try: - value = value.__pg_repr__() - except AttributeError: - raise InterfaceError( - 'Do not know how to adapt type %s' % (type(value),)) - if isinstance(value, (tuple, list)): - value = self._quote(value) - return value - - def _quoteparams(self, string, parameters): - """Quote parameters. - - This function works for both mappings and sequences. - - The function should be used even when there are no parameters, - so that we have a consistent behavior regarding percent signs. - """ - if not parameters: - try: - return string % () # unescape literal quotes if possible - except (TypeError, ValueError): - return string # silently accept unescaped quotes - if isinstance(parameters, dict): - parameters = _quotedict(parameters) - parameters.quote = self._quote - else: - parameters = tuple(map(self._quote, parameters)) - return string % parameters - - def _make_description(self, info): - """Make the description tuple for the given field info.""" - name, typ, size, mod = info[1:] - type_code = self.type_cache[typ] - if mod > 0: - mod -= 4 - if type_code == 'numeric': - precision, scale = mod >> 16, mod & 0xffff - size = precision - else: - if not size: - size = type_code.size - if size == -1: - size = mod - precision = scale = None - return CursorDescription(name, type_code, - None, size, precision, scale, None) - - @property - def description(self): - """Read-only attribute describing the result columns.""" - descr = self._description - if self._description is True: - make = self._make_description - descr = [make(info) for info in self._src.listinfo()] - self._description = descr - return descr - - @property - def colnames(self): - """Unofficial convenience method for getting the column names.""" - return [d[0] for d in self.description] - - @property - def coltypes(self): - """Unofficial convenience method for getting the column types.""" - return [d[1] for d in self.description] - - def close(self): - """Close the cursor object.""" - self._src.close() - - def execute(self, operation, parameters=None): - """Prepare and execute a database operation (query or command).""" - # The parameters may also be specified as list of tuples to e.g. - # insert multiple rows in a single operation, but this kind of - # usage is deprecated. We make several plausibility checks because - # tuples can also be passed with the meaning of ROW constructors. - if (parameters and isinstance(parameters, list) - and len(parameters) > 1 - and all(isinstance(p, tuple) for p in parameters) - and all(len(p) == len(parameters[0]) for p in parameters[1:])): - return self.executemany(operation, parameters) - else: - # not a list of tuples - return self.executemany(operation, [parameters]) - - def executemany(self, operation, seq_of_parameters): - """Prepare operation and execute it against a parameter sequence.""" - if not seq_of_parameters: - # don't do anything without parameters - return - self._description = None - self.rowcount = -1 - # first try to execute all queries - rowcount = 0 - sql = "BEGIN" - try: - if not self._dbcnx._tnx and not self._dbcnx.autocommit: - try: - self._src.execute(sql) - except DatabaseError: - raise # database provides error message - except Exception: - raise _op_error("Can't start transaction") - else: - self._dbcnx._tnx = True - for parameters in seq_of_parameters: - sql = operation - sql = self._quoteparams(sql, parameters) - rows = self._src.execute(sql) - if rows: # true if not DML - rowcount += rows - else: - self.rowcount = -1 - except DatabaseError: - raise # database provides error message - except Error as err: - raise _db_error( - "Error in '%s': '%s' " % (sql, err), InterfaceError) - except Exception as err: - raise _op_error("Internal error in '%s': %s" % (sql, err)) - # then initialize result raw count and description - if self._src.resulttype == RESULT_DQL: - self._description = True # fetch on demand - self.rowcount = self._src.ntuples - self.lastrowid = None - if self.build_row_factory: - self.row_factory = self.build_row_factory() - else: - self.rowcount = rowcount - self.lastrowid = self._src.oidstatus() - # return the cursor object, so you can write statements such as - # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)" - return self - - def fetchone(self): - """Fetch the next row of a query result set.""" - res = self.fetchmany(1, False) - try: - return res[0] - except IndexError: - return None - - def fetchall(self): - """Fetch all (remaining) rows of a query result.""" - return self.fetchmany(-1, False) - - def fetchmany(self, size=None, keep=False): - """Fetch the next set of rows of a query result. - - The number of rows to fetch per call is specified by the - size parameter. If it is not given, the cursor's arraysize - determines the number of rows to be fetched. If you set - the keep parameter to true, this is kept as new arraysize. - """ - if size is None: - size = self.arraysize - if keep: - self.arraysize = size - try: - result = self._src.fetch(size) - except DatabaseError: - raise - except Error as err: - raise _db_error(str(err)) - typecast = self.type_cache.typecast - return [self.row_factory([typecast(value, typ) - for typ, value in zip(self.coltypes, row)]) for row in result] - - def callproc(self, procname, parameters=None): - """Call a stored database procedure with the given name. - - The sequence of parameters must contain one entry for each input - argument that the procedure expects. The result of the call is the - same as this input sequence; replacement of output and input/output - parameters in the return value is currently not supported. - - The procedure may also provide a result set as output. These can be - requested through the standard fetch methods of the cursor. - """ - n = parameters and len(parameters) or 0 - query = 'select * from "%s"(%s)' % (procname, ','.join(n * ['%s'])) - self.execute(query, parameters) - return parameters - - def copy_from(self, stream, table, - format=None, sep=None, null=None, size=None, columns=None): - """Copy data from an input stream to the specified table. - - The input stream can be a file-like object with a read() method or - it can also be an iterable returning a row or multiple rows of input - on each iteration. - - The format must be text, csv or binary. The sep option sets the - column separator (delimiter) used in the non binary formats. - The null option sets the textual representation of NULL in the input. - - The size option sets the size of the buffer used when reading data - from file-like objects. - - The copy operation can be restricted to a subset of columns. If no - columns are specified, all of them will be copied. - """ - binary_format = format == 'binary' - try: - read = stream.read - except AttributeError: - if size: - raise ValueError("Size must only be set for file-like objects") - if binary_format: - input_type = bytes - type_name = 'byte strings' - else: - input_type = basestring - type_name = 'strings' - - if isinstance(stream, basestring): - if not isinstance(stream, input_type): - raise ValueError("The input must be %s" % (type_name,)) - if not binary_format: - if isinstance(stream, str): - if not stream.endswith('\n'): - stream += '\n' - else: - if not stream.endswith(b'\n'): - stream += b'\n' - - def chunks(): - yield stream - - elif isinstance(stream, Iterable): - - def chunks(): - for chunk in stream: - if not isinstance(chunk, input_type): - raise ValueError( - "Input stream must consist of %s" - % (type_name,)) - if isinstance(chunk, str): - if not chunk.endswith('\n'): - chunk += '\n' - else: - if not chunk.endswith(b'\n'): - chunk += b'\n' - yield chunk - - else: - raise TypeError("Need an input stream to copy from") - else: - if size is None: - size = 8192 - elif not isinstance(size, int): - raise TypeError("The size option must be an integer") - if size > 0: - - def chunks(): - while True: - buffer = read(size) - yield buffer - if not buffer or len(buffer) < size: - break - - else: - - def chunks(): - yield read() - - if not table or not isinstance(table, basestring): - raise TypeError("Need a table to copy to") - if table.lower().startswith('select'): - raise ValueError("Must specify a table, not a query") - else: - table = '"%s"' % (table,) - operation = ['copy %s' % (table,)] - options = [] - params = [] - if format is not None: - if not isinstance(format, basestring): - raise TypeError("The format option must be be a string") - if format not in ('text', 'csv', 'binary'): - raise ValueError("Invalid format") - options.append('format %s' % (format,)) - if sep is not None: - if not isinstance(sep, basestring): - raise TypeError("The sep option must be a string") - if format == 'binary': - raise ValueError( - "The sep option is not allowed with binary format") - if len(sep) != 1: - raise ValueError( - "The sep option must be a single one-byte character") - options.append('delimiter %s') - params.append(sep) - if null is not None: - if not isinstance(null, basestring): - raise TypeError("The null option must be a string") - options.append('null %s') - params.append(null) - if columns: - if not isinstance(columns, basestring): - columns = ','.join('"%s"' % (col,) for col in columns) - operation.append('(%s)' % (columns,)) - operation.append("from stdin") - if options: - operation.append('(%s)' % (','.join(options),)) - operation = ' '.join(operation) - - putdata = self._src.putdata - self.execute(operation, params) - - try: - for chunk in chunks(): - putdata(chunk) - except BaseException as error: - self.rowcount = -1 - # the following call will re-raise the error - putdata(error) - else: - self.rowcount = putdata(None) - - # return the cursor object, so you can chain operations - return self - - def copy_to(self, stream, table, - format=None, sep=None, null=None, decode=None, columns=None): - """Copy data from the specified table to an output stream. - - The output stream can be a file-like object with a write() method or - it can also be None, in which case the method will return a generator - yielding a row on each iteration. - - Output will be returned as byte strings unless you set decode to true. - - Note that you can also use a select query instead of the table name. - - The format must be text, csv or binary. The sep option sets the - column separator (delimiter) used in the non binary formats. - The null option sets the textual representation of NULL in the output. - - The copy operation can be restricted to a subset of columns. If no - columns are specified, all of them will be copied. - """ - binary_format = format == 'binary' - if stream is not None: - try: - write = stream.write - except AttributeError: - raise TypeError("Need an output stream to copy to") - if not table or not isinstance(table, basestring): - raise TypeError("Need a table to copy to") - if table.lower().startswith('select'): - if columns: - raise ValueError("Columns must be specified in the query") - table = '(%s)' % (table,) - else: - table = '"%s"' % (table,) - operation = ['copy %s' % (table,)] - options = [] - params = [] - if format is not None: - if not isinstance(format, basestring): - raise TypeError("The format option must be a string") - if format not in ('text', 'csv', 'binary'): - raise ValueError("Invalid format") - options.append('format %s' % (format,)) - if sep is not None: - if not isinstance(sep, basestring): - raise TypeError("The sep option must be a string") - if binary_format: - raise ValueError( - "The sep option is not allowed with binary format") - if len(sep) != 1: - raise ValueError( - "The sep option must be a single one-byte character") - options.append('delimiter %s') - params.append(sep) - if null is not None: - if not isinstance(null, basestring): - raise TypeError("The null option must be a string") - options.append('null %s') - params.append(null) - if decode is None: - if format == 'binary': - decode = False - else: - decode = str is unicode - else: - if not isinstance(decode, (int, bool)): - raise TypeError("The decode option must be a boolean") - if decode and binary_format: - raise ValueError( - "The decode option is not allowed with binary format") - if columns: - if not isinstance(columns, basestring): - columns = ','.join('"%s"' % (col,) for col in columns) - operation.append('(%s)' % (columns,)) - - operation.append("to stdout") - if options: - operation.append('(%s)' % (','.join(options),)) - operation = ' '.join(operation) - - getdata = self._src.getdata - self.execute(operation, params) - - def copy(): - self.rowcount = 0 - while True: - row = getdata(decode) - if isinstance(row, int): - if self.rowcount != row: - self.rowcount = row - break - self.rowcount += 1 - yield row - - if stream is None: - # no input stream, return the generator - return copy() - - # write the rows to the file-like input stream - for row in copy(): - write(row) - - # return the cursor object, so you can chain operations - return self - - def __next__(self): - """Return the next row (support for the iteration protocol).""" - res = self.fetchone() - if res is None: - raise StopIteration - return res - - # Note that since Python 2.6 the iterator protocol uses __next()__ - # instead of next(), we keep it only for backward compatibility of pgdb. - next = __next__ - - @staticmethod - def nextset(): - """Not supported.""" - raise NotSupportedError("The nextset() method is not supported") - - @staticmethod - def setinputsizes(sizes): - """Not supported.""" - pass # unsupported, but silently passed - - @staticmethod - def setoutputsize(size, column=0): - """Not supported.""" - pass # unsupported, but silently passed - - @staticmethod - def row_factory(row): - """Process rows before they are returned. - - You can overwrite this statically with a custom row factory, or - you can build a row factory dynamically with build_row_factory(). - - For example, you can create a Cursor class that returns rows as - Python dictionaries like this: - - class DictCursor(pgdb.Cursor): - - def row_factory(self, row): - return {desc[0]: value - for desc, value in zip(self.description, row)} - - cur = DictCursor(con) # get one DictCursor instance or - con.cursor_type = DictCursor # always use DictCursor instances - """ - raise NotImplementedError - - def build_row_factory(self): - """Build a row factory based on the current description. - - This implementation builds a row factory for creating named tuples. - You can overwrite this method if you want to dynamically create - different row factories whenever the column description changes. - """ - names = self.colnames - if names: - return _row_factory(tuple(names)) - - -CursorDescription = namedtuple('CursorDescription', - ['name', 'type_code', 'display_size', 'internal_size', - 'precision', 'scale', 'null_ok']) - - -### Connection Objects - -class Connection(object): - """Connection object.""" - - # expose the exceptions as attributes on the connection object - Error = Error - Warning = Warning - InterfaceError = InterfaceError - DatabaseError = DatabaseError - InternalError = InternalError - OperationalError = OperationalError - ProgrammingError = ProgrammingError - IntegrityError = IntegrityError - DataError = DataError - NotSupportedError = NotSupportedError - - def __init__(self, cnx): - """Create a database connection object.""" - self._cnx = cnx # connection - self._tnx = False # transaction state - self.type_cache = TypeCache(cnx) - self.cursor_type = Cursor - self.autocommit = False - try: - self._cnx.source() - except Exception: - raise _op_error("Invalid connection") - - def __enter__(self): - """Enter the runtime context for the connection object. - - The runtime context can be used for running transactions. - - This also starts a transaction in autocommit mode. - """ - if self.autocommit: - try: - self._cnx.source().execute("BEGIN") - except DatabaseError: - raise # database provides error message - except Exception: - raise _op_error("Can't start transaction") - else: - self._tnx = True - return self - - def __exit__(self, et, ev, tb): - """Exit the runtime context for the connection object. - - This does not close the connection, but it ends a transaction. - """ - if et is None and ev is None and tb is None: - self.commit() - else: - self.rollback() - - def close(self): - """Close the connection object.""" - if self._cnx: - if self._tnx: - try: - self.rollback() - except DatabaseError: - pass - self._cnx.close() - self._cnx = None - else: - raise _op_error("Connection has been closed") - - @property - def closed(self): - """Check whether the connection has been closed or is broken.""" - try: - return not self._cnx or self._cnx.status != 1 - except TypeError: - return True - - def commit(self): - """Commit any pending transaction to the database.""" - if self._cnx: - if self._tnx: - self._tnx = False - try: - self._cnx.source().execute("COMMIT") - except DatabaseError: - raise # database provides error message - except Exception: - raise _op_error("Can't commit transaction") - else: - raise _op_error("Connection has been closed") - - def rollback(self): - """Roll back to the start of any pending transaction.""" - if self._cnx: - if self._tnx: - self._tnx = False - try: - self._cnx.source().execute("ROLLBACK") - except DatabaseError: - raise # database provides error message - except Exception: - raise _op_error("Can't rollback transaction") - else: - raise _op_error("Connection has been closed") - - def cursor(self): - """Return a new cursor object using the connection.""" - if self._cnx: - try: - return self.cursor_type(self) - except Exception: - raise _op_error("Invalid connection") - else: - raise _op_error("Connection has been closed") - - if shortcutmethods: # otherwise do not implement and document this - - def execute(self, operation, params=None): - """Shortcut method to run an operation on an implicit cursor.""" - cursor = self.cursor() - cursor.execute(operation, params) - return cursor - - def executemany(self, operation, param_seq): - """Shortcut method to run an operation against a sequence.""" - cursor = self.cursor() - cursor.executemany(operation, param_seq) - return cursor - - -### Module Interface - -_connect = connect - -def connect(dsn=None, - user=None, password=None, - host=None, database=None, **kwargs): - """Connect to a database.""" - # first get params from DSN - dbport = -1 - dbhost = "" - dbname = "" - dbuser = "" - dbpasswd = "" - dbopt = "" - try: - params = dsn.split(":") - dbhost = params[0] - dbname = params[1] - dbuser = params[2] - dbpasswd = params[3] - dbopt = params[4] - except (AttributeError, IndexError, TypeError): - pass - - # override if necessary - if user is not None: - dbuser = user - if password is not None: - dbpasswd = password - if database is not None: - dbname = database - if host is not None: - try: - params = host.split(":") - dbhost = params[0] - dbport = int(params[1]) - except (AttributeError, IndexError, TypeError, ValueError): - pass - - # empty host is localhost - if dbhost == "": - dbhost = None - if dbuser == "": - dbuser = None - - # pass keyword arguments as connection info string - if kwargs: - kwargs = list(kwargs.items()) - if '=' in dbname: - dbname = [dbname] - else: - kwargs.insert(0, ('dbname', dbname)) - dbname = [] - for kw, value in kwargs: - value = str(value) - if not value or ' ' in value: - value = "'%s'" % (value.replace( - "'", "\\'").replace('\\', '\\\\'),) - dbname.append('%s=%s' % (kw, value)) - dbname = ' '.join(dbname) - - # open the connection - cnx = _connect(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) - return Connection(cnx) - - -### Types Handling - -class Type(frozenset): - """Type class for a couple of PostgreSQL data types. - - PostgreSQL is object-oriented: types are dynamic. - We must thus use type names as internal type codes. - """ - - def __new__(cls, values): - if isinstance(values, basestring): - values = values.split() - return super(Type, cls).__new__(cls, values) - - def __eq__(self, other): - if isinstance(other, basestring): - if other.startswith('_'): - other = other[1:] - return other in self - else: - return super(Type, self).__eq__(other) - - def __ne__(self, other): - if isinstance(other, basestring): - if other.startswith('_'): - other = other[1:] - return other not in self - else: - return super(Type, self).__ne__(other) - - -class ArrayType: - """Type class for PostgreSQL array types.""" - - def __eq__(self, other): - if isinstance(other, basestring): - return other.startswith('_') - else: - return isinstance(other, ArrayType) - - def __ne__(self, other): - if isinstance(other, basestring): - return not other.startswith('_') - else: - return not isinstance(other, ArrayType) - - -class RecordType: - """Type class for PostgreSQL record types.""" - - def __eq__(self, other): - if isinstance(other, TypeCode): - return other.type == 'c' - elif isinstance(other, basestring): - return other == 'record' - else: - return isinstance(other, RecordType) - - def __ne__(self, other): - if isinstance(other, TypeCode): - return other.type != 'c' - elif isinstance(other, basestring): - return other != 'record' - else: - return not isinstance(other, RecordType) - - -# Mandatory type objects defined by DB-API 2 specs: - -STRING = Type('char bpchar name text varchar') -BINARY = Type('bytea') -NUMBER = Type('int2 int4 serial int8 float4 float8 numeric money') -DATETIME = Type('date time timetz timestamp timestamptz interval' - ' abstime reltime') # these are very old -ROWID = Type('oid') - - -# Additional type objects (more specific): - -BOOL = Type('bool') -SMALLINT = Type('int2') -INTEGER = Type('int2 int4 int8 serial') -LONG = Type('int8') -FLOAT = Type('float4 float8') -NUMERIC = Type('numeric') -MONEY = Type('money') -DATE = Type('date') -TIME = Type('time timetz') -TIMESTAMP = Type('timestamp timestamptz') -INTERVAL = Type('interval') -UUID = Type('uuid') -HSTORE = Type('hstore') -JSON = Type('json jsonb') - -# Type object for arrays (also equate to their base types): - -ARRAY = ArrayType() - -# Type object for records (encompassing all composite types): - -RECORD = RecordType() - - -# Mandatory type helpers defined by DB-API 2 specs: - -def Date(year, month, day): - """Construct an object holding a date value.""" - return date(year, month, day) - - -def Time(hour, minute=0, second=0, microsecond=0, tzinfo=None): - """Construct an object holding a time value.""" - return time(hour, minute, second, microsecond, tzinfo) - - -def Timestamp(year, month, day, hour=0, minute=0, second=0, microsecond=0, - tzinfo=None): - """Construct an object holding a time stamp value.""" - return datetime(year, month, day, hour, minute, second, microsecond, tzinfo) - - -def DateFromTicks(ticks): - """Construct an object holding a date value from the given ticks value.""" - return Date(*localtime(ticks)[:3]) - - -def TimeFromTicks(ticks): - """Construct an object holding a time value from the given ticks value.""" - return Time(*localtime(ticks)[3:6]) - - -def TimestampFromTicks(ticks): - """Construct an object holding a time stamp from the given ticks value.""" - return Timestamp(*localtime(ticks)[:6]) - - -class Binary(bytes): - """Construct an object capable of holding a binary (long) string value.""" - - -# Additional type helpers for PyGreSQL: - -def Interval(days, hours=0, minutes=0, seconds=0, microseconds=0): - """Construct an object holding a time interval value.""" - return timedelta(days, hours=hours, minutes=minutes, seconds=seconds, - microseconds=microseconds) - - -Uuid = Uuid # Construct an object holding a UUID value - - -class Hstore(dict): - """Wrapper class for marking hstore values.""" - - _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') - _re_escape = regex(r'(["\\])') - - @classmethod - def _quote(cls, s): - if s is None: - return 'NULL' - if not s: - return '""' - quote = cls._re_quote.search(s) - s = cls._re_escape.sub(r'\\\1', s) - if quote: - s = '"%s"' % (s,) - return s - - def __str__(self): - q = self._quote - return ','.join('%s=>%s' % (q(k), q(v)) for k, v in self.items()) - - -class Json: - """Construct a wrapper for holding an object serializable to JSON.""" - - def __init__(self, obj, encode=None): - self.obj = obj - self.encode = encode or jsonencode - - def __str__(self): - obj = self.obj - if isinstance(obj, basestring): - return obj - return self.encode(obj) - - -class Literal: - """Construct a wrapper for holding a literal SQL string.""" - - def __init__(self, sql): - self.sql = sql - - def __str__(self): - return self.sql - - __pg_repr__ = __str__ - - -# If run as script, print some information: - -if __name__ == '__main__': - print('PyGreSQL version', version) - print('') - print(__doc__) diff --git a/pgdb/__init__.py b/pgdb/__init__.py new file mode 100644 index 00000000..132ce292 --- /dev/null +++ b/pgdb/__init__.py @@ -0,0 +1,182 @@ +#!/usr/bin/python +# +# PyGreSQL - a Python interface for the PostgreSQL database. +# +# This file contains the DB-API 2 compatible pgdb module. +# +# Copyright (c) 2025 by the PyGreSQL Development Team +# +# Please see the LICENSE.TXT file for specific restrictions. + +"""pgdb - DB-API 2.0 compliant module for PyGreSQL. + +(c) 1999, Pascal Andre . +See package documentation for further information on copyright. + +Inline documentation is sparse. +See DB-API 2.0 specification for usage information: +http://www.python.org/peps/pep-0249.html + +Basic usage: + + pgdb.connect(connect_string) # open a connection + # connect_string = 'host:database:user:password:opt' + # All parts are optional. You may also pass host through + # password as keyword arguments. To pass a port, + # pass it in the host keyword parameter: + connection = pgdb.connect(host='localhost:5432') + + cursor = connection.cursor() # open a cursor + + cursor.execute(query[, params]) + # Execute a query, binding params (a dictionary) if they are + # passed. The binding syntax is the same as the % operator + # for dictionaries, and no quoting is done. + + cursor.executemany(query, list of params) + # Execute a query many times, binding each param dictionary + # from the list. + + cursor.fetchone() # fetch one row, [value, value, ...] + + cursor.fetchall() # fetch all rows, [[value, value, ...], ...] + + cursor.fetchmany([size]) + # returns size or cursor.arraysize number of rows, + # [[value, value, ...], ...] from result set. + # Default cursor.arraysize is 1. + + cursor.description # returns information about the columns + # [(column_name, type_name, display_size, + # internal_size, precision, scale, null_ok), ...] + # Note that display_size, precision, scale and null_ok + # are not implemented. + + cursor.rowcount # number of rows available in the result set + # Available after a call to execute. + + connection.commit() # commit transaction + + connection.rollback() # or rollback transaction + + cursor.close() # close the cursor + + connection.close() # close the connection +""" + +from pg.core import ( + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, + version, +) + +from .adapt import ( + ARRAY, + BINARY, + BOOL, + DATE, + DATETIME, + FLOAT, + HSTORE, + INTEGER, + INTERVAL, + JSON, + LONG, + MONEY, + NUMBER, + NUMERIC, + RECORD, + ROWID, + SMALLINT, + STRING, + TIME, + TIMESTAMP, + UUID, + Binary, + Date, + DateFromTicks, + DbType, + Hstore, + Interval, + Json, + Literal, + Time, + TimeFromTicks, + Timestamp, + TimestampFromTicks, + Uuid, +) +from .cast import get_typecast, reset_typecast, set_typecast +from .connect import connect +from .connection import Connection +from .constants import apilevel, paramstyle, shortcutmethods, threadsafety +from .cursor import Cursor + +__all__ = [ + 'ARRAY', + 'BINARY', + 'BOOL', + 'DATE', + 'DATETIME', + 'FLOAT', + 'HSTORE', + 'INTEGER', + 'INTERVAL', + 'JSON', + 'LONG', + 'MONEY', + 'NUMBER', + 'NUMERIC', + 'RECORD', + 'ROWID', + 'SMALLINT', + 'STRING', + 'TIME', + 'TIMESTAMP', + 'UUID', + 'Binary', + 'Connection', + 'Cursor', + 'DataError', + 'DatabaseError', + 'Date', + 'DateFromTicks', + 'DbType', + 'Error', + 'Hstore', + 'IntegrityError', + 'InterfaceError', + 'InternalError', + 'Interval', + 'Json', + 'Literal', + 'NotSupportedError', + 'OperationalError', + 'ProgrammingError', + 'Time', + 'TimeFromTicks', + 'Timestamp', + 'TimestampFromTicks', + 'Uuid', + 'Warning', + '__version__', + 'apilevel', + 'connect', + 'get_typecast', + 'paramstyle', + 'reset_typecast', + 'set_typecast', + 'shortcutmethods', + 'threadsafety', + 'version', +] + +__version__ = version diff --git a/pgdb/adapt.py b/pgdb/adapt.py new file mode 100644 index 00000000..f657b190 --- /dev/null +++ b/pgdb/adapt.py @@ -0,0 +1,261 @@ +"""Type helpers for adaptation of parameters.""" + +from __future__ import annotations + +from datetime import date, datetime, time, timedelta, tzinfo +from json import dumps as jsonencode +from re import compile as regex +from time import localtime +from typing import Any, Callable, Iterable +from uuid import UUID as Uuid # noqa: N811 + +from .typecode import TypeCode + +__all__ = [ + 'ARRAY', + 'BINARY', + 'BOOL', + 'DATE', + 'DATETIME', + 'FLOAT', + 'HSTORE', + 'INTEGER', + 'INTERVAL', + 'JSON', + 'LONG', + 'MONEY', + 'NUMBER', + 'NUMERIC', + 'RECORD', + 'ROWID', + 'SMALLINT', + 'STRING', + 'TIME', + 'TIMESTAMP', + 'UUID', + 'ArrayType', + 'Date', + 'DateFromTicks', + 'DbType', + 'RecordType', + 'Time', + 'TimeFromTicks', + 'Timestamp', + 'TimestampFromTicks' + +] + + +class DbType(frozenset): + """Type class for a couple of PostgreSQL data types. + + PostgreSQL is object-oriented: types are dynamic. + We must thus use type names as internal type codes. + """ + + def __new__(cls, values: str | Iterable[str]) -> DbType: + """Create new type object.""" + if isinstance(values, str): + values = values.split() + return super().__new__(cls, values) + + def __eq__(self, other: Any) -> bool: + """Check whether types are considered equal.""" + if isinstance(other, str): + if other.startswith('_'): + other = other[1:] + return other in self + return super().__eq__(other) + + def __ne__(self, other: Any) -> bool: + """Check whether types are not considered equal.""" + if isinstance(other, str): + if other.startswith('_'): + other = other[1:] + return other not in self + return super().__ne__(other) + + +class ArrayType: + """Type class for PostgreSQL array types.""" + + def __eq__(self, other: Any) -> bool: + """Check whether arrays are equal.""" + if isinstance(other, str): + return other.startswith('_') + return isinstance(other, ArrayType) + + def __ne__(self, other: Any) -> bool: + """Check whether arrays are different.""" + if isinstance(other, str): + return not other.startswith('_') + return not isinstance(other, ArrayType) + + +class RecordType: + """Type class for PostgreSQL record types.""" + + def __eq__(self, other: Any) -> bool: + """Check whether records are equal.""" + if isinstance(other, TypeCode): + return other.type == 'c' + if isinstance(other, str): + return other == 'record' + return isinstance(other, RecordType) + + def __ne__(self, other: Any) -> bool: + """Check whether records are different.""" + if isinstance(other, TypeCode): + return other.type != 'c' + if isinstance(other, str): + return other != 'record' + return not isinstance(other, RecordType) + + +# Mandatory type objects defined by DB-API 2 specs: + +STRING = DbType('char bpchar name text varchar') +BINARY = DbType('bytea') +NUMBER = DbType('int2 int4 serial int8 float4 float8 numeric money') +DATETIME = DbType('date time timetz timestamp timestamptz interval' + ' abstime reltime') # these are very old +ROWID = DbType('oid') + + +# Additional type objects (more specific): + +BOOL = DbType('bool') +SMALLINT = DbType('int2') +INTEGER = DbType('int2 int4 int8 serial') +LONG = DbType('int8') +FLOAT = DbType('float4 float8') +NUMERIC = DbType('numeric') +MONEY = DbType('money') +DATE = DbType('date') +TIME = DbType('time timetz') +TIMESTAMP = DbType('timestamp timestamptz') +INTERVAL = DbType('interval') +UUID = DbType('uuid') +HSTORE = DbType('hstore') +JSON = DbType('json jsonb') + +# Type object for arrays (also equate to their base types): + +ARRAY = ArrayType() + +# Type object for records (encompassing all composite types): + +RECORD = RecordType() + + +# Mandatory type helpers defined by DB-API 2 specs: + +def Date(year: int, month: int, day: int) -> date: # noqa: N802 + """Construct an object holding a date value.""" + return date(year, month, day) + + +def Time(hour: int, minute: int = 0, # noqa: N802 + second: int = 0, microsecond: int = 0, + tzinfo: tzinfo | None = None) -> time: + """Construct an object holding a time value.""" + return time(hour, minute, second, microsecond, tzinfo) + + +def Timestamp(year: int, month: int, day: int, # noqa: N802 + hour: int = 0, minute: int = 0, + second: int = 0, microsecond: int = 0, + tzinfo: tzinfo | None = None) -> datetime: + """Construct an object holding a time stamp value.""" + return datetime(year, month, day, hour, minute, + second, microsecond, tzinfo) + + +def DateFromTicks(ticks: float | None) -> date: # noqa: N802 + """Construct an object holding a date value from the given ticks value.""" + return Date(*localtime(ticks)[:3]) + + +def TimeFromTicks(ticks: float | None) -> time: # noqa: N802 + """Construct an object holding a time value from the given ticks value.""" + return Time(*localtime(ticks)[3:6]) + + +def TimestampFromTicks(ticks: float | None) -> datetime: # noqa: N802 + """Construct an object holding a time stamp from the given ticks value.""" + return Timestamp(*localtime(ticks)[:6]) + + +class Binary(bytes): + """Construct an object capable of holding a binary (long) string value.""" + + +# Additional type helpers for PyGreSQL: + +def Interval(days: int | float, # noqa: N802 + hours: int | float = 0, minutes: int | float = 0, + seconds: int | float = 0, microseconds: int | float = 0 + ) -> timedelta: + """Construct an object holding a time interval value.""" + return timedelta(days, hours=hours, minutes=minutes, + seconds=seconds, microseconds=microseconds) + + +Uuid = Uuid # Construct an object holding a UUID value + + +class Hstore(dict): + """Wrapper class for marking hstore values.""" + + _re_quote = regex('^[Nn][Uu][Ll][Ll]$|[ ,=>]') + _re_escape = regex(r'(["\\])') + + @classmethod + def _quote(cls, s: Any) -> Any: + if s is None: + return 'NULL' + if not isinstance(s, str): + s = str(s) + if not s: + return '""' + quote = cls._re_quote.search(s) + s = cls._re_escape.sub(r'\\\1', s) + if quote: + s = f'"{s}"' + return s + + def __str__(self) -> str: + """Create a printable representation of the hstore value.""" + q = self._quote + return ','.join(f'{q(k)}=>{q(v)}' for k, v in self.items()) + + +class Json: + """Construct a wrapper for holding an object serializable to JSON.""" + + def __init__(self, obj: Any, + encode: Callable[[Any], str] | None = None) -> None: + """Initialize the JSON object.""" + self.obj = obj + self.encode = encode or jsonencode + + def __str__(self) -> str: + """Create a printable representation of the JSON object.""" + obj = self.obj + if isinstance(obj, str): + return obj + return self.encode(obj) + + +class Literal: + """Construct a wrapper for holding a literal SQL string.""" + + def __init__(self, sql: str) -> None: + """Initialize literal SQL string.""" + self.sql = sql + + def __str__(self) -> str: + """Return a printable representation of the SQL string.""" + return self.sql + + __pg_repr__ = __str__ \ No newline at end of file diff --git a/pgdb/cast.py b/pgdb/cast.py new file mode 100644 index 00000000..49b4bd84 --- /dev/null +++ b/pgdb/cast.py @@ -0,0 +1,594 @@ +"""Internal type handling.""" + +from __future__ import annotations + +from collections import namedtuple +from datetime import date, datetime, time, timedelta +from decimal import Decimal as _Decimal +from functools import partial +from inspect import signature +from json import loads as jsondecode +from re import compile as regex +from typing import Any, Callable, ClassVar, Sequence +from uuid import UUID as Uuid # noqa: N811 + +from pg.core import Connection as Cnx +from pg.core import ( + ProgrammingError, + cast_array, + cast_hstore, + cast_record, + unescape_bytea, +) + +from .typecode import TypeCode + +__all__ = [ + 'Decimal', + 'FieldInfo', + 'LocalTypecasts', + 'TypeCache', + 'Typecasts', + 'cast_bool', + 'cast_date', + 'cast_int2vector', + 'cast_interval', + 'cast_money', + 'cast_time', + 'cast_timestamp', + 'cast_timestamptz', + 'cast_timetz', + 'decimal_type', + 'get_typecast', + 'reset_typecast', + 'set_typecast' +] + + +Decimal: type = _Decimal + + +def get_args(func: Callable) -> list: + return list(signature(func).parameters) + + +# time zones used in Postgres timestamptz output +_timezones: dict[str, str] = { + 'CET': '+0100', 'EET': '+0200', 'EST': '-0500', + 'GMT': '+0000', 'HST': '-1000', 'MET': '+0100', 'MST': '-0700', + 'UCT': '+0000', 'UTC': '+0000', 'WET': '+0000' +} + + +def _timezone_as_offset(tz: str) -> str: + if tz.startswith(('+', '-')): + if len(tz) < 5: + return tz + '00' + return tz.replace(':', '') + return _timezones.get(tz, '+0000') + + +def decimal_type(decimal_type: type | None = None) -> type: + """Get or set global type to be used for decimal values. + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call con.type_cache.reset_typecast(). + """ + global Decimal + if decimal_type is not None: + Decimal = decimal_type + set_typecast('numeric', decimal_type) + return Decimal + + +def cast_bool(value: str) -> bool | None: + """Cast boolean value in database format to bool.""" + return value[0] in ('t', 'T') if value else None + + +def cast_money(value: str) -> _Decimal | None: + """Cast money value in database format to Decimal.""" + if not value: + return None + value = value.replace('(', '-') + return Decimal(''.join(c for c in value if c.isdigit() or c in '.-')) + + +def cast_int2vector(value: str) -> list[int]: + """Cast an int2vector value.""" + return [int(v) for v in value.split()] + + +def cast_date(value: str, cnx: Cnx) -> date: + """Cast a date value.""" + # The output format depends on the server setting DateStyle. The default + # setting ISO and the setting for German are actually unambiguous. The + # order of days and months in the other two settings is however ambiguous, + # so at least here we need to consult the setting to properly parse values. + if value == '-infinity': + return date.min + if value == 'infinity': + return date.max + values = value.split() + if values[-1] == 'BC': + return date.min + value = values[0] + if len(value) > 10: + return date.max + format = cnx.date_format() + return datetime.strptime(value, format).date() + + +def cast_time(value: str) -> time: + """Cast a time value.""" + fmt = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + return datetime.strptime(value, fmt).time() + + +_re_timezone = regex('(.*)([+-].*)') + + +def cast_timetz(value: str) -> time: + """Cast a timetz value.""" + m = _re_timezone.match(value) + if m: + value, tz = m.groups() + else: + tz = '+0000' + format = '%H:%M:%S.%f' if len(value) > 8 else '%H:%M:%S' + value += _timezone_as_offset(tz) + format += '%z' + return datetime.strptime(value, format).timetz() + + +def cast_timestamp(value: str, cnx: Cnx) -> datetime: + """Cast a timestamp value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = cnx.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:5] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + else: + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +def cast_timestamptz(value: str, cnx: Cnx) -> datetime: + """Cast a timestamptz value.""" + if value == '-infinity': + return datetime.min + if value == 'infinity': + return datetime.max + values = value.split() + if values[-1] == 'BC': + return datetime.min + format = cnx.date_format() + if format.endswith('-%Y') and len(values) > 2: + values = values[1:] + if len(values[3]) > 4: + return datetime.max + formats = ['%d %b' if format.startswith('%d') else '%b %d', + '%H:%M:%S.%f' if len(values[2]) > 8 else '%H:%M:%S', '%Y'] + values, tz = values[:-1], values[-1] + else: + if format.startswith('%Y-'): + m = _re_timezone.match(values[1]) + if m: + values[1], tz = m.groups() + else: + tz = '+0000' + else: + values, tz = values[:-1], values[-1] + if len(values[0]) > 10: + return datetime.max + formats = [format, '%H:%M:%S.%f' if len(values[1]) > 8 else '%H:%M:%S'] + values.append(_timezone_as_offset(tz)) + formats.append('%z') + return datetime.strptime(' '.join(values), ' '.join(formats)) + + +_re_interval_sql_standard = regex( + '(?:([+-])?([0-9]+)-([0-9]+) ?)?' + '(?:([+-]?[0-9]+)(?!:) ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres = regex( + '(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-])?([0-9]+):([0-9]+):([0-9]+)(?:\\.([0-9]+))?)?') + +_re_interval_postgres_verbose = regex( + '@ ?(?:([+-]?[0-9]+) ?years? ?)?' + '(?:([+-]?[0-9]+) ?mons? ?)?' + '(?:([+-]?[0-9]+) ?days? ?)?' + '(?:([+-]?[0-9]+) ?hours? ?)?' + '(?:([+-]?[0-9]+) ?mins? ?)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))? ?secs?)? ?(ago)?') + +_re_interval_iso_8601 = regex( + 'P(?:([+-]?[0-9]+)Y)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-]?[0-9]+)D)?' + '(?:T(?:([+-]?[0-9]+)H)?' + '(?:([+-]?[0-9]+)M)?' + '(?:([+-])?([0-9]+)(?:\\.([0-9]+))?S)?)?') + + +def cast_interval(value: str) -> timedelta: + """Cast an interval value.""" + # The output format depends on the server setting IntervalStyle, but it's + # not necessary to consult this setting to parse it. It's faster to just + # check all possible formats, and there is no ambiguity here. + m = _re_interval_iso_8601.match(value) + if m: + s = [v or '0' for v in m.groups()] + secs_ago = s.pop(5) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = -secs + usecs = -usecs + else: + m = _re_interval_postgres_verbose.match(value) + if m: + s, ago = [v or '0' for v in m.groups()[:8]], m.group(9) + secs_ago = s.pop(5) == '-' + d = [-int(v) for v in s] if ago else [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if secs_ago: + secs = - secs + usecs = -usecs + else: + m = _re_interval_postgres.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + m = _re_interval_sql_standard.match(value) + if m and any(m.groups()): + s = [v or '0' for v in m.groups()] + years_ago = s.pop(0) == '-' + hours_ago = s.pop(3) == '-' + d = [int(v) for v in s] + years, mons, days, hours, mins, secs, usecs = d + if years_ago: + years = -years + mons = -mons + if hours_ago: + hours = -hours + mins = -mins + secs = -secs + usecs = -usecs + else: + raise ValueError(f'Cannot parse interval: {value}') + days += 365 * years + 30 * mons + return timedelta(days=days, hours=hours, minutes=mins, + seconds=secs, microseconds=usecs) + + +class Typecasts(dict): + """Dictionary mapping database types to typecast functions. + + The cast functions get passed the string representation of a value in + the database which they need to convert to a Python object. The + passed string will never be None since NULL values are already + handled before the cast function is called. + """ + + # the default cast functions + # (str functions are ignored but have been added for faster access) + defaults: ClassVar[dict[str, Callable]] = { + 'char': str, 'bpchar': str, 'name': str, + 'text': str, 'varchar': str, 'sql_identifier': str, + 'bool': cast_bool, 'bytea': unescape_bytea, + 'int2': int, 'int4': int, 'serial': int, 'int8': int, 'oid': int, + 'hstore': cast_hstore, 'json': jsondecode, 'jsonb': jsondecode, + 'float4': float, 'float8': float, + 'numeric': Decimal, 'money': cast_money, + 'date': cast_date, 'interval': cast_interval, + 'time': cast_time, 'timetz': cast_timetz, + 'timestamp': cast_timestamp, 'timestamptz': cast_timestamptz, + 'int2vector': cast_int2vector, 'uuid': Uuid, + 'anyarray': cast_array, 'record': cast_record} + + cnx: Cnx | None = None # for local connection specific instances + + def __missing__(self, typ: str) -> Callable | None: + """Create a cast function if it is not cached. + + Note that this class never raises a KeyError, + but returns None when no special cast function exists. + """ + if not isinstance(typ, str): + raise TypeError(f'Invalid type: {typ}') + cast = self.defaults.get(typ) + if cast: + # store default for faster access + cast = self._add_connection(cast) + self[typ] = cast + elif typ.startswith('_'): + # create array cast + base_cast = self[typ[1:]] + cast = self.create_array_cast(base_cast) + if base_cast: + # store only if base type exists + self[typ] = cast + return cast + + @staticmethod + def _needs_connection(func: Callable) -> bool: + """Check if a typecast function needs a connection argument.""" + try: + args = get_args(func) + except (TypeError, ValueError): + return False + return 'cnx' in args[1:] + + def _add_connection(self, cast: Callable) -> Callable: + """Add a connection argument to the typecast function if necessary.""" + if not self.cnx or not self._needs_connection(cast): + return cast + return partial(cast, cnx=self.cnx) + + def get(self, typ: str, default: Callable | None = None # type: ignore + ) -> Callable | None: + """Get the typecast function for the given database type.""" + return self[typ] or default + + def set(self, typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a typecast function for the specified database type(s).""" + if isinstance(typ, str): + typ = [typ] + if cast is None: + for t in typ: + self.pop(t, None) + self.pop(f'_{t}', None) + else: + if not callable(cast): + raise TypeError("Cast parameter must be callable") + for t in typ: + self[t] = self._add_connection(cast) + self.pop(f'_{t}', None) + + def reset(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecasts for the specified type(s) to their defaults. + + When no type is specified, all typecasts will be reset. + """ + defaults = self.defaults + if typ is None: + self.clear() + self.update(defaults) + else: + if isinstance(typ, str): + typ = [typ] + for t in typ: + cast = defaults.get(t) + if cast: + self[t] = self._add_connection(cast) + t = f'_{t}' + cast = defaults.get(t) + if cast: + self[t] = self._add_connection(cast) + else: + self.pop(t, None) + else: + self.pop(t, None) + self.pop(f'_{t}', None) + + def create_array_cast(self, basecast: Callable) -> Callable: + """Create an array typecast for the given base cast.""" + cast_array = self['anyarray'] + + def cast(v: Any) -> list: + return cast_array(v, basecast) + return cast + + def create_record_cast(self, name: str, fields: Sequence[str], + casts: Sequence[str]) -> Callable: + """Create a named record typecast for the given fields and casts.""" + cast_record = self['record'] + record = namedtuple(name, fields) # type: ignore + + def cast(v: Any) -> record: + # noinspection PyArgumentList + return record(*cast_record(v, casts)) + return cast + + +_typecasts = Typecasts() # this is the global typecast dictionary + + +def get_typecast(typ: str) -> Callable | None: + """Get the global typecast function for the given database type.""" + return _typecasts.get(typ) + + +def set_typecast(typ: str | Sequence[str], cast: Callable | None) -> None: + """Set a global typecast function for the given database type(s). + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call con.type_cache.reset_typecast(). + """ + _typecasts.set(typ, cast) + + +def reset_typecast(typ: str | Sequence[str] | None = None) -> None: + """Reset the global typecasts for the given type(s) to their default. + + When no type is specified, all typecasts will be reset. + + Note that connections cache cast functions. To be sure a global change + is picked up by a running connection, call con.type_cache.reset_typecast(). + """ + _typecasts.reset(typ) + + +class LocalTypecasts(Typecasts): + """Map typecasts, including local composite types, to cast functions.""" + + defaults = _typecasts + + cnx: Cnx | None = None # set in connection specific instances + + def __missing__(self, typ: str) -> Callable | None: + """Create a cast function if it is not cached.""" + cast: Callable | None + if typ.startswith('_'): + base_cast = self[typ[1:]] + cast = self.create_array_cast(base_cast) + if base_cast: + self[typ] = cast + else: + cast = self.defaults.get(typ) + if cast: + cast = self._add_connection(cast) + self[typ] = cast + else: + fields = self.get_fields(typ) + if fields: + casts = [self[field.type] for field in fields] + field_names = [field.name for field in fields] + cast = self.create_record_cast(typ, field_names, casts) + self[typ] = cast + return cast + + # noinspection PyMethodMayBeStatic,PyUnusedLocal + def get_fields(self, typ: str) -> list[FieldInfo]: + """Return the fields for the given record type. + + This method will be replaced with a method that looks up the fields + using the type cache of the connection. + """ + return [] + + +FieldInfo = namedtuple('FieldInfo', ('name', 'type')) + + +class TypeCache(dict): + """Cache for database types. + + This cache maps type OIDs and names to TypeCode strings containing + important information on the associated database type. + """ + + def __init__(self, cnx: Cnx) -> None: + """Initialize type cache for connection.""" + super().__init__() + self._escape_string = cnx.escape_string + self._src = cnx.source() + self._typecasts = LocalTypecasts() + self._typecasts.get_fields = self.get_fields # type: ignore + self._typecasts.cnx = cnx + self._query_pg_type = ( + "SELECT oid, typname," + " typlen, typtype, typcategory, typdelim, typrelid" + " FROM pg_catalog.pg_type WHERE oid OPERATOR(pg_catalog.=) {}") + + def __missing__(self, key: int | str) -> TypeCode: + """Get the type info from the database if it is not cached.""" + oid: int | str + if isinstance(key, int): + oid = key + else: + if '.' not in key and '"' not in key: + key = f'"{key}"' + oid = f"'{self._escape_string(key)}'::pg_catalog.regtype" + try: + self._src.execute(self._query_pg_type.format(oid)) + except ProgrammingError: + res = None + else: + res = self._src.fetch(1) + if not res: + raise KeyError(f'Type {key} could not be found') + r = res[0] + type_code = TypeCode.create( + int(r[0]), r[1], int(r[2]), r[3], r[4], r[5], int(r[6])) + # noinspection PyUnresolvedReferences + self[type_code.oid] = self[str(type_code)] = type_code + return type_code + + def get(self, key: int | str, # type: ignore + default: TypeCode | None = None) -> TypeCode | None: + """Get the type even if it is not cached.""" + try: + return self[key] + except KeyError: + return default + + def get_fields(self, typ: int | str | TypeCode) -> list[FieldInfo] | None: + """Get the names and types of the fields of composite types.""" + if isinstance(typ, TypeCode): + relid = typ.relid + else: + type_code = self.get(typ) + if not type_code: + return None + relid = type_code.relid + if not relid: + return None # this type is not composite + self._src.execute( + "SELECT attname, atttypid" # noqa: S608 + " FROM pg_catalog.pg_attribute" + f" WHERE attrelid OPERATOR(pg_catalog.=) {relid}" + " AND attnum OPERATOR(pg_catalog.>) 0" + " AND NOT attisdropped ORDER BY attnum") + return [FieldInfo(name, self.get(int(oid))) + for name, oid in self._src.fetch(-1)] + + def get_typecast(self, typ: str) -> Callable | None: + """Get the typecast function for the given database type.""" + return self._typecasts[typ] + + def set_typecast(self, typ: str | Sequence[str], + cast: Callable | None) -> None: + """Set a typecast function for the specified database type(s).""" + self._typecasts.set(typ, cast) + + def reset_typecast(self, typ: str | Sequence[str] | None = None) -> None: + """Reset the typecast function for the specified database type(s).""" + self._typecasts.reset(typ) + + def typecast(self, value: Any, typ: str) -> Any: + """Cast the given value according to the given database type.""" + if value is None: + # for NULL values, no typecast is necessary + return None + cast = self._typecasts[typ] + if cast is None or cast is str: + # no typecast is necessary + return value + return cast(value) + + def get_row_caster(self, types: Sequence[str]) -> Callable: + """Get a typecast function for a complete row of values.""" + typecasts = self._typecasts + casts = [typecasts[typ] for typ in types] + casts = [cast if cast is not str else None for cast in casts] + + def row_caster(row: Sequence) -> Sequence: + return [value if cast is None or value is None else cast(value) + for cast, value in zip(casts, row)] + + return row_caster \ No newline at end of file diff --git a/pgdb/connect.py b/pgdb/connect.py new file mode 100644 index 00000000..73b96a36 --- /dev/null +++ b/pgdb/connect.py @@ -0,0 +1,74 @@ +"""The DB API 2 connect function.""" + +from __future__ import annotations + +from typing import Any + +from pg.core import connect as get_cnx + +from .connection import Connection + +__all__ = ['connect'] + +def connect(dsn: str | None = None, + user: str | None = None, password: str | None = None, + host: str | None = None, database: str | None = None, + **kwargs: Any) -> Connection: + """Connect to a database.""" + # first get params from DSN + dbport = -1 + dbhost: str | None = "" + dbname: str | None = "" + dbuser: str | None = "" + dbpasswd: str | None = "" + dbopt: str | None = "" + if dsn: + try: + params = dsn.split(":", 4) + dbhost = params[0] + dbname = params[1] + dbuser = params[2] + dbpasswd = params[3] + dbopt = params[4] + except (AttributeError, IndexError, TypeError): + pass + + # override if necessary + if user is not None: + dbuser = user + if password is not None: + dbpasswd = password + if database is not None: + dbname = database + if host: + try: + params = host.split(":", 1) + dbhost = params[0] + dbport = int(params[1]) + except (AttributeError, IndexError, TypeError, ValueError): + pass + + # empty host is localhost + if dbhost == "": + dbhost = None + if dbuser == "": + dbuser = None + + # pass keyword arguments as connection info string + if kwargs: + kwarg_list = list(kwargs.items()) + kw_parts = [] + if dbname and '=' in dbname: + kw_parts.append(dbname) + else: + kwarg_list.insert(0, ('dbname', dbname)) + for kw, value in kwarg_list: + value = str(value) + if not value or ' ' in value: + value = value.replace('\\', '\\\\').replace("'", "\\'") + value = f"'{value}'" + kw_parts.append(f'{kw}={value}') + dbname = ' '.join(kw_parts) + # open the connection + cnx = get_cnx(dbname, dbhost, dbport, dbopt, dbuser, dbpasswd) + return Connection(cnx) diff --git a/pgdb/connection.py b/pgdb/connection.py new file mode 100644 index 00000000..17d32bcc --- /dev/null +++ b/pgdb/connection.py @@ -0,0 +1,156 @@ +"""The DB API 2 Connection objects.""" + +from __future__ import annotations + +from contextlib import suppress +from typing import Any, Sequence + +from pg.core import Connection as Cnx +from pg.core import ( + DatabaseError, + DataError, + Error, + IntegrityError, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, + Warning, +) +from pg.error import op_error + +from .cast import TypeCache +from .constants import shortcutmethods +from .cursor import Cursor + +__all__ = ['Connection'] + +class Connection: + """Connection object.""" + + # expose the exceptions as attributes on the connection object + Error = Error + Warning = Warning + InterfaceError = InterfaceError + DatabaseError = DatabaseError + InternalError = InternalError + OperationalError = OperationalError + ProgrammingError = ProgrammingError + IntegrityError = IntegrityError + DataError = DataError + NotSupportedError = NotSupportedError + + def __init__(self, cnx: Cnx) -> None: + """Create a database connection object.""" + self._cnx: Cnx | None = cnx # connection + self._tnx = False # transaction state + self.type_cache = TypeCache(cnx) + self.cursor_type = Cursor + self.autocommit = False + try: + self._cnx.source() + except Exception as e: + raise op_error("Invalid connection") from e + + def __enter__(self) -> Connection: + """Enter the runtime context for the connection object. + + The runtime context can be used for running transactions. + + This also starts a transaction in autocommit mode. + """ + if self.autocommit: + cnx = self._cnx + if not cnx: + raise op_error("Connection has been closed") + try: + cnx.source().execute("BEGIN") + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't start transaction") from e + else: + self._tnx = True + return self + + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: + """Exit the runtime context for the connection object. + + This does not close the connection, but it ends a transaction. + """ + if et is None and ev is None and tb is None: + self.commit() + else: + self.rollback() + + def close(self) -> None: + """Close the connection object.""" + if not self._cnx: + raise op_error("Connection has been closed") + if self._tnx: + with suppress(DatabaseError): + self.rollback() + self._cnx.close() + self._cnx = None + + @property + def closed(self) -> bool: + """Check whether the connection has been closed or is broken.""" + try: + return not self._cnx or self._cnx.status != 1 + except TypeError: + return True + + def commit(self) -> None: + """Commit any pending transaction to the database.""" + if not self._cnx: + raise op_error("Connection has been closed") + if self._tnx: + self._tnx = False + try: + self._cnx.source().execute("COMMIT") + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't commit transaction") from e + + def rollback(self) -> None: + """Roll back to the start of any pending transaction.""" + if not self._cnx: + raise op_error("Connection has been closed") + if self._tnx: + self._tnx = False + try: + self._cnx.source().execute("ROLLBACK") + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't rollback transaction") from e + + def cursor(self) -> Cursor: + """Return a new cursor object using the connection.""" + if not self._cnx: + raise op_error("Connection has been closed") + try: + return self.cursor_type(self) + except Exception as e: + raise op_error("Invalid connection") from e + + if shortcutmethods: # otherwise do not implement and document this + + def execute(self, operation: str, + parameters: Sequence | None = None) -> Cursor: + """Shortcut method to run an operation on an implicit cursor.""" + cursor = self.cursor() + cursor.execute(operation, parameters) + return cursor + + def executemany(self, operation: str, + seq_of_parameters: Sequence[Sequence | None] + ) -> Cursor: + """Shortcut method to run an operation against a sequence.""" + cursor = self.cursor() + cursor.executemany(operation, seq_of_parameters) + return cursor \ No newline at end of file diff --git a/pgdb/constants.py b/pgdb/constants.py new file mode 100644 index 00000000..e6547f9c --- /dev/null +++ b/pgdb/constants.py @@ -0,0 +1,14 @@ +"""The DB API 2 module constants.""" + +# compliant with DB API 2.0 +apilevel = '2.0' + +# module may be shared, but not connections +threadsafety = 1 + +# this module use extended python format codes +paramstyle = 'pyformat' + +# shortcut methods have been excluded from DB API 2 and +# are not recommended by the DB SIG, but they can be handy +shortcutmethods = 1 diff --git a/pgdb/cursor.py b/pgdb/cursor.py new file mode 100644 index 00000000..753f4691 --- /dev/null +++ b/pgdb/cursor.py @@ -0,0 +1,645 @@ +"""The DB API 2 Cursor object.""" + +from __future__ import annotations + +from collections import namedtuple +from collections.abc import Iterable +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from math import isinf, isnan +from typing import TYPE_CHECKING, Any, Callable, Generator, Mapping, Sequence +from uuid import UUID as Uuid # noqa: N811 + +from pg.core import ( + RESULT_DQL, + DatabaseError, + Error, + InterfaceError, + NotSupportedError, +) +from pg.core import Connection as Cnx +from pg.error import db_error, if_error, op_error +from pg.helpers import QuoteDict, RowCache + +from .adapt import Binary, Hstore, Json, Literal +from .cast import TypeCache +from .typecode import TypeCode + +if TYPE_CHECKING: + from .connection import Connection + +__all__ = ['Cursor', 'CursorDescription'] + + +class Cursor: + """Cursor object.""" + + def __init__(self, connection: Connection) -> None: + """Create a cursor object for the database connection.""" + self.connection = self._connection = connection + cnx = connection._cnx + if not cnx: + raise op_error("Connection has been closed") + self._cnx: Cnx = cnx + self.type_cache: TypeCache = connection.type_cache + self._src = self._cnx.source() + # the official attribute for describing the result columns + self._description: list[CursorDescription] | bool | None = None + if self.row_factory is Cursor.row_factory: + # the row factory needs to be determined dynamically + self.row_factory = None # type: ignore + else: + self.build_row_factory = None # type: ignore + self.rowcount: int | None = -1 + self.arraysize: int = 1 + self.lastrowid: int | None = None + + def __iter__(self) -> Cursor: + """Make cursor compatible to the iteration protocol.""" + return self + + def __enter__(self) -> Cursor: + """Enter the runtime context for the cursor object.""" + return self + + def __exit__(self, et: type[BaseException] | None, + ev: BaseException | None, tb: Any) -> None: + """Exit the runtime context for the cursor object.""" + self.close() + + def _quote(self, value: Any) -> Any: + """Quote value depending on its type.""" + if value is None: + return 'NULL' + if isinstance(value, (Hstore, Json)): + value = str(value) + if isinstance(value, (bytes, str)): + cnx = self._cnx + if isinstance(value, Binary): + value = cnx.escape_bytea(value).decode('ascii') + else: + value = cnx.escape_string(value) + return f"'{value}'" + if isinstance(value, float): + if isinf(value): + return "'-Infinity'" if value < 0 else "'Infinity'" + if isnan(value): + return "'NaN'" + return value + if isinstance(value, (int, Decimal, Literal)): + return value + if isinstance(value, datetime): + if value.tzinfo: + return f"'{value}'::timestamptz" + return f"'{value}'::timestamp" + if isinstance(value, date): + return f"'{value}'::date" + if isinstance(value, time): + if value.tzinfo: + return f"'{value}'::timetz" + return f"'{value}'::time" + if isinstance(value, timedelta): + return f"'{value}'::interval" + if isinstance(value, Uuid): + return f"'{value}'::uuid" + if isinstance(value, list): + # Quote value as an ARRAY constructor. This is better than using + # an array literal because it carries the information that this is + # an array and not a string. One issue with this syntax is that + # you need to add an explicit typecast when passing empty arrays. + # The ARRAY keyword is actually only necessary at the top level. + if not value: # exception for empty array + return "'{}'" + q = self._quote + v = ','.join(str(q(v)) for v in value) + return f'ARRAY[{v}]' + if isinstance(value, tuple): + # Quote as a ROW constructor. This is better than using a record + # literal because it carries the information that this is a record + # and not a string. We don't use the keyword ROW in order to make + # this usable with the IN syntax as well. It is only necessary + # when the records has a single column which is not really useful. + q = self._quote + v = ','.join(str(q(v)) for v in value) + return f'({v})' + try: # noinspection PyUnresolvedReferences + value = value.__pg_repr__() + except AttributeError as e: + raise InterfaceError( + f'Do not know how to adapt type {type(value)}') from e + if isinstance(value, (tuple, list)): + value = self._quote(value) + return value + + def _quoteparams(self, string: str, + parameters: Mapping | Sequence | None) -> str: + """Quote parameters. + + This function works for both mappings and sequences. + + The function should be used even when there are no parameters, + so that we have a consistent behavior regarding percent signs. + """ + if not parameters: + try: + return string % () # unescape literal quotes if possible + except (TypeError, ValueError): + return string # silently accept unescaped quotes + if isinstance(parameters, dict): + parameters = QuoteDict(parameters) + parameters.quote = self._quote + else: + parameters = tuple(map(self._quote, parameters)) + return string % parameters + + def _make_description(self, info: tuple[int, str, int, int, int] + ) -> CursorDescription: + """Make the description tuple for the given field info.""" + name, typ, size, mod = info[1:] + type_code = self.type_cache[typ] + if mod > 0: + mod -= 4 + precision: int | None + scale: int | None + if type_code == 'numeric': + precision, scale = mod >> 16, mod & 0xffff + size = precision + else: + if not size: + size = type_code.size + if size == -1: + size = mod + precision = scale = None + return CursorDescription( + name, type_code, None, size, precision, scale, None) + + @property + def description(self) -> list[CursorDescription] | None: + """Read-only attribute describing the result columns.""" + description = self._description + if description is None: + return None + if not isinstance(description, list): + make = self._make_description + description = [make(info) for info in self._src.listinfo()] + self._description = description + return description + + @property + def colnames(self) -> Sequence[str] | None: + """Unofficial convenience method for getting the column names.""" + description = self.description + return None if description is None else [d[0] for d in description] + + @property + def coltypes(self) -> Sequence[TypeCode] | None: + """Unofficial convenience method for getting the column types.""" + description = self.description + return None if description is None else [d[1] for d in description] + + def close(self) -> None: + """Close the cursor object.""" + self._src.close() + + def execute(self, operation: str, parameters: Sequence | None = None + ) -> Cursor: + """Prepare and execute a database operation (query or command).""" + # The parameters may also be specified as list of tuples to e.g. + # insert multiple rows in a single operation, but this kind of + # usage is deprecated. We make several plausibility checks because + # tuples can also be passed with the meaning of ROW constructors. + if (parameters and isinstance(parameters, list) + and len(parameters) > 1 + and all(isinstance(p, tuple) for p in parameters) + and all(len(p) == len(parameters[0]) for p in parameters[1:])): + return self.executemany(operation, parameters) + # not a list of tuples + return self.executemany(operation, [parameters]) + + def executemany(self, operation: str, + seq_of_parameters: Sequence[Sequence | None]) -> Cursor: + """Prepare operation and execute it against a parameter sequence.""" + if not seq_of_parameters: + # don't do anything without parameters + return self + self._description = None + self.rowcount = -1 + # first try to execute all queries + rowcount = 0 + sql = "BEGIN" + try: + if not self._connection._tnx and not self._connection.autocommit: + try: + self._src.execute(sql) + except DatabaseError: + raise # database provides error message + except Exception as e: + raise op_error("Can't start transaction") from e + else: + self._connection._tnx = True + for parameters in seq_of_parameters: + sql = operation + sql = self._quoteparams(sql, parameters) + rows = self._src.execute(sql) + if rows: # true if not DML + rowcount += rows + else: + self.rowcount = -1 + except DatabaseError: + raise # database provides error message + except Error as err: + # noinspection PyTypeChecker + raise if_error(f"Error in '{sql}': '{err}'") from err + except Exception as err: + raise op_error(f"Internal error in '{sql}': {err}") from err + # then initialize result raw count and description + if self._src.resulttype == RESULT_DQL: + self._description = True # fetch on demand + self.rowcount = self._src.ntuples + self.lastrowid = None + build_row_factory = self.build_row_factory + if build_row_factory: # type: ignore + self.row_factory = build_row_factory() # type: ignore + else: + self.rowcount = rowcount + self.lastrowid = self._src.oidstatus() + # return the cursor object, so you can write statements such as + # "cursor.execute(...).fetchall()" or "for row in cursor.execute(...)" + return self + + def fetchone(self) -> Sequence | None: + """Fetch the next row of a query result set.""" + res = self.fetchmany(1, False) + try: + return res[0] + except IndexError: + return None + + def fetchall(self) -> Sequence[Sequence]: + """Fetch all (remaining) rows of a query result.""" + return self.fetchmany(-1, False) + + def fetchmany(self, size: int | None = None, keep: bool = False + ) -> Sequence[Sequence]: + """Fetch the next set of rows of a query result. + + The number of rows to fetch per call is specified by the + size parameter. If it is not given, the cursor's arraysize + determines the number of rows to be fetched. If you set + the keep parameter to true, this is kept as new arraysize. + """ + if size is None: + size = self.arraysize + if keep: + self.arraysize = size + try: + result = self._src.fetch(size) + except DatabaseError: + raise + except Error as err: + raise db_error(str(err)) from err + row_factory = self.row_factory + coltypes = self.coltypes + if coltypes is None: + # cannot determine column types, return raw result + return [row_factory(row) for row in result] + if len(result) > 5: + # optimize the case where we really fetch many values + # by looking up all type casting functions upfront + cast_row = self.type_cache.get_row_caster(coltypes) + return [row_factory(cast_row(row)) for row in result] + cast_value = self.type_cache.typecast + return [row_factory([cast_value(value, typ) + for typ, value in zip(coltypes, row)]) for row in result] + + def callproc(self, procname: str, parameters: Sequence | None = None + ) -> Sequence | None: + """Call a stored database procedure with the given name. + + The sequence of parameters must contain one entry for each input + argument that the procedure expects. The result of the call is the + same as this input sequence; replacement of output and input/output + parameters in the return value is currently not supported. + + The procedure may also provide a result set as output. These can be + requested through the standard fetch methods of the cursor. + """ + n = len(parameters) if parameters else 0 + s = ','.join(n * ['%s']) + query = f'select * from "{procname}"({s})' # noqa: S608 + self.execute(query, parameters) + return parameters + + # noinspection PyShadowingBuiltins + def copy_from(self, stream: Any, table: str, + format: str | None = None, sep: str | None = None, + null: str | None = None, size: int | None = None, + columns: Sequence[str] | None = None) -> Cursor: + """Copy data from an input stream to the specified table. + + The input stream can be a file-like object with a read() method or + it can also be an iterable returning a row or multiple rows of input + on each iteration. + + The format must be 'text', 'csv' or 'binary'. The sep option sets the + column separator (delimiter) used in the non binary formats. + The null option sets the textual representation of NULL in the input. + + The size option sets the size of the buffer used when reading data + from file-like objects. + + The copy operation can be restricted to a subset of columns. If no + columns are specified, all of them will be copied. + """ + binary_format = format == 'binary' + try: + read = stream.read + except AttributeError as e: + if size: + raise ValueError( + "Size must only be set for file-like objects") from e + input_type: type | tuple[type, ...] + type_name: str + if binary_format: + input_type = bytes + type_name = 'byte strings' + else: + input_type = (bytes, str) + type_name = 'strings' + + if isinstance(stream, (bytes, str)): + if not isinstance(stream, input_type): + raise ValueError(f"The input must be {type_name}") from e + if not binary_format: + if isinstance(stream, str): + if not stream.endswith('\n'): + stream += '\n' + else: + if not stream.endswith(b'\n'): + stream += b'\n' + + def chunks() -> Generator: + yield stream + + elif isinstance(stream, Iterable): + + def chunks() -> Generator: + for chunk in stream: + if not isinstance(chunk, input_type): + raise ValueError( + f"Input stream must consist of {type_name}") + if isinstance(chunk, str): + if not chunk.endswith('\n'): + chunk += '\n' + else: + if not chunk.endswith(b'\n'): + chunk += b'\n' + yield chunk + + else: + raise TypeError("Need an input stream to copy from") from e + else: + if size is None: + size = 8192 + elif not isinstance(size, int): + raise TypeError("The size option must be an integer") + if size > 0: + + def chunks() -> Generator: + while True: + buffer = read(size) + yield buffer + if not buffer or len(buffer) < size: + break + + else: + + def chunks() -> Generator: + yield read() + + if not table or not isinstance(table, str): + raise TypeError("Need a table to copy to") + if table.lower().startswith('select '): + raise ValueError("Must specify a table, not a query") + cnx = self._cnx + table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) + operation_parts = [f'copy {table}'] + options = [] + parameters = [] + if format is not None: + if not isinstance(format, str): + raise TypeError("The format option must be be a string") + if format not in ('text', 'csv', 'binary'): + raise ValueError("Invalid format") + options.append(f'format {format}') + if sep is not None: + if not isinstance(sep, str): + raise TypeError("The sep option must be a string") + if format == 'binary': + raise ValueError( + "The sep option is not allowed with binary format") + if len(sep) != 1: + raise ValueError( + "The sep option must be a single one-byte character") + options.append('delimiter %s') + parameters.append(sep) + if null is not None: + if not isinstance(null, str): + raise TypeError("The null option must be a string") + options.append('null %s') + parameters.append(null) + if columns: + if not isinstance(columns, str): + columns = ','.join(map(cnx.escape_identifier, columns)) + operation_parts.append(f'({columns})') + operation_parts.append("from stdin") + if options: + operation_parts.append(f"({','.join(options)})") + operation = ' '.join(operation_parts) + + putdata = self._src.putdata + self.execute(operation, parameters) + + try: + for chunk in chunks(): + putdata(chunk) + except BaseException as error: + self.rowcount = -1 + # the following call will re-raise the error + putdata(error) + else: + rowcount = putdata(None) + self.rowcount = -1 if rowcount is None else rowcount + + # return the cursor object, so you can chain operations + return self + + # noinspection PyShadowingBuiltins + def copy_to(self, stream: Any, table: str, + format: str | None = None, sep: str | None = None, + null: str | None = None, decode: bool | None = None, + columns: Sequence[str] | None = None) -> Cursor | Generator: + """Copy data from the specified table to an output stream. + + The output stream can be a file-like object with a write() method or + it can also be None, in which case the method will return a generator + yielding a row on each iteration. + + Output will be returned as byte strings unless you set decode to true. + + Note that you can also use a select query instead of the table name. + + The format must be 'text', 'csv' or 'binary'. The sep option sets the + column separator (delimiter) used in the non binary formats. + The null option sets the textual representation of NULL in the output. + + The copy operation can be restricted to a subset of columns. If no + columns are specified, all of them will be copied. + """ + binary_format = format == 'binary' + if stream is None: + write = None + else: + try: + write = stream.write + except AttributeError as e: + raise TypeError("Need an output stream to copy to") from e + if not table or not isinstance(table, str): + raise TypeError("Need a table to copy to") + cnx = self._cnx + if table.lower().startswith('select '): + if columns: + raise ValueError("Columns must be specified in the query") + table = f'({table})' + else: + table = '.'.join(map(cnx.escape_identifier, table.split('.', 1))) + operation_parts = [f'copy {table}'] + options = [] + parameters = [] + if format is not None: + if not isinstance(format, str): + raise TypeError("The format option must be a string") + if format not in ('text', 'csv', 'binary'): + raise ValueError("Invalid format") + options.append(f'format {format}') + if sep is not None: + if not isinstance(sep, str): + raise TypeError("The sep option must be a string") + if binary_format: + raise ValueError( + "The sep option is not allowed with binary format") + if len(sep) != 1: + raise ValueError( + "The sep option must be a single one-byte character") + options.append('delimiter %s') + parameters.append(sep) + if null is not None: + if not isinstance(null, str): + raise TypeError("The null option must be a string") + options.append('null %s') + parameters.append(null) + if decode is None: + decode = format != 'binary' + else: + if not isinstance(decode, (int, bool)): + raise TypeError("The decode option must be a boolean") + if decode and binary_format: + raise ValueError( + "The decode option is not allowed with binary format") + if columns: + if not isinstance(columns, str): + columns = ','.join(map(cnx.escape_identifier, columns)) + operation_parts.append(f'({columns})') + + operation_parts.append("to stdout") + if options: + operation_parts.append(f"({','.join(options)})") + operation = ' '.join(operation_parts) + + getdata = self._src.getdata + self.execute(operation, parameters) + + def copy() -> Generator: + self.rowcount = 0 + while True: + row = getdata(decode) + if isinstance(row, int): + if self.rowcount != row: + self.rowcount = row + break + self.rowcount += 1 + yield row + + if write is None: + # no input stream, return the generator + return copy() + + # write the rows to the file-like input stream + for row in copy(): + # noinspection PyUnboundLocalVariable + write(row) + + # return the cursor object, so you can chain operations + return self + + def __next__(self) -> Sequence: + """Return the next row (support for the iteration protocol).""" + res = self.fetchone() + if res is None: + raise StopIteration + return res + + # Note that the iterator protocol now uses __next()__ instead of next(), + # but we keep it for backward compatibility of pgdb. + next = __next__ + + @staticmethod + def nextset() -> bool | None: + """Not supported.""" + raise NotSupportedError("The nextset() method is not supported") + + @staticmethod + def setinputsizes(sizes: Sequence[int]) -> None: + """Not supported.""" + pass # unsupported, but silently passed + + @staticmethod + def setoutputsize(size: int, column: int = 0) -> None: + """Not supported.""" + pass # unsupported, but silently passed + + @staticmethod + def row_factory(row: Sequence) -> Sequence: + """Process rows before they are returned. + + You can overwrite this statically with a custom row factory, or + you can build a row factory dynamically with build_row_factory(). + + For example, you can create a Cursor class that returns rows as + Python dictionaries like this: + + class DictCursor(pgdb.Cursor): + + def row_factory(self, row): + return {desc[0]: value + for desc, value in zip(self.description, row)} + + cur = DictCursor(con) # get one DictCursor instance or + con.cursor_type = DictCursor # always use DictCursor instances + """ + raise NotImplementedError + + def build_row_factory(self) -> Callable[[Sequence], Sequence] | None: + """Build a row factory based on the current description. + + This implementation builds a row factory for creating named tuples. + You can overwrite this method if you want to dynamically create + different row factories whenever the column description changes. + """ + names = self.colnames + return RowCache.row_factory(tuple(names)) if names else None + + +CursorDescription = namedtuple('CursorDescription', ( + 'name', 'type_code', 'display_size', 'internal_size', + 'precision', 'scale', 'null_ok')) diff --git a/pgdb/py.typed b/pgdb/py.typed new file mode 100644 index 00000000..ead52d46 --- /dev/null +++ b/pgdb/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. The pgdb package uses inline types. diff --git a/pgdb/typecode.py b/pgdb/typecode.py new file mode 100644 index 00000000..fcfb4620 --- /dev/null +++ b/pgdb/typecode.py @@ -0,0 +1,34 @@ +"""Support for DB API 2 type codes.""" + +from __future__ import annotations + +__all__ = ['TypeCode'] + + +class TypeCode(str): + """Class representing the type_code used by the DB-API 2.0. + + TypeCode objects are strings equal to the PostgreSQL type name, + but carry some additional information. + """ + + oid: int + len: int + type: str + category: str + delim: str + relid: int + + # noinspection PyShadowingBuiltins + @classmethod + def create(cls, oid: int, name: str, len: int, type: str, category: str, + delim: str, relid: int) -> TypeCode: + """Create a type code for a PostgreSQL data type.""" + self = cls(name) + self.oid = oid + self.len = len + self.type = type + self.category = category + self.delim = delim + self.relid = relid + return self \ No newline at end of file diff --git a/pgnotice.c b/pgnotice.c deleted file mode 100644 index 7f0c0cc4..00000000 --- a/pgnotice.c +++ /dev/null @@ -1,124 +0,0 @@ -/* - * PyGreSQL - a Python interface for the PostgreSQL database. - * - * The notice object - this file is part a of the C extension module. - * - * Copyright (c) 2020 by the PyGreSQL Development Team - * - * Please see the LICENSE.TXT file for specific restrictions. - */ - -/* Get notice object attributes. */ -static PyObject * -notice_getattr(noticeObject *self, PyObject *nameobj) -{ - PGresult const *res = self->res; - const char *name = PyStr_AsString(nameobj); - int fieldcode; - - if (!res) { - PyErr_SetString(PyExc_TypeError, "Cannot get current notice"); - return NULL; - } - - /* pg connection object */ - if (!strcmp(name, "pgcnx")) { - if (self->pgcnx && _check_cnx_obj(self->pgcnx)) { - Py_INCREF(self->pgcnx); - return (PyObject *) self->pgcnx; - } - else { - Py_INCREF(Py_None); - return Py_None; - } - } - - /* full message */ - if (!strcmp(name, "message")) { - return PyStr_FromString(PQresultErrorMessage(res)); - } - - /* other possible fields */ - fieldcode = 0; - if (!strcmp(name, "severity")) - fieldcode = PG_DIAG_SEVERITY; - else if (!strcmp(name, "primary")) - fieldcode = PG_DIAG_MESSAGE_PRIMARY; - else if (!strcmp(name, "detail")) - fieldcode = PG_DIAG_MESSAGE_DETAIL; - else if (!strcmp(name, "hint")) - fieldcode = PG_DIAG_MESSAGE_HINT; - if (fieldcode) { - char *s = PQresultErrorField(res, fieldcode); - if (s) { - return PyStr_FromString(s); - } - else { - Py_INCREF(Py_None); return Py_None; - } - } - - return PyObject_GenericGetAttr((PyObject *) self, nameobj); -} - -/* Get the list of notice attributes. */ -static PyObject * -notice_dir(noticeObject *self, PyObject *noargs) -{ - PyObject *attrs; - - attrs = PyObject_Dir(PyObject_Type((PyObject *) self)); - PyObject_CallMethod( - attrs, "extend", "[ssssss]", - "pgcnx", "severity", "message", "primary", "detail", "hint"); - - return attrs; -} - -/* Return notice as string in human readable form. */ -static PyObject * -notice_str(noticeObject *self) -{ - return notice_getattr(self, PyBytes_FromString("message")); -} - -/* Notice object methods */ -static struct PyMethodDef notice_methods[] = { - {"__dir__", (PyCFunction) notice_dir, METH_NOARGS, NULL}, - {NULL, NULL} -}; - -static char notice__doc__[] = "PostgreSQL notice object"; - -/* Notice type definition */ -static PyTypeObject noticeType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.Notice", /* tp_name */ - sizeof(noticeObject), /* tp_basicsize */ - 0, /* tp_itemsize */ - /* methods */ - 0, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) notice_str, /* tp_str */ - (getattrofunc) notice_getattr, /* tp_getattro */ - PyObject_GenericSetAttr, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - notice__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - notice_methods, /* tp_methods */ -}; diff --git a/pgquery.c b/pgquery.c deleted file mode 100644 index d90db5dc..00000000 --- a/pgquery.c +++ /dev/null @@ -1,746 +0,0 @@ -/* - * PyGreSQL - a Python interface for the PostgreSQL database. - * - * The query object - this file is part a of the C extension module. - * - * Copyright (c) 2020 by the PyGreSQL Development Team - * - * Please see the LICENSE.TXT file for specific restrictions. - */ - -/* Deallocate the query object. */ -static void -query_dealloc(queryObject *self) -{ - Py_XDECREF(self->pgcnx); - if (self->col_types) { - PyMem_Free(self->col_types); - } - if (self->result) { - PQclear(self->result); - } - - PyObject_Del(self); -} - -/* Return query as string in human readable form. */ -static PyObject * -query_str(queryObject *self) -{ - return format_result(self->result); -} - -/* Return length of a query object. */ -static Py_ssize_t -query_len(PyObject *self) -{ - PyObject *tmp; - Py_ssize_t len; - - tmp = PyLong_FromLong(((queryObject*) self)->max_row); - len = PyLong_AsSsize_t(tmp); - Py_DECREF(tmp); - return len; -} - -/* Return the value in the given column of the current row. */ -static PyObject * -_query_value_in_column(queryObject *self, int column) -{ - char *s; - int type; - - if (PQgetisnull(self->result, self->current_row, column)) { - Py_INCREF(Py_None); - return Py_None; - } - - /* get the string representation of the value */ - /* note: this is always null-terminated text format */ - s = PQgetvalue(self->result, self->current_row, column); - /* get the PyGreSQL type of the column */ - type = self->col_types[column]; - - /* cast the string representation into a Python object */ - if (type & PYGRES_ARRAY) - return cast_array(s, - PQgetlength(self->result, self->current_row, column), - self->encoding, type, NULL, 0); - if (type == PYGRES_BYTEA) - return cast_bytea_text(s); - if (type == PYGRES_OTHER) - return cast_other(s, - PQgetlength(self->result, self->current_row, column), - self->encoding, - PQftype(self->result, column), self->pgcnx->cast_hook); - if (type & PYGRES_TEXT) - return cast_sized_text(s, - PQgetlength(self->result, self->current_row, column), - self->encoding, type); - return cast_unsized_simple(s, type); -} - -/* Return the current row as a tuple. */ -static PyObject * -_query_row_as_tuple(queryObject *self) -{ - PyObject *row_tuple = NULL; - int j; - - if (!(row_tuple = PyTuple_New(self->num_fields))) { - return NULL; - } - - for (j = 0; j < self->num_fields; ++j) { - PyObject *val = _query_value_in_column(self, j); - if (!val) { - Py_DECREF(row_tuple); return NULL; - } - PyTuple_SET_ITEM(row_tuple, j, val); - } - - return row_tuple; -} - -/* Return given item from a query object. */ -static PyObject * -query_getitem(PyObject *self, Py_ssize_t i) -{ - queryObject *q = (queryObject *) self; - PyObject *tmp; - long row; - - tmp = PyLong_FromSize_t((size_t) i); - row = PyLong_AsLong(tmp); - Py_DECREF(tmp); - - if (row < 0 || row >= q->max_row) { - PyErr_SetNone(PyExc_IndexError); - return NULL; - } - - q->current_row = (int) row; - return _query_row_as_tuple(q); -} - -/* __iter__() method of the queryObject: - Returns the default iterator yielding rows as tuples. */ -static PyObject* query_iter(queryObject *self) -{ - self->current_row = 0; - Py_INCREF(self); - return (PyObject*) self; -} - -/* __next__() method of the queryObject: - Returns the current current row as a tuple and moves to the next one. */ -static PyObject * -query_next(queryObject *self, PyObject *noargs) -{ - PyObject *row_tuple = NULL; - - if (self->current_row >= self->max_row) { - PyErr_SetNone(PyExc_StopIteration); - return NULL; - } - - row_tuple = _query_row_as_tuple(self); - if (row_tuple) ++self->current_row; - return row_tuple; -} - -/* Get number of rows. */ -static char query_ntuples__doc__[] = -"ntuples() -- return number of tuples returned by query"; - -static PyObject * -query_ntuples(queryObject *self, PyObject *noargs) -{ - return PyInt_FromLong(self->max_row); -} - -/* List field names from query result. */ -static char query_listfields__doc__[] = -"listfields() -- List field names from result"; - -static PyObject * -query_listfields(queryObject *self, PyObject *noargs) -{ - int i; - char *name; - PyObject *fieldstuple, *str; - - /* builds tuple */ - fieldstuple = PyTuple_New(self->num_fields); - if (fieldstuple) { - for (i = 0; i < self->num_fields; ++i) { - name = PQfname(self->result, i); - str = PyStr_FromString(name); - PyTuple_SET_ITEM(fieldstuple, i, str); - } - } - return fieldstuple; -} - -/* Get field name from number in last result. */ -static char query_fieldname__doc__[] = -"fieldname(num) -- return name of field from result from its position"; - -static PyObject * -query_fieldname(queryObject *self, PyObject *args) -{ - int i; - char *name; - - /* gets args */ - if (!PyArg_ParseTuple(args, "i", &i)) { - PyErr_SetString(PyExc_TypeError, - "Method fieldname() takes an integer as argument"); - return NULL; - } - - /* checks number validity */ - if (i >= self->num_fields) { - PyErr_SetString(PyExc_ValueError, "Invalid field number"); - return NULL; - } - - /* gets fields name and builds object */ - name = PQfname(self->result, i); - return PyStr_FromString(name); -} - -/* Get field number from name in last result. */ -static char query_fieldnum__doc__[] = -"fieldnum(name) -- return position in query for field from its name"; - -static PyObject * -query_fieldnum(queryObject *self, PyObject *args) -{ - int num; - char *name; - - /* gets args */ - if (!PyArg_ParseTuple(args, "s", &name)) { - PyErr_SetString(PyExc_TypeError, - "Method fieldnum() takes a string as argument"); - return NULL; - } - - /* gets field number */ - if ((num = PQfnumber(self->result, name)) == -1) { - PyErr_SetString(PyExc_ValueError, "Unknown field"); - return NULL; - } - - return PyInt_FromLong(num); -} - -/* Retrieve one row from the result as a tuple. */ -static char query_one__doc__[] = -"one() -- Get one row from the result of a query\n\n" -"Only one row from the result is returned as a tuple of fields.\n" -"This method can be called multiple times to return more rows.\n" -"It returns None if the result does not contain one more row.\n"; - -static PyObject * -query_one(queryObject *self, PyObject *noargs) -{ - PyObject *row_tuple; - - if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; - } - - row_tuple = _query_row_as_tuple(self); - if (row_tuple) ++self->current_row; - return row_tuple; -} - -/* Retrieve the single row from the result as a tuple. */ -static char query_single__doc__[] = -"single() -- Get the result of a query as single row\n\n" -"The single row from the query result is returned as a tuple of fields.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; - -static PyObject * -query_single(queryObject *self, PyObject *noargs) -{ - PyObject *row_tuple; - - if (self->max_row != 1) { - if (self->max_row) - set_error_msg(MultipleResultsError, "Multiple results found"); - else - set_error_msg(NoResultError, "No result found"); - return NULL; - } - - self->current_row = 0; - row_tuple = _query_row_as_tuple(self); - if (row_tuple) ++self->current_row; - return row_tuple; -} - -/* Retrieve the last query result as a list of tuples. */ -static char query_getresult__doc__[] = -"getresult() -- Get the result of a query\n\n" -"The result is returned as a list of rows, each one a tuple of fields\n" -"in the order returned by the server.\n"; - -static PyObject * -query_getresult(queryObject *self, PyObject *noargs) -{ - PyObject *result_list; - int i; - - if (!(result_list = PyList_New(self->max_row))) { - return NULL; - } - - for (i = self->current_row = 0; i < self->max_row; ++i) { - PyObject *row_tuple = query_next(self, noargs); - - if (!row_tuple) { - Py_DECREF(result_list); return NULL; - } - PyList_SET_ITEM(result_list, i, row_tuple); - } - - return result_list; -} - -/* Return the current row as a dict. */ -static PyObject * -_query_row_as_dict(queryObject *self) -{ - PyObject *row_dict = NULL; - int j; - - if (!(row_dict = PyDict_New())) { - return NULL; - } - - for (j = 0; j < self->num_fields; ++j) { - PyObject *val = _query_value_in_column(self, j); - - if (!val) { - Py_DECREF(row_dict); return NULL; - } - PyDict_SetItemString(row_dict, PQfname(self->result, j), val); - Py_DECREF(val); - } - - return row_dict; -} - -/* Return the current current row as a dict and move to the next one. */ -static PyObject * -query_next_dict(queryObject *self, PyObject *noargs) -{ - PyObject *row_dict = NULL; - - if (self->current_row >= self->max_row) { - PyErr_SetNone(PyExc_StopIteration); - return NULL; - } - - row_dict = _query_row_as_dict(self); - if (row_dict) ++self->current_row; - return row_dict; -} - -/* Retrieve one row from the result as a dictionary. */ -static char query_onedict__doc__[] = -"onedict() -- Get one row from the result of a query\n\n" -"Only one row from the result is returned as a dictionary with\n" -"the field names used as the keys.\n" -"This method can be called multiple times to return more rows.\n" -"It returns None if the result does not contain one more row.\n"; - -static PyObject * -query_onedict(queryObject *self, PyObject *noargs) -{ - PyObject *row_dict; - - if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; - } - - row_dict = _query_row_as_dict(self); - if (row_dict) ++self->current_row; - return row_dict; -} - -/* Retrieve the single row from the result as a dictionary. */ -static char query_singledict__doc__[] = -"singledict() -- Get the result of a query as single row\n\n" -"The single row from the query result is returned as a dictionary with\n" -"the field names used as the keys.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; - -static PyObject * -query_singledict(queryObject *self, PyObject *noargs) -{ - PyObject *row_dict; - - if (self->max_row != 1) { - if (self->max_row) - set_error_msg(MultipleResultsError, "Multiple results found"); - else - set_error_msg(NoResultError, "No result found"); - return NULL; - } - - self->current_row = 0; - row_dict = _query_row_as_dict(self); - if (row_dict) ++self->current_row; - return row_dict; -} - -/* Retrieve the last query result as a list of dictionaries. */ -static char query_dictresult__doc__[] = -"dictresult() -- Get the result of a query\n\n" -"The result is returned as a list of rows, each one a dictionary with\n" -"the field names used as the keys.\n"; - -static PyObject * -query_dictresult(queryObject *self, PyObject *noargs) -{ - PyObject *result_list; - int i; - - if (!(result_list = PyList_New(self->max_row))) { - return NULL; - } - - for (i = self->current_row = 0; i < self->max_row; ++i) { - PyObject *row_dict = query_next_dict(self, noargs); - - if (!row_dict) { - Py_DECREF(result_list); return NULL; - } - PyList_SET_ITEM(result_list, i, row_dict); - } - - return result_list; -} - -/* Retrieve last result as iterator of dictionaries. */ -static char query_dictiter__doc__[] = -"dictiter() -- Get the result of a query\n\n" -"The result is returned as an iterator of rows, each one a a dictionary\n" -"with the field names used as the keys.\n"; - -static PyObject * -query_dictiter(queryObject *self, PyObject *noargs) -{ - if (!dictiter) { - return query_dictresult(self, noargs); - } - - return PyObject_CallFunction(dictiter, "(O)", self); -} - -/* Retrieve one row from the result as a named tuple. */ -static char query_onenamed__doc__[] = -"onenamed() -- Get one row from the result of a query\n\n" -"Only one row from the result is returned as a named tuple of fields.\n" -"This method can be called multiple times to return more rows.\n" -"It returns None if the result does not contain one more row.\n"; - -static PyObject * -query_onenamed(queryObject *self, PyObject *noargs) -{ - if (!namednext) { - return query_one(self, noargs); - } - - if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; - } - - return PyObject_CallFunction(namednext, "(O)", self); -} - -/* Retrieve the single row from the result as a tuple. */ -static char query_singlenamed__doc__[] = -"singlenamed() -- Get the result of a query as single row\n\n" -"The single row from the query result is returned as named tuple of fields.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; - -static PyObject * -query_singlenamed(queryObject *self, PyObject *noargs) -{ - if (!namednext) { - return query_single(self, noargs); - } - - if (self->max_row != 1) { - if (self->max_row) - set_error_msg(MultipleResultsError, "Multiple results found"); - else - set_error_msg(NoResultError, "No result found"); - return NULL; - } - - self->current_row = 0; - return PyObject_CallFunction(namednext, "(O)", self); -} - -/* Retrieve last result as list of named tuples. */ -static char query_namedresult__doc__[] = -"namedresult() -- Get the result of a query\n\n" -"The result is returned as a list of rows, each one a named tuple of fields\n" -"in the order returned by the server.\n"; - -static PyObject * -query_namedresult(queryObject *self, PyObject *noargs) -{ - PyObject *res, *res_list; - - if (!namediter) { - return query_getresult(self, noargs); - } - - res = PyObject_CallFunction(namediter, "(O)", self); - if (!res) return NULL; - if (PyList_Check(res)) return res; - res_list = PySequence_List(res); - Py_DECREF(res); - return res_list; -} - -/* Retrieve last result as iterator of named tuples. */ -static char query_namediter__doc__[] = -"namediter() -- Get the result of a query\n\n" -"The result is returned as an iterator of rows, each one a named tuple\n" -"of fields in the order returned by the server.\n"; - -static PyObject * -query_namediter(queryObject *self, PyObject *noargs) -{ - PyObject *res, *res_iter; - - if (!namediter) { - return query_iter(self); - } - - res = PyObject_CallFunction(namediter, "(O)", self); - if (!res) return NULL; - if (!PyList_Check(res)) return res; - res_iter = (Py_TYPE(res)->tp_iter)((PyObject *) self); - Py_DECREF(res); - return res_iter; -} - -/* Retrieve the last query result as a list of scalar values. */ -static char query_scalarresult__doc__[] = -"scalarresult() -- Get query result as scalars\n\n" -"The result is returned as a list of scalar values where the values\n" -"are the first fields of the rows in the order returned by the server.\n"; - -static PyObject * -query_scalarresult(queryObject *self, PyObject *noargs) -{ - PyObject *result_list; - - if (!self->num_fields) { - set_error_msg(ProgrammingError, "No fields in result"); - return NULL; - } - - if (!(result_list = PyList_New(self->max_row))) { - return NULL; - } - - for (self->current_row = 0; - self->current_row < self->max_row; - ++self->current_row) - { - PyObject *value = _query_value_in_column(self, 0); - - if (!value) { - Py_DECREF(result_list); return NULL; - } - PyList_SET_ITEM(result_list, self->current_row, value); - } - - return result_list; -} - -/* Retrieve the last query result as iterator of scalar values. */ -static char query_scalariter__doc__[] = -"scalariter() -- Get query result as scalars\n\n" -"The result is returned as an iterator of scalar values where the values\n" -"are the first fields of the rows in the order returned by the server.\n"; - -static PyObject * -query_scalariter(queryObject *self, PyObject *noargs) -{ - if (!scalariter) { - return query_scalarresult(self, noargs); - } - - if (!self->num_fields) { - set_error_msg(ProgrammingError, "No fields in result"); - return NULL; - } - - return PyObject_CallFunction(scalariter, "(O)", self); -} - -/* Retrieve one result as scalar value. */ -static char query_onescalar__doc__[] = -"onescalar() -- Get one scalar value from the result of a query\n\n" -"Returns the first field of the next row from the result as a scalar value.\n" -"This method can be called multiple times to return more rows as scalars.\n" -"It returns None if the result does not contain one more row.\n"; - -static PyObject * -query_onescalar(queryObject *self, PyObject *noargs) -{ - PyObject *value; - - if (!self->num_fields) { - set_error_msg(ProgrammingError, "No fields in result"); - return NULL; - } - - if (self->current_row >= self->max_row) { - Py_INCREF(Py_None); return Py_None; - } - - value = _query_value_in_column(self, 0); - if (value) ++self->current_row; - return value; -} - -/* Retrieves the single row from the result as a tuple. */ -static char query_singlescalar__doc__[] = -"singlescalar() -- Get scalar value from single result of a query\n\n" -"Returns the first field of the next row from the result as a scalar value.\n" -"This method returns the same single row when called multiple times.\n" -"It raises an InvalidResultError if the result doesn't have exactly one row,\n" -"which will be of type NoResultError or MultipleResultsError specifically.\n"; - -static PyObject * -query_singlescalar(queryObject *self, PyObject *noargs) -{ - PyObject *value; - - if (!self->num_fields) { - set_error_msg(ProgrammingError, "No fields in result"); - return NULL; - } - - if (self->max_row != 1) { - if (self->max_row) - set_error_msg(MultipleResultsError, "Multiple results found"); - else - set_error_msg(NoResultError, "No result found"); - return NULL; - } - - self->current_row = 0; - value = _query_value_in_column(self, 0); - if (value) ++self->current_row; - return value; -} - -/* Query sequence protocol methods */ -static PySequenceMethods query_sequence_methods = { - (lenfunc) query_len, /* sq_length */ - 0, /* sq_concat */ - 0, /* sq_repeat */ - (ssizeargfunc) query_getitem, /* sq_item */ - 0, /* sq_ass_item */ - 0, /* sq_contains */ - 0, /* sq_inplace_concat */ - 0, /* sq_inplace_repeat */ -}; - -/* Query object methods */ -static struct PyMethodDef query_methods[] = { - {"getresult", (PyCFunction) query_getresult, - METH_NOARGS, query_getresult__doc__}, - {"dictresult", (PyCFunction) query_dictresult, - METH_NOARGS, query_dictresult__doc__}, - {"dictiter", (PyCFunction) query_dictiter, - METH_NOARGS, query_dictiter__doc__}, - {"namedresult", (PyCFunction) query_namedresult, - METH_NOARGS, query_namedresult__doc__}, - {"namediter", (PyCFunction) query_namediter, - METH_NOARGS, query_namediter__doc__}, - {"one", (PyCFunction) query_one, - METH_NOARGS, query_one__doc__}, - {"single", (PyCFunction) query_single, - METH_NOARGS, query_single__doc__}, - {"onedict", (PyCFunction) query_onedict, - METH_NOARGS, query_onedict__doc__}, - {"singledict", (PyCFunction) query_singledict, - METH_NOARGS, query_singledict__doc__}, - {"onenamed", (PyCFunction) query_onenamed, - METH_NOARGS, query_onenamed__doc__}, - {"singlenamed", (PyCFunction) query_singlenamed, - METH_NOARGS, query_singlenamed__doc__}, - {"scalarresult", (PyCFunction) query_scalarresult, - METH_NOARGS, query_scalarresult__doc__}, - {"scalariter", (PyCFunction) query_scalariter, - METH_NOARGS, query_scalariter__doc__}, - {"onescalar", (PyCFunction) query_onescalar, - METH_NOARGS, query_onescalar__doc__}, - {"singlescalar", (PyCFunction) query_singlescalar, - METH_NOARGS, query_singlescalar__doc__}, - {"fieldname", (PyCFunction) query_fieldname, - METH_VARARGS, query_fieldname__doc__}, - {"fieldnum", (PyCFunction) query_fieldnum, - METH_VARARGS, query_fieldnum__doc__}, - {"listfields", (PyCFunction) query_listfields, - METH_NOARGS, query_listfields__doc__}, - {"ntuples", (PyCFunction) query_ntuples, - METH_NOARGS, query_ntuples__doc__}, - {NULL, NULL} -}; - -static char query__doc__[] = "PyGreSQL query object"; - -/* Query type definition */ -static PyTypeObject queryType = { - PyVarObject_HEAD_INIT(NULL, 0) - "pg.Query", /* tp_name */ - sizeof(queryObject), /* tp_basicsize */ - 0, /* tp_itemsize */ - /* methods */ - (destructor) query_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - &query_sequence_methods, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - (reprfunc) query_str, /* tp_str */ - PyObject_GenericGetAttr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT - |Py_TPFLAGS_HAVE_ITER, /* tp_flags */ - query__doc__, /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - (getiterfunc) query_iter, /* tp_iter */ - (iternextfunc) query_next, /* tp_iternext */ - query_methods, /* tp_methods */ -}; diff --git a/py3c.h b/py3c.h deleted file mode 100644 index 63a3222a..00000000 --- a/py3c.h +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright (c) 2015, Red Hat, Inc. and/or its affiliates - * Licensed under the MIT license; see py3c.h - */ - -#ifndef _PY3C_COMPAT_H_ -#define _PY3C_COMPAT_H_ -#define PY_SSIZE_T_CLEAN -#include - -#if PY_MAJOR_VERSION >= 3 - -/***** Python 3 *****/ - -#define IS_PY3 1 - -/* Strings */ - -#define PyStr_Type PyUnicode_Type -#define PyStr_Check PyUnicode_Check -#define PyStr_CheckExact PyUnicode_CheckExact -#define PyStr_FromString PyUnicode_FromString -#define PyStr_FromStringAndSize PyUnicode_FromStringAndSize -#define PyStr_FromFormat PyUnicode_FromFormat -#define PyStr_FromFormatV PyUnicode_FromFormatV -#define PyStr_AsString PyUnicode_AsUTF8 -#define PyStr_Concat PyUnicode_Concat -#define PyStr_Format PyUnicode_Format -#define PyStr_InternInPlace PyUnicode_InternInPlace -#define PyStr_InternFromString PyUnicode_InternFromString -#define PyStr_Decode PyUnicode_Decode - -#define PyStr_AsUTF8String PyUnicode_AsUTF8String // returns PyBytes -#define PyStr_AsUTF8 PyUnicode_AsUTF8 -#define PyStr_AsUTF8AndSize PyUnicode_AsUTF8AndSize - -/* Ints */ - -#define PyInt_Type PyLong_Type -#define PyInt_Check PyLong_Check -#define PyInt_CheckExact PyLong_CheckExact -#define PyInt_FromString PyLong_FromString -#define PyInt_FromLong PyLong_FromLong -#define PyInt_FromSsize_t PyLong_FromSsize_t -#define PyInt_FromSize_t PyLong_FromSize_t -#define PyInt_AsLong PyLong_AsLong -#define PyInt_AS_LONG PyLong_AS_LONG -#define PyInt_AsUnsignedLongLongMask PyLong_AsUnsignedLongLongMask -#define PyInt_AsSsize_t PyLong_AsSsize_t - -/* Module init */ - -#define MODULE_INIT_FUNC(name) \ - PyMODINIT_FUNC PyInit_ ## name(void); \ - PyMODINIT_FUNC PyInit_ ## name(void) - -/* Other */ - -#define Py_TPFLAGS_HAVE_ITER 0 // not needed in Python 3 - -#else - -/***** Python 2 *****/ - -#define IS_PY3 0 - -/* Strings */ - -#define PyStr_Type PyString_Type -#define PyStr_Check PyString_Check -#define PyStr_CheckExact PyString_CheckExact -#define PyStr_FromString PyString_FromString -#define PyStr_FromStringAndSize PyString_FromStringAndSize -#define PyStr_FromFormat PyString_FromFormat -#define PyStr_FromFormatV PyString_FromFormatV -#define PyStr_AsString PyString_AsString -#define PyStr_Format PyString_Format -#define PyStr_InternInPlace PyString_InternInPlace -#define PyStr_InternFromString PyString_InternFromString -#define PyStr_Decode PyString_Decode - -static inline PyObject *PyStr_Concat(PyObject *left, PyObject *right) { - PyObject *str = left; - Py_INCREF(left); // reference to old left will be stolen - PyString_Concat(&str, right); - if (str) { - return str; - } else { - return NULL; - } -} - -#define PyStr_AsUTF8String(str) (Py_INCREF(str), (str)) -#define PyStr_AsUTF8 PyString_AsString -#define PyStr_AsUTF8AndSize(pystr, sizeptr) \ - ((*sizeptr=PyString_Size(pystr)), PyString_AsString(pystr)) - -#define PyBytes_Type PyString_Type -#define PyBytes_Check PyString_Check -#define PyBytes_CheckExact PyString_CheckExact -#define PyBytes_FromString PyString_FromString -#define PyBytes_FromStringAndSize PyString_FromStringAndSize -#define PyBytes_FromFormat PyString_FromFormat -#define PyBytes_FromFormatV PyString_FromFormatV -#define PyBytes_Size PyString_Size -#define PyBytes_GET_SIZE PyString_GET_SIZE -#define PyBytes_AsString PyString_AsString -#define PyBytes_AS_STRING PyString_AS_STRING -#define PyBytes_AsStringAndSize PyString_AsStringAndSize -#define PyBytes_Concat PyString_Concat -#define PyBytes_ConcatAndDel PyString_ConcatAndDel -#define _PyBytes_Resize _PyString_Resize - -/* Floats */ - -#define PyFloat_FromString(str) PyFloat_FromString(str, NULL) - -/* Module init */ - -#define PyModuleDef_HEAD_INIT 0 - -typedef struct PyModuleDef { - int m_base; - const char* m_name; - const char* m_doc; - Py_ssize_t m_size; - PyMethodDef *m_methods; -} PyModuleDef; - -#define PyModule_Create(def) \ - Py_InitModule3((def)->m_name, (def)->m_methods, (def)->m_doc) - -#define MODULE_INIT_FUNC(name) \ - static PyObject *PyInit_ ## name(void); \ - void init ## name(void); \ - void init ## name(void) { PyInit_ ## name(); } \ - static PyObject *PyInit_ ## name(void) - - -#endif - -#endif diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..01b5086f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,106 @@ +[project] +name = "PyGreSQL" +version = "6.1.0" +requires-python = ">=3.7" +authors = [ + {name = "D'Arcy J. M. Cain", email = "darcy@pygresql.org"}, + {name = "Christoph Zwerschke", email = "cito@online.de"}, +] +description = "Python PostgreSQL interfaces" +readme = "README.rst" +keywords = ["pygresql", "postgresql", "database", "api", "dbapi"] +classifiers = [ + "Development Status :: 6 - Mature", + "Intended Audience :: Developers", + "License :: OSI Approved :: PostgreSQL License", + "Operating System :: OS Independent", + "Programming Language :: C", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: SQL", + "Topic :: Database", + "Topic :: Database :: Front-Ends", + "Topic :: Software Development :: Libraries :: Python Modules", +] + +[project.license] +file = "LICENSE.txt" + +[project.urls] +Homepage = "https://pygresql.github.io/" +Documentation = "https://pygresql.github.io/contents/" +"Source Code" = "https://github.com/PyGreSQL/PyGreSQL" +"Issue Tracker" = "https://github.com/PyGreSQL/PyGreSQL/issues/" +Changelog = "https://pygresql.github.io/contents/changelog.html" +Download = "https://pygresql.github.io/download/" +"Mailing List" = "https://mail.vex.net/mailman/listinfo/pygresql" + +[tool.ruff] +target-version = "py37" +line-length = 79 +exclude = [ + "__pycache__", + "__pypackages__", + ".git", + ".tox", + ".venv", + ".devcontainer", + ".vscode", + "docs", + "build", + "dist", + "local", + "venv", +] + +[tool.ruff.lint] +select = [ + "E", # pycodestyle + "F", # pyflakes + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade + "D", # pydocstyle + "B", # bugbear + "S", # bandit + "SIM", # simplify + "RUF", # ruff +] +ignore = ["D203", "D213"] + +[tool.ruff.lint.per-file-ignores] +"tests/*.py" = ["D100", "D101", "D102", "D103", "D105", "D107", "S"] + +[tool.mypy] +python_version = "3.13" +check_untyped_defs = true +no_implicit_optional = true +strict_optional = true +warn_redundant_casts = true +warn_unused_ignores = true +disallow_untyped_defs = true + +[[tool.mypy.overrides]] +module = [ + "tests.*" +] +disallow_untyped_defs = false + +[tool.setuptools] +packages = ["pg", "pgdb"] +license-files = ["LICENSE.txt"] + +[tool.setuptools.package-data] +pg = ["pg.typed"] +pgdb = ["pg.typed"] + +[build-system] +requires = ["setuptools>=68", "wheel>=0.42"] +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py index 3e8d9c9e..bf652276 100755 --- a/setup.py +++ b/setup.py @@ -1,41 +1,11 @@ #!/usr/bin/python -# -# PyGreSQL - a Python interface for the PostgreSQL database. -# -# Copyright (c) 2020 by the PyGreSQL Development Team -# -# Please see the LICENSE.TXT file for specific restrictions. - -"""Setup script for PyGreSQL version 5.1.2 - -PyGreSQL is an open-source Python module that interfaces to a -PostgreSQL database. It embeds the PostgreSQL query library to allow -easy use of the powerful PostgreSQL features from a Python script. - -Authors and history: -* PyGreSQL written 1997 by D'Arcy J.M. Cain -* based on code written 1995 by Pascal Andre -* setup script created 2000 by Mark Alexander -* improved 2000 by Jeremy Hylton -* improved 2001 by Gerhard Haering -* improved 2006 to 2018 by Christoph Zwerschke - -Prerequisites to be installed: -* Python including devel package (header files and distutils) -* PostgreSQL libs and devel packages (header file of the libpq client) -* PostgreSQL pg_config tool (usually included in the devel package) - (the Windows installer has it as part of the database server feature) - -PyGreSQL currently supports Python versions 2.6, 2.7 and 3.3 to 3.8, -and PostgreSQL versions 9.0 to 9.6 and 10 to 12. - -Use as follows: -python setup.py build_ext # to build the module -python setup.py install # to install it - -See docs.python.org/doc/install/ for more information on -using distutils to install Python programs. +"""Driver script for building PyGreSQL using setuptools. + +You can build the PyGreSQL distribution like this: + + pip install build + python -m build -C strict -C memory-size """ import os @@ -43,39 +13,51 @@ import re import sys import warnings -try: - from setuptools import setup -except ImportError: - from distutils.core import setup -from distutils.extension import Extension -from distutils.command.build_ext import build_ext from distutils.ccompiler import get_default_compiler from distutils.sysconfig import get_python_inc, get_python_lib -version = '5.1.2' +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext + + +def project_version(): + """Read the PyGreSQL version from the pyproject.toml file.""" + with open('pyproject.toml') as f: + for d in f: + if d.startswith("version ="): + version = d.split("=")[1].strip().strip('"') + return version + raise Exception("Cannot determine PyGreSQL version") + + +def project_readme(): + """Get the content of the README file.""" + with open('README.rst') as f: + return f.read() + + +version = project_version() -if (not (2, 6) <= sys.version_info[:2] < (3, 0) - and not (3, 3) <= sys.version_info[:2] < (4, 0)): +if not (3, 7) <= sys.version_info[:2] < (4, 0): raise Exception( - "Sorry, PyGreSQL %s does not support this Python version" % version) + f"Sorry, PyGreSQL {version} does not support this Python version") + +long_description = project_readme() + # For historical reasons, PyGreSQL does not install itself as a single # "pygresql" package, but as two top-level modules "pg", providing the # classic interface, and "pgdb" for the modern DB-API 2.0 interface. # These two top-level Python modules share the same C extension "_pg". -py_modules = ['pg', 'pgdb'] -c_sources = ['pgmodule.c'] - - def pg_config(s): """Retrieve information about installed version of PostgreSQL.""" - f = os.popen('pg_config --%s' % s) + f = os.popen(f'pg_config --{s}') # noqa: S605 d = f.readline().strip() if f.close() is not None: raise Exception("pg_config tool is not available.") if not d: - raise Exception("Could not get %s information." % s) + raise Exception(f"Could not get {s} information.") return d @@ -84,7 +66,7 @@ def pg_version(): match = re.search(r'(\d+)\.(\d+)', pg_config('version')) if match: return tuple(map(int, match.groups())) - return 9, 0 + return 10, 0 pg_version = pg_version() @@ -98,151 +80,110 @@ def pg_version(): extra_compile_args = ['-O2', '-funsigned-char', '-Wall', '-Wconversion'] -class build_pg_ext(build_ext): +class build_pg_ext(build_ext): # noqa: N801 """Customized build_ext command for PyGreSQL.""" description = "build the PyGreSQL C extension" - user_options = build_ext.user_options + [ + user_options = [*build_ext.user_options, # noqa: RUF012 ('strict', None, "count all compiler warnings as errors"), - ('direct-access', None, "enable direct access functions"), - ('no-direct-access', None, "disable direct access functions"), - ('direct-access', None, "enable direct access functions"), - ('no-direct-access', None, "disable direct access functions"), - ('large-objects', None, "enable large object support"), - ('no-large-objects', None, "disable large object support"), - ('default-vars', None, "enable default variables use"), - ('no-default-vars', None, "disable default variables use"), - ('escaping-funcs', None, "enable string escaping functions"), - ('no-escaping-funcs', None, "disable string escaping functions"), - ('ssl-info', None, "use new ssl info functions"), - ('no-ssl-info', None, "do not use new ssl info functions")] - - boolean_options = build_ext.boolean_options + [ - 'strict', 'direct-access', 'large-objects', 'default-vars', - 'escaping-funcs', 'ssl-info'] - - negative_opt = { - 'no-direct-access': 'direct-access', - 'no-large-objects': 'large-objects', - 'no-default-vars': 'default-vars', - 'no-escaping-funcs': 'escaping-funcs', - 'no-ssl-info': 'ssl-info'} + ('memory-size', None, "enable memory size function"), + ('no-memory-size', None, "disable memory size function")] + + boolean_options = [*build_ext.boolean_options, # noqa: RUF012 + 'strict', 'memory-size'] + + negative_opt = { # noqa: RUF012 + 'no-memory-size': 'memory-size'} def get_compiler(self): """Return the C compiler used for building the extension.""" return self.compiler or get_default_compiler() def initialize_options(self): + """Initialize the supported options with default values.""" build_ext.initialize_options(self) self.strict = False - self.direct_access = None - self.large_objects = None - self.default_vars = None - self.escaping_funcs = None - self.ssl_info = None - if pg_version < (9, 0): + self.memory_size = None + supported = pg_version >= (10, 0) + if not supported: warnings.warn( - "PyGreSQL does not support the installed PostgreSQL version.") + "PyGreSQL does not support the installed PostgreSQL version.", + stacklevel=2) def finalize_options(self): """Set final values for all build_pg options.""" build_ext.finalize_options(self) if self.strict: extra_compile_args.append('-Werror') - if self.direct_access is None or self.direct_access: - define_macros.append(('DIRECT_ACCESS', None)) - if self.large_objects is None or self.large_objects: - define_macros.append(('LARGE_OBJECTS', None)) - if self.default_vars is None or self.default_vars: - define_macros.append(('DEFAULT_VARS', None)) - if self.escaping_funcs is None or self.escaping_funcs: - if pg_version >= (9, 0): - define_macros.append(('ESCAPING_FUNCS', None)) - else: - (warnings.warn if self.escaping_funcs is None else sys.exit)( - "The installed PostgreSQL version" - " does not support the newer string escaping functions.") - if self.ssl_info is None or self.ssl_info: - if pg_version >= (9, 5): - define_macros.append(('SSL_INFO', None)) - else: - (warnings.warn if self.ssl_info is None else sys.exit)( + wanted = self.memory_size + supported = pg_version >= (12, 0) + if (wanted is None and supported) or wanted: + define_macros.append(('MEMORY_SIZE', None)) + if not supported: + warnings.warn( "The installed PostgreSQL version" - " does not support ssl info functions.") + " does not support the memory size function.", + stacklevel=2) if sys.platform == 'win32': - bits = platform.architecture()[0] - if bits == '64bit': # we need to find libpq64 - for path in os.environ['PATH'].split(os.pathsep) + [ - r'C:\Program Files\PostgreSQL\libpq64']: - library_dir = os.path.join(path, 'lib') - if not os.path.isdir(library_dir): - continue - lib = os.path.join(library_dir, 'libpqdll.') - if not (os.path.exists(lib + 'lib') - or os.path.exists(lib + 'a')): - continue - include_dir = os.path.join(path, 'include') - if not os.path.isdir(include_dir): - continue - if library_dir not in library_dirs: - library_dirs.insert(1, library_dir) - if include_dir not in include_dirs: - include_dirs.insert(1, include_dir) - libraries[0] += 'dll' # libpqdll instead of libpq - break + libraries[0] = 'lib' + libraries[0] + if os.path.exists(os.path.join( + library_dirs[1], libraries[0] + 'dll.lib')): + libraries[0] += 'dll' compiler = self.get_compiler() if compiler == 'mingw32': # MinGW - if bits == '64bit': # needs MinGW-w64 + if platform.architecture()[0] == '64bit': # needs MinGW-w64 define_macros.append(('MS_WIN64', None)) elif compiler == 'msvc': # Microsoft Visual C++ - libraries[0] = 'lib' + libraries[0] extra_compile_args[1:] = [ - '-J', '-W3', '-WX', + '-J', '-W3', '-WX', '-wd4391', '-Dinline=__inline'] # needed for MSVC 9 setup( - name="PyGreSQL", + name='PyGreSQL', version=version, - description="Python PostgreSQL Interfaces", - long_description=__doc__.split('\n\n', 2)[1], # first passage - long_description_content_type='text/plain', - keywords="pygresql postgresql database api dbapi", + description='Python PostgreSQL Interfaces', + long_description=long_description, + long_description_content_type='text/x-rst', + keywords='pygresql postgresql database api dbapi', author="D'Arcy J. M. Cain", author_email="darcy@PyGreSQL.org", - url="http://www.pygresql.org", - download_url="http://www.pygresql.org/download/", - platforms=["any"], - license="PostgreSQL", - py_modules=py_modules, - ext_modules=[Extension( - '_pg', c_sources, - include_dirs=include_dirs, library_dirs=library_dirs, - define_macros=define_macros, undef_macros=undef_macros, - libraries=libraries, extra_compile_args=extra_compile_args)], - zip_safe=False, - cmdclass=dict(build_ext=build_pg_ext), - test_suite='tests.discover', + url='https://pygresql.github.io/', + download_url='https://pygresql.github.io/download/', + project_urls={ + 'Documentation': 'https://pygresql.github.io/contents/', + 'Issue Tracker': 'https://github.com/PyGreSQL/PyGreSQL/issues/', + 'Mailing List': 'https://mail.vex.net/mailman/listinfo/pygresql', + 'Source Code': 'https://github.com/PyGreSQL/PyGreSQL'}, classifiers=[ - "Development Status :: 6 - Mature", - "Intended Audience :: Developers", - "License :: OSI Approved :: PostgreSQL License", - "Operating System :: OS Independent", - "Programming Language :: C", + 'Development Status :: 6 - Mature', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: PostgreSQL License', + 'Operating System :: OS Independent', + 'Programming Language :: C', 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.6', - 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', - "Programming Language :: SQL", - "Topic :: Database", - "Topic :: Database :: Front-Ends", - "Topic :: Software Development :: Libraries :: Python Modules"] + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', + 'Programming Language :: SQL', + 'Topic :: Database', + 'Topic :: Database :: Front-Ends', + 'Topic :: Software Development :: Libraries :: Python Modules'], + license='PostgreSQL', + test_suite='tests.discover', + zip_safe=False, + packages=["pg", "pgdb"], + package_data={"pg": ["py.typed"], "pgdb": ["py.typed"]}, + ext_modules=[Extension( + 'pg._pg', ["ext/pgmodule.c"], + include_dirs=include_dirs, library_dirs=library_dirs, + define_macros=define_macros, undef_macros=undef_macros, + libraries=libraries, extra_compile_args=extra_compile_args)], + cmdclass=dict(build_ext=build_pg_ext), ) diff --git a/tests/__init__.py b/tests/__init__.py index 38f807e0..f3070dd1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,10 +3,7 @@ You can specify your local database settings in LOCAL_PyGreSQL.py. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +import unittest if not (hasattr(unittest, 'skip') and hasattr(unittest.TestCase, 'setUpClass') @@ -18,4 +15,4 @@ def discover(): loader = unittest.TestLoader() suite = loader.discover('.') - return suite \ No newline at end of file + return suite diff --git a/tests/config.py b/tests/config.py new file mode 100644 index 00000000..4e27c3ae --- /dev/null +++ b/tests/config.py @@ -0,0 +1,31 @@ +#!/usr/bin/python + +from os import environ + +# We need a database to test against. + +# The connection parameters are taken from the usual PG* environment +# variables and can be overridden with PYGRESQL_* environment variables +# or values specified in the file .LOCAL_PyGreSQL or LOCAL_PyGreSQL.py. + +# The tests should be run with various PostgreSQL versions and databases +# created with different encodings and locales. Particularly, make sure the +# tests are running against databases created with both SQL_ASCII and UTF8. + +# The current user must have create schema privilege on the database. + +get = environ.get + +dbname = get('PYGRESQL_DB', get('PGDATABASE', 'test')) +dbhost = get('PYGRESQL_HOST', get('PGHOST', 'localhost')) +dbport = int(get('PYGRESQL_PORT', get('PGPORT', 5432))) +dbuser = get('PYGRESQL_USER', get('PGUSER')) +dbpasswd = get('PYGRESQL_PASSWD', get('PGPASSWORD')) + +try: + from .LOCAL_PyGreSQL import * # type: ignore # noqa +except (ImportError, ValueError): + try: # noqa + from LOCAL_PyGreSQL import * # type: ignore # noqa + except ImportError: + pass diff --git a/tests/dbapi20.py b/tests/dbapi20.py index 0656cddf..bf3c5718 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -1,95 +1,109 @@ #!/usr/bin/python -'''Python DB API 2.0 driver compliance unit test suite. + +"""Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. -''' -__version__ = '1.5' +Some modernization of the code has been done by the PyGreSQL team. +""" -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +from __future__ import annotations import time +import unittest +from contextlib import suppress +from typing import Any, ClassVar +__version__ = '1.15.0' class DatabaseAPI20Test(unittest.TestCase): - ''' Test a database self.driver for DB API 2.0 compatibility. - This implementation tests Gadfly, but the TestCase - is structured so that other self.drivers can subclass this - test case to ensure compiliance with the DB-API. It is - expected that this TestCase may be expanded in the future - if ambiguities or edge conditions are discovered. + """Test a database self.driver for DB API 2.0 compatibility. + + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compliance with the DB-API. It is + expected that this TestCase may be expanded i qn the future + if ambiguities or edge conditions are discovered. - The 'Optional Extensions' are not yet being tested. + The 'Optional Extensions' are not yet being tested. - self.drivers should subclass this test, overriding setUp, tearDown, - self.driver, connect_args and connect_kw_args. Class specification - should be as follows: + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: - import dbapi20 - class mytest(dbapi20.DatabaseAPI20Test): - [...] + import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] - Don't 'import DatabaseAPI20Test from dbapi20', or you will - confuse the unit tester - just 'import dbapi20'. - ''' + Don't 'import DatabaseAPI20Test from dbapi20', or you will + confuse the unit tester - just 'import dbapi20'. + """ # The self.driver module. This should be the module where the 'connect' # method is to be found - driver = None - connect_args = () # List of arguments to pass to connect - connect_kw_args = {} # Keyword arguments for connect - table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + driver: Any = None + connect_args: tuple = () # List of arguments to pass to connect + connect_kw_args: ClassVar[dict[str, Any]] = {} # Keyword arguments + table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables - ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix - ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix - xddl1 = 'drop table %sbooze' % table_prefix - xddl2 = 'drop table %sbarflys' % table_prefix + ddl1 = f'create table {table_prefix}booze (name varchar(20))' + ddl2 = (f'create table {table_prefix}barflys (name varchar(20),' + ' drink varchar(30))') + xddl1 = f'drop table {table_prefix}booze' + xddl2 = f'drop table {table_prefix}barflys' + insert = 'insert' - lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + lowerfunc = 'lower' # Name of stored procedure to convert str to lowercase # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. - def executeDDL1(self,cursor): + def execute_ddl1(self, cursor): cursor.execute(self.ddl1) - def executeDDL2(self,cursor): + def execute_ddl2(self, cursor): cursor.execute(self.ddl2) def setUp(self): - """self.drivers should override this method to perform required setup - if any is necessary, such as creating the database. + """Set up test fixture. + + self.drivers should override this method to perform required setup + if any is necessary, such as creating the database. """ pass def tearDown(self): - """self.drivers should override this method to perform required cleanup - if any is necessary, such as deleting the test database. - The default drops the tables that may be created. + """Tear down test fixture. + + self.drivers should override this method to perform required cleanup + if any is necessary, such as deleting the test database. + The default drops the tables that may be created. """ - con = self._connect() try: - cur = con.cursor() - for ddl in (self.xddl1,self.xddl2): - try: - cur.execute(ddl) - con.commit() - except self.driver.Error: - # Assume table didn't exist. Other tests will check if - # execute is busted. - pass - finally: - con.close() + con = self._connect() + try: + cur = con.cursor() + for ddl in (self.xddl1, self.xddl2): + try: + cur.execute(ddl) + con.commit() + except self.driver.Error: + # Assume table didn't exist. Other tests will check if + # execute is busted. + pass + finally: + con.close() + except Exception: + pass def _connect(self): try: - return self.driver.connect( - *self.connect_args,**self.connect_kw_args - ) + con = self.driver.connect( + *self.connect_args, **self.connect_kw_args) except AttributeError: self.fail("No connect method found in self.driver module") + if not isinstance(con, self.driver.Connection): + self.fail("The connect method does not return a Connection") + return con def test_connect(self): con = self._connect() @@ -100,7 +114,7 @@ def test_apilevel(self): # Must exist apilevel = self.driver.apilevel # Must equal 2.0 - self.assertEqual(apilevel,'2.0') + self.assertEqual(apilevel, '2.0') except AttributeError: self.fail("Driver doesn't define apilevel") @@ -109,7 +123,7 @@ def test_threadsafety(self): # Must exist threadsafety = self.driver.threadsafety # Must be a valid value - self.assertTrue(threadsafety in (0,1,2,3)) + self.assertIn(threadsafety, (0, 1, 2, 3)) except AttributeError: self.fail("Driver doesn't define threadsafety") @@ -118,60 +132,44 @@ def test_paramstyle(self): # Must exist paramstyle = self.driver.paramstyle # Must be a valid value - self.assertTrue(paramstyle in ( - 'qmark','numeric','named','format','pyformat' - )) + self.assertIn(paramstyle, ( + 'qmark', 'numeric', 'named', 'format', 'pyformat')) except AttributeError: self.fail("Driver doesn't define paramstyle") - def test_Exceptions(self): - """Make sure required exceptions exist, and are in the - defined hierarchy. - """ - self.assertTrue(issubclass(self.driver.Warning,Exception)) - self.assertTrue(issubclass(self.driver.Error,Exception)) - self.assertTrue( - issubclass(self.driver.InterfaceError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.DatabaseError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.OperationalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.IntegrityError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.InternalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.ProgrammingError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.NotSupportedError,self.driver.Error) - ) - - def test_ExceptionsAsConnectionAttributes(self): - """Optional extension - - Test for the optional DB API 2.0 extension, where the exceptions - are exposed as attributes on the Connection object - I figure this optional extension will be implemented by any - driver author who is using this test suite, so it is enabled - by default. - """ + def test_exceptions(self): + # Make sure required exceptions exist, and are in the + # defined hierarchy. + sub = issubclass + self.assertTrue(sub(self.driver.Warning, Exception)) + self.assertTrue(sub(self.driver.Error, Exception)) + + self.assertTrue(sub(self.driver.InterfaceError, self.driver.Error)) + self.assertTrue(sub(self.driver.DatabaseError, self.driver.Error)) + self.assertTrue(sub(self.driver.OperationalError, self.driver.Error)) + self.assertTrue(sub(self.driver.IntegrityError, self.driver.Error)) + self.assertTrue(sub(self.driver.InternalError, self.driver.Error)) + self.assertTrue(sub(self.driver.ProgrammingError, self.driver.Error)) + self.assertTrue(sub(self.driver.NotSupportedError, self.driver.Error)) + + def test_exceptions_as_connection_attributes(self): + # OPTIONAL EXTENSION + # Test for the optional DB API 2.0 extension, where the exceptions + # are exposed as attributes on the Connection object + # I figure this optional extension will be implemented by any + # driver author who is using this test suite, so it is enabled + # by default. con = self._connect() drv = self.driver - self.assertTrue(con.Warning is drv.Warning) - self.assertTrue(con.Error is drv.Error) - self.assertTrue(con.InterfaceError is drv.InterfaceError) - self.assertTrue(con.DatabaseError is drv.DatabaseError) - self.assertTrue(con.OperationalError is drv.OperationalError) - self.assertTrue(con.IntegrityError is drv.IntegrityError) - self.assertTrue(con.InternalError is drv.InternalError) - self.assertTrue(con.ProgrammingError is drv.ProgrammingError) - self.assertTrue(con.NotSupportedError is drv.NotSupportedError) + self.assertIs(con.Warning, drv.Warning) + self.assertIs(con.Error, drv.Error) + self.assertIs(con.InterfaceError, drv.InterfaceError) + self.assertIs(con.DatabaseError, drv.DatabaseError) + self.assertIs(con.OperationalError, drv.OperationalError) + self.assertIs(con.IntegrityError, drv.IntegrityError) + self.assertIs(con.InternalError, drv.InternalError) + self.assertIs(con.ProgrammingError, drv.ProgrammingError) + self.assertIs(con.NotSupportedError, drv.NotSupportedError) def test_commit(self): con = self._connect() @@ -185,16 +183,16 @@ def test_rollback(self): con = self._connect() # If rollback is defined, it should either work or throw # the documented exception - if hasattr(con,'rollback'): - try: + if hasattr(con, 'rollback'): + with suppress(self.driver.NotSupportedError): + # noinspection PyCallingNonCallable con.rollback() - except self.driver.NotSupportedError: - pass def test_cursor(self): con = self._connect() try: cur = con.cursor() + self.assertIsNotNone(cur) finally: con.close() @@ -205,15 +203,14 @@ def test_cursor_isolation(self): # the documented transaction isolation level cur1 = con.cursor() cur2 = con.cursor() - self.executeDDL1(cur1) - cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - cur2.execute("select name from %sbooze" % self.table_prefix) + self.execute_ddl1(cur1) + cur1.execute(f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") + cur2.execute(f"select name from {self.table_prefix}booze") booze = cur2.fetchall() - self.assertEqual(len(booze),1) - self.assertEqual(len(booze[0]),1) - self.assertEqual(booze[0][0],'Victoria Bitter') + self.assertEqual(len(booze), 1) + self.assertEqual(len(booze[0]), 1) + self.assertEqual(booze[0][0], 'Victoria Bitter') finally: con.close() @@ -221,32 +218,32 @@ def test_description(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) - self.assertEqual(cur.description,None, - 'cursor.description should be none after executing a ' - 'statement that can return no rows (such as DDL)' - ) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(len(cur.description),1, - 'cursor.description describes too many columns' - ) - self.assertEqual(len(cur.description[0]),7, - 'cursor.description[x] tuples must have 7 elements' - ) - self.assertEqual(cur.description[0][0].lower(),'name', - 'cursor.description[x][0] must return column name' - ) - self.assertEqual(cur.description[0][1],self.driver.STRING, - 'cursor.description[x][1] must return column type. Got %r' - % cur.description[0][1] - ) + self.execute_ddl1(cur) + self.assertIsNone( + cur.description, + 'cursor.description should be none after executing a' + ' statement that can return no rows (such as DDL)') + cur.execute(f'select name from {self.table_prefix}booze') + self.assertEqual( + len(cur.description), 1, + 'cursor.description describes too many columns') + self.assertEqual( + len(cur.description[0]), 7, + 'cursor.description[x] tuples must have 7 elements') + self.assertEqual( + cur.description[0][0].lower(), 'name', + 'cursor.description[x][0] must return column name') + self.assertEqual( + cur.description[0][1], self.driver.STRING, + 'cursor.description[x][1] must return column type.' + f' Got: {cur.description[0][1]!r}') # Make sure self.description gets reset - self.executeDDL2(cur) - self.assertEqual(cur.description,None, - 'cursor.description not being set to None when executing ' - 'no-result statements (eg. DDL)' - ) + self.execute_ddl2(cur) + self.assertIsNone( + cur.description, + 'cursor.description not being set to None when executing' + ' no-result statements (eg. DDL)') finally: con.close() @@ -254,48 +251,47 @@ def test_rowcount(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount should be -1 after executing no-result ' - 'statements' - ) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number or rows inserted, or ' - 'set to -1 after executing an insert statement' - ) - cur.execute("select name from %sbooze" % self.table_prefix) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) - self.executeDDL2(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount not being reset to -1 after executing ' - 'no-result statements' - ) + self.execute_ddl1(cur) + self.assertIn( + cur.rowcount, (-1, 0), # Bug #543885 + 'cursor.rowcount should be -1 or 0 after executing no-result' + ' statements') + cur.execute(f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") + self.assertIn( + cur.rowcount, (-1, 1), + 'cursor.rowcount should == number or rows inserted, or' + ' set to -1 after executing an insert statement') + cur.execute(f"select name from {self.table_prefix}booze") + self.assertIn( + cur.rowcount, (-1, 1), + 'cursor.rowcount should == number of rows returned, or' + ' set to -1 after executing a select statement') + self.execute_ddl2(cur) + self.assertIn( + cur.rowcount, (-1, 0), # Bug #543885 + 'cursor.rowcount should be -1 or 0 after executing no-result' + ' statements') finally: con.close() lower_func = 'lower' + def test_callproc(self): con = self._connect() try: cur = con.cursor() - if self.lower_func and hasattr(cur,'callproc'): - r = cur.callproc(self.lower_func,('FOO',)) - self.assertEqual(len(r),1) - self.assertEqual(r[0],'FOO') + if self.lower_func and hasattr(cur, 'callproc'): + # noinspection PyCallingNonCallable + r = cur.callproc(self.lower_func, ('FOO',)) + self.assertEqual(len(r), 1) + self.assertEqual(r[0], 'FOO') r = cur.fetchall() - self.assertEqual(len(r),1,'callproc produced no result set') - self.assertEqual(len(r[0]),1, - 'callproc produced invalid result set' - ) - self.assertEqual(r[0][0],'foo', - 'callproc produced invalid results' - ) + self.assertEqual(len(r), 1, 'callproc produced no result set') + self.assertEqual( + len(r[0]), 1, 'callproc produced invalid result set') + self.assertEqual( + r[0][0], 'foo', 'callproc produced invalid results') finally: con.close() @@ -308,14 +304,18 @@ def test_close(self): # cursor.execute should raise an Error if called after connection # closed - self.assertRaises(self.driver.Error,self.executeDDL1,cur) + self.assertRaises(self.driver.Error, self.execute_ddl1, cur) # connection.commit should raise an Error if called after connection' # closed.' - self.assertRaises(self.driver.Error,con.commit) + self.assertRaises(self.driver.Error, con.commit) + def test_non_idempotent_close(self): + con = self._connect() + con.close() # connection.close should raise an Error if called more than once - self.assertRaises(self.driver.Error,con.close) + # (the usefulness of this test and this feature is questionable) + self.assertRaises(self.driver.Error, con.close) def test_execute(self): con = self._connect() @@ -325,105 +325,96 @@ def test_execute(self): finally: con.close() - def _paraminsert(self,cur): - self.executeDDL1(cur) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1)) + def _paraminsert(self, cur): + self.execute_ddl2(cur) + table_prefix = self.table_prefix + insert = f"{self.insert} into {table_prefix}barflys values" + cur.execute( + f"{insert} ('Victoria Bitter'," + " 'thi%s :may ca%(u)se? troub:1e')") + self.assertIn(cur.rowcount, (-1, 1)) if self.driver.paramstyle == 'qmark': cur.execute( - 'insert into %sbooze values (?)' % self.table_prefix, - ("Cooper's",) - ) + f"{insert} (?, 'thi%s :may ca%(u)se? troub:1e')", + ("Cooper's",)) elif self.driver.paramstyle == 'numeric': cur.execute( - 'insert into %sbooze values (:1)' % self.table_prefix, - ("Cooper's",) - ) + f"{insert} (:1, 'thi%s :may ca%(u)se? troub:1e')", + ("Cooper's",)) elif self.driver.paramstyle == 'named': cur.execute( - 'insert into %sbooze values (:beer)' % self.table_prefix, - {'beer':"Cooper's"} - ) + f"{insert} (:beer, 'thi%s :may ca%(u)se? troub:1e')", + {'beer': "Cooper's"}) elif self.driver.paramstyle == 'format': cur.execute( - 'insert into %sbooze values (%%s)' % self.table_prefix, - ("Cooper's",) - ) + f"{insert} (%s, 'thi%%s :may ca%%(u)se? troub:1e')", + ("Cooper's",)) elif self.driver.paramstyle == 'pyformat': cur.execute( - 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, - {'beer':"Cooper's"} - ) + f"{insert} (%(beer)s, 'thi%%s :may ca%%(u)se? troub:1e')", + {'beer': "Cooper's"}) else: self.fail('Invalid paramstyle') - self.assertTrue(cur.rowcount in (-1,1)) + self.assertIn(cur.rowcount, (-1, 1)) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name, drink from {table_prefix}barflys') res = cur.fetchall() - self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') - beers = [res[0][0],res[1][0]] + self.assertEqual(len(res), 2, 'cursor.fetchall returned too few rows') + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Cooper's", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) - self.assertEqual(beers[1],"Victoria Bitter", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) + self.assertEqual( + beers[0], "Cooper's", + 'cursor.fetchall retrieved incorrect data, or data inserted' + ' incorrectly') + self.assertEqual( + beers[1], "Victoria Bitter", + 'cursor.fetchall retrieved incorrect data, or data inserted' + ' incorrectly') + trouble = "thi%s :may ca%(u)se? troub:1e" + self.assertEqual( + res[0][1], trouble, + 'cursor.fetchall retrieved incorrect data, or data inserted' + f' incorrectly. Got: {res[0][1]!r}, Expected: {trouble!r}') + self.assertEqual( + res[1][1], trouble, + 'cursor.fetchall retrieved incorrect data, or data inserted' + f' incorrectly. Got: {res[1][1]!r}, Expected: {trouble!r}') def test_executemany(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) - largs = [ ("Cooper's",) , ("Boag's",) ] - margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] + self.execute_ddl1(cur) + table_prefix = self.table_prefix + insert = f'{self.insert} into {table_prefix}booze values' + largs = [("Cooper's",), ("Boag's",)] + margs = [{'beer': "Cooper's"}, {'beer': "Boag's"}] if self.driver.paramstyle == 'qmark': - cur.executemany( - 'insert into %sbooze values (?)' % self.table_prefix, - largs - ) + cur.executemany(f'{insert} (?)', largs) elif self.driver.paramstyle == 'numeric': - cur.executemany( - 'insert into %sbooze values (:1)' % self.table_prefix, - largs - ) + cur.executemany(f'{insert} (:1)', largs) elif self.driver.paramstyle == 'named': - cur.executemany( - 'insert into %sbooze values (:beer)' % self.table_prefix, - margs - ) + cur.executemany(f'{insert} (:beer)', margs) elif self.driver.paramstyle == 'format': - cur.executemany( - 'insert into %sbooze values (%%s)' % self.table_prefix, - largs - ) + cur.executemany(f'{insert} (%s)', largs) elif self.driver.paramstyle == 'pyformat': - cur.executemany( - 'insert into %sbooze values (%%(beer)s)' % ( - self.table_prefix - ), - margs - ) + cur.executemany(f'{insert} (%(beer)s)', margs) else: self.fail('Unknown paramstyle') - self.assertTrue(cur.rowcount in (-1,2), - 'insert using cursor.executemany set cursor.rowcount to ' - 'incorrect value %r' % cur.rowcount - ) - cur.execute('select name from %sbooze' % self.table_prefix) + self.assertIn( + cur.rowcount, (-1, 2), + 'insert using cursor.executemany set cursor.rowcount to' + f' incorrect value {cur.rowcount!r}') + cur.execute(f'select name from {table_prefix}booze') res = cur.fetchall() - self.assertEqual(len(res),2, - 'cursor.fetchall retrieved incorrect number of rows' - ) - beers = [res[0][0],res[1][0]] + self.assertEqual( + len(res), 2, + 'cursor.fetchall retrieved incorrect number of rows') + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') - self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') + self.assertEqual(beers[0], "Boag's", 'incorrect data retrieved') + self.assertEqual(beers[1], "Cooper's", 'incorrect data retrieved') finally: con.close() @@ -434,59 +425,98 @@ def test_fetchone(self): # cursor.fetchone should raise an Error if called before # executing a select-type query - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) # cursor.fetchone should raise an Error if called after - # executing a query that cannnot return rows - self.executeDDL1(cur) - self.assertRaises(self.driver.Error,cur.fetchone) + # executing a query that cannot return rows + self.execute_ddl1(cur) + self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if a query retrieves ' - 'no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute(f'select name from {self.table_prefix}booze') + self.assertIsNone( + cur.fetchone(), + 'cursor.fetchone should return None if a query retrieves' + ' no rows') + self.assertIn(cur.rowcount, (-1, 0)) # cursor.fetchone should raise an Error if called after - # executing a query that cannnot return rows - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertRaises(self.driver.Error,cur.fetchone) + # executing a query that cannot return rows + cur.execute( + f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") + self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') r = cur.fetchone() - self.assertEqual(len(r),1, - 'cursor.fetchone should have retrieved a single row' - ) - self.assertEqual(r[0],'Victoria Bitter', - 'cursor.fetchone retrieved incorrect data' - ) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if no more rows available' - ) - self.assertTrue(cur.rowcount in (-1,1)) + self.assertEqual( + len(r), 1, + 'cursor.fetchone should have retrieved a single row') + self.assertEqual( + r[0], 'Victoria Bitter', + 'cursor.fetchone retrieved incorrect data') + self.assertIsNone( + cur.fetchone(), + 'cursor.fetchone should return None if no more rows available') + self.assertIn(cur.rowcount, (-1, 1)) + finally: + con.close() + + def test_next(self): + """Test extension for getting the next row.""" + con = self._connect() + try: + cur = con.cursor() + if not hasattr(cur, 'next'): + return + + # cursor.next should raise an Error if called before + # executing a select-type query + self.assertRaises(self.driver.Error, cur.next) + + # cursor.next should raise an Error if called after + # executing a query that cannot return rows + self.execute_ddl1(cur) + self.assertRaises(self.driver.Error, cur.next) + + # cursor.next should return None if a query retrieves no rows + cur.execute(f'select name from {self.table_prefix}booze') + self.assertRaises(StopIteration, cur.next) + self.assertIn(cur.rowcount, (-1, 0)) + + # cursor.next should raise an Error if called after + # executing a query that cannot return rows + cur.execute(f"{self.insert} into {self.table_prefix}booze" + " values ('Victoria Bitter')") + self.assertRaises(self.driver.Error, cur.next) + + cur.execute(f'select name from {self.table_prefix}booze') + r = cur.next() + self.assertEqual( + len(r), 1, + 'cursor.fetchone should have retrieved a single row') + self.assertEqual( + r[0], 'Victoria Bitter', + 'cursor.next retrieved incorrect data') + # cursor.next should raise StopIteration if no more rows available + self.assertRaises(StopIteration, cur.next) + self.assertIn(cur.rowcount, (-1, 1)) finally: con.close() - samples = [ + samples = ( 'Carlton Cold', 'Carlton Draft', 'Mountain Goat', 'Redback', 'Victoria Bitter', 'XXXX' - ] + ) def _populate(self): - """Return a list of sql commands to setup the DB for the fetch - tests. - """ + """Return a list of SQL commands to setup the DB for fetching tests.""" populate = [ - "insert into %sbooze values ('%s')" % (self.table_prefix,s) - for s in self.samples - ] + f"{self.insert} into {self.table_prefix}booze values ('{s}')" + for s in self.samples] return populate def test_fetchmany(self): @@ -495,78 +525,78 @@ def test_fetchmany(self): cur = con.cursor() # cursor.fetchmany should raise an Error if called without - #issuing a query - self.assertRaises(self.driver.Error,cur.fetchmany,4) + # issuing a query + self.assertRaises(self.driver.Error, cur.fetchmany, 4) - self.executeDDL1(cur) + self.execute_ddl1(cur) for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') r = cur.fetchmany() - self.assertEqual(len(r),1, - 'cursor.fetchmany retrieved incorrect number of rows, ' - 'default of arraysize is one.' - ) - cur.arraysize=10 - r = cur.fetchmany(3) # Should get 3 rows - self.assertEqual(len(r),3, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should get 2 more - self.assertEqual(len(r),2, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should be an empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence after ' - 'results are exhausted' - ) - self.assertTrue(cur.rowcount in (-1,6)) + self.assertEqual( + len(r), 1, + 'cursor.fetchmany retrieved incorrect number of rows,' + ' default of arraysize is one.') + cur.arraysize = 10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual( + len(r), 3, + 'cursor.fetchmany retrieved incorrect number of rows') + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual( + len(r), 2, + 'cursor.fetchmany retrieved incorrect number of rows') + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual( + len(r), 0, + 'cursor.fetchmany should return an empty sequence after' + ' results are exhausted') + self.assertIn(cur.rowcount, (-1, 6)) # Same as above, using cursor.arraysize - cur.arraysize=4 - cur.execute('select name from %sbooze' % self.table_prefix) - r = cur.fetchmany() # Should get 4 rows - self.assertEqual(len(r),4, - 'cursor.arraysize not being honoured by fetchmany' - ) - r = cur.fetchmany() # Should get 2 more - self.assertEqual(len(r),2) - r = cur.fetchmany() # Should be an empty sequence - self.assertEqual(len(r),0) - self.assertTrue(cur.rowcount in (-1,6)) - - cur.arraysize=6 - cur.execute('select name from %sbooze' % self.table_prefix) - rows = cur.fetchmany() # Should get all rows - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows),6) - self.assertEqual(len(rows),6) + cur.arraysize = 4 + cur.execute(f'select name from {self.table_prefix}booze') + r = cur.fetchmany() # Should get 4 rows + self.assertEqual( + len(r), 4, + 'cursor.arraysize not being honoured by fetchmany') + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r), 2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r), 0) + self.assertIn(cur.rowcount, (-1, 6)) + + cur.arraysize = 6 + cur.execute(f'select name from {self.table_prefix}booze') + rows = cur.fetchmany() # Should get all rows + self.assertIn(cur.rowcount, (-1, 6)) + self.assertEqual(len(rows), 6) + self.assertEqual(len(rows), 6) rows = [r[0] for r in rows] rows.sort() # Make sure we get the right data back out - for i in range(0,6): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved by cursor.fetchmany' - ) - - rows = cur.fetchmany() # Should return an empty list - self.assertEqual(len(rows),0, - 'cursor.fetchmany should return an empty sequence if ' - 'called after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,6)) - - self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) - r = cur.fetchmany() # Should get empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence if ' - 'query retrieved no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + for i in range(0, 6): + self.assertEqual( + rows[i], self.samples[i], + 'incorrect data retrieved by cursor.fetchmany') + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual( + len(rows), 0, + 'cursor.fetchmany should return an empty sequence if' + ' called after the whole result set has been fetched') + self.assertIn(cur.rowcount, (-1, 6)) + + self.execute_ddl2(cur) + cur.execute(f'select name from {self.table_prefix}barflys') + r = cur.fetchmany() # Should get empty sequence + self.assertEqual( + len(r), 0, + 'cursor.fetchmany should return an empty sequence if' + ' query retrieved no rows') + self.assertIn(cur.rowcount, (-1, 0)) finally: con.close() @@ -580,42 +610,40 @@ def test_fetchall(self): # as a select) self.assertRaises(self.driver.Error, cur.fetchall) - self.executeDDL1(cur) + self.execute_ddl1(cur) for sql in self._populate(): cur.execute(sql) # cursor.fetchall should raise an Error if called # after executing a a statement that cannot return rows - self.assertRaises(self.driver.Error,cur.fetchall) + self.assertRaises(self.driver.Error, cur.fetchall) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute(f'select name from {self.table_prefix}booze') rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - self.assertEqual(len(rows),len(self.samples), - 'cursor.fetchall did not retrieve all rows' - ) - rows = [r[0] for r in rows] - rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'cursor.fetchall retrieved incorrect rows' - ) + self.assertIn(cur.rowcount, (-1, len(self.samples))) + self.assertEqual( + len(rows), len(self.samples), + 'cursor.fetchall did not retrieve all rows') + rows = sorted(r[0] for r in rows) + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], + 'cursor.fetchall retrieved incorrect rows') rows = cur.fetchall() self.assertEqual( - len(rows),0, - 'cursor.fetchall should return an empty list if called ' - 'after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - - self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + len(rows), 0, + 'cursor.fetchall should return an empty list if called' + ' after the whole result set has been fetched') + self.assertIn(cur.rowcount, (-1, len(self.samples))) + + self.execute_ddl2(cur) + cur.execute(f'select name from {self.table_prefix}barflys') rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,0)) - self.assertEqual(len(rows),0, - 'cursor.fetchall should return an empty list if ' - 'a select query returns no rows' - ) + self.assertIn(cur.rowcount, (-1, 0)) + self.assertEqual( + len(rows), 0, + 'cursor.fetchall should return an empty list if' + ' a select query returns no rows') finally: con.close() @@ -624,97 +652,93 @@ def test_mixedfetch(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) + self.execute_ddl1(cur) for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) - rows1 = cur.fetchone() + cur.execute(f'select name from {self.table_prefix}booze') + rows1 = cur.fetchone() rows23 = cur.fetchmany(2) - rows4 = cur.fetchone() + rows4 = cur.fetchone() rows56 = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows23),2, - 'fetchmany returned incorrect number of rows' - ) - self.assertEqual(len(rows56),2, - 'fetchall returned incorrect number of rows' - ) + self.assertIn(cur.rowcount, (-1, 6)) + self.assertEqual( + len(rows23), 2, + 'fetchmany returned incorrect number of rows') + self.assertEqual( + len(rows56), 2, + 'fetchall returned incorrect number of rows') rows = [rows1[0]] - rows.extend([rows23[0][0],rows23[1][0]]) + rows.extend([rows23[0][0], rows23[1][0]]) rows.append(rows4[0]) - rows.extend([rows56[0][0],rows56[1][0]]) + rows.extend([rows56[0][0], rows56[1][0]]) rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved or inserted' - ) + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], + 'incorrect data retrieved or inserted') finally: con.close() - def help_nextset_setUp(self, cur): - """Should create a procedure called deleteme - that returns two result sets, first the - number of rows in booze then "name from booze" - """ - if False: - sql = """ - create procedure deleteme as - begin - select count(*) from booze - select name from booze - end - """ - cur.execute(sql) - else: - raise NotImplementedError('Helper not implemented') + def help_nextset_setup(self, cur): + """Set up nextset test. - def help_nextset_tearDown(self, cur): - """If cleaning up is needed after nextSetTest""" - if False: - cur.execute("drop procedure deleteme") - else: - - raise NotImplementedError('Helper not implemented') + Should create a procedure called deleteme that returns two result sets, + first the number of rows in booze, then "name from booze". + """ + raise NotImplementedError('Helper not implemented') + # sql = """ + # create procedure deleteme as + # begin + # select count(*) from booze + # select name from booze + # end + # """ + # cur.execute(sql) + + def help_nextset_teardown(self, cur): + """Clean up after nextset test. + + If cleaning up is needed after test_nextset. + """ + raise NotImplementedError('Helper not implemented') + # cur.execute("drop procedure deleteme") def test_nextset(self): - con = self._connect() - try: - cur = con.cursor() - if not hasattr(cur,'nextset'): - return - - try: - self.executeDDL1(cur) - sql=self._populate() - for sql in self._populate(): - cur.execute(sql) - - self.help_nextset_setUp(cur) - - cur.callproc('deleteme') - numberofrows=cur.fetchone() - assert numberofrows[0]== len(self.samples) - assert cur.nextset() - names=cur.fetchall() - assert len(names) == len(self.samples) - s=cur.nextset() - assert s == None,'No more return sets, should return None' - finally: - self.help_nextset_tearDown(cur) - - finally: - con.close() + """Test the nextset functionality.""" + raise NotImplementedError('Drivers need to override this test') + # example test implementation only: + # con = self._connect() + # try: + # cur = con.cursor() + # if not hasattr(cur, 'nextset'): + # return + # try: + # self.executeDDL1(cur) + # for sql in self._populate(): + # cur.execute(sql) + # self.help_nextset_setup(cur) + # cur.callproc('deleteme') + # number_of_rows = cur.fetchone() + # self.assertEqual(number_of_rows[0], len(self.samples)) + # self.assertTrue(cur.nextset()) + # names = cur.fetchall() + # self.assertEqual(len(names), len(self.samples)) + # self.assertIsNone( + # cur.nextset(), 'No more return sets, should return None') + # finally: + # self.help_nextset_teardown(cur) + # finally: + # con.close() def test_arraysize(self): - """Not much here - rest of the tests for this are in test_fetchmany""" + # Not much here - rest of the tests for this are in test_fetchmany con = self._connect() try: cur = con.cursor() - self.assertTrue(hasattr(cur,'arraysize'), - 'cursor.arraysize must be defined' - ) + self.assertTrue(hasattr(cur, 'arraysize'), + 'cursor.arraysize must be defined') finally: con.close() @@ -722,86 +746,86 @@ def test_setinputsizes(self): con = self._connect() try: cur = con.cursor() - cur.setinputsizes( (25,) ) - self._paraminsert(cur) # Make sure cursor still works + cur.setinputsizes((25,)) + self._paraminsert(cur) # Make sure cursor still works finally: con.close() def test_setoutputsize_basic(self): - """Basic test is to make sure setoutputsize doesn't blow up""" + # Basic test is to make sure setoutputsize doesn't blow up con = self._connect() try: cur = con.cursor() cur.setoutputsize(1000) - cur.setoutputsize(2000,0) - self._paraminsert(cur) # Make sure the cursor still works + cur.setoutputsize(2000, 0) + self._paraminsert(cur) # Make sure the cursor still works finally: con.close() def test_setoutputsize(self): - """Real test for setoutputsize is driver dependant""" - raise NotImplementedError('Driver needs to override this test') + # Real test for setoutputsize is driver dependant + raise NotImplementedError('Driver needed to override this test') - def test_None(self): + def test_none(self): con = self._connect() try: cur = con.cursor() - self.executeDDL1(cur) - cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) - cur.execute('select name from %sbooze' % self.table_prefix) + self.execute_ddl2(cur) + # inserting NULL to the second column, because some drivers might + # need the first one to be primary key, which means it needs + # to have a non-NULL value + cur.execute(f"{self.insert} into {self.table_prefix}barflys" + " values ('a', NULL)") + cur.execute(f'select drink from {self.table_prefix}barflys') r = cur.fetchall() - self.assertEqual(len(r),1) - self.assertEqual(len(r[0]),1) - self.assertEqual(r[0][0],None,'NULL value not returned as None') + self.assertEqual(len(r), 1) + self.assertEqual(len(r[0]), 1) + self.assertIsNone(r[0][0], 'NULL value not returned as None') finally: con.close() - def test_Date(self): - d1 = self.driver.Date(2002,12,25) - d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) + def test_date(self): + d1 = self.driver.Date(2002, 12, 25) + d2 = self.driver.DateFromTicks( + time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied - # self.assertEqual(str(d1),str(d2)) + self.assertEqual(str(d1), str(d2)) - def test_Time(self): - t1 = self.driver.Time(13,45,30) - t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) + def test_time(self): + t1 = self.driver.Time(13, 45, 30) + t2 = self.driver.TimeFromTicks( + time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied - # self.assertEqual(str(t1),str(t2)) + self.assertEqual(str(t1), str(t2)) - def test_Timestamp(self): - t1 = self.driver.Timestamp(2002,12,25,13,45,30) + def test_timestamp(self): + t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) t2 = self.driver.TimestampFromTicks( - time.mktime((2002,12,25,13,45,30,0,0,0)) - ) + time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) + ) # Can we assume this? API doesn't specify, but it seems implied - # self.assertEqual(str(t1),str(t2)) - - def test_Binary(self): - b = self.driver.Binary(b'Something') - b = self.driver.Binary(b'') + self.assertEqual(str(t1), str(t2)) - def test_STRING(self): - self.assertTrue(hasattr(self.driver,'STRING'), - 'module.STRING must be defined' - ) + def test_binary_string(self): + self.driver.Binary(b'Something') + self.driver.Binary(b'') - def test_BINARY(self): - self.assertTrue(hasattr(self.driver,'BINARY'), - 'module.BINARY must be defined.' - ) + def test_string_type(self): + self.assertTrue(hasattr(self.driver, 'STRING'), + 'module.STRING must be defined') - def test_NUMBER(self): - self.assertTrue(hasattr(self.driver,'NUMBER'), - 'module.NUMBER must be defined.' - ) + def test_binary_type(self): + self.assertTrue(hasattr(self.driver, 'BINARY'), + 'module.BINARY must be defined.') - def test_DATETIME(self): - self.assertTrue(hasattr(self.driver,'DATETIME'), - 'module.DATETIME must be defined.' - ) + def test_number_type(self): + self.assertTrue(hasattr(self.driver, 'NUMBER'), + 'module.NUMBER must be defined.') - def test_ROWID(self): - self.assertTrue(hasattr(self.driver,'ROWID'), - 'module.ROWID must be defined.' - ) + def test_datetime_type(self): + self.assertTrue(hasattr(self.driver, 'DATETIME'), + 'module.DATETIME must be defined.') + def test_rowid_type(self): + self.assertTrue(hasattr(self.driver, 'ROWID'), + 'module.ROWID must be defined.') diff --git a/tests/test_classic.py b/tests/test_classic.py index bb5133ee..3bf0fe5c 100755 --- a/tests/test_classic.py +++ b/tests/test_classic.py @@ -1,37 +1,26 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- -from __future__ import print_function - -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest - -import sys +import unittest +from contextlib import suppress from functools import partial -from time import sleep from threading import Thread +from time import sleep -from pg import * - -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -dbname = 'unittest' -dbhost = None -dbport = 5432 +from pg import ( + DB, + DatabaseError, + Error, + IntegrityError, + NotificationHandler, + NotSupportedError, + ProgrammingError, +) -try: - from .LOCAL_PyGreSQL import * -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * - except ImportError: - pass +from .config import dbhost, dbname, dbpasswd, dbport, dbuser -def opendb(): - db = DB(dbname, dbhost, dbport) +def open_db(): + db = DB(dbname, dbhost, dbport, user=dbuser, passwd=dbpasswd) db.query("SET DATESTYLE TO 'ISO'") db.query("SET TIME ZONE 'EST5EDT'") db.query("SET DEFAULT_WITH_OIDS=FALSE") @@ -39,55 +28,46 @@ def opendb(): db.query("SET STANDARD_CONFORMING_STRINGS=FALSE") return db -db = opendb() -for q in ( - "DROP TABLE _test1._test_schema", - "DROP TABLE _test2._test_schema", - "DROP SCHEMA _test1", - "DROP SCHEMA _test2", -): - try: - db.query(q) - except Exception: - pass -db.close() - class UtilityTest(unittest.TestCase): - def setUp(self): - """Setup test tables or empty them if they already exist.""" - db = opendb() - + @classmethod + def setUpClass(cls): + """Recreate test tables and schemas.""" + db = open_db() + with suppress(Exception): + db.query("DROP VIEW _test_vschema") + with suppress(Exception): + db.query("DROP TABLE _test_schema") + db.query("CREATE TABLE _test_schema" + " (_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)") + db.query("CREATE VIEW _test_vschema AS" + " SELECT _test, 'abc'::text AS _test2 FROM _test_schema") for t in ('_test1', '_test2'): - try: + with suppress(Exception): db.query("CREATE SCHEMA " + t) - except Error: - pass - try: - db.query("CREATE TABLE %s._test_schema " - "(%s int PRIMARY KEY)" % (t, t)) - except Error: - db.query("DELETE FROM %s._test_schema" % t) - try: - db.query("CREATE TABLE _test_schema " - "(_test int PRIMARY KEY, _i interval, dvar int DEFAULT 999)") - except Error: - db.query("DELETE FROM _test_schema") - try: - db.query("CREATE VIEW _test_vschema AS " - "SELECT _test, 'abc'::text AS _test2 FROM _test_schema") - except Error: - pass + with suppress(Exception): + db.query(f"DROP TABLE {t}._test_schema") + db.query(f"CREATE TABLE {t}._test_schema" + f" ({t} int PRIMARY KEY)") + db.close() + + def setUp(self): + """Set up test tables or empty them if they already exist.""" + db = open_db() + db.query("TRUNCATE TABLE _test_schema") + for t in ('_test1', '_test2'): + db.query(f"TRUNCATE TABLE {t}._test_schema") + db.close() - def test_invalidname(self): - """Make sure that invalid table names are caught""" - db = opendb() + def test_invalid_name(self): + """Make sure that invalid table names are caught.""" + db = open_db() self.assertRaises(NotSupportedError, db.get_attnames, 'x.y.z') def test_schema(self): - """Does it differentiate the same table name in different schemas""" - db = opendb() + """Check differentiation of same table name in different schemas.""" + db = open_db() # see if they differentiate the table names properly self.assertEqual( db.get_attnames('_test_schema'), @@ -107,7 +87,7 @@ def test_schema(self): ) def test_pkey(self): - db = opendb() + db = open_db() self.assertEqual(db.pkey('_test_schema'), '_test') self.assertEqual(db.pkey('public._test_schema'), '_test') self.assertEqual(db.pkey('_test1._test_schema'), '_test1') @@ -115,7 +95,7 @@ def test_pkey(self): self.assertRaises(KeyError, db.pkey, '_test_vschema') def test_get(self): - db = opendb() + db = open_db() db.query("INSERT INTO _test_schema VALUES (1234)") db.get('_test_schema', 1234) db.get('_test_schema', 1234, keyname='_test') @@ -123,13 +103,13 @@ def test_get(self): db.get('_test_vschema', 1234, keyname='_test') def test_params(self): - db = opendb() + db = open_db() db.query("INSERT INTO _test_schema VALUES ($1, $2, $3)", 12, None, 34) d = db.get('_test_schema', 12) self.assertEqual(d['dvar'], 34) def test_insert(self): - db = opendb() + db = open_db() d = dict(_test=1234) db.insert('_test_schema', d) self.assertEqual(d['dvar'], 999) @@ -137,7 +117,7 @@ def test_insert(self): self.assertEqual(d['dvar'], 999) def test_context_manager(self): - db = opendb() + db = open_db() t = '_test_schema' d = dict(_test=1235) with db: @@ -163,26 +143,28 @@ def test_context_manager(self): self.assertTrue(db.get(t, 1239)) def test_sqlstate(self): - db = opendb() + db = open_db() db.query("INSERT INTO _test_schema VALUES (1234)") try: db.query("INSERT INTO _test_schema VALUES (1234)") except DatabaseError as error: - self.assertTrue(isinstance(error, IntegrityError)) + self.assertIsInstance(error, IntegrityError) # the SQLSTATE error code for unique violation is 23505 + # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '23505') def test_mixed_case(self): - db = opendb() + db = open_db() try: db.query('CREATE TABLE _test_mc ("_Test" int PRIMARY KEY)') except Error: - db.query("DELETE FROM _test_mc") + db.query("TRUNCATE TABLE _test_mc") d = dict(_Test=1234) - db.insert('_test_mc', d) + r = db.insert('_test_mc', d) + self.assertEqual(r, d) def test_update(self): - db = opendb() + db = open_db() db.query("INSERT INTO _test_schema VALUES (1234)") r = db.get('_test_schema', 1234) @@ -214,7 +196,7 @@ def test_notify(self, options=None): run_as_method = options.get('run_as_method') call_notify = options.get('call_notify') two_payloads = options.get('two_payloads') - db = opendb() + db = open_db() # Get function under test, can be standalone or DB method. fut = db.notification_handler if run_as_method else partial( NotificationHandler, db) @@ -226,14 +208,14 @@ def test_notify(self, options=None): thread.start() try: # Wait until the thread has started. - for n in range(500): + for _n in range(500): if target.listening: break sleep(0.01) self.assertTrue(target.listening) self.assertTrue(thread.is_alive()) # Open another connection for sending notifications. - db2 = opendb() + db2 = open_db() # Generate notification from the other connection. if two_payloads: db2.begin() @@ -248,7 +230,7 @@ def test_notify(self, options=None): if two_payloads: db2.commit() # Wait until the notification has been caught. - for n in range(500): + for _n in range(500): if arg_dict['called'] or self.notify_timeout: break sleep(0.01) @@ -256,7 +238,7 @@ def test_notify(self, options=None): self.assertTrue(arg_dict['called']) self.assertEqual(arg_dict['event'], 'event_1') self.assertEqual(arg_dict['extra'], 'payload 1') - self.assertTrue(isinstance(arg_dict['pid'], int)) + self.assertIsInstance(arg_dict['pid'], int) self.assertFalse(self.notify_timeout) arg_dict['called'] = False self.assertTrue(thread.is_alive()) @@ -267,7 +249,7 @@ def test_notify(self, options=None): db2.query("notify stop_event_1, 'payload 2'") db2.close() # Wait until the notification has been caught. - for n in range(500): + for _n in range(500): if arg_dict['called'] or self.notify_timeout: break sleep(0.01) @@ -275,7 +257,7 @@ def test_notify(self, options=None): self.assertTrue(arg_dict['called']) self.assertEqual(arg_dict['event'], 'stop_event_1') self.assertEqual(arg_dict['extra'], 'payload 2') - self.assertTrue(isinstance(arg_dict['pid'], int)) + self.assertIsInstance(arg_dict['pid'], int) self.assertFalse(self.notify_timeout) thread.join(5) self.assertFalse(thread.is_alive()) @@ -299,18 +281,18 @@ def test_notify_other_options(self): def test_notify_timeout(self): for run_as_method in False, True: - db = opendb() + db = open_db() # Get function under test, can be standalone or DB method. fut = db.notification_handler if run_as_method else partial( NotificationHandler, db) arg_dict = dict(event=None, called=False) self.notify_timeout = False - # Listen for 'event_1' with timeout of 10ms. - target = fut('event_1', self.notify_callback, arg_dict, 0.01) + # Listen for 'event_1' with timeout of 50ms. + target = fut('event_1', self.notify_callback, arg_dict, 0.05) thread = Thread(None, target) thread.start() - # Sleep 20ms, long enough to time out. - sleep(0.02) + # Sleep 250ms, long enough to time out. + sleep(0.25) # Verify that we've indeed timed out. self.assertFalse(arg_dict.get('called')) self.assertTrue(self.notify_timeout) @@ -320,24 +302,4 @@ def test_notify_timeout(self): if __name__ == '__main__': - if len(sys.argv) == 2 and sys.argv[1] == '-l': - print('\n'.join(unittest.getTestCaseNames(UtilityTest, 'test_'))) - sys.exit(0) - - test_list = [name for name in sys.argv[1:] if not name.startswith('-')] - if not test_list: - test_list = unittest.getTestCaseNames(UtilityTest, 'test_') - - suite = unittest.TestSuite() - for test_name in test_list: - try: - suite.addTest(UtilityTest(test_name)) - except Exception: - print("\n ERROR: %s.\n" % sys.exc_value) - sys.exit(1) - - verbosity = '-v' in sys.argv[1:] and 2 or 1 - failfast = '-l' in sys.argv[1:] - runner = unittest.TextTestRunner(verbosity=verbosity, failfast=failfast) - rc = runner.run(suite) - sys.exit(1 if rc.errors or rc.failures else 0) + unittest.main() diff --git a/tests/test_classic_attrdict.py b/tests/test_classic_attrdict.py new file mode 100644 index 00000000..8eef00df --- /dev/null +++ b/tests/test_classic_attrdict.py @@ -0,0 +1,100 @@ +#!/usr/bin/python + +"""Test the classic PyGreSQL interface. + +Sub-tests for the DB wrapper object. + +Contributed by Christoph Zwerschke. + +These tests need a database to test against. +""" + +import unittest + +import pg.attrs # the module under test + + +class TestAttrDict(unittest.TestCase): + """Test the simple ordered dictionary for attribute names.""" + + cls = pg.attrs.AttrDict + + def test_init(self): + a = self.cls() + self.assertIsInstance(a, dict) + self.assertEqual(a, {}) + items = [('id', 'int'), ('name', 'text')] + a = self.cls(items) + self.assertIsInstance(a, dict) + self.assertEqual(a, dict(items)) + iteritems = iter(items) + a = self.cls(iteritems) + self.assertIsInstance(a, dict) + self.assertEqual(a, dict(items)) + + def test_iter(self): + a = self.cls() + self.assertEqual(list(a), []) + keys = ['id', 'name', 'age'] + items = [(key, None) for key in keys] + a = self.cls(items) + self.assertEqual(list(a), keys) + + def test_keys(self): + a = self.cls() + self.assertEqual(list(a.keys()), []) + keys = ['id', 'name', 'age'] + items = [(key, None) for key in keys] + a = self.cls(items) + self.assertEqual(list(a.keys()), keys) + + def test_values(self): + a = self.cls() + self.assertEqual(list(a.values()), []) + items = [('id', 'int'), ('name', 'text')] + values = [item[1] for item in items] + a = self.cls(items) + self.assertEqual(list(a.values()), values) + + def test_items(self): + a = self.cls() + self.assertEqual(list(a.items()), []) + items = [('id', 'int'), ('name', 'text')] + a = self.cls(items) + self.assertEqual(list(a.items()), items) + + def test_get(self): + a = self.cls([('id', 1)]) + try: + self.assertEqual(a['id'], 1) + except KeyError: + self.fail('AttrDict should be readable') + + def test_set(self): + a = self.cls() + try: + a['id'] = 1 + except TypeError: + pass + else: + self.fail('AttrDict should be read-only') + + def test_del(self): + a = self.cls([('id', 1)]) + try: + del a['id'] + except TypeError: + pass + else: + self.fail('AttrDict should be read-only') + + def test_write_methods(self): + a = self.cls([('id', 1)]) + self.assertEqual(a['id'], 1) + for method in 'clear', 'update', 'pop', 'setdefault', 'popitem': + method = getattr(a, method) + self.assertRaises(TypeError, method, a) # type: ignore + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_classic_connection.py b/tests/test_classic_connection.py index 3f0e6bbf..90d69a59 100755 --- a/tests/test_classic_connection.py +++ b/tests/test_classic_connection.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. @@ -10,52 +9,21 @@ These tests need a database to test against. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +from __future__ import annotations + +import os import threading import time -import os - +import unittest from collections import namedtuple -try: - from collections.abc import Iterable -except ImportError: - from collections import Iterable - +from collections.abc import Iterable +from contextlib import suppress from decimal import Decimal +from typing import Any, Sequence import pg # the module under test -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -# These tests should be run with various PostgreSQL versions and databases -# created with different encodings and locales. Particularly, make sure the -# tests are running against databases created with both SQL_ASCII and UTF8. -dbname = 'unittest' -dbhost = None -dbport = 5432 - -try: - from .LOCAL_PyGreSQL import * -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * - except ImportError: - pass - -try: # noinspection PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - -unicode_strings = str is not bytes +from .config import dbhost, dbname, dbpasswd, dbport, dbuser windows = os.name == 'nt' @@ -67,19 +35,55 @@ def connect(): """Create a basic pg connection to the test database.""" - connection = pg.connect(dbname, dbhost, dbport) + # noinspection PyArgumentList + connection = pg.connect(dbname, dbhost, dbport, + user=dbuser, passwd=dbpasswd) connection.query("set client_min_messages=warning") return connection +def connect_nowait(): + """Start a basic pg connection in a non-blocking manner.""" + # noinspection PyArgumentList + return pg.connect(dbname, dbhost, dbport, + user=dbuser, passwd=dbpasswd, nowait=True) + + class TestCanConnect(unittest.TestCase): """Test whether a basic connection to PostgreSQL is possible.""" - def testCanConnect(self): + def test_can_connect(self): try: connection = connect() + rc = connection.poll() + except pg.Error as error: + self.fail(f'Cannot connect to database {dbname}:\n{error}') + self.assertEqual(rc, pg.POLLING_OK) + self.assertIs(connection.is_non_blocking(), False) + connection.set_non_blocking(True) + self.assertIs(connection.is_non_blocking(), True) + connection.set_non_blocking(False) + self.assertIs(connection.is_non_blocking(), False) + try: + connection.close() + except pg.Error: + self.fail('Cannot close the database connection') + + def test_can_connect_no_wait(self): + try: + connection = connect_nowait() + rc = connection.poll() + self.assertIn(rc, (pg.POLLING_READING, pg.POLLING_WRITING)) + while rc not in (pg.POLLING_OK, pg.POLLING_FAILED): + rc = connection.poll() except pg.Error as error: - self.fail('Cannot connect to database %s:\n%s' % (dbname, error)) + self.fail(f'Cannot connect to database {dbname}:\n{error}') + self.assertEqual(rc, pg.POLLING_OK) + self.assertIs(connection.is_non_blocking(), False) + connection.set_non_blocking(True) + self.assertIs(connection.is_non_blocking(), True) + connection.set_non_blocking(False) + self.assertIs(connection.is_non_blocking(), False) try: connection.close() except pg.Error: @@ -93,10 +97,8 @@ def setUp(self): self.connection = connect() def tearDown(self): - try: + with suppress(pg.InternalError): self.connection.close() - except pg.InternalError: - pass def is_method(self, attribute): """Check if given attribute on the connection is a method.""" @@ -104,139 +106,186 @@ def is_method(self, attribute): return False return callable(getattr(self.connection, attribute)) - def testClassName(self): + def test_class_name(self): self.assertEqual(self.connection.__class__.__name__, 'Connection') - def testModuleName(self): + def test_module_name(self): self.assertEqual(self.connection.__class__.__module__, 'pg') - def testStr(self): + def test_str(self): r = str(self.connection) self.assertTrue(r.startswith('= 10.0 + self.assertLess(server_version, 190000) # < 20.0 - def testAttributeSocket(self): + def test_attribute_socket(self): socket = self.connection.socket self.assertIsInstance(socket, int) self.assertGreaterEqual(socket, 0) - def testAttributeBackendPid(self): + def test_attribute_backend_pid(self): backend_pid = self.connection.backend_pid self.assertIsInstance(backend_pid, int) self.assertGreaterEqual(backend_pid, 1) - def testAttributeSslInUse(self): + def test_attribute_ssl_in_use(self): ssl_in_use = self.connection.ssl_in_use self.assertIsInstance(ssl_in_use, bool) self.assertFalse(ssl_in_use) - def testAttributeSslAttributes(self): + def test_attribute_ssl_attributes(self): ssl_attributes = self.connection.ssl_attributes self.assertIsInstance(ssl_attributes, dict) - self.assertEqual(ssl_attributes, { - 'cipher': None, 'compression': None, 'key_bits': None, - 'library': None, 'protocol': None}) + if ssl_attributes: + self.assertEqual(ssl_attributes, { + 'cipher': None, 'compression': None, 'key_bits': None, + 'library': None, 'protocol': None}) - def testAttributeStatus(self): + def test_attribute_status(self): status_ok = 1 self.assertIsInstance(self.connection.status, int) self.assertEqual(self.connection.status, status_ok) - def testAttributeUser(self): + def test_attribute_user(self): no_user = 'Deprecated facility' user = self.connection.user self.assertTrue(user) self.assertIsInstance(user, str) self.assertNotEqual(user, no_user) - def testMethodQuery(self): + def test_method_query(self): query = self.connection.query query("select 1+1") query("select 1+$1", (1,)) query("select 1+$1+$2", (2, 3)) query("select 1+$1+$2", [2, 3]) - def testMethodQueryEmpty(self): + def test_method_query_empty(self): self.assertRaises(ValueError, self.connection.query, '') - def testAllQueryMembers(self): + def test_method_send_query_single(self): + query = self.connection.send_query + for q, args, result in ( + ("select 1+1 as a", (), 2), + ("select 1+$1 as a", ((1,),), 2), + ("select 1+$1+$2 as a", ((2, 3),), 6)): + pgq = query(q, *args) + self.assertEqual(self.connection.transaction(), pg.TRANS_ACTIVE) + self.assertEqual(pgq.getresult()[0][0], result) + self.assertEqual(self.connection.transaction(), pg.TRANS_ACTIVE) + self.assertIsNone(pgq.getresult()) + self.assertEqual(self.connection.transaction(), pg.TRANS_IDLE) + + pgq = query(q, *args) + self.assertEqual(pgq.namedresult()[0].a, result) + self.assertIsNone(pgq.namedresult()) + + pgq = query(q, *args) + self.assertEqual(pgq.dictresult()[0]['a'], result) + self.assertIsNone(pgq.dictresult()) + + def test_method_send_query_multiple(self): + query = self.connection.send_query + + pgq = query("select 1+1; select 'pg';") + self.assertEqual(pgq.getresult()[0][0], 2) + self.assertEqual(pgq.getresult()[0][0], 'pg') + self.assertIsNone(pgq.getresult()) + + pgq = query("select 1+1 as a; select 'pg' as a;") + self.assertEqual(pgq.namedresult()[0].a, 2) + self.assertEqual(pgq.namedresult()[0].a, 'pg') + self.assertIsNone(pgq.namedresult()) + + pgq = query("select 1+1 as a; select 'pg' as a;") + self.assertEqual(pgq.dictresult()[0]['a'], 2) + self.assertEqual(pgq.dictresult()[0]['a'], 'pg') + self.assertIsNone(pgq.dictresult()) + + def test_method_send_query_empty(self): + query = self.connection.send_query('') + self.assertRaises(ValueError, query.getresult) + + def test_all_query_members(self): query = self.connection.query("select true where false") members = ''' - dictiter dictresult fieldname fieldnum getresult listfields - namediter namedresult ntuples one onedict onenamed onescalar - scalariter scalarresult single singledict singlenamed singlescalar + dictiter dictresult fieldinfo fieldname fieldnum getresult + listfields memsize namediter namedresult + one onedict onenamed onescalar scalariter scalarresult + single singledict singlenamed singlescalar '''.split() - query_members = [a for a in dir(query) - if not a.startswith('__') - and a != 'next'] # this is only needed in Python 2 + # noinspection PyUnresolvedReferences + if pg.get_pqlib_version() < 120000: + members.remove('memsize') + query_members = [ + a for a in dir(query) + if not a.startswith('__')] self.assertEqual(members, query_members) - def testMethodEndcopy(self): - try: + def test_method_endcopy(self): + with suppress(OSError): self.connection.endcopy() - except IOError: - pass - def testMethodClose(self): + def test_method_close(self): self.connection.close() try: self.connection.reset() @@ -253,13 +302,13 @@ def testMethodClose(self): self.fail('Query should give an error for a closed connection') self.connection = connect() - def testMethodReset(self): + def test_method_reset(self): query = self.connection.query # check that client encoding gets reset encoding = query('show client_encoding').getresult()[0][0].upper() changed_encoding = 'LATIN1' if encoding == 'UTF8' else 'UTF8' self.assertNotEqual(encoding, changed_encoding) - self.connection.query("set client_encoding=%s" % changed_encoding) + self.connection.query(f"set client_encoding={changed_encoding}") new_encoding = query('show client_encoding').getresult()[0][0].upper() self.assertEqual(new_encoding, changed_encoding) self.connection.reset() @@ -267,12 +316,12 @@ def testMethodReset(self): self.assertNotEqual(new_encoding, changed_encoding) self.assertEqual(new_encoding, encoding) - def testMethodCancel(self): + def test_method_cancel(self): r = self.connection.cancel() self.assertIsInstance(r, int) self.assertEqual(r, 1) - def testCancelLongRunningThread(self): + def test_cancel_long_running_thread(self): errors = [] def sleep(): @@ -297,12 +346,12 @@ def sleep(): self.assertLessEqual(t2 - t1, 3) # time should be under 3 seconds self.assertTrue(errors) - def testMethodFileNo(self): + def test_method_file_no(self): r = self.connection.fileno() self.assertIsInstance(r, int) self.assertGreaterEqual(r, 0) - def testMethodTransaction(self): + def test_method_transaction(self): transaction = self.connection.transaction self.assertRaises(TypeError, transaction, None) self.assertEqual(transaction(), pg.TRANS_IDLE) @@ -311,7 +360,7 @@ def testMethodTransaction(self): self.connection.query('rollback') self.assertEqual(transaction(), pg.TRANS_IDLE) - def testMethodParameter(self): + def test_method_parameter(self): parameter = self.connection.parameter query = self.connection.query self.assertRaises(TypeError, parameter) @@ -345,42 +394,43 @@ def tearDown(self): self.doCleanups() self.c.close() - def testClassName(self): + def test_class_name(self): r = self.c.query("select 1") self.assertEqual(r.__class__.__name__, 'Query') - def testModuleName(self): + def test_module_name(self): r = self.c.query("select 1") self.assertEqual(r.__class__.__module__, 'pg') - def testStr(self): + def test_str(self): q = ("select 1 as a, 'hello' as h, 'w' as world" - " union select 2, 'xyz', 'uvw'") + " union select 2, 'xyz', 'uvw'") r = self.c.query(q) - self.assertEqual(str(r), + self.assertEqual( + str(r), 'a| h |world\n' '-+-----+-----\n' '1|hello|w \n' '2|xyz |uvw \n' '(2 rows)') - def testRepr(self): + def test_repr(self): r = repr(self.c.query("select 1")) self.assertTrue(r.startswith(' 0: + field_name = f'"{field_name}"' + r = f(field_name) + self.assertIsInstance(r, tuple) + self.assertEqual(len(r), 4) + self.assertEqual(r, info) + r = f(field_num) + self.assertIsInstance(r, tuple) + self.assertEqual(len(r), 4) + self.assertEqual(r, info) + self.assertRaises(IndexError, f, 'foobaz') + self.assertRaises(IndexError, f, '"Foobar"') + self.assertRaises(IndexError, f, -1) + self.assertRaises(IndexError, f, 4) + + def test_len(self): q = "select 1 where false" self.assertEqual(len(self.c.query(q)), 0) q = ("select 1 as a, 2 as b, 3 as c, 4 as d" - " union select 5 as a, 6 as b, 7 as c, 8 as d") + " union select 5 as a, 6 as b, 7 as c, 8 as d") self.assertEqual(len(self.c.query(q)), 2) q = ("select 1 union select 2 union select 3" - " union select 4 union select 5 union select 6") + " union select 4 union select 5 union select 6") self.assertEqual(len(self.c.query(q)), 6) - def testQuery(self): + def test_query(self): query = self.c.query query("drop table if exists test_table") self.addCleanup(query, "drop table test_table") @@ -648,12 +752,13 @@ def testQuery(self): r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '4') + # noinspection SqlWithoutWhere q = "delete from test_table" r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '5') - def testQueryWithOids(self): + def test_query_with_oids(self): if self.c.server_version >= 120000: self.skipTest("database does not support tables with oids") query = self.c.query @@ -685,11 +790,27 @@ def testQueryWithOids(self): r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '4') + # noinspection SqlWithoutWhere q = "delete from test_table" r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '5') + def test_mem_size(self): + # noinspection PyUnresolvedReferences + if pg.get_pqlib_version() < 120000: + self.skipTest("pqlib does not support memsize()") + query = self.c.query + q = query("select repeat('foo!', 8)") + size = q.memsize() + self.assertIsInstance(size, int) + self.assertGreaterEqual(size, 32) + self.assertLess(size, 8000) + q = query("select repeat('foo!', 2000)") + size = q.memsize() + self.assertGreaterEqual(size, 8000) + self.assertLess(size, 16000) + class TestUnicodeQueries(unittest.TestCase): """Test unicode strings as queries via a basic pg connection.""" @@ -701,153 +822,136 @@ def setUp(self): def tearDown(self): self.c.close() - def testGetresulAscii(self): - result = u'Hello, world!' - q = u"select '%s'" % result - v = self.c.query(q).getresult()[0][0] + def test_getresul_ascii(self): + result = 'Hello, world!' + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresulAscii(self): - result = u'Hello, world!' - q = u"select '%s' as greeting" % result - v = self.c.query(q).dictresult()[0]['greeting'] + def test_dictresul_ascii(self): + result = 'Hello, world!' + cmd = f"select '{result}' as greeting" + v = self.c.query(cmd).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testGetresultUtf8(self): - result = u'Hello, wörld & мир!' - q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('utf8') + def test_getresult_utf8(self): + result = 'Hello, wörld & мир!' + cmd = f"select '{result}'" # pass the query as unicode try: - v = self.c.query(q).getresult()[0][0] - except(pg.DataError, pg.NotSupportedError): + v = self.c.query(cmd).getresult()[0][0] + except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('utf8') - # pass the query as bytes - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode() + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultUtf8(self): - result = u'Hello, wörld & мир!' - q = u"select '%s' as greeting" % result - if not unicode_strings: - result = result.encode('utf8') + def test_dictresult_utf8(self): + result = 'Hello, wörld & мир!' + cmd = f"select '{result}' as greeting" try: - v = self.c.query(q).dictresult()[0]['greeting'] + v = self.c.query(cmd).dictresult()[0]['greeting'] except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('utf8') - v = self.c.query(q).dictresult()[0]['greeting'] + cmd_bytes = cmd.encode() + v = self.c.query(cmd_bytes).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultLatin1(self): + def test_getresult_latin1(self): try: self.c.query('set client_encoding=latin1') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") - result = u'Hello, wörld!' - q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('latin1') - v = self.c.query(q).getresult()[0][0] + result = 'Hello, wörld!' + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin1') - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode('latin1') + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultLatin1(self): + def test_dictresult_latin1(self): try: self.c.query('set client_encoding=latin1') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") - result = u'Hello, wörld!' - q = u"select '%s' as greeting" % result - if not unicode_strings: - result = result.encode('latin1') - v = self.c.query(q).dictresult()[0]['greeting'] + result = 'Hello, wörld!' + cmd = f"select '{result}' as greeting" + v = self.c.query(cmd).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin1') - v = self.c.query(q).dictresult()[0]['greeting'] + cmd_bytes = cmd.encode('latin1') + v = self.c.query(cmd_bytes).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testGetresultCyrillic(self): + def test_getresult_cyrillic(self): try: self.c.query('set client_encoding=iso_8859_5') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") - result = u'Hello, мир!' - q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('cyrillic') - v = self.c.query(q).getresult()[0][0] + result = 'Hello, мир!' + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('cyrillic') - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode('cyrillic') + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultCyrillic(self): + def test_dictresult_cyrillic(self): try: self.c.query('set client_encoding=iso_8859_5') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") - result = u'Hello, мир!' - q = u"select '%s' as greeting" % result - if not unicode_strings: - result = result.encode('cyrillic') - v = self.c.query(q).dictresult()[0]['greeting'] + result = 'Hello, мир!' + cmd = f"select '{result}' as greeting" + v = self.c.query(cmd).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('cyrillic') - v = self.c.query(q).dictresult()[0]['greeting'] + cmd_bytes = cmd.encode('cyrillic') + v = self.c.query(cmd_bytes).dictresult()[0]['greeting'] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testGetresultLatin9(self): + def test_getresult_latin9(self): try: self.c.query('set client_encoding=latin9') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin9") - result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' - q = u"select '%s'" % result - if not unicode_strings: - result = result.encode('latin9') - v = self.c.query(q).getresult()[0][0] + result = 'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' + cmd = f"select '{result}'" + v = self.c.query(cmd).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin9') - v = self.c.query(q).getresult()[0][0] + cmd_bytes = cmd.encode('latin9') + v = self.c.query(cmd_bytes).getresult()[0][0] self.assertIsInstance(v, str) self.assertEqual(v, result) - def testDictresultLatin9(self): + def test_dictresult_latin9(self): try: self.c.query('set client_encoding=latin9') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin9") - result = u'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' - q = u"select '%s' as menu" % result - if not unicode_strings: - result = result.encode('latin9') - v = self.c.query(q).dictresult()[0]['menu'] + result = 'smœrebrœd with pražská šunka (pay in ¢, £, €, or ¥)' + cmd = f"select '{result}' as menu" + v = self.c.query(cmd).dictresult()[0]['menu'] self.assertIsInstance(v, str) self.assertEqual(v, result) - q = q.encode('latin9') - v = self.c.query(q).dictresult()[0]['menu'] + cmd_bytes = cmd.encode('latin9') + v = self.c.query(cmd_bytes).dictresult()[0]['menu'] self.assertIsInstance(v, str) self.assertEqual(v, result) @@ -862,24 +966,25 @@ def setUp(self): def tearDown(self): self.c.close() - def testQueryWithNoneParam(self): + def test_query_with_none_param(self): self.assertRaises(TypeError, self.c.query, "select $1", None) self.assertRaises(TypeError, self.c.query, "select $1+$2", None, None) - self.assertEqual(self.c.query("select $1::integer", (None,) - ).getresult(), [(None,)]) - self.assertEqual(self.c.query("select $1::text", [None] - ).getresult(), [(None,)]) - self.assertEqual(self.c.query("select $1::text", [[None]] - ).getresult(), [(None,)]) - - def testQueryWithBoolParams(self, bool_enabled=None): + self.assertEqual( + self.c.query("select $1::integer", (None,)).getresult(), [(None,)]) + self.assertEqual( + self.c.query("select $1::text", [None]).getresult(), [(None,)]) + self.assertEqual( + self.c.query("select $1::text", [[None]]).getresult(), [(None,)]) + + def test_query_with_bool_params(self, bool_enabled=None): query = self.c.query + bool_enabled_default = None if bool_enabled is not None: bool_enabled_default = pg.get_bool() pg.set_bool(bool_enabled) try: bool_on = bool_enabled or bool_enabled is None - v_false, v_true = (False, True) if bool_on else 'ft' + v_false, v_true = (False, True) if bool_on else ('f', 't') r_false, r_true = [(v_false,)], [(v_true,)] self.assertEqual(query("select false").getresult(), r_false) self.assertEqual(query("select true").getresult(), r_true) @@ -896,131 +1001,150 @@ def testQueryWithBoolParams(self, bool_enabled=None): self.assertEqual(query(q, (False,)).getresult(), r_false) self.assertEqual(query(q, (True,)).getresult(), r_true) finally: - if bool_enabled is not None: + if bool_enabled_default is not None: pg.set_bool(bool_enabled_default) - def testQueryWithBoolParamsNotDefault(self): - self.testQueryWithBoolParams(bool_enabled=not pg.get_bool()) + def test_query_with_bool_params_not_default(self): + self.test_query_with_bool_params(bool_enabled=not pg.get_bool()) - def testQueryWithIntParams(self): + def test_query_with_int_params(self): query = self.c.query self.assertEqual(query("select 1+1").getresult(), [(2,)]) self.assertEqual(query("select 1+$1", (1,)).getresult(), [(2,)]) self.assertEqual(query("select 1+$1", [1]).getresult(), [(2,)]) self.assertEqual(query("select $1::integer", (2,)).getresult(), [(2,)]) self.assertEqual(query("select $1::text", (2,)).getresult(), [('2',)]) - self.assertEqual(query("select 1+$1::numeric", [1]).getresult(), - [(Decimal('2'),)]) - self.assertEqual(query("select 1, $1::integer", (2,) - ).getresult(), [(1, 2)]) - self.assertEqual(query("select 1 union select $1::integer", (2,) - ).getresult(), [(1,), (2,)]) - self.assertEqual(query("select $1::integer+$2", (1, 2) - ).getresult(), [(3,)]) - self.assertEqual(query("select $1::integer+$2", [1, 2] - ).getresult(), [(3,)]) - self.assertEqual(query("select 0+$1+$2+$3+$4+$5+$6", list(range(6)) - ).getresult(), [(15,)]) - - def testQueryWithStrParams(self): + self.assertEqual( + query("select 1+$1::numeric", [1]).getresult(), [(Decimal('2'),)]) + self.assertEqual( + query("select 1, $1::integer", (2,)).getresult(), [(1, 2)]) + self.assertEqual( + query("select 1 union select $1::integer", (2,)).getresult(), + [(1,), (2,)]) + self.assertEqual( + query("select $1::integer+$2", (1, 2)).getresult(), [(3,)]) + self.assertEqual( + query("select $1::integer+$2", [1, 2]).getresult(), [(3,)]) + self.assertEqual( + query("select 0+$1+$2+$3+$4+$5+$6", list(range(6))).getresult(), + [(15,)]) + + def test_query_with_str_params(self): query = self.c.query - self.assertEqual(query("select $1||', world!'", ('Hello',) - ).getresult(), [('Hello, world!',)]) - self.assertEqual(query("select $1||', world!'", ['Hello'] - ).getresult(), [('Hello, world!',)]) - self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', 'world'), - ).getresult(), [('Hello, world!',)]) - self.assertEqual(query("select $1::text", ('Hello, world!',) - ).getresult(), [('Hello, world!',)]) - self.assertEqual(query("select $1::text,$2::text", ('Hello', 'world') - ).getresult(), [('Hello', 'world')]) - self.assertEqual(query("select $1::text,$2::text", ['Hello', 'world'] - ).getresult(), [('Hello', 'world')]) - self.assertEqual(query("select $1::text union select $2::text", - ('Hello', 'world')).getresult(), [('Hello',), ('world',)]) + self.assertEqual( + query("select $1||', world!'", ('Hello',)).getresult(), + [('Hello, world!',)]) + self.assertEqual( + query("select $1||', world!'", ['Hello']).getresult(), + [('Hello, world!',)]) + self.assertEqual( + query("select $1||', '||$2||'!'", ('Hello', 'world')).getresult(), + [('Hello, world!',)]) + self.assertEqual( + query("select $1::text", ('Hello, world!',)).getresult(), + [('Hello, world!',)]) + self.assertEqual( + query("select $1::text,$2::text", ('Hello', 'world')).getresult(), + [('Hello', 'world')]) + self.assertEqual( + query("select $1::text,$2::text", ['Hello', 'world']).getresult(), + [('Hello', 'world')]) + self.assertEqual( + query("select $1::text union select $2::text", + ('Hello', 'world')).getresult(), + [('Hello',), ('world',)]) try: query("select 'wörld'") except (pg.DataError, pg.NotSupportedError): self.skipTest('database does not support utf8') - self.assertEqual(query("select $1||', '||$2||'!'", ('Hello', - 'w\xc3\xb6rld')).getresult(), [('Hello, w\xc3\xb6rld!',)]) + self.assertEqual( + query("select $1||', '||$2||'!'", + ('Hello', 'w\xc3\xb6rld')).getresult(), + [('Hello, w\xc3\xb6rld!',)]) - def testQueryWithUnicodeParams(self): + def test_query_with_unicode_params(self): query = self.c.query try: query('set client_encoding=utf8') - query("select 'wörld'").getresult()[0][0] == 'wörld' + self.assertEqual( + query("select 'wörld'").getresult()[0][0], 'wörld') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support utf8") - self.assertEqual(query("select $1||', '||$2||'!'", - ('Hello', u'wörld')).getresult(), [('Hello, wörld!',)]) + self.assertEqual( + query("select $1||', '||$2||'!'", ('Hello', 'wörld')).getresult(), + [('Hello, wörld!',)]) - def testQueryWithUnicodeParamsLatin1(self): + def test_query_with_unicode_params_latin1(self): query = self.c.query try: query('set client_encoding=latin1') - query("select 'wörld'").getresult()[0][0] == 'wörld' + self.assertEqual( + query("select 'wörld'").getresult()[0][0], 'wörld') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") - r = query("select $1||', '||$2||'!'", ('Hello', u'wörld')).getresult() - if unicode_strings: - self.assertEqual(r, [('Hello, wörld!',)]) - else: - self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)]) - self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'мир')) + r = query("select $1||', '||$2||'!'", ('Hello', 'wörld')).getresult() + self.assertEqual(r, [('Hello, wörld!',)]) + self.assertRaises( + UnicodeError, query, "select $1||', '||$2||'!'", + ('Hello', 'мир')) query('set client_encoding=iso_8859_1') - r = query("select $1||', '||$2||'!'", - ('Hello', u'wörld')).getresult() - if unicode_strings: - self.assertEqual(r, [('Hello, wörld!',)]) - else: - self.assertEqual(r, [(u'Hello, wörld!'.encode('latin1'),)]) - self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'мир')) + r = query( + "select $1||', '||$2||'!'", ('Hello', 'wörld')).getresult() + self.assertEqual(r, [('Hello, wörld!',)]) + self.assertRaises( + UnicodeError, query, "select $1||', '||$2||'!'", + ('Hello', 'мир')) query('set client_encoding=sql_ascii') - self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'wörld')) + self.assertRaises( + UnicodeError, query, "select $1||', '||$2||'!'", + ('Hello', 'wörld')) - def testQueryWithUnicodeParamsCyrillic(self): + def test_query_with_unicode_params_cyrillic(self): query = self.c.query try: query('set client_encoding=iso_8859_5') - query("select 'мир'").getresult()[0][0] == 'мир' + self.assertEqual( + query("select 'мир'").getresult()[0][0], 'мир') except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support cyrillic") - self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'wörld')) - r = query("select $1||', '||$2||'!'", - ('Hello', u'мир')).getresult() - if unicode_strings: - self.assertEqual(r, [('Hello, мир!',)]) - else: - self.assertEqual(r, [(u'Hello, мир!'.encode('cyrillic'),)]) + self.assertRaises( + UnicodeError, query, "select $1||', '||$2||'!'", + ('Hello', 'wörld')) + r = query( + "select $1||', '||$2||'!'", ('Hello', 'мир')).getresult() + self.assertEqual(r, [('Hello, мир!',)]) query('set client_encoding=sql_ascii') - self.assertRaises(UnicodeError, query, "select $1||', '||$2||'!'", - ('Hello', u'мир!')) - - def testQueryWithMixedParams(self): - self.assertEqual(self.c.query("select $1+2,$2||', world!'", - (1, 'Hello'),).getresult(), [(3, 'Hello, world!')]) - self.assertEqual(self.c.query("select $1::integer,$2::date,$3::text", - (4711, None, 'Hello!'),).getresult(), [(4711, None, 'Hello!')]) - - def testQueryWithDuplicateParams(self): - self.assertRaises(pg.ProgrammingError, - self.c.query, "select $1+$1", (1,)) - self.assertRaises(pg.ProgrammingError, - self.c.query, "select $1+$1", (1, 2)) - - def testQueryWithZeroParams(self): - self.assertEqual(self.c.query("select 1+1", [] - ).getresult(), [(2,)]) - - def testQueryWithGarbage(self): + self.assertRaises( + UnicodeError, query, "select $1||', '||$2||'!'", + ('Hello', 'мир!')) + + def test_query_with_mixed_params(self): + self.assertEqual( + self.c.query( + "select $1+2,$2||', world!'", (1, 'Hello')).getresult(), + [(3, 'Hello, world!')]) + self.assertEqual( + self.c.query( + "select $1::integer,$2::date,$3::text", + (4711, None, 'Hello!')).getresult(), + [(4711, None, 'Hello!')]) + + def test_query_with_duplicate_params(self): + self.assertRaises( + pg.ProgrammingError, self.c.query, "select $1+$1", (1,)) + self.assertRaises( + pg.ProgrammingError, self.c.query, "select $1+$1", (1, 2)) + + def test_query_with_zero_params(self): + self.assertEqual( + self.c.query("select 1+1", []).getresult(), [(2,)]) + + def test_query_with_garbage(self): garbage = r"'\{}+()-#[]oo324" - self.assertEqual(self.c.query("select $1::text AS garbage", (garbage,) - ).dictresult(), [{'garbage': garbage}]) + self.assertEqual( + self.c.query("select $1::text AS garbage", + (garbage,)).dictresult(), + [{'garbage': garbage}]) class TestPreparedQueries(unittest.TestCase): @@ -1033,38 +1157,38 @@ def setUp(self): def tearDown(self): self.c.close() - def testEmptyPreparedStatement(self): + def test_empty_prepared_statement(self): self.c.prepare('', '') self.assertRaises(ValueError, self.c.query_prepared, '') - def testInvalidPreparedStatement(self): + def test_invalid_prepared_statement(self): self.assertRaises(pg.ProgrammingError, self.c.prepare, '', 'bad') - def testDuplicatePreparedStatement(self): + def test_duplicate_prepared_statement(self): self.assertIsNone(self.c.prepare('q', 'select 1')) self.assertRaises(pg.ProgrammingError, self.c.prepare, 'q', 'select 2') - def testNonExistentPreparedStatement(self): - self.assertRaises(pg.OperationalError, - self.c.query_prepared, 'does-not-exist') + def test_non_existent_prepared_statement(self): + self.assertRaises( + pg.OperationalError, self.c.query_prepared, 'does-not-exist') - def testUnnamedQueryWithoutParams(self): + def test_unnamed_query_without_params(self): self.assertIsNone(self.c.prepare('', "select 'anon'")) self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)]) self.assertEqual(self.c.query_prepared('').getresult(), [('anon',)]) - def testNamedQueryWithoutParams(self): + def test_named_query_without_params(self): self.assertIsNone(self.c.prepare('hello', "select 'world'")) - self.assertEqual(self.c.query_prepared('hello').getresult(), - [('world',)]) + self.assertEqual( + self.c.query_prepared('hello').getresult(), [('world',)]) - def testMultipleNamedQueriesWithoutParams(self): + def test_multiple_named_queries_without_params(self): self.assertIsNone(self.c.prepare('query17', "select 17")) self.assertIsNone(self.c.prepare('query42', "select 42")) self.assertEqual(self.c.query_prepared('query17').getresult(), [(17,)]) self.assertEqual(self.c.query_prepared('query42').getresult(), [(42,)]) - def testUnnamedQueryWithParams(self): + def test_unnamed_query_with_params(self): self.assertIsNone(self.c.prepare('', "select $1 || ', ' || $2")) self.assertEqual( self.c.query_prepared('', ['hello', 'world']).getresult(), @@ -1073,29 +1197,31 @@ def testUnnamedQueryWithParams(self): self.assertEqual( self.c.query_prepared('', [17, -5, 29]).getresult(), [(42,)]) - def testMultipleNamedQueriesWithParams(self): + def test_multiple_named_queries_with_params(self): self.assertIsNone(self.c.prepare('q1', "select $1 || '!'")) self.assertIsNone(self.c.prepare('q2', "select $1 || '-' || $2")) - self.assertEqual(self.c.query_prepared('q1', ['hello']).getresult(), + self.assertEqual( + self.c.query_prepared('q1', ['hello']).getresult(), [('hello!',)]) - self.assertEqual(self.c.query_prepared('q2', ['he', 'lo']).getresult(), + self.assertEqual( + self.c.query_prepared('q2', ['he', 'lo']).getresult(), [('he-lo',)]) - def testDescribeNonExistentQuery(self): - self.assertRaises(pg.OperationalError, - self.c.describe_prepared, 'does-not-exist') + def test_describe_non_existent_query(self): + self.assertRaises( + pg.OperationalError, self.c.describe_prepared, 'does-not-exist') - def testDescribeUnnamedQuery(self): + def test_describe_unnamed_query(self): self.c.prepare('', "select 1::int, 'a'::char") r = self.c.describe_prepared('') self.assertEqual(r.listfields(), ('int4', 'bpchar')) - def testDescribeNamedQuery(self): + def test_describe_named_query(self): self.c.prepare('myquery', "select 1 as first, 2 as second") r = self.c.describe_prepared('myquery') self.assertEqual(r.listfields(), ('first', 'second')) - def testDescribeMultipleNamedQueries(self): + def test_describe_multiple_named_queries(self): self.c.prepare('query1', "select 1::int") self.c.prepare('query2', "select 1::int, 2::int") r = self.c.describe_prepared('query1') @@ -1117,17 +1243,19 @@ def tearDown(self): self.c.close() def assert_proper_cast(self, value, pgtype, pytype): - q = 'select $1::%s' % (pgtype,) + q = f'select $1::{pgtype}' try: r = self.c.query(q, (value,)).getresult()[0][0] - except pg.ProgrammingError: + except pg.ProgrammingError as e: if pgtype in ('json', 'jsonb'): self.skipTest('database does not support json') + self.fail(str(e)) + # noinspection PyUnboundLocalVariable self.assertIsInstance(r, pytype) - if isinstance(value, str): - if not value or ' ' in value or '{' in value: - value = '"%s"' % value - value = '{%s}' % value + if isinstance(value, str) and ( + not value or ' ' in value or '{' in value): + value = f'"{value}"' + value = f'{{{value}}}' r = self.c.query(q + '[]', (value,)).getresult()[0][0] if pgtype.startswith(('date', 'time', 'interval')): # arrays of these are casted by the DB wrapper only @@ -1137,37 +1265,36 @@ def assert_proper_cast(self, value, pgtype, pytype): self.assertEqual(len(r), 1) self.assertIsInstance(r[0], pytype) - def testInt(self): + def test_int(self): self.assert_proper_cast(0, 'int', int) self.assert_proper_cast(0, 'smallint', int) self.assert_proper_cast(0, 'oid', int) self.assert_proper_cast(0, 'cid', int) self.assert_proper_cast(0, 'xid', int) - def testLong(self): - self.assert_proper_cast(0, 'bigint', long) + def test_long(self): + self.assert_proper_cast(0, 'bigint', int) - def testFloat(self): + def test_float(self): self.assert_proper_cast(0, 'float', float) self.assert_proper_cast(0, 'real', float) - self.assert_proper_cast(0, 'double', float) self.assert_proper_cast(0, 'double precision', float) self.assert_proper_cast('infinity', 'float', float) - def testFloat(self): + def test_numeric(self): decimal = pg.get_decimal() self.assert_proper_cast(decimal(0), 'numeric', decimal) self.assert_proper_cast(decimal(0), 'decimal', decimal) - def testMoney(self): + def test_money(self): decimal = pg.get_decimal() self.assert_proper_cast(decimal('0'), 'money', decimal) - def testBool(self): + def test_bool(self): bool_type = bool if pg.get_bool() else str self.assert_proper_cast('f', 'bool', bool_type) - def testDate(self): + def test_date(self): self.assert_proper_cast('1956-01-31', 'date', str) self.assert_proper_cast('10:20:30', 'interval', str) self.assert_proper_cast('08:42:15', 'time', str) @@ -1175,16 +1302,16 @@ def testDate(self): self.assert_proper_cast('1956-01-31 08:42:15', 'timestamp', str) self.assert_proper_cast('1956-01-31 08:42:15+00', 'timestamptz', str) - def testText(self): + def test_text(self): self.assert_proper_cast('', 'text', str) self.assert_proper_cast('', 'char', str) self.assert_proper_cast('', 'bpchar', str) self.assert_proper_cast('', 'varchar', str) - def testBytea(self): + def test_bytea(self): self.assert_proper_cast('', 'bytea', bytes) - def testJson(self): + def test_json(self): self.assert_proper_cast('{}', 'json', dict) @@ -1197,56 +1324,57 @@ def setUp(self): def tearDown(self): self.c.close() - def testLen(self): + def test_len(self): r = self.c.query("select generate_series(3,7)") self.assertEqual(len(r), 5) - def testGetItem(self): + def test_get_item(self): r = self.c.query("select generate_series(7,9)") self.assertEqual(r[0], (7,)) self.assertEqual(r[1], (8,)) self.assertEqual(r[2], (9,)) - def testGetItemWithNegativeIndex(self): + def test_get_item_with_negative_index(self): r = self.c.query("select generate_series(7,9)") self.assertEqual(r[-1], (9,)) self.assertEqual(r[-2], (8,)) self.assertEqual(r[-3], (7,)) - def testGetItemOutOfRange(self): + def test_get_item_out_of_range(self): r = self.c.query("select generate_series(7,9)") self.assertRaises(IndexError, r.__getitem__, 3) - def testIterate(self): + def test_iterate(self): r = self.c.query("select generate_series(3,5)") self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) self.assertEqual(list(r), [(3,), (4,), (5,)]) + # noinspection PyUnresolvedReferences self.assertIsInstance(r[1], tuple) - def testIterateTwice(self): + def test_iterate_twice(self): r = self.c.query("select generate_series(3,5)") - for i in range(2): + for _i in range(2): self.assertEqual(list(r), [(3,), (4,), (5,)]) - def testIterateTwoColumns(self): + def test_iterate_two_columns(self): r = self.c.query("select 1,2 union select 3,4") self.assertIsInstance(r, Iterable) self.assertEqual(list(r), [(1, 2), (3, 4)]) - def testNext(self): + def test_next(self): r = self.c.query("select generate_series(7,9)") self.assertEqual(next(r), (7,)) self.assertEqual(next(r), (8,)) self.assertEqual(next(r), (9,)) self.assertRaises(StopIteration, next, r) - def testContains(self): + def test_contains(self): r = self.c.query("select generate_series(7,9)") self.assertIn((8,), r) self.assertNotIn((5,), r) - def testDictIterate(self): + def test_dict_iterate(self): r = self.c.query("select generate_series(3,5) as n").dictiter() self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) @@ -1254,26 +1382,27 @@ def testDictIterate(self): self.assertEqual(r, [dict(n=3), dict(n=4), dict(n=5)]) self.assertIsInstance(r[1], dict) - def testDictIterateTwoColumns(self): - r = self.c.query("select 1 as one, 2 as two" + def test_dict_iterate_two_columns(self): + r = self.c.query( + "select 1 as one, 2 as two" " union select 3 as one, 4 as two").dictiter() self.assertIsInstance(r, Iterable) r = list(r) self.assertEqual(r, [dict(one=1, two=2), dict(one=3, two=4)]) - def testDictNext(self): + def test_dict_next(self): r = self.c.query("select generate_series(7,9) as n").dictiter() self.assertEqual(next(r), dict(n=7)) self.assertEqual(next(r), dict(n=8)) self.assertEqual(next(r), dict(n=9)) self.assertRaises(StopIteration, next, r) - def testDictContains(self): + def test_dict_contains(self): r = self.c.query("select generate_series(7,9) as n").dictiter() self.assertIn(dict(n=8), r) self.assertNotIn(dict(n=5), r) - def testNamedIterate(self): + def test_named_iterate(self): r = self.c.query("select generate_series(3,5) as number").namediter() self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) @@ -1283,8 +1412,9 @@ def testNamedIterate(self): self.assertEqual(r[1]._fields, ('number',)) self.assertEqual(r[1].number, 4) - def testNamedIterateTwoColumns(self): - r = self.c.query("select 1 as one, 2 as two" + def test_named_iterate_two_columns(self): + r = self.c.query( + "select 1 as one, 2 as two" " union select 3 as one, 4 as two").namediter() self.assertIsInstance(r, Iterable) r = list(r) @@ -1294,7 +1424,7 @@ def testNamedIterateTwoColumns(self): self.assertEqual(r[1]._fields, ('one', 'two')) self.assertEqual(r[1].two, 4) - def testNamedNext(self): + def test_named_next(self): r = self.c.query("select generate_series(7,9) as number").namediter() self.assertEqual(next(r), (7,)) self.assertEqual(next(r), (8,)) @@ -1303,12 +1433,12 @@ def testNamedNext(self): self.assertEqual(n.number, 9) self.assertRaises(StopIteration, next, r) - def testNamedContains(self): + def test_named_contains(self): r = self.c.query("select generate_series(7,9)").namediter() self.assertIn((8,), r) self.assertNotIn((5,), r) - def testScalarIterate(self): + def test_scalar_iterate(self): r = self.c.query("select generate_series(3,5)").scalariter() self.assertNotIsInstance(r, (list, tuple)) self.assertIsInstance(r, Iterable) @@ -1316,20 +1446,20 @@ def testScalarIterate(self): self.assertEqual(r, [3, 4, 5]) self.assertIsInstance(r[1], int) - def testScalarIterateTwoColumns(self): + def test_scalar_iterate_two_columns(self): r = self.c.query("select 1, 2 union select 3, 4").scalariter() self.assertIsInstance(r, Iterable) r = list(r) self.assertEqual(r, [1, 3]) - def testScalarNext(self): + def test_scalar_next(self): r = self.c.query("select generate_series(7,9)").scalariter() self.assertEqual(next(r), 7) self.assertEqual(next(r), 8) self.assertEqual(next(r), 9) self.assertRaises(StopIteration, next, r) - def testScalarContains(self): + def test_scalar_contains(self): r = self.c.query("select generate_series(7,9)").scalariter() self.assertIn(8, r) self.assertNotIn(5, r) @@ -1344,46 +1474,46 @@ def setUp(self): def tearDown(self): self.c.close() - def testOneWithEmptyQuery(self): + def test_one_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.one()) - def testOneWithSingleRow(self): + def test_one_with_single_row(self): q = self.c.query("select 1, 2") r = q.one() self.assertIsInstance(r, tuple) self.assertEqual(r, (1, 2)) self.assertEqual(q.one(), None) - def testOneWithTwoRows(self): + def test_one_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") self.assertEqual(q.one(), (1, 2)) self.assertEqual(q.one(), (3, 4)) self.assertEqual(q.one(), None) - def testOneDictWithEmptyQuery(self): + def test_one_dict_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.onedict()) - def testOneDictWithSingleRow(self): + def test_one_dict_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") r = q.onedict() self.assertIsInstance(r, dict) self.assertEqual(r, dict(one=1, two=2)) self.assertEqual(q.onedict(), None) - def testOneDictWithTwoRows(self): + def test_one_dict_with_two_rows(self): q = self.c.query( "select 1 as one, 2 as two union select 3 as one, 4 as two") self.assertEqual(q.onedict(), dict(one=1, two=2)) self.assertEqual(q.onedict(), dict(one=3, two=4)) self.assertEqual(q.onedict(), None) - def testOneNamedWithEmptyQuery(self): + def test_one_named_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.onenamed()) - def testOneNamedWithSingleRow(self): + def test_one_named_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") r = q.onenamed() self.assertEqual(r._fields, ('one', 'two')) @@ -1392,7 +1522,7 @@ def testOneNamedWithSingleRow(self): self.assertEqual(r, (1, 2)) self.assertEqual(q.onenamed(), None) - def testOneNamedWithTwoRows(self): + def test_one_named_with_two_rows(self): q = self.c.query( "select 1 as one, 2 as two union select 3 as one, 4 as two") r = q.onenamed() @@ -1407,35 +1537,35 @@ def testOneNamedWithTwoRows(self): self.assertEqual(r, (3, 4)) self.assertEqual(q.onenamed(), None) - def testOneScalarWithEmptyQuery(self): + def test_one_scalar_with_empty_query(self): q = self.c.query("select 0 where false") self.assertIsNone(q.onescalar()) - def testOneScalarWithSingleRow(self): + def test_one_scalar_with_single_row(self): q = self.c.query("select 1, 2") r = q.onescalar() self.assertIsInstance(r, int) self.assertEqual(r, 1) self.assertEqual(q.onescalar(), None) - def testOneScalarWithTwoRows(self): + def test_one_scalar_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") self.assertEqual(q.onescalar(), 1) self.assertEqual(q.onescalar(), 3) self.assertEqual(q.onescalar(), None) - def testSingleWithEmptyQuery(self): + def test_single_with_empty_query(self): q = self.c.query("select 0 where false") try: q.single() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleWithSingleRow(self): + def test_single_with_single_row(self): q = self.c.query("select 1, 2") r = q.single() self.assertIsInstance(r, tuple) @@ -1444,29 +1574,29 @@ def testSingleWithSingleRow(self): self.assertIsInstance(r, tuple) self.assertEqual(r, (1, 2)) - def testSingleWithTwoRows(self): + def test_single_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.single() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testSingleDictWithEmptyQuery(self): + def test_single_dict_with_empty_query(self): q = self.c.query("select 0 where false") try: q.singledict() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleDictWithSingleRow(self): + def test_single_dict_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") r = q.singledict() self.assertIsInstance(r, dict) @@ -1475,31 +1605,31 @@ def testSingleDictWithSingleRow(self): self.assertIsInstance(r, dict) self.assertEqual(r, dict(one=1, two=2)) - def testSingleDictWithTwoRows(self): + def test_single_dict_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.singledict() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testSingleNamedWithEmptyQuery(self): + def test_single_named_with_empty_query(self): q = self.c.query("select 0 where false") try: q.singlenamed() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleNamedWithSingleRow(self): + def test_single_named_with_single_row(self): q = self.c.query("select 1 as one, 2 as two") - r = q.singlenamed() + r: Any = q.singlenamed() self.assertEqual(r._fields, ('one', 'two')) self.assertEqual(r.one, 1) self.assertEqual(r.two, 2) @@ -1510,29 +1640,29 @@ def testSingleNamedWithSingleRow(self): self.assertEqual(r.two, 2) self.assertEqual(r, (1, 2)) - def testSingleNamedWithTwoRows(self): + def test_single_named_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.singlenamed() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testSingleScalarWithEmptyQuery(self): + def test_single_scalar_with_empty_query(self): q = self.c.query("select 0 where false") try: q.singlescalar() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.NoResultError) self.assertEqual(str(r), 'No result found') - def testSingleScalarWithSingleRow(self): + def test_single_scalar_with_single_row(self): q = self.c.query("select 1, 2") r = q.singlescalar() self.assertIsInstance(r, int) @@ -1541,24 +1671,24 @@ def testSingleScalarWithSingleRow(self): self.assertIsInstance(r, int) self.assertEqual(r, 1) - def testSingleWithTwoRows(self): + def test_single_scalar_with_two_rows(self): q = self.c.query("select 1, 2 union select 3, 4") try: q.singlescalar() except pg.InvalidResultError as e: - r = e + r: Any = e else: r = None self.assertIsInstance(r, pg.MultipleResultsError) self.assertEqual(str(r), 'Multiple results found') - def testScalarResult(self): + def test_scalar_result(self): q = self.c.query("select 1, 2 union select 3, 4") r = q.scalarresult() self.assertIsInstance(r, list) self.assertEqual(r, [1, 3]) - def testScalarIter(self): + def test_scalar_iter(self): q = self.c.query("select 1, 2 union select 3, 4") r = q.scalariter() self.assertNotIsInstance(r, (list, tuple)) @@ -1571,15 +1701,17 @@ class TestInserttable(unittest.TestCase): """Test inserttable method.""" cls_set_up = False + has_encoding = False @classmethod def setUpClass(cls): c = connect() c.query("drop table if exists test cascade") c.query("create table test (" - "i2 smallint, i4 integer, i8 bigint, b boolean, dt date, ti time," - "d numeric, f4 real, f8 double precision, m money," - "c char(1), v4 varchar(4), c4 char(4), t text)") + "i2 smallint, i4 integer, i8 bigint," + "b boolean, dt date, ti time," + "d numeric, f4 real, f8 double precision, m money," + "c char(1), v4 varchar(4), c4 char(4), t text)") # Check whether the test database uses SQL_ASCII - this means # that it does not consider encoding when calculating lengths. c.query("set client_encoding=utf8") @@ -1610,22 +1742,23 @@ def tearDown(self): self.c.query("truncate table test") self.c.close() - data = [ - (-1, -1, long(-1), True, '1492-10-12', '08:30:00', - -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), - (0, 0, long(0), False, '1607-04-14', '09:00:00', - 0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'), - (1, 1, long(1), True, '1801-03-04', '03:45:00', - 1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'), - (2, 2, long(2), False, '1903-12-17', '11:22:00', - 2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')] + data: Sequence[tuple] = [ + (-1, -1, -1, True, '1492-10-12', '08:30:00', + -1.2345, -1.75, -1.875, '-1.25', '-', 'r?', '!u', 'xyz'), + (0, 0, 0, False, '1607-04-14', '09:00:00', + 0.0, 0.0, 0.0, '0.0', ' ', '0123', '4567', '890'), + (1, 1, 1, True, '1801-03-04', '03:45:00', + 1.23456, 1.75, 1.875, '1.25', 'x', 'bc', 'cdef', 'g'), + (2, 2, 2, False, '1903-12-17', '11:22:00', + 2.345678, 2.25, 2.125, '2.75', 'y', 'q', 'ijk', 'mnop\nstux!')] @classmethod def db_len(cls, s, encoding): + # noinspection PyUnresolvedReferences if cls.has_encoding: - s = s if isinstance(s, unicode) else s.decode(encoding) + s = s if isinstance(s, str) else s.decode(encoding) else: - s = s.encode(encoding) if isinstance(s, unicode) else s + s = s.encode(encoding) if isinstance(s, str) else s return len(s) def get_back(self, encoding='utf-8'): @@ -1639,7 +1772,7 @@ def get_back(self, encoding='utf-8'): if row[1] is not None: # integer self.assertIsInstance(row[1], int) if row[2] is not None: # bigint - self.assertIsInstance(row[2], long) + self.assertIsInstance(row[2], int) if row[3] is not None: # boolean self.assertIsInstance(row[3], bool) if row[4] is not None: # date @@ -1675,49 +1808,76 @@ def get_back(self, encoding='utf-8'): data.append(row) return data - def testInserttable1Row(self): + def test_inserttable1_row(self): data = self.data[2:3] self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttable4Rows(self): + def test_inserttable4_rows(self): data = self.data self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttableFromTupleOfLists(self): + def test_inserttable_from_tuple_of_lists(self): data = tuple(list(row) for row in self.data) self.c.inserttable('test', data) self.assertEqual(self.get_back(), self.data) - def testInserttableFromSetofTuples(self): - data = set(row for row in self.data) + def test_inserttable_with_different_row_sizes(self): + data = [*self.data[:-1], (self.data[-1][:-1],)] try: self.c.inserttable('test', data) except TypeError as e: - r = str(e) + self.assertIn( + 'second arg must contain sequences of the same size', str(e)) else: - r = 'this is fine' - self.assertIn('list or a tuple as second argument', r) + self.assertFalse('expected an error') + + def test_inserttable_from_setof_tuples(self): + data = {row for row in self.data} + self.c.inserttable('test', data) + self.assertEqual(self.get_back(), self.data) - def testInserttableFromListOfSets(self): + def test_inserttable_from_dict_as_interable(self): + data = {row: None for row in self.data} + self.c.inserttable('test', data) + self.assertEqual(self.get_back(), self.data) + + def test_inserttable_from_dict_keys(self): + data = {row: None for row in self.data} + keys = data.keys() + self.c.inserttable('test', keys) + self.assertEqual(self.get_back(), self.data) + + def test_inserttable_from_dict_values(self): + data = {i: row for i, row in enumerate(self.data)} + values = data.values() + self.c.inserttable('test', values) + self.assertEqual(self.get_back(), self.data) + + def test_inserttable_from_generator_of_tuples(self): + data = (row for row in self.data) + self.c.inserttable('test', data) + self.assertEqual(self.get_back(), self.data) + + def test_inserttable_from_list_of_sets(self): data = [set(row) for row in self.data] try: self.c.inserttable('test', data) except TypeError as e: - r = str(e) + self.assertIn( + 'second argument must contain tuples or lists', str(e)) else: - r = 'this is fine' - self.assertIn('second argument must contain a tuple or a list', r) + self.assertFalse('expected an error') - def testInserttableMultipleRows(self): + def test_inserttable_multiple_rows(self): num_rows = 100 - data = self.data[2:3] * num_rows + data = list(self.data[2:3]) * num_rows self.c.inserttable('test', data) r = self.c.query("select count(*) from test").getresult()[0][0] self.assertEqual(r, num_rows) - def testInserttableMultipleCalls(self): + def test_inserttable_multiple_calls(self): num_rows = 10 data = self.data[2:3] for _i in range(num_rows): @@ -1725,80 +1885,161 @@ def testInserttableMultipleCalls(self): r = self.c.query("select count(*) from test").getresult()[0][0] self.assertEqual(r, num_rows) - def testInserttableNullValues(self): + def test_inserttable_null_values(self): data = [(None,) * 14] * 100 self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttableMaxValues(self): - data = [(2 ** 15 - 1, int(2 ** 31 - 1), long(2 ** 31 - 1), - True, '2999-12-31', '11:59:59', 1e99, - 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, - "1", "1234", "1234", "1234" * 100)] + def test_inserttable_no_column(self): + data = [()] * 10 + self.c.inserttable('test', data, []) + self.assertEqual(self.get_back(), []) + + def test_inserttable_only_one_column(self): + data: list[tuple] = [(42,)] * 50 + self.c.inserttable('test', data, ['i4']) + data = [tuple([42 if i == 1 else None for i in range(14)])] * 50 + self.assertEqual(self.get_back(), data) + + def test_inserttable_only_two_columns(self): + data: list[tuple] = [(bool(i % 2), i * .5) for i in range(20)] + self.c.inserttable('test', data, ('b', 'f4')) + # noinspection PyTypeChecker + data = [(None,) * 3 + (bool(i % 2),) + (None,) * 3 + (i * .5,) + + (None,) * 6 for i in range(20)] + self.assertEqual(self.get_back(), data) + + def test_inserttable_with_dotted_table_name(self): + data = self.data + self.c.inserttable('public.test', data) + self.assertEqual(self.get_back(), data) + + def test_inserttable_with_invalid_table_name(self): + data = [(42,)] + # check that the table name is not inserted unescaped + # (this would pass otherwise since there is a column named i4) + try: + self.c.inserttable('test (i4)', data) + except ValueError as e: + self.assertIn('relation "test (i4)" does not exist', str(e)) + else: + self.assertFalse('expected an error') + # make sure that it works if parameters are passed properly + self.c.inserttable('test', data, ['i4']) + + def test_inserttable_with_invalid_data_type(self): + try: + self.c.inserttable('test', 42) + except TypeError as e: + self.assertIn('expects an iterable as second argument', str(e)) + else: + self.assertFalse('expected an error') + + def test_inserttable_with_invalid_column_name(self): + data = [(2, 4)] + # check that the column names are not inserted unescaped + # (this would pass otherwise since there are columns i2 and i4) + try: + self.c.inserttable('test', data, ['i2,i4']) + except ValueError as e: + self.assertIn( + 'column "i2,i4" of relation "test" does not exist', str(e)) + else: + self.assertFalse('expected an error') + # make sure that it works if parameters are passed properly + self.c.inserttable('test', data, ['i2', 'i4']) + + def test_inserttable_with_invalid_colum_list(self): + data = self.data + try: + self.c.inserttable('test', data, 'invalid') + except TypeError as e: + self.assertIn( + 'expects a tuple or a list as third argument', str(e)) + else: + self.assertFalse('expected an error') + + def test_inserttable_with_huge_list_of_column_names(self): + data = self.data + # try inserting data with a huge list of column names + cols = ['very_long_column_name'] * 2000 + # Should raise a value error because the column does not exist + self.assertRaises(ValueError, self.c.inserttable, 'test', data, cols) + # double the size, should catch buffer overflow and raise memory error + cols *= 2 + self.assertRaises(MemoryError, self.c.inserttable, 'test', data, cols) + + def test_inserttable_with_out_of_range_data(self): + # try inserting data out of range for the column type + # Should raise a value error because of smallint out of range + self.assertRaises( + ValueError, self.c.inserttable, 'test', [[33000]], ['i2']) + + def test_inserttable_max_values(self): + data = [(2 ** 15 - 1, 2 ** 31 - 1, 2 ** 31 - 1, + True, '2999-12-31', '11:59:59', 1e99, + 1.0 + 1.0 / 32, 1.0 + 1.0 / 32, None, + "1", "1234", "1234", "1234" * 100)] self.c.inserttable('test', data) self.assertEqual(self.get_back(), data) - def testInserttableByteValues(self): + def test_inserttable_byte_values(self): try: self.c.query("select '€', 'käse', 'сыр', 'pont-l''évêque'") except pg.DataError: self.skipTest("database does not support utf8") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' - row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") - row_bytes = tuple(s.encode('utf-8') - if isinstance(s, unicode) else s for s in row_unicode) - data = [row_bytes] * 2 - self.c.inserttable('test', data) - if unicode_strings: - data = [row_unicode] * 2 - self.assertEqual(self.get_back(), data) - - def testInserttableUnicodeUtf8(self): + c = '€' if self.has_encoding else '$' + row_unicode = ( + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "käse сыр pont-l'évêque") + row_bytes = tuple( + s.encode() if isinstance(s, str) else s + for s in row_unicode) + data_bytes = [row_bytes] * 2 + self.c.inserttable('test', data_bytes) + data_unicode = [row_unicode] * 2 + self.assertEqual(self.get_back(), data_unicode) + + def test_inserttable_unicode_utf8(self): try: self.c.query("select '€', 'käse', 'сыр', 'pont-l''évêque'") except pg.DataError: self.skipTest("database does not support utf8") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' - row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"käse сыр pont-l'évêque") + c = '€' if self.has_encoding else '$' + row_unicode = ( + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "käse сыр pont-l'évêque") data = [row_unicode] * 2 self.c.inserttable('test', data) - if not unicode_strings: - row_bytes = tuple(s.encode('utf-8') - if isinstance(s, unicode) else s for s in row_unicode) - data = [row_bytes] * 2 self.assertEqual(self.get_back(), data) - def testInserttableUnicodeLatin1(self): + def test_inserttable_unicode_latin1(self): try: self.c.query("set client_encoding=latin1") self.c.query("select '¥'") except (pg.DataError, pg.NotSupportedError): self.skipTest("database does not support latin1") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' - row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") + c = '€' if self.has_encoding else '$' + row_unicode: tuple = ( + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] # cannot encode € sign with latin1 encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) - row_unicode = tuple(s.replace(u'€', u'¥') - if isinstance(s, unicode) else s for s in row_unicode) + row_unicode = tuple( + s.replace('€', '¥') if isinstance(s, str) else s + for s in row_unicode) data = [row_unicode] * 2 self.c.inserttable('test', data) - if not unicode_strings: - row_bytes = tuple(s.encode('latin1') - if isinstance(s, unicode) else s for s in row_unicode) - data = [row_bytes] * 2 self.assertEqual(self.get_back('latin1'), data) - def testInserttableUnicodeLatin9(self): + def test_inserttable_unicode_latin9(self): try: self.c.query("set client_encoding=latin9") self.c.query("select '€'") @@ -1806,29 +2047,77 @@ def testInserttableUnicodeLatin9(self): self.skipTest("database does not support latin9") return # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' - row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") + c = '€' if self.has_encoding else '$' + row_unicode = ( + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] * 2 self.c.inserttable('test', data) - if not unicode_strings: - row_bytes = tuple(s.encode('latin9') - if isinstance(s, unicode) else s for s in row_unicode) - data = [row_bytes] * 2 self.assertEqual(self.get_back('latin9'), data) - def testInserttableNoEncoding(self): + def test_inserttable_no_encoding(self): self.c.query("set client_encoding=sql_ascii") # non-ascii chars do not fit in char(1) when there is no encoding - c = u'€' if self.has_encoding else u'$' - row_unicode = (0, 0, long(0), False, u'1970-01-01', u'00:00:00', - 0.0, 0.0, 0.0, u'0.0', - c, u'bäd', u'bäd', u"for käse and pont-l'évêque pay in €") + c = '€' if self.has_encoding else '$' + row_unicode = ( + 0, 0, 0, False, '1970-01-01', '00:00:00', + 0.0, 0.0, 0.0, '0.0', + c, 'bäd', 'bäd', "for käse and pont-l'évêque pay in €") data = [row_unicode] # cannot encode non-ascii unicode without a specific encoding self.assertRaises(UnicodeEncodeError, self.c.inserttable, 'test', data) + def test_inserttable_from_query(self): + data = self.c.query( + "select 2::int2 as i2, 4::int4 as i4, 8::int8 as i8, true as b," + "null as dt, null as ti, null as d," + "4.5::float as float4, 8.5::float8 as f8," + "null as m, 'c' as c, 'v4' as v4, null as c4, 'text' as text") + self.c.inserttable('test', data) + self.assertEqual(self.get_back(), [ + (2, 4, 8, True, None, None, None, 4.5, 8.5, + None, 'c', 'v4', None, 'text')]) + + def test_inserttable_special_chars(self): + class S: + def __repr__(self): + return s + + s = '1\'2"3\b4\f5\n6\r7\t8\b9\\0' + s1 = s.encode('ascii') + s2 = S() + data = [(t,) for t in (s, s1, s2)] + self.c.inserttable('test', data, ['t']) + self.assertEqual( + self.c.query('select t from test').getresult(), [(s,)] * 3) + + def test_insert_table_big_row_size(self): + # inserting rows with a size of up to 64k bytes should work + t = '*' * 50000 + data = [(t,)] + self.c.inserttable('test', data, ['t']) + self.assertEqual( + self.c.query('select t from test').getresult(), data) + # double the size, should catch buffer overflow and raise memory error + t *= 2 + data = [(t,)] + self.assertRaises(MemoryError, self.c.inserttable, 'test', data, ['t']) + + def test_insert_table_small_int_overflow(self): + rest_row = self.data[2][1:] + data = [(32000, *rest_row)] + self.c.inserttable('test', data) + self.assertEqual(self.get_back(), data) + data = [(33000, *rest_row)] + try: + self.c.inserttable('test', data) + except ValueError as e: + self.assertIn( + 'value "33000" is out of range for type smallint', str(e)) + else: + self.assertFalse('expected an error') + class TestDirectSocketAccess(unittest.TestCase): """Test copy command with direct socket access.""" @@ -1858,21 +2147,20 @@ def tearDown(self): self.c.query("truncate table test") self.c.close() - def testPutline(self): + def test_putline(self): putline = self.c.putline query = self.c.query data = list(enumerate("apple pear plum cherry banana".split())) query("copy test from stdin") try: for i, v in data: - putline("%d\t%s\n" % (i, v)) - putline("\\.\n") + putline(f"{i}\t{v}\n") finally: self.c.endcopy() r = query("select * from test").getresult() self.assertEqual(r, data) - def testPutlineBytesAndUnicode(self): + def test_putline_bytes_and_unicode(self): putline = self.c.putline query = self.c.query try: @@ -1881,15 +2169,14 @@ def testPutlineBytesAndUnicode(self): self.skipTest('database does not support utf8') query("copy test from stdin") try: - putline(u"47\tkäse\n".encode('utf8')) + putline("47\tkäse\n".encode()) putline("35\twürstel\n") - putline(b"\\.\n") finally: self.c.endcopy() r = query("select * from test").getresult() self.assertEqual(r, [(47, 'käse'), (35, 'würstel')]) - def testGetline(self): + def test_getline(self): getline = self.c.getline query = self.c.query data = list(enumerate("apple banana pear plum strawberry".split())) @@ -1897,28 +2184,25 @@ def testGetline(self): self.c.inserttable('test', data) query("copy test to stdout") try: - for i in range(n + 2): + for i in range(n + 1): v = getline() if i < n: - self.assertEqual(v, '%d\t%s' % data[i]) + # noinspection PyStringFormat + self.assertEqual(v, '{}\t{}'.format(*data[i])) elif i == n: - self.assertEqual(v, '\\.') - else: self.assertIsNone(v) finally: - try: + with suppress(OSError): self.c.endcopy() - except IOError: - pass - def testGetlineBytesAndUnicode(self): + def test_getline_bytes_and_unicode(self): getline = self.c.getline query = self.c.query try: query("select 'käse+würstel'") except (pg.DataError, pg.NotSupportedError): self.skipTest('database does not support utf8') - data = [(54, u'käse'.encode('utf8')), (73, u'würstel')] + data = [(54, 'käse'.encode()), (73, 'würstel')] self.c.inserttable('test', data) query("copy test to stdout") try: @@ -1928,15 +2212,12 @@ def testGetlineBytesAndUnicode(self): v = getline() self.assertIsInstance(v, str) self.assertEqual(v, '73\twürstel') - self.assertEqual(getline(), '\\.') self.assertIsNone(getline()) finally: - try: + with suppress(OSError): self.c.endcopy() - except IOError: - pass - def testParameterChecks(self): + def test_parameter_checks(self): self.assertRaises(TypeError, self.c.putline) self.assertRaises(TypeError, self.c.getline, 'invalid') self.assertRaises(TypeError, self.c.endcopy, 'invalid') @@ -1952,7 +2233,7 @@ def tearDown(self): self.doCleanups() self.c.close() - def testGetNotify(self): + def test_get_notify(self): getnotify = self.c.getnotify query = self.c.query self.assertIsNone(getnotify()) @@ -1971,7 +2252,7 @@ def testGetNotify(self): self.assertIsNone(self.c.getnotify()) query("notify test_notify, 'test_payload'") r = getnotify() - self.assertTrue(isinstance(r, tuple)) + self.assertIsInstance(r, tuple) self.assertEqual(len(r), 3) self.assertIsInstance(r[0], str) self.assertIsInstance(r[1], int) @@ -1982,23 +2263,23 @@ def testGetNotify(self): finally: query('unlisten test_notify') - def testGetNoticeReceiver(self): + def test_get_notice_receiver(self): self.assertIsNone(self.c.get_notice_receiver()) - def testSetNoticeReceiver(self): + def test_set_notice_receiver(self): self.assertRaises(TypeError, self.c.set_notice_receiver, 42) self.assertRaises(TypeError, self.c.set_notice_receiver, 'invalid') self.assertIsNone(self.c.set_notice_receiver(lambda notice: None)) self.assertIsNone(self.c.set_notice_receiver(None)) - def testSetAndGetNoticeReceiver(self): - r = lambda notice: None + def test_set_and_get_notice_receiver(self): + r = lambda notice: None # noqa: E731 self.assertIsNone(self.c.set_notice_receiver(r)) self.assertIs(self.c.get_notice_receiver(), r) self.assertIsNone(self.c.set_notice_receiver(None)) self.assertIsNone(self.c.get_notice_receiver()) - def testNoticeReceiver(self): + def test_notice_receiver(self): self.addCleanup(self.c.query, 'drop function bilbo_notice();') self.c.query('''create function bilbo_notice() returns void AS $$ begin @@ -2040,7 +2321,7 @@ def setUp(self): def tearDown(self): self.c.close() - def testGetDecimalPoint(self): + def test_get_decimal_point(self): point = pg.get_decimal_point() # error if a parameter is passed self.assertRaises(TypeError, pg.get_decimal_point, point) @@ -2073,8 +2354,8 @@ def testGetDecimalPoint(self): pg.set_decimal_point(point) self.assertIsNone(r) - def testSetDecimalPoint(self): - d = pg.Decimal + def test_set_decimal_point(self): + d = Decimal point = pg.get_decimal_point() self.assertRaises(TypeError, pg.set_decimal_point) # error if decimal point is not a string @@ -2094,12 +2375,13 @@ def testSetDecimalPoint(self): en_locales = 'en', 'en_US', 'en_US.utf8', 'en_US.UTF-8' en_money = '$34.25', '$ 34.25', '34.25$', '34.25 $', '34.25 Dollar' de_locales = 'de', 'de_DE', 'de_DE.utf8', 'de_DE.UTF-8' - de_money = ('34,25€', '34,25 €', '€34,25', '€ 34,25', + de_money = ( + '34,25€', '34,25 €', '€34,25', '€ 34,25', 'EUR34,25', 'EUR 34,25', '34,25 EUR', '34,25 Euro', '34,25 DM') # first try with English localization (using the point) for lc in en_locales: try: - query("set lc_monetary='%s'" % lc) + query(f"set lc_monetary='{lc}'") except pg.DataError: pass else: @@ -2150,7 +2432,7 @@ def testSetDecimalPoint(self): # then try with German localization (using the comma) for lc in de_locales: try: - query("set lc_monetary='%s'" % lc) + query(f"set lc_monetary='{lc}'") except pg.DataError: pass else: @@ -2196,11 +2478,11 @@ def testSetDecimalPoint(self): pg.set_decimal_point(point) self.assertEqual(r, bad_money) - def testGetDecimal(self): + def test_get_decimal(self): decimal_class = pg.get_decimal() # error if a parameter is passed self.assertRaises(TypeError, pg.get_decimal, decimal_class) - self.assertIs(decimal_class, pg.Decimal) # the default setting + self.assertIs(decimal_class, Decimal) # the default setting pg.set_decimal(int) try: r = pg.get_decimal() @@ -2210,7 +2492,7 @@ def testGetDecimal(self): r = pg.get_decimal() self.assertIs(r, decimal_class) - def testSetDecimal(self): + def test_set_decimal(self): decimal_class = pg.get_decimal() # error if no parameter is passed self.assertRaises(TypeError, pg.set_decimal) @@ -2230,9 +2512,9 @@ def testSetDecimal(self): pg.set_decimal(decimal_class) self.assertNotIsInstance(r, decimal_class) self.assertIsInstance(r, int) - self.assertEqual(r, int(3425)) + self.assertEqual(r, 3425) - def testGetBool(self): + def test_get_bool(self): use_bool = pg.get_bool() # error if a parameter is passed self.assertRaises(TypeError, pg.get_bool, use_bool) @@ -2267,7 +2549,7 @@ def testGetBool(self): self.assertIsInstance(r, bool) self.assertIs(r, True) - def testSetBool(self): + def test_set_bool(self): use_bool = pg.get_bool() # error if no parameter is passed self.assertRaises(TypeError, pg.set_bool) @@ -2285,7 +2567,7 @@ def testSetBool(self): finally: pg.set_bool(use_bool) self.assertIsInstance(r, str) - self.assertIs(r, 't') + self.assertEqual(r, 't') pg.set_bool(True) try: r = query("select true::bool").getresult()[0][0] @@ -2294,7 +2576,7 @@ def testSetBool(self): self.assertIsInstance(r, bool) self.assertIs(r, True) - def testGetByteEscaped(self): + def test_get_byte_escaped(self): bytea_escaped = pg.get_bytea_escaped() # error if a parameter is passed self.assertRaises(TypeError, pg.get_bytea_escaped, bytea_escaped) @@ -2329,7 +2611,7 @@ def testGetByteEscaped(self): self.assertIsInstance(r, bool) self.assertIs(r, False) - def testSetByteaEscaped(self): + def test_set_bytea_escaped(self): bytea_escaped = pg.get_bytea_escaped() # error if no parameter is passed self.assertRaises(TypeError, pg.set_bytea_escaped) @@ -2356,16 +2638,13 @@ def testSetByteaEscaped(self): self.assertIsInstance(r, bytes) self.assertEqual(r, b'data') - def testSetRowFactorySize(self): - try: - from functools import lru_cache - except ImportError: # Python < 3.2 - lru_cache = None + def test_change_row_factory_cache_size(self): + cache = pg.RowCache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] query = self.c.query for maxsize in (None, 0, 1, 2, 3, 10, 1024): - pg.set_row_factory_size(maxsize) - for i in range(3): + cache.change_size(maxsize) + for _i in range(3): for q in queries: r = query(q).namedresult()[0] if q.endswith('abc'): @@ -2374,12 +2653,11 @@ def testSetRowFactorySize(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - if lru_cache: - info = pg._row_factory.cache_info() - self.assertEqual(info.maxsize, maxsize) - self.assertEqual(info.hits + info.misses, 6) - self.assertEqual(info.hits, - 0 if maxsize is not None and maxsize < 2 else 4) + info = cache.row_factory.cache_info() + self.assertEqual(info.maxsize, maxsize) + self.assertEqual(info.hits + info.misses, 6) + self.assertEqual(info.hits, + 0 if maxsize is not None and maxsize < 2 else 4) class TestStandaloneEscapeFunctions(unittest.TestCase): @@ -2400,49 +2678,45 @@ def setUpClass(cls): query = db.query query('set client_encoding=sql_ascii') query('set standard_conforming_strings=off') - try: - query('set bytea_output=escape') - except pg.ProgrammingError: - if db.server_version >= 90000: - raise # ignore for older server versions + query('set bytea_output=escape') db.close() cls.cls_set_up = True - def testEscapeString(self): + def test_escape_string(self): self.assertTrue(self.cls_set_up) f = pg.escape_string - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f(u'plain') - self.assertIsInstance(r, unicode) - self.assertEqual(r, u'plain') - r = f(u"das is' käse".encode('utf-8')) - self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is'' käse".encode('utf-8')) - r = f(u"that's cheesy") - self.assertIsInstance(r, unicode) - self.assertEqual(r, u"that''s cheesy") - r = f(r"It's bad to have a \ inside.") - self.assertEqual(r, r"It''s bad to have a \\ inside.") - - def testEscapeBytea(self): + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + b = f("das is' käse".encode()) + self.assertIsInstance(b, bytes) + self.assertEqual(b, "das is'' käse".encode()) + s = f("that's cheesy") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheesy") + s = f(r"It's bad to have a \ inside.") + self.assertEqual(s, r"It''s bad to have a \\ inside.") + + def test_escape_bytea(self): self.assertTrue(self.cls_set_up) f = pg.escape_bytea - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f(u'plain') - self.assertIsInstance(r, unicode) - self.assertEqual(r, u'plain') - r = f(u"das is' käse".encode('utf-8')) - self.assertIsInstance(r, bytes) - self.assertEqual(r, b"das is'' k\\\\303\\\\244se") - r = f(u"that's cheesy") - self.assertIsInstance(r, unicode) - self.assertEqual(r, u"that''s cheesy") - r = f(b'O\x00ps\xff!') - self.assertEqual(r, b'O\\\\000ps\\\\377!') + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + b = f("das is' käse".encode()) + self.assertIsInstance(b, bytes) + self.assertEqual(b, b"das is'' k\\\\303\\\\244se") + s = f("that's cheesy") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheesy") + b = f(b'O\x00ps\xff!') + self.assertEqual(b, b'O\\\\000ps\\\\377!') if __name__ == '__main__': diff --git a/tests/test_classic_dbwrapper.py b/tests/test_classic_dbwrapper.py index d9222560..1d64c754 100755 --- a/tests/test_classic_dbwrapper.py +++ b/tests/test_classic_dbwrapper.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. @@ -10,61 +9,28 @@ These tests need a database to test against. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +from __future__ import annotations -import os -import sys import gc import json +import os +import sys import tempfile - -import pg # the module under test - +import unittest +from contextlib import suppress +from datetime import date, datetime, time, timedelta from decimal import Decimal -from datetime import date, time, datetime, timedelta -from uuid import UUID -from time import strftime +from io import StringIO from operator import itemgetter +from time import strftime +from typing import Any, Callable, ClassVar +from uuid import UUID -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -# The current user must have create schema privilege on the database. -dbname = 'unittest' -dbhost = None -dbport = 5432 +import pg # the module under test -debug = False # let DB wrapper print debugging output +from .config import dbhost, dbname, dbpasswd, dbport, dbuser -try: - from .LOCAL_PyGreSQL import * -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * - except ImportError: - pass - -try: # noinspection PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str - -try: - from collections import OrderedDict -except ImportError: # Python 2.6 or 3.0 - OrderedDict = dict - -if str is bytes: # noinspection PyUnresolvedReferences - from StringIO import StringIO -else: - from io import StringIO +debug = False # let DB wrapper print debugging output windows = os.name == 'nt' @@ -74,110 +40,45 @@ do_not_ask_for_host_reason = 'libpq issue on Windows' -def DB(): +def DB(): # noqa: N802 """Create a DB wrapper object connecting to the test database.""" - db = pg.DB(dbname, dbhost, dbport) + db = pg.DB(dbname, dbhost, dbport, user=dbuser, passwd=dbpasswd) if debug: db.debug = debug db.query("set client_min_messages=warning") return db -class TestAttrDict(unittest.TestCase): - """Test the simple ordered dictionary for attribute names.""" - - cls = pg.AttrDict - base = OrderedDict - - def testInit(self): - a = self.cls() - self.assertIsInstance(a, self.base) - self.assertEqual(a, self.base()) - items = [('id', 'int'), ('name', 'text')] - a = self.cls(items) - self.assertIsInstance(a, self.base) - self.assertEqual(a, self.base(items)) - iteritems = iter(items) - a = self.cls(iteritems) - self.assertIsInstance(a, self.base) - self.assertEqual(a, self.base(items)) - - def testIter(self): - a = self.cls() - self.assertEqual(list(a), []) - keys = ['id', 'name', 'age'] - items = [(key, None) for key in keys] - a = self.cls(items) - self.assertEqual(list(a), keys) - - def testKeys(self): - a = self.cls() - self.assertEqual(list(a.keys()), []) - keys = ['id', 'name', 'age'] - items = [(key, None) for key in keys] - a = self.cls(items) - self.assertEqual(list(a.keys()), keys) - - def testValues(self): - a = self.cls() - self.assertEqual(list(a.values()), []) - items = [('id', 'int'), ('name', 'text')] - values = [item[1] for item in items] - a = self.cls(items) - self.assertEqual(list(a.values()), values) - - def testItems(self): - a = self.cls() - self.assertEqual(list(a.items()), []) - items = [('id', 'int'), ('name', 'text')] - a = self.cls(items) - self.assertEqual(list(a.items()), items) - - def testGet(self): - a = self.cls([('id', 1)]) - try: - self.assertEqual(a['id'], 1) - except KeyError: - self.fail('AttrDict should be readable') - - def testSet(self): - a = self.cls() - try: - a['id'] = 1 - except TypeError: - pass - else: - self.fail('AttrDict should be read-only') - - def testDel(self): - a = self.cls([('id', 1)]) - try: - del a['id'] - except TypeError: - pass - else: - self.fail('AttrDict should be read-only') - - def testWriteMethods(self): - a = self.cls([('id', 1)]) - self.assertEqual(a['id'], 1) - for method in 'clear', 'update', 'pop', 'setdefault', 'popitem': - method = getattr(a, method) - self.assertRaises(TypeError, method, a) - - class TestDBClassInit(unittest.TestCase): """Test proper handling of errors when creating DB instances.""" - def testBadParams(self): + def test_bad_params(self): self.assertRaises(TypeError, pg.DB, invalid=True) - def testDeleteDb(self): + # noinspection PyUnboundLocalVariable + def test_delete_db(self): db = DB() del db.db self.assertRaises(pg.InternalError, db.close) del db + def test_async_query_before_deletion(self): + db = DB() + query = db.send_query('select 1') + self.assertEqual(query.getresult(), [(1,)]) + self.assertIsNone(query.getresult()) + self.assertIsNone(query.getresult()) + del db + gc.collect() + + def test_async_query_after_deletion(self): + db = DB() + query = db.send_query('select 1') + del db + gc.collect() + self.assertIsNone(query.getresult()) + self.assertIsNone(query.getresult()) + class TestDBClassBasic(unittest.TestCase): """Test existence of the DB class wrapped pg connection methods.""" @@ -186,12 +87,10 @@ def setUp(self): self.db = DB() def tearDown(self): - try: + with suppress(pg.InternalError): self.db.close() - except pg.InternalError: - pass - def testAllDBAttributes(self): + def test_all_db_attributes(self): attributes = [ 'abort', 'adapter', 'backend_pid', 'begin', @@ -204,21 +103,21 @@ def testAllDBAttributes(self): 'escape_literal', 'escape_string', 'fileno', 'get', 'get_as_dict', 'get_as_list', - 'get_attnames', 'get_cast_hook', - 'get_databases', 'get_notice_receiver', + 'get_attnames', 'get_cast_hook', 'get_databases', + 'get_generated', 'get_notice_receiver', 'get_parameter', 'get_relations', 'get_tables', 'getline', 'getlo', 'getnotify', 'has_table_privilege', 'host', - 'insert', 'inserttable', + 'insert', 'inserttable', 'is_non_blocking', 'locreate', 'loimport', 'notification_handler', 'options', - 'parameter', 'pkey', 'port', + 'parameter', 'pkey', 'pkeys', 'poll', 'port', 'prepare', 'protocol_version', 'putline', 'query', 'query_formatted', 'query_prepared', 'release', 'reopen', 'reset', 'rollback', - 'savepoint', 'server_version', - 'set_cast_hook', 'set_notice_receiver', + 'savepoint', 'send_query', 'server_version', + 'set_cast_hook', 'set_non_blocking', 'set_notice_receiver', 'set_parameter', 'socket', 'source', 'ssl_attributes', 'ssl_in_use', 'start', 'status', @@ -226,89 +125,84 @@ def testAllDBAttributes(self): 'unescape_bytea', 'update', 'upsert', 'use_regtypes', 'user', ] - # __dir__ is not called in Python 2.6 for old-style classes - db_attributes = dir(self.db) if hasattr( - self.db.__class__, '__class__') else self.db.__dir__() - db_attributes = [a for a in db_attributes - if not a.startswith('_')] + db_attributes = [a for a in self.db.__dir__() if not a.startswith('_')] self.assertEqual(attributes, db_attributes) - def testAttributeDb(self): + def test_attribute_db(self): self.assertEqual(self.db.db.db, dbname) - def testAttributeDbname(self): + def test_attribute_dbname(self): self.assertEqual(self.db.dbname, dbname) - def testAttributeError(self): + def test_attribute_error(self): error = self.db.error self.assertTrue(not error or 'krb5_' in error) self.assertEqual(self.db.error, self.db.db.error) @unittest.skipIf(do_not_ask_for_host, do_not_ask_for_host_reason) - def testAttributeHost(self): - if dbhost and not dbhost.startswith('/'): - host = dbhost - else: - host = 'localhost' + def test_attribute_host(self): + host = dbhost if dbhost and not dbhost.startswith('/') else 'localhost' self.assertIsInstance(self.db.host, str) self.assertEqual(self.db.host, host) self.assertEqual(self.db.db.host, host) - def testAttributeOptions(self): + def test_attribute_options(self): no_options = '' options = self.db.options self.assertEqual(options, no_options) self.assertEqual(options, self.db.db.options) - def testAttributePort(self): + def test_attribute_port(self): def_port = 5432 port = self.db.port self.assertIsInstance(port, int) self.assertEqual(port, dbport or def_port) self.assertEqual(port, self.db.db.port) - def testAttributeProtocolVersion(self): + def test_attribute_protocol_version(self): protocol_version = self.db.protocol_version self.assertIsInstance(protocol_version, int) self.assertTrue(2 <= protocol_version < 4) self.assertEqual(protocol_version, self.db.db.protocol_version) - def testAttributeServerVersion(self): + def test_attribute_server_version(self): server_version = self.db.server_version self.assertIsInstance(server_version, int) - self.assertTrue(90000 <= server_version < 130000) + self.assertGreaterEqual(server_version, 100000) # >= 10.0 + self.assertLess(server_version, 200000) # < 20.0 self.assertEqual(server_version, self.db.db.server_version) - def testAttributeSocket(self): + def test_attribute_socket(self): socket = self.db.socket self.assertIsInstance(socket, int) self.assertGreaterEqual(socket, 0) - def testAttributeBackendPid(self): + def test_attribute_backend_pid(self): backend_pid = self.db.backend_pid self.assertIsInstance(backend_pid, int) self.assertGreaterEqual(backend_pid, 1) - def testAttributeSslInUse(self): + def test_attribute_ssl_in_use(self): ssl_in_use = self.db.ssl_in_use self.assertIsInstance(ssl_in_use, bool) self.assertFalse(ssl_in_use) - def testAttributeSslAttributes(self): + def test_attribute_ssl_attributes(self): ssl_attributes = self.db.ssl_attributes self.assertIsInstance(ssl_attributes, dict) - self.assertEqual(ssl_attributes, { - 'cipher': None, 'compression': None, 'key_bits': None, - 'library': None, 'protocol': None}) + if ssl_attributes: + self.assertEqual(ssl_attributes, { + 'cipher': None, 'compression': None, 'key_bits': None, + 'library': None, 'protocol': None}) - def testAttributeStatus(self): + def test_attribute_status(self): status_ok = 1 status = self.db.status self.assertIsInstance(status, int) self.assertEqual(status, status_ok) self.assertEqual(status, self.db.db.status) - def testAttributeUser(self): + def test_attribute_user(self): no_user = 'Deprecated facility' user = self.db.user self.assertTrue(user) @@ -316,29 +210,29 @@ def testAttributeUser(self): self.assertNotEqual(user, no_user) self.assertEqual(user, self.db.db.user) - def testMethodEscapeLiteral(self): + def test_method_escape_literal(self): self.assertEqual(self.db.escape_literal(''), "''") - def testMethodEscapeIdentifier(self): + def test_method_escape_identifier(self): self.assertEqual(self.db.escape_identifier(''), '""') - def testMethodEscapeString(self): + def test_method_escape_string(self): self.assertEqual(self.db.escape_string(''), '') - def testMethodEscapeBytea(self): + def test_method_escape_bytea(self): self.assertEqual(self.db.escape_bytea('').replace( '\\x', '').replace('\\', ''), '') - def testMethodUnescapeBytea(self): + def test_method_unescape_bytea(self): self.assertEqual(self.db.unescape_bytea(''), b'') - def testMethodDecodeJson(self): + def test_method_decode_json(self): self.assertEqual(self.db.decode_json('{}'), {}) - def testMethodEncodeJson(self): + def test_method_encode_json(self): self.assertEqual(self.db.encode_json({}), '{}') - def testMethodQuery(self): + def test_method_query(self): query = self.db.query query("select 1+1") query("select 1+$1+$2", 2, 3) @@ -346,22 +240,21 @@ def testMethodQuery(self): query("select 1+$1+$2", [2, 3]) query("select 1+$1", 1) - def testMethodQueryEmpty(self): + def test_method_query_empty(self): self.assertRaises(ValueError, self.db.query, '') - def testMethodQueryDataError(self): + def test_method_query_data_error(self): try: self.db.query("select 1/0") except pg.DataError as error: + # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') - def testMethodEndcopy(self): - try: + def test_method_endcopy(self): + with suppress(OSError): self.db.endcopy() - except IOError: - pass - def testMethodClose(self): + def test_method_close(self): self.db.close() try: self.db.reset() @@ -376,7 +269,7 @@ def testMethodClose(self): self.assertRaises(pg.InternalError, getattr, self.db, 'error') self.assertRaises(pg.InternalError, getattr, self.db, 'absent') - def testMethodReset(self): + def test_method_reset(self): con = self.db.db self.db.reset() self.assertIs(self.db.db, con) @@ -384,7 +277,7 @@ def testMethodReset(self): self.db.close() self.assertRaises(pg.InternalError, self.db.reset) - def testMethodReopen(self): + def test_method_reopen(self): con = self.db.db self.db.reopen() self.assertIsNot(self.db.db, con) @@ -396,7 +289,7 @@ def testMethodReopen(self): self.db.query("select 1+1") self.db.close() - def testExistingConnection(self): + def test_existing_connection(self): db = pg.DB(self.db.db) self.assertIsNotNone(db.db) self.assertEqual(self.db.db, db.db) @@ -410,12 +303,13 @@ def testExistingConnection(self): self.assertIsNone(db.db) db = pg.DB(self.db) self.assertEqual(self.db.db, db.db) + assert self.db.db is not None db = pg.DB(db=self.db.db) self.assertEqual(self.db.db, db.db) - def testExistingDbApi2Connection(self): + def test_existing_db_api2_connection(self): - class DBApi2Con: + class FakeDbApi2Connection: def __init__(self, cnx): self._cnx = cnx @@ -423,8 +317,8 @@ def __init__(self, cnx): def close(self): self._cnx.close() - db2 = DBApi2Con(self.db.db) - db = pg.DB(db2) + db2 = FakeDbApi2Connection(self.db.db) + db = pg.DB(db2) # type: ignore self.assertEqual(self.db.db, db.db) db.close() self.assertIsNone(db.db) @@ -444,11 +338,12 @@ class TestDBClass(unittest.TestCase): cls_set_up = False regtypes = None + supports_oids = False @classmethod def setUpClass(cls): db = DB() - cls.oids = db.server_version < 120000 + cls.supports_oids = db.server_version < 120000 db.query("drop table if exists test cascade") db.query("create table test (" "i2 smallint, i4 integer, i8 bigint," @@ -477,147 +372,143 @@ def setUp(self): query("set lc_monetary='C'") query("set datestyle='ISO,YMD'") query('set standard_conforming_strings=on') - try: - query('set bytea_output=hex') - except pg.ProgrammingError: - if self.db.server_version >= 90000: - raise # ignore for older server versions + query('set bytea_output=hex') def tearDown(self): self.doCleanups() self.db.close() - def createTable(self, table, definition, + def create_table(self, table, definition, temporary=True, oids=None, values=None): query = self.db.query if '"' not in table or '.' in table: - table = '"%s"' % table + table = f'"{table}"' if not temporary: - q = 'drop table if exists %s cascade' % table + q = f'drop table if exists {table} cascade' query(q) self.addCleanup(query, q) temporary = 'temporary table' if temporary else 'table' as_query = definition.startswith(('as ', 'AS ')) if not as_query and not definition.startswith('('): - definition = '(%s)' % definition + definition = f'({definition})' with_oids = 'with oids' if oids else ( - 'without oids' if self.oids else '') - q = ['create', temporary, table] + 'without oids' if self.supports_oids else '') + cmd_parts = ['create', temporary, table] if as_query: - q.extend([with_oids, definition]) + cmd_parts.extend([with_oids, definition]) else: - q.extend([definition, with_oids]) - q = ' '.join(q) - query(q) + cmd_parts.extend([definition, with_oids]) + cmd = ' '.join(cmd_parts) + query(cmd) if values: for params in values: if not isinstance(params, (list, tuple)): params = [params] - values = ', '.join('$%d' % (n + 1) for n in range(len(params))) - q = "insert into %s values (%s)" % (table, values) - query(q, params) + values = ', '.join(f'${n + 1}' for n in range(len(params))) + cmd = f"insert into {table} values ({values})" + query(cmd, params) - def testClassName(self): + def test_class_name(self): self.assertEqual(self.db.__class__.__name__, 'DB') - def testModuleName(self): - self.assertEqual(self.db.__module__, 'pg') - self.assertEqual(self.db.__class__.__module__, 'pg') + def test_module_name(self): + self.assertEqual(self.db.__module__, 'pg.db') + self.assertEqual(self.db.__class__.__module__, 'pg.db') - def testEscapeLiteral(self): + def test_escape_literal(self): f = self.db.escape_literal - r = f(b"plain") + r: Any = f(b"plain") self.assertIsInstance(r, bytes) self.assertEqual(r, b"'plain'") - r = f(u"plain") - self.assertIsInstance(r, unicode) - self.assertEqual(r, u"'plain'") - r = f(u"that's käse".encode('utf-8')) + r = f("plain") + self.assertIsInstance(r, str) + self.assertEqual(r, "'plain'") + r = f("that's käse".encode()) self.assertIsInstance(r, bytes) - self.assertEqual(r, u"'that''s käse'".encode('utf-8')) - r = f(u"that's käse") - self.assertIsInstance(r, unicode) - self.assertEqual(r, u"'that''s käse'") + self.assertEqual(r, "'that''s käse'".encode()) + r = f("that's käse") + self.assertIsInstance(r, str) + self.assertEqual(r, "'that''s käse'") self.assertEqual(f(r"It's fine to have a \ inside."), r" E'It''s fine to have a \\ inside.'") self.assertEqual(f('No "quotes" must be escaped.'), "'No \"quotes\" must be escaped.'") - def testEscapeIdentifier(self): + def test_escape_identifier(self): f = self.db.escape_identifier r = f(b"plain") self.assertIsInstance(r, bytes) self.assertEqual(r, b'"plain"') - r = f(u"plain") - self.assertIsInstance(r, unicode) - self.assertEqual(r, u'"plain"') - r = f(u"that's käse".encode('utf-8')) + r = f("plain") + self.assertIsInstance(r, str) + self.assertEqual(r, '"plain"') + r = f("that's käse".encode()) self.assertIsInstance(r, bytes) - self.assertEqual(r, u'"that\'s käse"'.encode('utf-8')) - r = f(u"that's käse") - self.assertIsInstance(r, unicode) - self.assertEqual(r, u'"that\'s käse"') + self.assertEqual(r, '"that\'s käse"'.encode()) + r = f("that's käse") + self.assertIsInstance(r, str) + self.assertEqual(r, '"that\'s käse"') self.assertEqual(f(r"It's fine to have a \ inside."), '"It\'s fine to have a \\ inside."') self.assertEqual(f('All "quotes" must be escaped.'), '"All ""quotes"" must be escaped."') - def testEscapeString(self): + def test_escape_string(self): f = self.db.escape_string r = f(b"plain") self.assertIsInstance(r, bytes) self.assertEqual(r, b"plain") - r = f(u"plain") - self.assertIsInstance(r, unicode) - self.assertEqual(r, u"plain") - r = f(u"that's käse".encode('utf-8')) + r = f("plain") + self.assertIsInstance(r, str) + self.assertEqual(r, "plain") + r = f("that's käse".encode()) self.assertIsInstance(r, bytes) - self.assertEqual(r, u"that''s käse".encode('utf-8')) - r = f(u"that's käse") - self.assertIsInstance(r, unicode) - self.assertEqual(r, u"that''s käse") + self.assertEqual(r, "that''s käse".encode()) + r = f("that's käse") + self.assertIsInstance(r, str) + self.assertEqual(r, "that''s käse") self.assertEqual(f(r"It's fine to have a \ inside."), r"It''s fine to have a \ inside.") - def testEscapeBytea(self): + def test_escape_bytea(self): f = self.db.escape_bytea # note that escape_byte always returns hex output since Pg 9.0, # regardless of the bytea_output setting r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x706c61696e') - r = f(u'plain') - self.assertIsInstance(r, unicode) - self.assertEqual(r, u'\\x706c61696e') - r = f(u"das is' käse".encode('utf-8')) + r = f('plain') + self.assertIsInstance(r, str) + self.assertEqual(r, '\\x706c61696e') + r = f("das is' käse".encode()) self.assertIsInstance(r, bytes) self.assertEqual(r, b'\\x64617320697327206bc3a47365') - r = f(u"das is' käse") - self.assertIsInstance(r, unicode) - self.assertEqual(r, u'\\x64617320697327206bc3a47365') + r = f("das is' käse") + self.assertIsInstance(r, str) + self.assertEqual(r, '\\x64617320697327206bc3a47365') self.assertEqual(f(b'O\x00ps\xff!'), b'\\x4f007073ff21') - def testUnescapeBytea(self): + def test_unescape_bytea(self): f = self.db.unescape_bytea r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(b"das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf8')) - r = f(u"das is' k\\303\\244se") + self.assertEqual(r, "das is' käse".encode()) + r = f("das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf8')) + self.assertEqual(r, "das is' käse".encode()) self.assertEqual(f(r'O\\000ps\\377!'), b'O\\000ps\\377!') self.assertEqual(f(r'\\x706c61696e'), b'\\x706c61696e') self.assertEqual(f(r'\\x746861742773206be47365'), b'\\x746861742773206be47365') self.assertEqual(f(r'\\x4f007073ff21'), b'\\x4f007073ff21') - def testDecodeJson(self): + def test_decode_json(self): f = self.db.decode_json self.assertIsNone(f('null')) data = { @@ -630,13 +521,13 @@ def testDecodeJson(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) self.assertIsInstance(r['stock'], dict) - def testEncodeJson(self): + def test_encode_json(self): f = self.db.encode_json self.assertEqual(f(None), 'null') data = { @@ -649,7 +540,7 @@ def testEncodeJson(self): self.assertIsInstance(r, str) self.assertEqual(r, text) - def testGetParameter(self): + def test_get_parameter(self): f = self.db.get_parameter self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, None) @@ -671,10 +562,10 @@ def testGetParameter(self): self.assertEqual(r, ['hex', 'C']) r = f(('standard_conforming_strings', 'datestyle', 'bytea_output')) self.assertEqual(r, ['on', 'ISO, YMD', 'hex']) - r = f(set(['bytea_output', 'lc_monetary'])) + r = f({'bytea_output', 'lc_monetary'}) self.assertIsInstance(r, dict) self.assertEqual(r, {'bytea_output': 'hex', 'lc_monetary': 'C'}) - r = f(set(['Bytea_Output', ' LC_Monetary '])) + r = f({'Bytea_Output', ' LC_Monetary '}) self.assertIsInstance(r, dict) self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'}) s = dict.fromkeys(('bytea_output', 'lc_monetary')) @@ -686,14 +577,14 @@ def testGetParameter(self): self.assertIs(r, s) self.assertEqual(r, {'Bytea_Output': 'hex', ' LC_Monetary ': 'C'}) - def testGetParameterServerVersion(self): + def test_get_parameter_server_version(self): r = self.db.get_parameter('server_version_num') self.assertIsInstance(r, str) s = self.db.server_version self.assertIsInstance(s, int) self.assertEqual(r, str(s)) - def testGetParameterAll(self): + def test_get_parameter_all(self): f = self.db.get_parameter r = f('all') self.assertIsInstance(r, dict) @@ -702,7 +593,7 @@ def testGetParameterAll(self): self.assertEqual(r['DateStyle'], 'ISO, YMD') self.assertEqual(r['bytea_output'], 'hex') - def testSetParameter(self): + def test_set_parameter(self): f = self.db.set_parameter g = self.db.get_parameter self.assertRaises(TypeError, f) @@ -731,20 +622,22 @@ def testSetParameter(self): f(('escape_string_warning', 'standard_conforming_strings'), 'off') self.assertEqual(g('escape_string_warning'), 'off') self.assertEqual(g('standard_conforming_strings'), 'off') - f(set(['escape_string_warning', 'standard_conforming_strings']), 'on') + f({'escape_string_warning', 'standard_conforming_strings'}, 'on') self.assertEqual(g('escape_string_warning'), 'on') self.assertEqual(g('standard_conforming_strings'), 'on') - self.assertRaises(ValueError, f, set(['escape_string_warning', - 'standard_conforming_strings']), ['off', 'on']) - f(set(['escape_string_warning', 'standard_conforming_strings']), - ['off', 'off']) + self.assertRaises( + ValueError, f, + {'escape_string_warning', 'standard_conforming_strings'}, + ['off', 'on']) + f({'escape_string_warning', 'standard_conforming_strings'}, + ['off', 'off']) self.assertEqual(g('escape_string_warning'), 'off') self.assertEqual(g('standard_conforming_strings'), 'off') f({'standard_conforming_strings': 'on', 'datestyle': 'ISO, YMD'}) self.assertEqual(g('standard_conforming_strings'), 'on') self.assertEqual(g('datestyle'), 'ISO, YMD') - def testResetParameter(self): + def test_reset_parameter(self): db = DB() f = db.set_parameter g = db.get_parameter @@ -780,12 +673,12 @@ def testResetParameter(self): f('standard_conforming_strings', not_scs) self.assertEqual(g('escape_string_warning'), not_esw) self.assertEqual(g('standard_conforming_strings'), not_scs) - f(set(['escape_string_warning', 'standard_conforming_strings'])) + f({'escape_string_warning', 'standard_conforming_strings'}) self.assertEqual(g('escape_string_warning'), esw) self.assertEqual(g('standard_conforming_strings'), scs) db.close() - def testResetParameterAll(self): + def test_reset_parameter_all(self): db = DB() f = db.set_parameter self.assertRaises(ValueError, f, 'all', 0) @@ -806,7 +699,7 @@ def testResetParameterAll(self): self.assertEqual(g('standard_conforming_strings'), scs) db.close() - def testSetParameterLocal(self): + def test_set_parameter_local(self): f = self.db.set_parameter g = self.db.get_parameter self.assertEqual(g('standard_conforming_strings'), 'on') @@ -816,7 +709,7 @@ def testSetParameterLocal(self): self.db.end() self.assertEqual(g('standard_conforming_strings'), 'on') - def testSetParameterSession(self): + def test_set_parameter_session(self): f = self.db.set_parameter g = self.db.get_parameter self.assertEqual(g('standard_conforming_strings'), 'on') @@ -826,12 +719,12 @@ def testSetParameterSession(self): self.db.end() self.assertEqual(g('standard_conforming_strings'), 'off') - def testReset(self): + def test_reset(self): db = DB() default_datestyle = db.get_parameter('datestyle') changed_datestyle = 'ISO, DMY' if changed_datestyle == default_datestyle: - changed_datestyle == 'ISO, YMD' + changed_datestyle = 'ISO, YMD' self.db.set_parameter('datestyle', changed_datestyle) r = self.db.get_parameter('datestyle') self.assertEqual(r, changed_datestyle) @@ -847,12 +740,12 @@ def testReset(self): self.assertEqual(r, default_datestyle) db.close() - def testReopen(self): + def test_reopen(self): db = DB() default_datestyle = db.get_parameter('datestyle') changed_datestyle = 'ISO, DMY' if changed_datestyle == default_datestyle: - changed_datestyle == 'ISO, YMD' + changed_datestyle = 'ISO, YMD' self.db.set_parameter('datestyle', changed_datestyle) r = self.db.get_parameter('datestyle') self.assertEqual(r, changed_datestyle) @@ -866,32 +759,32 @@ def testReopen(self): self.assertEqual(r, default_datestyle) db.close() - def testCreateTable(self): + def test_create_table(self): table = 'test hello world' values = [(2, "World!"), (1, "Hello")] - self.createTable(table, "n smallint, t varchar", + self.create_table(table, "n smallint, t varchar", temporary=True, oids=False, values=values) - r = self.db.query('select t from "%s" order by n' % table).getresult() + r = self.db.query(f'select t from "{table}" order by n').getresult() r = ', '.join(row[0] for row in r) self.assertEqual(r, "Hello, World!") - def testCreateTableWithOids(self): - if not self.oids: + def test_create_table_with_oids(self): + if not self.supports_oids: self.skipTest("database does not support tables with oids") table = 'test hello world' values = [(2, "World!"), (1, "Hello")] - self.createTable(table, "n smallint, t varchar", + self.create_table(table, "n smallint, t varchar", temporary=True, oids=True, values=values) - r = self.db.query('select t from "%s" order by n' % table).getresult() + r = self.db.query(f'select t from "{table}" order by n').getresult() r = ', '.join(row[0] for row in r) self.assertEqual(r, "Hello, World!") - r = self.db.query('select oid from "%s" limit 1' % table).getresult() + r = self.db.query(f'select oid from "{table}" limit 1').getresult() self.assertIsInstance(r[0][0], int) - def testQuery(self): + def test_query(self): query = self.db.query table = 'test_table' - self.createTable(table, "n integer", oids=False) + self.create_table(table, "n integer", oids=False) q = "insert into test_table values (1)" r = query(q) self.assertIsInstance(r, str) @@ -916,17 +809,18 @@ def testQuery(self): r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '4') + # noinspection SqlWithoutWhere q = "delete from test_table" r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '5') - def testQueryWithOids(self): - if not self.oids: + def test_query_with_oids(self): + if not self.supports_oids: self.skipTest("database does not support tables with oids") query = self.db.query table = 'test_table' - self.createTable(table, "n integer", oids=True) + self.create_table(table, "n integer", oids=True) q = "insert into test_table values (1)" r = query(q) self.assertIsInstance(r, int) @@ -949,20 +843,21 @@ def testQueryWithOids(self): r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '4') + # noinspection SqlWithoutWhere q = "delete from test_table" r = query(q) self.assertIsInstance(r, str) self.assertEqual(r, '5') - def testMultipleQueries(self): + def test_multiple_queries(self): self.assertEqual(self.db.query( "create temporary table test_multi (n integer);" "insert into test_multi values (4711);" "select n from test_multi").getresult()[0][0], 4711) - def testQueryWithParams(self): + def test_query_with_params(self): query = self.db.query - self.createTable('test_table', 'n1 integer, n2 integer', oids=False) + self.create_table('test_table', 'n1 integer, n2 integer', oids=False) q = "insert into test_table values ($1, $2)" r = query(q, (1, 2)) self.assertEqual(r, '1') @@ -985,16 +880,17 @@ def testQueryWithParams(self): r = query(q, 4) self.assertEqual(r, '3') - def testEmptyQuery(self): + def test_empty_query(self): self.assertRaises(ValueError, self.db.query, '') - def testQueryDataError(self): + def test_query_data_error(self): try: self.db.query("select 1/0") except pg.DataError as error: + # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') - def testQueryFormatted(self): + def test_query_formatted(self): f = self.db.query_formatted t = True if pg.get_bool() else 't' # test with tuple @@ -1005,11 +901,6 @@ def testQueryFormatted(self): # test with tuple, inline q = f("select %s, %s, %s, %s", (3, 2.5, 'hello', True), inline=True) r = q.getresult()[0] - if isinstance(r[1], Decimal): - # Python 2.6 cannot compare float and Decimal - r = list(r) - r[1] = float(r[1]) - r = tuple(r) self.assertEqual(r, (3, 2.5, 'hello', t)) # test with dict q = f("select %(a)s::int, %(b)s::real, %(c)s::text, %(d)s::bool", @@ -1020,11 +911,6 @@ def testQueryFormatted(self): q = f("select %(a)s, %(b)s, %(c)s, %(d)s", dict(a=3, b=2.5, c='hello', d=True), inline=True) r = q.getresult()[0] - if isinstance(r[1], Decimal): - # Python 2.6 cannot compare float and Decimal - r = list(r) - r[1] = float(r[1]) - r = tuple(r) self.assertEqual(r, (3, 2.5, 'hello', t)) # test with dict and extra values q = f("select %(a)s||%(b)s||%(c)s||%(d)s||'epsilon'", @@ -1032,7 +918,7 @@ def testQueryFormatted(self): r = q.getresult()[0][0] self.assertEqual(r, 'alphabetagammadeltaepsilon') - def testQueryFormattedWithAny(self): + def test_query_formatted_with_any(self): f = self.db.query_formatted q = "select 2 = any(%s)" r = f(q, [[1, 3]]).getresult()[0][0] @@ -1044,7 +930,7 @@ def testQueryFormattedWithAny(self): r = f(q, [[None]]).getresult()[0][0] self.assertIsNone(r) - def testQueryFormattedWithoutParams(self): + def test_query_formatted_without_params(self): f = self.db.query_formatted q = "select 42" r = f(q).getresult()[0][0] @@ -1056,19 +942,19 @@ def testQueryFormattedWithoutParams(self): r = f(q, {}).getresult()[0][0] self.assertEqual(r, 42) - def testPrepare(self): + def test_prepare(self): p = self.db.prepare self.assertIsNone(p('my query', "select 'hello'")) self.assertIsNone(p('my other query', "select 'world'")) - self.assertRaises(pg.ProgrammingError, - p, 'my query', "select 'hello, too'") + self.assertRaises( + pg.ProgrammingError, p, 'my query', "select 'hello, too'") - def testPrepareUnnamed(self): + def test_prepare_unnamed(self): p = self.db.prepare self.assertIsNone(p('', "select null")) self.assertIsNone(p(None, "select null")) - def testQueryPreparedWithoutParams(self): + def test_query_prepared_without_params(self): f = self.db.query_prepared self.assertRaises(pg.OperationalError, f, 'q') p = self.db.prepare @@ -1079,7 +965,7 @@ def testQueryPreparedWithoutParams(self): r = f('q2').getresult()[0][0] self.assertEqual(r, 42) - def testQueryPreparedWithParams(self): + def test_query_prepared_with_params(self): p = self.db.prepare p('sum', "select 1 + $1 + $2 + $3") p('cat', "select initcap($1) || ', ' || $2 || '!'") @@ -1089,7 +975,7 @@ def testQueryPreparedWithParams(self): r = f('cat', 'hello', 'world').getresult()[0][0] self.assertEqual(r, 'Hello, world!') - def testQueryPreparedUnnamedWithOutParams(self): + def test_query_prepared_unnamed_with_out_params(self): f = self.db.query_prepared self.assertRaises(pg.OperationalError, f, None) self.assertRaises(pg.OperationalError, f, '') @@ -1107,7 +993,7 @@ def testQueryPreparedUnnamedWithOutParams(self): r = f('').getresult()[0][0] self.assertEqual(r, 'none') - def testQueryPreparedUnnamedWithParams(self): + def test_query_prepared_unnamed_with_params(self): p = self.db.prepare p('', "select 1 + $1 + $2") f = self.db.query_prepared @@ -1122,13 +1008,13 @@ def testQueryPreparedUnnamedWithParams(self): r = f(None, 3, 4).getresult()[0][0] self.assertEqual(r, 9) - def testDescribePrepared(self): + def test_describe_prepared(self): self.db.prepare('count', "select 1 as first, 2 as second") f = self.db.describe_prepared r = f('count').listfields() self.assertEqual(r, ('first', 'second')) - def testDescribePreparedUnnamed(self): + def test_describe_prepared_unnamed(self): self.db.prepare('', "select null as anon") f = self.db.describe_prepared r = f().listfields() @@ -1138,7 +1024,7 @@ def testDescribePreparedUnnamed(self): r = f('').listfields() self.assertEqual(r, ('anon',)) - def testDeletePrepared(self): + def test_delete_prepared(self): f = self.db.delete_prepared f() e = pg.OperationalError @@ -1156,61 +1042,73 @@ def testDeletePrepared(self): self.assertRaises(e, f, 'q1') self.assertRaises(e, f, 'q2') - def testPkey(self): + def test_pkey(self): query = self.db.query pkey = self.db.pkey self.assertRaises(KeyError, pkey, 'test') for t in ('pkeytest', 'primary key test'): - self.createTable('%s0' % t, 'a smallint') - self.createTable('%s1' % t, 'b smallint primary key') - self.createTable('%s2' % t, - 'c smallint, d smallint primary key') - self.createTable('%s3' % t, + self.create_table(f'{t}0', 'a smallint') + self.create_table(f'{t}1', 'b smallint primary key') + self.create_table(f'{t}2', 'c smallint, d smallint primary key') + self.create_table( + f'{t}3', 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (f, h)') - self.createTable('%s4' % t, + self.create_table( + f'{t}4', 'e smallint, f smallint, g smallint, h smallint, i smallint,' ' primary key (h, f)') - self.createTable('%s5' % t, - 'more_than_one_letter varchar primary key') - self.createTable('%s6' % t, - '"with space" date primary key') - self.createTable('%s7' % t, + self.create_table( + f'{t}5', 'more_than_one_letter varchar primary key') + self.create_table( + f'{t}6', '"with space" date primary key') + self.create_table( + f'{t}7', 'a_very_long_column_name varchar, "with space" date, "42" int,' ' primary key (a_very_long_column_name, "with space", "42")') - self.assertRaises(KeyError, pkey, '%s0' % t) - self.assertEqual(pkey('%s1' % t), 'b') - self.assertEqual(pkey('%s1' % t, True), ('b',)) - self.assertEqual(pkey('%s1' % t, composite=False), 'b') - self.assertEqual(pkey('%s1' % t, composite=True), ('b',)) - self.assertEqual(pkey('%s2' % t), 'd') - self.assertEqual(pkey('%s2' % t, composite=True), ('d',)) - r = pkey('%s3' % t) + self.assertRaises(KeyError, pkey, f'{t}0') + self.assertEqual(pkey(f'{t}1'), 'b') + self.assertEqual(pkey(f'{t}1', True), ('b',)) + self.assertEqual(pkey(f'{t}1', composite=False), 'b') + self.assertEqual(pkey(f'{t}1', composite=True), ('b',)) + self.assertEqual(pkey(f'{t}2'), 'd') + self.assertEqual(pkey(f'{t}2', composite=True), ('d',)) + r = pkey(f'{t}3') self.assertIsInstance(r, tuple) self.assertEqual(r, ('f', 'h')) - r = pkey('%s3' % t, composite=False) + r = pkey(f'{t}3', composite=False) self.assertIsInstance(r, tuple) self.assertEqual(r, ('f', 'h')) - r = pkey('%s4' % t) + r = pkey(f'{t}4') self.assertIsInstance(r, tuple) self.assertEqual(r, ('h', 'f')) - self.assertEqual(pkey('%s5' % t), 'more_than_one_letter') - self.assertEqual(pkey('%s6' % t), 'with space') - r = pkey('%s7' % t) + self.assertEqual(pkey(f'{t}5'), 'more_than_one_letter') + self.assertEqual(pkey(f'{t}6'), 'with space') + r = pkey(f'{t}7') self.assertIsInstance(r, tuple) self.assertEqual(r, ( 'a_very_long_column_name', 'with space', '42')) # a newly added primary key will be detected - query('alter table "%s0" add primary key (a)' % t) - self.assertEqual(pkey('%s0' % t), 'a') + query(f'alter table "{t}0" add primary key (a)') + self.assertEqual(pkey(f'{t}0'), 'a') # a changed primary key will not be detected, # indicating that the internal cache is operating - query('alter table "%s1" rename column b to x' % t) - self.assertEqual(pkey('%s1' % t), 'b') + query(f'alter table "{t}1" rename column b to x') + self.assertEqual(pkey(f'{t}1'), 'b') # we get the changed primary key when the cache is flushed - self.assertEqual(pkey('%s1' % t, flush=True), 'x') - - def testGetDatabases(self): + self.assertEqual(pkey(f'{t}1', flush=True), 'x') + + def test_pkeys(self): + pkeys = self.db.pkeys + t = 'pkeys_test_' + self.create_table(f'{t}0', 'a int') + self.create_table(f'{t}1', 'a int primary key, b int') + self.create_table(f'{t}2', 'a int, b int, c int, primary key (a, c)') + self.assertRaises(KeyError, pkeys, f'{t}0') + self.assertEqual(pkeys(f'{t}1'), ('a',)) + self.assertEqual(pkeys(f'{t}2'), ('a', 'c')) + + def test_get_databases(self): databases = self.db.get_databases() self.assertIn('template0', databases) self.assertIn('template1', databases) @@ -1218,36 +1116,36 @@ def testGetDatabases(self): self.assertIn('postgres', databases) self.assertIn(dbname, databases) - def testGetTables(self): + def test_get_tables(self): get_tables = self.db.get_tables tables = ('A very Special Name', 'A_MiXeD_quoted_NaMe', 'Hello, Test World!', 'Zoro', 'a1', 'a2', 'a321', 'averyveryveryveryveryveryveryreallyreallylongtablename', 'b0', 'b3', 'x', 'xXx', 'xx', 'y', 'z') for t in tables: - self.db.query('drop table if exists "%s" cascade' % t) + self.db.query(f'drop table if exists "{t}" cascade') before_tables = get_tables() self.assertIsInstance(before_tables, list) for t in before_tables: - t = t.split('.', 1) - self.assertGreaterEqual(len(t), 2) - if len(t) > 2: - self.assertTrue(t[1].startswith('"')) - t = t[0] + s = t.split('.', 1) + self.assertGreaterEqual(len(s), 2) + if len(s) > 2: + self.assertTrue(s[1].startswith('"')) + t = s[0] self.assertNotEqual(t, 'information_schema') self.assertFalse(t.startswith('pg_')) for t in tables: - self.createTable(t, 'as select 0', temporary=False) + self.create_table(t, 'as select 0', temporary=False) current_tables = get_tables() new_tables = [t for t in current_tables if t not in before_tables] - expected_new_tables = ['public.%s' % ( - '"%s"' % t if ' ' in t or t != t.lower() else t) for t in tables] + expected_new_tables = ['public.' + ( + f'"{t}"' if ' ' in t or t != t.lower() else t) for t in tables] self.assertEqual(new_tables, expected_new_tables) self.doCleanups() after_tables = get_tables() self.assertEqual(after_tables, before_tables) - def testGetSystemTables(self): + def test_get_system_tables(self): get_tables = self.db.get_tables result = get_tables() self.assertNotIn('pg_catalog.pg_class', result) @@ -1259,7 +1157,7 @@ def testGetSystemTables(self): self.assertIn('pg_catalog.pg_class', result) self.assertNotIn('information_schema.tables', result) - def testGetRelations(self): + def test_get_relations(self): get_relations = self.db.get_relations result = get_relations() self.assertIn('public.test', result) @@ -1277,7 +1175,7 @@ def testGetRelations(self): self.assertNotIn('public.test', result) self.assertNotIn('public.test_view', result) - def testGetSystemRelations(self): + def test_get_system_relations(self): get_relations = self.db.get_relations result = get_relations() self.assertNotIn('pg_catalog.pg_class', result) @@ -1289,7 +1187,7 @@ def testGetSystemRelations(self): self.assertIn('pg_catalog.pg_class', result) self.assertIn('information_schema.tables', result) - def testGetAttnames(self): + def test_get_attnames(self): get_attnames = self.db.get_attnames self.assertRaises(pg.ProgrammingError, self.db.get_attnames, 'does_not_exist') @@ -1307,7 +1205,7 @@ def testGetAttnames(self): i2='int', i4='int', i8='int', d='num', f4='float', f8='float', m='money', v4='text', c4='text', t='text')) - self.createTable('test_table', + self.create_table('test_table', 'n int, alpha smallint, beta bool,' ' gamma char(5), tau text, v varchar(3)') r = get_attnames('test_table') @@ -1321,10 +1219,11 @@ def testGetAttnames(self): n='int', alpha='int', beta='bool', gamma='text', tau='text', v='text')) - def testGetAttnamesWithQuotes(self): + def test_get_attnames_with_quotes(self): get_attnames = self.db.get_attnames table = 'test table for get_attnames()' - self.createTable(table, + self.create_table( + table, '"Prime!" smallint, "much space" integer, "Questions?" text') r = get_attnames(table) self.assertIsInstance(r, dict) @@ -1336,7 +1235,7 @@ def testGetAttnamesWithQuotes(self): self.assertEqual(r, { 'Prime!': 'int', 'much space': 'int', 'Questions?': 'text'}) table = 'yet another test table for get_attnames()' - self.createTable(table, + self.create_table(table, 'a smallint, b integer, c bigint,' ' e numeric, f real, f2 double precision, m money,' ' x smallint, y smallint, z smallint,' @@ -1354,16 +1253,18 @@ def testGetAttnamesWithQuotes(self): 't': 'text', 'v': 'character varying', 'y': 'smallint', 'x': 'smallint', 'z': 'smallint'}) else: - self.assertEqual(r, {'a': 'int', 'b': 'int', 'c': 'int', - 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money', - 'normal_name': 'int', 'Special Name': 'int', - 'u': 'text', 't': 'text', 'v': 'text', - 'y': 'int', 'x': 'int', 'z': 'int'}) + self.assertEqual(r, { + 'a': 'int', 'b': 'int', 'c': 'int', + 'e': 'num', 'f': 'float', 'f2': 'float', 'm': 'money', + 'normal_name': 'int', 'Special Name': 'int', + 'u': 'text', 't': 'text', 'v': 'text', + 'y': 'int', 'x': 'int', 'z': 'int'}) - def testGetAttnamesWithRegtypes(self): + def test_get_attnames_with_regtypes(self): get_attnames = self.db.get_attnames - self.createTable('test_table', 'n int, alpha smallint, beta bool,' - ' gamma char(5), tau text, v varchar(3)') + self.create_table( + 'test_table', 'n int, alpha smallint, beta bool,' + ' gamma char(5), tau text, v varchar(3)') use_regtypes = self.db.use_regtypes regtypes = use_regtypes() self.assertEqual(regtypes, self.regtypes) @@ -1377,10 +1278,11 @@ def testGetAttnamesWithRegtypes(self): n='integer', alpha='smallint', beta='boolean', gamma='character', tau='text', v='character varying')) - def testGetAttnamesWithoutRegtypes(self): + def test_get_attnames_without_regtypes(self): get_attnames = self.db.get_attnames - self.createTable('test_table', 'n int, alpha smallint, beta bool,' - ' gamma char(5), tau text, v varchar(3)') + self.create_table( + 'test_table', 'n int, alpha smallint, beta bool,' + ' gamma char(5), tau text, v varchar(3)') use_regtypes = self.db.use_regtypes regtypes = use_regtypes() self.assertEqual(regtypes, self.regtypes) @@ -1394,12 +1296,12 @@ def testGetAttnamesWithoutRegtypes(self): n='int', alpha='int', beta='bool', gamma='text', tau='text', v='text')) - def testGetAttnamesIsCached(self): + def test_get_attnames_is_cached(self): get_attnames = self.db.get_attnames int_type = 'integer' if self.regtypes else 'int' text_type = 'text' query = self.db.query - self.createTable('test_table', 'col int') + self.create_table('test_table', 'col int') r = get_attnames("test_table") self.assertIsInstance(r, dict) self.assertEqual(r, dict(col=int_type)) @@ -1420,80 +1322,125 @@ def testGetAttnamesIsCached(self): r = get_attnames("test_table", flush=True) self.assertEqual(r, dict()) - def testGetAttnamesIsOrdered(self): + def test_get_attnames_is_ordered(self): get_attnames = self.db.get_attnames r = get_attnames('test', flush=True) - self.assertIsInstance(r, OrderedDict) + self.assertIsInstance(r, dict) if self.regtypes: - self.assertEqual(r, OrderedDict([ - ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'), - ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'), - ('m', 'money'), ('v4', 'character varying'), - ('c4', 'character'), ('t', 'text')])) + self.assertEqual(r, { + 'i2': 'smallint', 'i4': 'integer', 'i8': 'bigint', + 'd': 'numeric', 'f4': 'real', 'f8': 'double precision', + 'm': 'money', 'v4': 'character varying', + 'c4': 'character', 't': 'text'}) else: - self.assertEqual(r, OrderedDict([ - ('i2', 'int'), ('i4', 'int'), ('i8', 'int'), - ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'), - ('v4', 'text'), ('c4', 'text'), ('t', 'text')])) - if OrderedDict is not dict: - r = ' '.join(list(r.keys())) - self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') + self.assertEqual(r, { + 'i2': 'int', 'i4': 'int', 'i8': 'int', + 'd': 'num', 'f4': 'float', 'f8': 'float', 'm': 'money', + 'v4': 'text', 'c4': 'text', 't': 'text'}) + r = ' '.join(list(r.keys())) + self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' - self.createTable(table, 'n int, alpha smallint, v varchar(3),' - ' gamma char(5), tau text, beta bool') + self.create_table( + table, 'n int, alpha smallint, v varchar(3),' + ' gamma char(5), tau text, beta bool') r = get_attnames(table) - self.assertIsInstance(r, OrderedDict) + self.assertIsInstance(r, dict) if self.regtypes: - self.assertEqual(r, OrderedDict([ - ('n', 'integer'), ('alpha', 'smallint'), - ('v', 'character varying'), ('gamma', 'character'), - ('tau', 'text'), ('beta', 'boolean')])) - else: - self.assertEqual(r, OrderedDict([ - ('n', 'int'), ('alpha', 'int'), ('v', 'text'), - ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')])) - if OrderedDict is not dict: - r = ' '.join(list(r.keys())) - self.assertEqual(r, 'n alpha v gamma tau beta') + self.assertEqual(r, { + 'n': 'integer', 'alpha': 'smallint', + 'v': 'character varying', 'gamma': 'character', + 'tau': 'text', 'beta': 'boolean'}) else: - self.skipTest('OrderedDict is not supported') + self.assertEqual(r, { + 'n': 'int', 'alpha': 'int', 'v': 'text', + 'gamma': 'text', 'tau': 'text', 'beta': 'bool'}) + r = ' '.join(list(r.keys())) + self.assertEqual(r, 'n alpha v gamma tau beta') - def testGetAttnamesIsAttrDict(self): - AttrDict = pg.AttrDict + def test_get_attnames_is_attr_dict(self): + from pg.attrs import AttrDict get_attnames = self.db.get_attnames r = get_attnames('test', flush=True) self.assertIsInstance(r, AttrDict) if self.regtypes: - self.assertEqual(r, AttrDict([ - ('i2', 'smallint'), ('i4', 'integer'), ('i8', 'bigint'), - ('d', 'numeric'), ('f4', 'real'), ('f8', 'double precision'), - ('m', 'money'), ('v4', 'character varying'), - ('c4', 'character'), ('t', 'text')])) + self.assertEqual(r, AttrDict( + i2='smallint', i4='integer', i8='bigint', + d='numeric', f4='real', f8='double precision', + m='money', v4='character varying', + c4='character', t='text')) else: - self.assertEqual(r, AttrDict([ - ('i2', 'int'), ('i4', 'int'), ('i8', 'int'), - ('d', 'num'), ('f4', 'float'), ('f8', 'float'), ('m', 'money'), - ('v4', 'text'), ('c4', 'text'), ('t', 'text')])) + self.assertEqual(r, AttrDict( + i2='int', i4='int', i8='int', + d='num', f4='float', f8='float', m='money', + v4='text', c4='text', t='text')) r = ' '.join(list(r.keys())) self.assertEqual(r, 'i2 i4 i8 d f4 f8 m v4 c4 t') table = 'test table for get_attnames' - self.createTable(table, 'n int, alpha smallint, v varchar(3),' - ' gamma char(5), tau text, beta bool') + self.create_table( + table, 'n int, alpha smallint, v varchar(3),' + ' gamma char(5), tau text, beta bool') r = get_attnames(table) self.assertIsInstance(r, AttrDict) if self.regtypes: - self.assertEqual(r, AttrDict([ - ('n', 'integer'), ('alpha', 'smallint'), - ('v', 'character varying'), ('gamma', 'character'), - ('tau', 'text'), ('beta', 'boolean')])) + self.assertEqual(r, AttrDict( + n='integer', alpha='smallint', + v='character varying', gamma='character', + tau='text', beta='boolean')) else: - self.assertEqual(r, AttrDict([ - ('n', 'int'), ('alpha', 'int'), ('v', 'text'), - ('gamma', 'text'), ('tau', 'text'), ('beta', 'bool')])) + self.assertEqual(r, AttrDict( + n='int', alpha='int', v='text', + gamma='text', tau='text', beta='bool')) r = ' '.join(list(r.keys())) self.assertEqual(r, 'n alpha v gamma tau beta') - def testHasTablePrivilege(self): + def test_get_generated(self): + get_generated = self.db.get_generated + server_version = self.db.server_version + if server_version >= 100000: + self.assertRaises(pg.ProgrammingError, + self.db.get_generated, 'does_not_exist') + self.assertRaises(pg.ProgrammingError, + self.db.get_generated, 'has.too.many.dots') + r = get_generated('test') + self.assertIsInstance(r, frozenset) + self.assertFalse(r) + if server_version >= 100000: + table = 'test_get_generated_1' + self.create_table( + table, + 'i int generated always as identity primary key,' + ' j int generated always as identity,' + ' k int generated by default as identity,' + ' n serial, m int') + r = get_generated(table) + self.assertIsInstance(r, frozenset) + self.assertEqual(r, {'i', 'j'}) + if server_version >= 120000: + table = 'test_get_generated_2' + self.create_table( + table, + 'n int, m int generated always as (n + 3) stored,' + ' i int generated always as identity,' + ' j int generated by default as identity') + r = get_generated(table) + self.assertIsInstance(r, frozenset) + self.assertEqual(r, {'m', 'i'}) + + def test_get_generated_is_cached(self): + server_version = self.db.server_version + if server_version < 100000: + self.skipTest("database does not support generated columns") + get_generated = self.db.get_generated + query = self.db.query + table = 'test_get_generated_2' + self.create_table(table, 'i int primary key') + self.assertFalse(get_generated(table)) + query(f'alter table {table} alter column i' + ' add generated always as identity') + self.assertFalse(get_generated(table)) + self.assertEqual(get_generated(table, flush=True), {'i'}) + + def test_has_table_privilege(self): can = self.db.has_table_privilege self.assertEqual(can('test'), True) self.assertEqual(can('test', 'select'), True) @@ -1504,7 +1451,8 @@ def testHasTablePrivilege(self): self.assertEqual(can('test', 'delete'), True) self.assertRaises(pg.DataError, can, 'test', 'foobar') self.assertRaises(pg.ProgrammingError, can, 'table_does_not_exist') - r = self.db.query('select rolsuper FROM pg_roles' + r = self.db.query( + 'select rolsuper FROM pg_roles' ' where rolname=current_user').getresult()[0][0] if not pg.get_bool(): r = r == 't' @@ -1513,16 +1461,16 @@ def testHasTablePrivilege(self): self.assertEqual(can('pg_views', 'select'), True) self.assertEqual(can('pg_views', 'delete'), False) - def testGet(self): + def test_get(self): get = self.db.get query = self.db.query table = 'get_test_table' self.assertRaises(TypeError, get) self.assertRaises(TypeError, get, table) - self.createTable(table, 'n integer, t text', + self.create_table(table, 'n integer, t text', values=enumerate('xyz', start=1)) self.assertRaises(pg.ProgrammingError, get, table, 2) - r = get(table, 2, 'n') + r: Any = get(table, 2, 'n') self.assertIsInstance(r, dict) self.assertEqual(r, dict(n=2, t='y')) r = get(table, 1, 'n') @@ -1535,7 +1483,7 @@ def testGet(self): self.assertRaises(pg.DatabaseError, get, table, 4, 'n') self.assertRaises(pg.DatabaseError, get, table, 'y') self.assertRaises(pg.DatabaseError, get, table, 2, 't') - s = dict(n=3) + s: dict = dict(n=3) self.assertRaises(pg.ProgrammingError, get, table, s) r = get(table, s, 'n') self.assertIs(r, s) @@ -1547,8 +1495,8 @@ def testGet(self): r = get(table, s, ('n', 't')) self.assertIs(r, s) self.assertEqual(r, dict(n=1, t='x')) - query('alter table "%s" alter n set not null' % table) - query('alter table "%s" add primary key (n)' % table) + query(f'alter table "{table}" alter n set not null') + query(f'alter table "{table}" add primary key (n)') r = get(table, 2) self.assertIsInstance(r, dict) self.assertEqual(r, dict(n=2, t='y')) @@ -1568,18 +1516,18 @@ def testGet(self): s.pop('n') self.assertRaises(KeyError, get, table, s) - def testGetWithOids(self): - if not self.oids: + def test_get_with_oids(self): + if not self.supports_oids: self.skipTest("database does not support tables with oids") get = self.db.get query = self.db.query table = 'get_with_oid_test_table' - self.createTable(table, 'n integer, t text', oids=True, + self.create_table(table, 'n integer, t text', oids=True, values=enumerate('xyz', start=1)) self.assertRaises(pg.ProgrammingError, get, table, 2) self.assertRaises(KeyError, get, table, {}, 'oid') r = get(table, 2, 'n') - qoid = 'oid(%s)' % table + qoid = f'oid({table})' self.assertIn(qoid, r) oid = r[qoid] self.assertIsInstance(oid, int) @@ -1606,8 +1554,8 @@ def testGetWithOids(self): self.assertEqual(get(table, r, 'n')['t'], 'z') self.assertEqual(get(table, 1, 'n')['t'], 'x') self.assertEqual(get(table, r, 'oid')['t'], 'z') - query('alter table "%s" alter n set not null' % table) - query('alter table "%s" add primary key (n)' % table) + query(f'alter table "{table}" alter n set not null') + query(f'alter table "{table}" add primary key (n)') self.assertEqual(get(table, 3)['t'], 'z') self.assertEqual(get(table, 1)['t'], 'x') self.assertEqual(get(table, 2)['t'], 'y') @@ -1634,12 +1582,12 @@ def testGetWithOids(self): self.assertEqual(r['n'], 3) self.assertNotEqual(r[qoid], oid) - def testGetWithCompositeKey(self): + def test_get_with_composite_key(self): get = self.db.get - query = self.db.query table = 'get_test_table_1' - self.createTable(table, 'n integer primary key, t text', - values=enumerate('abc', start=1)) + self.create_table( + table, 'n integer primary key, t text', + values=enumerate('abc', start=1)) self.assertEqual(get(table, 2)['t'], 'b') self.assertEqual(get(table, 1, 'n')['t'], 'a') self.assertEqual(get(table, 2, ('n',))['t'], 'b') @@ -1649,10 +1597,10 @@ def testGetWithCompositeKey(self): self.assertEqual(get(table, ('a',), ('t',))['n'], 1) self.assertEqual(get(table, ['c'], ['t'])['n'], 3) table = 'get_test_table_2' - self.createTable(table, - 'n integer, m integer, t text, primary key (n, m)', - values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) - for n in range(3) for m in range(2)]) + self.create_table( + table, 'n integer, m integer, t text, primary key (n, m)', + values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) + for n in range(3) for m in range(2)]) self.assertRaises(KeyError, get, table, 2) self.assertEqual(get(table, (1, 1))['t'], 'a') self.assertEqual(get(table, (1, 2))['t'], 'b') @@ -1666,20 +1614,20 @@ def testGetWithCompositeKey(self): self.assertEqual(get(table, dict(n=2, m=1), ['n', 'm'])['t'], 'c') self.assertEqual(get(table, dict(n=3, m=2), ('m', 'n'))['t'], 'f') - def testGetWithQuotedNames(self): + def test_get_with_quoted_names(self): get = self.db.get - query = self.db.query table = 'test table for get()' - self.createTable(table, '"Prime!" smallint primary key,' - ' "much space" integer, "Questions?" text', - values=[(17, 1001, 'No!')]) + self.create_table( + table, '"Prime!" smallint primary key,' + ' "much space" integer, "Questions?" text', + values=[(17, 1001, 'No!')]) r = get(table, 17) self.assertIsInstance(r, dict) self.assertEqual(r['Prime!'], 17) self.assertEqual(r['much space'], 1001) self.assertEqual(r['Questions?'], 'No!') - def testGetFromView(self): + def test_get_from_view(self): self.db.query('delete from test where i4=14') self.db.query('insert into test (i4, v4) values(' "14, 'abc4')") @@ -1687,10 +1635,11 @@ def testGetFromView(self): self.assertIn('v4', r) self.assertEqual(r['v4'], 'abc4') - def testGetLittleBobbyTables(self): + def test_get_little_bobby_tables(self): get = self.db.get query = self.db.query - self.createTable('test_students', + self.create_table( + 'test_students', 'firstname varchar primary key, nickname varchar, grade char(2)', values=[("D'Arcy", 'Darcey', 'A+'), ('Sheldon', 'Moonpie', 'A+'), ('Robert', 'Little Bobby Tables', 'D-')]) @@ -1706,13 +1655,15 @@ def testGetLittleBobbyTables(self): try: get('test_students', "D' Arcy") except pg.DatabaseError as error: - self.assertEqual(str(error), + self.assertEqual( + str(error), 'No such record in test_students\nwhere "firstname" = $1\n' 'with $1="D\' Arcy"') try: get('test_students', "Robert'); TRUNCATE TABLE test_students;--") except pg.DatabaseError as error: - self.assertEqual(str(error), + self.assertEqual( + str(error), 'No such record in test_students\nwhere "firstname" = $1\n' 'with $1="Robert\'); TRUNCATE TABLE test_students;--"') q = "select * from test_students order by 1 limit 4" @@ -1720,64 +1671,64 @@ def testGetLittleBobbyTables(self): self.assertEqual(len(r), 3) self.assertEqual(r[1][2], 'D-') - def testInsert(self): + def test_insert(self): insert = self.db.insert query = self.db.query bool_on = pg.get_bool() decimal = pg.get_decimal() table = 'insert_test_table' - self.createTable(table, - 'i2 smallint, i4 integer, i8 bigint,' - ' d numeric, f4 real, f8 double precision, m money,' - ' v4 varchar(4), c4 char(4), t text,' - ' b boolean, ts timestamp') - tests = [dict(i2=None, i4=None, i8=None), - (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)), - (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)), - dict(i2=42, i4=123456, i8=9876543210), - dict(i2=2 ** 15 - 1, - i4=int(2 ** 31 - 1), i8=long(2 ** 63 - 1)), - dict(d=None), (dict(d=''), dict(d=None)), - dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))), - dict(f4=None, f8=None), dict(f4=0, f8=0), - (dict(f4='', f8=''), dict(f4=None, f8=None)), - (dict(d=1234.5, f4=1234.5, f8=1234.5), - dict(d=Decimal('1234.5'))), - dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875), - dict(d=Decimal('123456789.9876543212345678987654321')), - dict(m=None), (dict(m=''), dict(m=None)), - dict(m=Decimal('-1234.56')), - (dict(m='-1234.56'), dict(m=Decimal('-1234.56'))), - dict(m=Decimal('1234.56')), dict(m=Decimal('123456')), - (dict(m='1234.56'), dict(m=Decimal('1234.56'))), - (dict(m=1234.5), dict(m=Decimal('1234.5'))), - (dict(m=-1234.5), dict(m=Decimal('-1234.5'))), - (dict(m=123456), dict(m=Decimal('123456'))), - (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))), - dict(b=None), (dict(b=''), dict(b=None)), - dict(b='f'), dict(b='t'), - (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')), - (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')), - (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')), - (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')), - (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')), - (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')), - dict(v4=None, c4=None, t=None), - (dict(v4='', c4='', t=''), dict(c4=' ' * 4)), - dict(v4='1234', c4='1234', t='1234' * 10), - dict(v4='abcd', c4='abcd', t='abcdefg'), - (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')), - dict(ts=None), (dict(ts=''), dict(ts=None)), - (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)), - dict(ts='2012-12-21 00:00:00'), - (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')), - dict(ts='2012-12-21 12:21:12'), - dict(ts='2013-01-05 12:13:14'), - dict(ts='current_timestamp')] + self.create_table( + table, 'i2 smallint, i4 integer, i8 bigint,' + ' d numeric, f4 real, f8 double precision, m money,' + ' v4 varchar(4), c4 char(4), t text,' + ' b boolean, ts timestamp') + tests: list[dict | tuple[dict, dict]] = [ + dict(i2=None, i4=None, i8=None), + (dict(i2='', i4='', i8=''), dict(i2=None, i4=None, i8=None)), + (dict(i2=0, i4=0, i8=0), dict(i2=0, i4=0, i8=0)), + dict(i2=42, i4=123456, i8=9876543210), + dict(i2=2 ** 15 - 1, i4=2 ** 31 - 1, i8=2 ** 63 - 1), + dict(d=None), (dict(d=''), dict(d=None)), + dict(d=Decimal(0)), (dict(d=0), dict(d=Decimal(0))), + dict(f4=None, f8=None), dict(f4=0, f8=0), + (dict(f4='', f8=''), dict(f4=None, f8=None)), + (dict(d=1234.5, f4=1234.5, f8=1234.5), + dict(d=Decimal('1234.5'))), + dict(d=Decimal('123.456789'), f4=12.375, f8=123.4921875), + dict(d=Decimal('123456789.9876543212345678987654321')), + dict(m=None), (dict(m=''), dict(m=None)), + dict(m=Decimal('-1234.56')), + (dict(m='-1234.56'), dict(m=Decimal('-1234.56'))), + dict(m=Decimal('1234.56')), dict(m=Decimal('123456')), + (dict(m='1234.56'), dict(m=Decimal('1234.56'))), + (dict(m=1234.5), dict(m=Decimal('1234.5'))), + (dict(m=-1234.5), dict(m=Decimal('-1234.5'))), + (dict(m=123456), dict(m=Decimal('123456'))), + (dict(m='1234567.89'), dict(m=Decimal('1234567.89'))), + dict(b=None), (dict(b=''), dict(b=None)), + dict(b='f'), dict(b='t'), + (dict(b=0), dict(b='f')), (dict(b=1), dict(b='t')), + (dict(b=False), dict(b='f')), (dict(b=True), dict(b='t')), + (dict(b='0'), dict(b='f')), (dict(b='1'), dict(b='t')), + (dict(b='n'), dict(b='f')), (dict(b='y'), dict(b='t')), + (dict(b='no'), dict(b='f')), (dict(b='yes'), dict(b='t')), + (dict(b='off'), dict(b='f')), (dict(b='on'), dict(b='t')), + dict(v4=None, c4=None, t=None), + (dict(v4='', c4='', t=''), dict(c4=' ' * 4)), + dict(v4='1234', c4='1234', t='1234' * 10), + dict(v4='abcd', c4='abcd', t='abcdefg'), + (dict(v4='abc', c4='abc', t='abc'), dict(c4='abc ')), + dict(ts=None), (dict(ts=''), dict(ts=None)), + (dict(ts=0), dict(ts=None)), (dict(ts=False), dict(ts=None)), + dict(ts='2012-12-21 00:00:00'), + (dict(ts='2012-12-21'), dict(ts='2012-12-21 00:00:00')), + dict(ts='2012-12-21 12:21:12'), + dict(ts='2013-01-05 12:13:14'), + dict(ts='current_timestamp')] for test in tests: if isinstance(test, dict): - data = test - change = {} + data: dict = test + change: dict = {} else: data, change = test expect = data.copy() @@ -1801,23 +1752,23 @@ def testInsert(self): if ts == 'current_timestamp': ts = data['ts'] self.assertIsInstance(ts, datetime) - self.assertEqual(ts.strftime('%Y-%m-%d'), - strftime('%Y-%m-%d')) + self.assertEqual( + ts.strftime('%Y-%m-%d'), strftime('%Y-%m-%d')) else: ts = datetime.strptime(ts, '%Y-%m-%d %H:%M:%S') expect['ts'] = ts self.assertEqual(data, expect) - data = query('select * from "%s"' % table).dictresult()[0] + data = query(f'select * from "{table}"').dictresult()[0] data = dict(item for item in data.items() if item[0] in expect) self.assertEqual(data, expect) - query('delete from "%s"' % table) + query(f'truncate table "{table}"') - def testInsertWithOids(self): - if not self.oids: + def test_insert_with_oids(self): + if not self.supports_oids: self.skipTest("database does not support tables with oids") insert = self.db.insert query = self.db.query - self.createTable('test_table', 'n int', oids=True) + self.create_table('test_table', 'n int', oids=True) self.assertRaises(pg.ProgrammingError, insert, 'test_table', m=1) r = insert('test_table', n=1) self.assertIsInstance(r, dict) @@ -1865,7 +1816,7 @@ def testInsertWithOids(self): q = 'select n from test_table order by 1 limit 9' r = ' '.join(str(row[0]) for row in query(q).getresult()) self.assertEqual(r, '1 2 3 3 3 4 5 6') - query("truncate test_table") + query("truncate table test_table") query("alter table test_table add unique (n)") r = insert('test_table', dict(n=7)) self.assertIsInstance(r, dict) @@ -1882,31 +1833,31 @@ def testInsertWithOids(self): r = ' '.join(str(row[0]) for row in query(q).getresult()) self.assertEqual(r, '6 7') - def testInsertWithQuotedNames(self): + def test_insert_with_quoted_names(self): insert = self.db.insert query = self.db.query table = 'test table for insert()' - self.createTable(table, '"Prime!" smallint primary key,' + self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') - r = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'} + r: Any = {'Prime!': 11, 'much space': 2002, 'Questions?': 'What?'} r = insert(table, r) self.assertIsInstance(r, dict) self.assertEqual(r['Prime!'], 11) self.assertEqual(r['much space'], 2002) self.assertEqual(r['Questions?'], 'What?') - r = query('select * from "%s" limit 2' % table).dictresult() + r = query(f'select * from "{table}" limit 2').dictresult() self.assertEqual(len(r), 1) r = r[0] self.assertEqual(r['Prime!'], 11) self.assertEqual(r['much space'], 2002) self.assertEqual(r['Questions?'], 'What?') - def testInsertIntoView(self): + def test_insert_into_view(self): insert = self.db.insert query = self.db.query - query("truncate test") + query("truncate table test") q = 'select * from test_view order by i4 limit 3' - r = query(q).getresult() + r: Any = query(q).getresult() self.assertEqual(r, []) r = dict(i4=1234, v4='abcd') insert('test', r) @@ -1918,13 +1869,7 @@ def testInsertIntoView(self): r = query(q).getresult() self.assertEqual(r, [(1234, 'abcd')]) r = dict(i4=5678, v4='efgh') - try: - insert('test_view', r) - except (pg.OperationalError, pg.NotSupportedError) as error: - if self.db.server_version < 90300: - # must setup rules in older PostgreSQL versions - self.skipTest('database cannot insert into view') - self.fail(str(error)) + insert('test_view', r) self.assertNotIn('i2', r) self.assertEqual(r['i4'], 5678) self.assertNotIn('i8', r) @@ -1933,30 +1878,56 @@ def testInsertIntoView(self): r = query(q).getresult() self.assertEqual(r, [(1234, 'abcd'), (5678, 'efgh')]) - def testUpdate(self): + def test_insert_with_generated_columns(self): + insert = self.db.insert + get = self.db.get + server_version = self.db.server_version + table = 'insert_test_table_2' + table_def = 'i int not null' + if server_version >= 100000: + table_def += ( + ', a int generated always as identity' + ', d int generated by default as identity primary key') + else: + table_def += ', a int not null default 1, d int primary key' + if server_version >= 120000: + table_def += ', j int generated always as (i + 7) stored' + else: + table_def += ', j int not null default 42' + self.create_table(table, table_def) + i, d = 35, 1001 + j = i + 7 + r = insert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r = get(table, d) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + + def test_update(self): update = self.db.update query = self.db.query self.assertRaises(pg.ProgrammingError, update, 'test', i2=2, i4=4, i8=8) table = 'update_test_table' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', values=enumerate('xyz', start=1)) self.assertRaises(pg.DatabaseError, self.db.get, table, 4) r = self.db.get(table, 2) r['t'] = 'u' s = update(table, r) self.assertEqual(s, r) - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'u') - def testUpdateWithOids(self): - if not self.oids: + def test_update_with_oids(self): + if not self.supports_oids: self.skipTest("database does not support tables with oids") update = self.db.update get = self.db.get query = self.db.query - self.createTable('test_table', 'n int', oids=True, values=[1]) + self.create_table('test_table', 'n int', oids=True, values=[1]) s = get('test_table', 1, 'n') self.assertIsInstance(s, dict) self.assertEqual(s['n'], 1) @@ -2030,27 +2001,27 @@ def testUpdateWithOids(self): r = query(q).getresult() self.assertEqual(r, [(1, 3), (4, 7)]) - def testUpdateWithoutOid(self): + def test_update_without_oid(self): update = self.db.update query = self.db.query self.assertRaises(pg.ProgrammingError, update, 'test', i2=2, i4=4, i8=8) table = 'update_test_table' - self.createTable(table, 'n integer primary key, t text', oids=False, + self.create_table(table, 'n integer primary key, t text', oids=False, values=enumerate('xyz', start=1)) r = self.db.get(table, 2) r['t'] = 'u' s = update(table, r) self.assertEqual(s, r) - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'u') - def testUpdateWithCompositeKey(self): + def test_update_with_composite_key(self): update = self.db.update query = self.db.query table = 'update_test_table_1' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', values=enumerate('abc', start=1)) self.assertRaises(KeyError, update, table, dict(t='b')) s = dict(n=2, t='d') @@ -2058,66 +2029,93 @@ def testUpdateWithCompositeKey(self): self.assertIs(r, s) self.assertEqual(r['n'], 2) self.assertEqual(r['t'], 'd') - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'd') s.update(dict(n=4, t='e')) r = update(table, s) self.assertEqual(r['n'], 4) self.assertEqual(r['t'], 'e') - q = 'select t from "%s" where n=2' % table + q = f'select t from "{table}" where n=2' r = query(q).getresult()[0][0] self.assertEqual(r, 'd') - q = 'select t from "%s" where n=4' % table + q = f'select t from "{table}" where n=4' r = query(q).getresult() self.assertEqual(len(r), 0) - query('drop table "%s"' % table) + query(f'drop table "{table}"') table = 'update_test_table_2' - self.createTable(table, + self.create_table(table, 'n integer, m integer, t text, primary key (n, m)', values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) for n in range(3) for m in range(2)]) self.assertRaises(KeyError, update, table, dict(n=2, t='b')) self.assertEqual(update(table, dict(n=2, m=2, t='x'))['t'], 'x') - q = 'select t from "%s" where n=2 order by m' % table + q = f'select t from "{table}" where n=2 order by m' r = [r[0] for r in query(q).getresult()] self.assertEqual(r, ['c', 'x']) - def testUpdateWithQuotedNames(self): + def test_update_with_quoted_names(self): update = self.db.update query = self.db.query table = 'test table for update()' - self.createTable(table, '"Prime!" smallint primary key,' + self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text', values=[(13, 3003, 'Why!')]) - r = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'} + r: Any = {'Prime!': 13, 'much space': 7007, 'Questions?': 'When?'} r = update(table, r) self.assertIsInstance(r, dict) self.assertEqual(r['Prime!'], 13) self.assertEqual(r['much space'], 7007) self.assertEqual(r['Questions?'], 'When?') - r = query('select * from "%s" limit 2' % table).dictresult() + r = query(f'select * from "{table}" limit 2').dictresult() self.assertEqual(len(r), 1) r = r[0] self.assertEqual(r['Prime!'], 13) self.assertEqual(r['much space'], 7007) self.assertEqual(r['Questions?'], 'When?') - def testUpsert(self): + def test_update_with_generated_columns(self): + update = self.db.update + get = self.db.get + query = self.db.query + server_version = self.db.server_version + table = 'update_test_table_2' + table_def = 'i int not null' + if server_version >= 100000: + table_def += ( + ', a int generated always as identity' + ', d int generated by default as identity primary key') + else: + table_def += ', a int not null default 1, d int primary key' + if server_version >= 120000: + table_def += ', j int generated always as (i + 7) stored' + else: + table_def += ', j int not null default 42' + self.create_table(table, table_def) + i, d = 35, 1001 + j = i + 7 + r: Any = query(f'insert into {table} (i, d) values ({i}, {d})') + self.assertEqual(r, '1') + r = get(table, d) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r['i'] += 1 + r = update(table, r) + i += 1 + if server_version >= 120000: + j += 1 + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + + def test_upsert(self): upsert = self.db.upsert query = self.db.query self.assertRaises(pg.ProgrammingError, upsert, 'test', i2=2, i4=4, i8=8) table = 'upsert_test_table' - self.createTable(table, 'n integer primary key, t text') - s = dict(n=1, t='x') - try: - r = upsert(table, s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + self.create_table(table, 'n integer primary key, t text') + s: dict = dict(n=1, t='x') + r: Any = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['n'], 1) self.assertEqual(r['t'], 'x') @@ -2126,7 +2124,7 @@ def testUpsert(self): self.assertIs(r, s) self.assertEqual(r['n'], 2) self.assertEqual(r['t'], 'y') - q = 'select n, t from "%s" order by n limit 3' % table + q = f'select n, t from "{table}" order by n limit 3' r = query(q).getresult() self.assertEqual(r, [(1, 'x'), (2, 'y')]) s.update(t='z') @@ -2175,17 +2173,23 @@ def testUpsert(self): s = dict(m=3, u='z') r = upsert(table, s, oid='invalid') self.assertIs(r, s) + s = dict(n=2) + # do not modify columns missing in the dict + r = upsert(table, s) + self.assertIs(r, s) + r = query(q).getresult() + self.assertEqual(r, [(1, 'x2'), (2, 'y3')]) - def testUpsertWithOids(self): - if not self.oids: + def test_upsert_with_oids(self): + if not self.supports_oids: self.skipTest("database does not support tables with oids") upsert = self.db.upsert get = self.db.get query = self.db.query - self.createTable('test_table', 'n int', oids=True, values=[1]) + self.create_table('test_table', 'n int', oids=True, values=[1]) self.assertRaises(pg.ProgrammingError, upsert, 'test_table', dict(n=2)) - r = get('test_table', 1, 'n') + r: Any = get('test_table', 1, 'n') self.assertIsInstance(r, dict) self.assertEqual(r['n'], 1) qoid = 'oid(test_table)' @@ -2199,12 +2203,7 @@ def testUpsertWithOids(self): self.assertIn('m', self.db.get_attnames('test_table', flush=True)) self.assertEqual('n', self.db.pkey('test_table', flush=True)) s = dict(n=2) - try: - r = upsert('test_table', s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + r = upsert('test_table', s) self.assertIs(r, s) self.assertEqual(r['n'], 2) self.assertIsNone(r['m']) @@ -2262,19 +2261,14 @@ def testUpsertWithOids(self): q = query("select n, m from test_table order by n limit 3") self.assertEqual(q.getresult(), [(1, 5), (2, 10)]) - def testUpsertWithCompositeKey(self): + def test_upsert_with_composite_key(self): upsert = self.db.upsert query = self.db.query table = 'upsert_test_table_2' - self.createTable( + self.create_table( table, 'n integer, m integer, t text, primary key (n, m)') - s = dict(n=1, m=2, t='x') - try: - r = upsert(table, s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + s: dict = dict(n=1, m=2, t='x') + r: Any = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['n'], 1) self.assertEqual(r['m'], 2) @@ -2285,7 +2279,7 @@ def testUpsertWithCompositeKey(self): self.assertEqual(r['n'], 1) self.assertEqual(r['m'], 3) self.assertEqual(r['t'], 'y') - q = 'select n, m, t from "%s" order by n, m limit 3' % table + q = f'select n, m, t from "{table}" order by n, m limit 3' r = query(q).getresult() self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'y')]) s.update(t='z') @@ -2329,24 +2323,19 @@ def testUpsertWithCompositeKey(self): r = query(q).getresult() self.assertEqual(r, [(1, 2, 'x'), (1, 3, 'nm'), (2, 3, 'y')]) - def testUpsertWithQuotedNames(self): + def test_upsert_with_quoted_names(self): upsert = self.db.upsert query = self.db.query table = 'test table for upsert()' - self.createTable(table, '"Prime!" smallint primary key,' + self.create_table(table, '"Prime!" smallint primary key,' ' "much space" integer, "Questions?" text') - s = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'} - try: - r = upsert(table, s) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) + s: dict = {'Prime!': 31, 'much space': 9009, 'Questions?': 'Yes.'} + r: Any = upsert(table, s) self.assertIs(r, s) self.assertEqual(r['Prime!'], 31) self.assertEqual(r['much space'], 9009) self.assertEqual(r['Questions?'], 'Yes.') - q = 'select * from "%s" limit 2' % table + q = f'select * from "{table}" limit 2' r = query(q).getresult() self.assertEqual(r, [(31, 9009, 'Yes.')]) s.update({'Questions?': 'No.'}) @@ -2358,56 +2347,88 @@ def testUpsertWithQuotedNames(self): r = query(q).getresult() self.assertEqual(r, [(31, 9009, 'No.')]) - def testClear(self): + def test_upsert_with_generated_columns(self): + upsert = self.db.upsert + get = self.db.get + server_version = self.db.server_version + table = 'upsert_test_table_2' + table_def = 'i int not null' + if server_version >= 100000: + table_def += ( + ', a int generated always as identity' + ', d int generated by default as identity primary key') + else: + table_def += ', a int not null default 1, d int primary key' + if server_version >= 120000: + table_def += ', j int generated always as (i + 7) stored' + else: + table_def += ', j int not null default 42' + self.create_table(table, table_def) + i, d = 35, 1001 + j = i + 7 + r: Any = upsert(table, {'i': i, 'd': d, 'a': 1, 'j': j}) + self.assertIsInstance(r, dict) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r['i'] += 1 + r = upsert(table, r) + i += 1 + if server_version >= 120000: + j += 1 + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + r = get(table, d) + self.assertEqual(r, {'a': 1, 'd': d, 'i': i, 'j': j}) + + def test_clear(self): clear = self.db.clear f = False if pg.get_bool() else 'f' - r = clear('test') + r: Any = clear('test') result = dict( i2=0, i4=0, i8=0, d=0, f4=0, f8=0, m=0, v4='', c4='', t='') self.assertEqual(r, result) table = 'clear_test_table' - self.createTable(table, - 'n integer, f float, b boolean, d date, t text') + self.create_table( + table, 'n integer, f float, b boolean, d date, t text') r = clear(table) result = dict(n=0, f=0, b=f, d='', t='') self.assertEqual(r, result) r['a'] = r['f'] = r['n'] = 1 r['d'] = r['t'] = 'x' r['b'] = 't' - r['oid'] = long(1) + r['oid'] = 1 r = clear(table, r) - result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=long(1)) + result = dict(a=1, n=0, f=0, b=f, d='', t='', oid=1) self.assertEqual(r, result) - def testClearWithQuotedNames(self): + def test_clear_with_quoted_names(self): clear = self.db.clear table = 'test table for clear()' - self.createTable(table, '"Prime!" smallint primary key,' - ' "much space" integer, "Questions?" text') + self.create_table( + table, '"Prime!" smallint primary key,' + ' "much space" integer, "Questions?" text') r = clear(table) self.assertIsInstance(r, dict) self.assertEqual(r['Prime!'], 0) self.assertEqual(r['much space'], 0) self.assertEqual(r['Questions?'], '') - def testDelete(self): + def test_delete(self): delete = self.db.delete query = self.db.query self.assertRaises(pg.ProgrammingError, delete, 'test', dict(i2=2, i4=4, i8=8)) table = 'delete_test_table' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', oids=False, values=enumerate('xyz', start=1)) self.assertRaises(pg.DatabaseError, self.db.get, table, 4) - r = self.db.get(table, 1) - s = delete(table, r) + r: Any = self.db.get(table, 1) + s: Any = delete(table, r) self.assertEqual(s, 1) r = self.db.get(table, 3) s = delete(table, r) self.assertEqual(s, 1) s = delete(table, r) self.assertEqual(s, 0) - r = query('select * from "%s"' % table).dictresult() + r = query(f'select * from "{table}"').dictresult() self.assertEqual(len(r), 1) r = r[0] result = {'n': 2, 't': 'y'} @@ -2423,16 +2444,16 @@ def testDelete(self): s = delete(table, r) self.assertEqual(s, 0) - def testDeleteWithOids(self): - if not self.oids: + def test_delete_with_oids(self): + if not self.supports_oids: self.skipTest("database does not support tables with oids") delete = self.db.delete get = self.db.get query = self.db.query - self.createTable('test_table', 'n int', oids=True, values=range(1, 7)) - r = dict(n=3) + self.create_table('test_table', 'n int', oids=True, values=range(1, 7)) + r: Any = dict(n=3) self.assertRaises(pg.ProgrammingError, delete, 'test_table', r) - s = get('test_table', 1, 'n') + s: Any = get('test_table', 1, 'n') qoid = 'oid(test_table)' self.assertIn(qoid, s) r = delete('test_table', s) @@ -2475,7 +2496,7 @@ def testDeleteWithOids(self): self.assertIn('m', self.db.get_attnames('test_table', flush=True)) self.assertEqual('n', self.db.pkey('test_table', flush=True)) for i in range(5): - query("insert into test_table values (%d, %d)" % (i + 1, i + 2)) + query(f"insert into test_table values ({i + 1}, {i + 2})") s = dict(m=2) self.assertRaises(KeyError, delete, 'test_table', s) s = dict(m=2, oid=oid) @@ -2519,62 +2540,64 @@ def testDeleteWithOids(self): self.assertEqual(r, 1) self.assertEqual(query(q).getresult()[0], (None, 0)) - def testDeleteWithCompositeKey(self): + def test_delete_with_composite_key(self): query = self.db.query table = 'delete_test_table_1' - self.createTable(table, 'n integer primary key, t text', + self.create_table(table, 'n integer primary key, t text', values=enumerate('abc', start=1)) self.assertRaises(KeyError, self.db.delete, table, dict(t='b')) self.assertEqual(self.db.delete(table, dict(n=2)), 1) - r = query('select t from "%s" where n=2' % table).getresult() + r: Any = query(f'select t from "{table}" where n=2').getresult() self.assertEqual(r, []) self.assertEqual(self.db.delete(table, dict(n=2)), 0) - r = query('select t from "%s" where n=3' % table).getresult()[0][0] + r = query(f'select t from "{table}" where n=3').getresult()[0][0] self.assertEqual(r, 'c') table = 'delete_test_table_2' - self.createTable(table, - 'n integer, m integer, t text, primary key (n, m)', - values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) - for n in range(3) for m in range(2)]) + self.create_table( + table, 'n integer, m integer, t text, primary key (n, m)', + values=[(n + 1, m + 1, chr(ord('a') + 2 * n + m)) + for n in range(3) for m in range(2)]) self.assertRaises(KeyError, self.db.delete, table, dict(n=2, t='b')) self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 1) - r = [r[0] for r in query('select t from "%s" where n=2' - ' order by m' % table).getresult()] + r = [r[0] for r in query(f'select t from "{table}" where n=2' + ' order by m').getresult()] self.assertEqual(r, ['c']) self.assertEqual(self.db.delete(table, dict(n=2, m=2)), 0) - r = [r[0] for r in query('select t from "%s" where n=3' - ' order by m' % table).getresult()] + r = [r[0] for r in query(f'select t from "{table}" where n=3' + ' order by m').getresult()] self.assertEqual(r, ['e', 'f']) self.assertEqual(self.db.delete(table, dict(n=3, m=1)), 1) - r = [r[0] for r in query('select t from "%s" where n=3' - ' order by m' % table).getresult()] + r = [r[0] for r in query(f'select t from "{table}" where n=3' + f' order by m').getresult()] self.assertEqual(r, ['f']) - def testDeleteWithQuotedNames(self): + def test_delete_with_quoted_names(self): delete = self.db.delete query = self.db.query table = 'test table for delete()' - self.createTable(table, '"Prime!" smallint primary key,' - ' "much space" integer, "Questions?" text', + self.create_table( + table, '"Prime!" smallint primary key,' + ' "much space" integer, "Questions?" text', values=[(19, 5005, 'Yes!')]) - r = {'Prime!': 17} + r: Any = {'Prime!': 17} r = delete(table, r) self.assertEqual(r, 0) - r = query('select count(*) from "%s"' % table).getresult() + r = query(f'select count(*) from "{table}"').getresult() self.assertEqual(r[0][0], 1) r = {'Prime!': 19} r = delete(table, r) self.assertEqual(r, 1) - r = query('select count(*) from "%s"' % table).getresult() + r = query(f'select count(*) from "{table}"').getresult() self.assertEqual(r[0][0], 0) - def testDeleteReferenced(self): + def test_delete_referenced(self): delete = self.db.delete query = self.db.query - self.createTable('test_parent', - 'n smallint primary key', values=range(3)) - self.createTable('test_child', - 'n smallint primary key references test_parent', values=range(3)) + self.create_table( + 'test_parent', 'n smallint primary key', values=range(3)) + self.create_table( + 'test_child', 'n smallint primary key references test_parent', + values=range(3)) q = ("select (select count(*) from test_parent)," " (select count(*) from test_child)") self.assertEqual(query(q).getresult()[0], (3, 3)) @@ -2582,7 +2605,7 @@ def testDeleteReferenced(self): delete, 'test_parent', None, n=2) self.assertRaises(pg.IntegrityError, delete, 'test_parent *', None, n=2) - r = delete('test_child', None, n=2) + r: Any = delete('test_child', None, n=2) self.assertEqual(r, 1) self.assertEqual(query(q).getresult()[0], (3, 2)) r = delete('test_parent', None, n=2) @@ -2605,45 +2628,46 @@ def testDeleteReferenced(self): q = "select n from test_parent natural join test_child limit 2" self.assertEqual(query(q).getresult(), [(1,)]) - def testTempCrud(self): + def test_temp_crud(self): table = 'test_temp_table' - self.createTable(table, "n int primary key, t varchar", temporary=True) + self.create_table(table, "n int primary key, t varchar", + temporary=True) self.db.insert(table, dict(n=1, t='one')) self.db.insert(table, dict(n=2, t='too')) self.db.insert(table, dict(n=3, t='three')) - r = self.db.get(table, 2) + r: Any = self.db.get(table, 2) self.assertEqual(r['t'], 'too') self.db.update(table, dict(n=2, t='two')) r = self.db.get(table, 2) self.assertEqual(r['t'], 'two') self.db.delete(table, r) - r = self.db.query('select n, t from %s order by 1' % table).getresult() + r = self.db.query(f'select n, t from {table} order by 1').getresult() self.assertEqual(r, [(1, 'one'), (3, 'three')]) - def testTruncate(self): + def test_truncate(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, None) self.assertRaises(TypeError, truncate, 42) self.assertRaises(TypeError, truncate, dict(test_table=None)) query = self.db.query - self.createTable('test_table', 'n smallint', - temporary=False, values=[1] * 3) + self.create_table('test_table', 'n smallint', + temporary=False, values=[1] * 3) q = "select count(*) from test_table" - r = query(q).getresult()[0][0] + r: Any = query(q).getresult()[0][0] self.assertEqual(r, 3) truncate('test_table') r = query(q).getresult()[0][0] self.assertEqual(r, 0) - for i in range(3): + for _i in range(3): query("insert into test_table values (1)") r = query(q).getresult()[0][0] self.assertEqual(r, 3) truncate('public.test_table') r = query(q).getresult()[0][0] self.assertEqual(r, 0) - self.createTable('test_table_2', 'n smallint', temporary=True) + self.create_table('test_table_2', 'n smallint', temporary=True) for t in (list, tuple, set): - for i in range(3): + for _i in range(3): query("insert into test_table values (1)") query("insert into test_table_2 values (2)") q = ("select (select count(*) from test_table)," @@ -2654,59 +2678,59 @@ def testTruncate(self): r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - def testTruncateRestart(self): + def test_truncate_restart(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, 'test_table', restart='invalid') query = self.db.query - self.createTable('test_table', 'n serial, t text') - for n in range(3): + self.create_table('test_table', 'n serial, t text') + for _n in range(3): query("insert into test_table (t) values ('test')") q = "select count(n), min(n), max(n) from test_table" - r = query(q).getresult()[0] + r: Any = query(q).getresult()[0] self.assertEqual(r, (3, 1, 3)) truncate('test_table') r = query(q).getresult()[0] self.assertEqual(r, (0, None, None)) - for n in range(3): + for _n in range(3): query("insert into test_table (t) values ('test')") r = query(q).getresult()[0] self.assertEqual(r, (3, 4, 6)) truncate('test_table', restart=True) r = query(q).getresult()[0] self.assertEqual(r, (0, None, None)) - for n in range(3): + for _n in range(3): query("insert into test_table (t) values ('test')") r = query(q).getresult()[0] self.assertEqual(r, (3, 1, 3)) - def testTruncateCascade(self): + def test_truncate_cascade(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, 'test_table', cascade='invalid') query = self.db.query - self.createTable('test_parent', 'n smallint primary key', + self.create_table('test_parent', 'n smallint primary key', values=range(3)) - self.createTable('test_child', + self.create_table('test_child', 'n smallint primary key references test_parent (n)', values=range(3)) q = ("select (select count(*) from test_parent)," " (select count(*) from test_child)") - r = query(q).getresult()[0] + r: Any = query(q).getresult()[0] self.assertEqual(r, (3, 3)) self.assertRaises(pg.NotSupportedError, truncate, 'test_parent') truncate(['test_parent', 'test_child']) r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) for n in range(3): - query("insert into test_parent (n) values (%d)" % n) - query("insert into test_child (n) values (%d)" % n) + query(f"insert into test_parent (n) values ({n})") + query(f"insert into test_child (n) values ({n})") r = query(q).getresult()[0] self.assertEqual(r, (3, 3)) truncate('test_parent', cascade=True) r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) for n in range(3): - query("insert into test_parent (n) values (%d)" % n) - query("insert into test_child (n) values (%d)" % n) + query(f"insert into test_parent (n) values ({n})") + query(f"insert into test_child (n) values ({n})") r = query(q).getresult()[0] self.assertEqual(r, (3, 3)) truncate('test_child') @@ -2717,13 +2741,13 @@ def testTruncateCascade(self): r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - def testTruncateOnly(self): + def test_truncate_only(self): truncate = self.db.truncate self.assertRaises(TypeError, truncate, 'test_table', only='invalid') query = self.db.query - self.createTable('test_parent', 'n smallint') - self.createTable('test_child', 'm smallint) inherits (test_parent') - for n in range(3): + self.create_table('test_parent', 'n smallint') + self.create_table('test_child', 'm smallint) inherits (test_parent') + for _n in range(3): query("insert into test_parent (n) values (1)") query("insert into test_child (n, m) values (2, 3)") q = ("select (select count(*) from test_parent)," @@ -2733,7 +2757,7 @@ def testTruncateOnly(self): truncate('test_parent') r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - for n in range(3): + for _n in range(3): query("insert into test_parent (n) values (1)") query("insert into test_child (n, m) values (2, 3)") r = query(q).getresult()[0] @@ -2741,7 +2765,7 @@ def testTruncateOnly(self): truncate('test_parent*') r = query(q).getresult()[0] self.assertEqual(r, (0, 0)) - for n in range(3): + for _n in range(3): query("insert into test_parent (n) values (1)") query("insert into test_child (n, m) values (2, 3)") r = query(q).getresult()[0] @@ -2754,12 +2778,13 @@ def testTruncateOnly(self): self.assertEqual(r, (0, 0)) self.assertRaises(ValueError, truncate, 'test_parent*', only=True) truncate('test_parent*', only=False) - self.createTable('test_parent_2', 'n smallint') - self.createTable('test_child_2', 'm smallint) inherits (test_parent_2') + self.create_table('test_parent_2', 'n smallint') + self.create_table('test_child_2', + 'm smallint) inherits (test_parent_2') for t in '', '_2': - for n in range(3): - query("insert into test_parent%s (n) values (1)" % t) - query("insert into test_child%s (n, m) values (2, 3)" % t) + for _n in range(3): + query(f"insert into test_parent{t} (n) values (1)") + query(f"insert into test_child{t} (n, m) values (2, 3)") q = ("select (select count(*) from test_parent)," " (select count(*) from test_child)," " (select count(*) from test_parent_2)," @@ -2772,42 +2797,44 @@ def testTruncateOnly(self): truncate(['test_parent', 'test_parent_2'], only=False) r = query(q).getresult()[0] self.assertEqual(r, (0, 0, 0, 0)) - self.assertRaises(ValueError, truncate, + self.assertRaises( + ValueError, truncate, ['test_parent*', 'test_child'], only=[True, False]) truncate(['test_parent*', 'test_child'], only=[False, True]) - def testTruncateQuoted(self): + def test_truncate_quoted(self): truncate = self.db.truncate query = self.db.query table = "test table for truncate()" - self.createTable(table, 'n smallint', temporary=False, values=[1] * 3) - q = 'select count(*) from "%s"' % table + self.create_table(table, 'n smallint', temporary=False, values=[1] * 3) + q = f'select count(*) from "{table}"' r = query(q).getresult()[0][0] self.assertEqual(r, 3) truncate(table) r = query(q).getresult()[0][0] self.assertEqual(r, 0) - for i in range(3): - query('insert into "%s" values (1)' % table) + for _i in range(3): + query(f'insert into "{table}" values (1)') r = query(q).getresult()[0][0] self.assertEqual(r, 3) - truncate('public."%s"' % table) + truncate(f'public."{table}"') r = query(q).getresult()[0][0] self.assertEqual(r, 0) - def testGetAsList(self): + # noinspection PyUnresolvedReferences + def test_get_as_list(self): get_as_list = self.db.get_as_list self.assertRaises(TypeError, get_as_list) self.assertRaises(TypeError, get_as_list, None) query = self.db.query table = 'test_aslist' - r = query('select 1 as colname').namedresult()[0] + r: Any = query('select 1 as colname').namedresult()[0] self.assertIsInstance(r, tuple) named = hasattr(r, 'colname') names = [(1, 'Homer'), (2, 'Marge'), (3, 'Bart'), (4, 'Lisa'), (5, 'Maggie')] - self.createTable(table, - 'id smallint primary key, name varchar', values=names) + self.create_table( + table, 'id smallint primary key, name varchar', values=names) r = get_as_list(table) self.assertIsInstance(r, list) self.assertEqual(r, names) @@ -2820,7 +2847,7 @@ def testGetAsList(self): self.assertEqual(t._asdict(), dict(id=n[0], name=n[1])) r = get_as_list(table, what='name') self.assertIsInstance(r, list) - expected = sorted((row[1],) for row in names) + expected: Any = sorted((row[1],) for row in names) self.assertEqual(r, expected) r = get_as_list(table, what='name, id') self.assertIsInstance(r, list) @@ -2835,8 +2862,8 @@ def testGetAsList(self): r = get_as_list(table, what='name', where="name like 'Ma%'") self.assertIsInstance(r, list) self.assertEqual(r, [('Maggie',), ('Marge',)]) - r = get_as_list(table, what='name', - where=["name like 'Ma%'", "name like '%r%'"]) + r = get_as_list( + table, what='name', where=["name like 'Ma%'", "name like '%r%'"]) self.assertIsInstance(r, list) self.assertEqual(r, [('Marge',)]) r = get_as_list(table, what='name', order='id') @@ -2872,10 +2899,10 @@ def testGetAsList(self): r = get_as_list(table, what='name', limit=1, scalar=True) self.assertIsInstance(r, list) self.assertEqual(r, expected[:1]) - query('alter table "%s" drop constraint "%s_pkey"' % (table, table)) + query(f'alter table "{table}" drop constraint "{table}_pkey"') self.assertRaises(KeyError, self.db.pkey, table, flush=True) names.insert(1, (1, 'Snowball')) - query('insert into "%s" values ($1, $2)' % table, (1, 'Snowball')) + query(f'insert into "{table}" values ($1, $2)', (1, 'Snowball')) r = get_as_list(table) self.assertIsInstance(r, list) self.assertEqual(r, names) @@ -2887,11 +2914,11 @@ def testGetAsList(self): self.assertIsInstance(r, list) self.assertEqual(set(r), set(names)) # test with arbitrary from clause - from_table = '(select lower(name) as n2 from "%s") as t2' % table + from_table = f'(select lower(name) as n2 from "{table}") as t2' r = get_as_list(from_table) self.assertIsInstance(r, list) - r = set(row[0] for row in r) - expected = set(row[1].lower() for row in names) + r = {row[0] for row in r} + expected = {row[1].lower() for row in names} self.assertEqual(r, expected) r = get_as_list(from_table, order='n2', scalar=True) self.assertIsInstance(r, list) @@ -2907,7 +2934,8 @@ def testGetAsList(self): else: self.assertEqual(t, ('bart',)) - def testGetAsDict(self): + # noinspection PyUnresolvedReferences + def test_get_as_dict(self): get_as_dict = self.db.get_as_dict self.assertRaises(TypeError, get_as_dict) self.assertRaises(TypeError, get_as_dict, None) @@ -2920,8 +2948,8 @@ def testGetAsDict(self): named = hasattr(r, 'colname') colors = [(1, '#7cb9e8', 'Aero'), (2, '#b5a642', 'Brass'), (3, '#b2ffff', 'Celeste'), (4, '#c19a6b', 'Desert')] - self.createTable(table, - 'id smallint primary key, rgb char(7), name varchar', + self.create_table( + table, 'id smallint primary key, rgb char(7), name varchar', values=colors) # keyname must be string, list or tuple self.assertRaises(KeyError, get_as_dict, table, 3) @@ -2930,8 +2958,8 @@ def testGetAsDict(self): self.assertRaises(KeyError, get_as_dict, table, keyname='rgb', what='name') r = get_as_dict(table) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[0], row[1:]) for row in colors) + self.assertIsInstance(r, dict) + expected: Any = {row[0]: row[1:] for row in colors} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, int) @@ -2944,29 +2972,28 @@ def testGetAsDict(self): self.assertEqual(row.rgb, t[0]) self.assertEqual(row.name, t[1]) self.assertEqual(row._asdict(), dict(rgb=t[0], name=t[1])) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) + self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname='rgb') - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[1], (row[0], row[2])) - for row in sorted(colors, key=itemgetter(1))) + self.assertIsInstance(r, dict) + expected = {row[1]: (row[0], row[2]) + for row in sorted(colors, key=itemgetter(1))} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, str) self.assertIn(key, expected) row = r[key] self.assertIsInstance(row, tuple) + # noinspection PyTypeChecker t = expected[key] self.assertEqual(row, t) if named: self.assertEqual(row.id, t[0]) self.assertEqual(row.name, t[1]) self.assertEqual(row._asdict(), dict(id=t[0], name=t[1])) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) + self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname=['id', 'rgb']) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[:2], row[2:]) for row in colors) + self.assertIsInstance(r, dict) + expected = {row[:2]: row[2:] for row in colors} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, tuple) @@ -2978,42 +3005,43 @@ def testGetAsDict(self): row = r[key] self.assertIsInstance(row, tuple) self.assertIsInstance(row[0], str) + # noinspection PyTypeChecker t = expected[key] self.assertEqual(row, t) if named: self.assertEqual(row.name, t[0]) self.assertEqual(row._asdict(), dict(name=t[0])) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) + self.assertEqual(r.keys(), expected.keys()) r = get_as_dict(table, keyname=['id', 'rgb'], scalar=True) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[:2], row[2]) for row in colors) + self.assertIsInstance(r, dict) + expected = {row[:2]: row[2] for row in colors} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, tuple) row = r[key] self.assertIsInstance(row, str) + # noinspection PyTypeChecker t = expected[key] self.assertEqual(row, t) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) - r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], scalar=True) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[1], row[2]) - for row in sorted(colors, key=itemgetter(1))) + self.assertEqual(r.keys(), expected.keys()) + r = get_as_dict(table, keyname='rgb', what=['rgb', 'name'], + scalar=True) + self.assertIsInstance(r, dict) + expected = {row[1]: row[2] + for row in sorted(colors, key=itemgetter(1))} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, str) row = r[key] self.assertIsInstance(row, str) + # noinspection PyTypeChecker t = expected[key] self.assertEqual(row, t) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) - r = get_as_dict(table, what='id, name', - where="rgb like '#b%'", scalar=True) - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[0], row[2]) for row in colors[1:3]) + self.assertEqual(r.keys(), expected.keys()) + r = get_as_dict( + table, what='id, name', where="rgb like '#b%'", scalar=True) + self.assertIsInstance(r, dict) + expected = {row[0]: row[2] for row in colors[1:3]} self.assertEqual(r, expected) for key in r: self.assertIsInstance(key, int) @@ -3021,16 +3049,17 @@ def testGetAsDict(self): self.assertIsInstance(row, str) t = expected[key] self.assertEqual(row, t) - if OrderedDict is not dict: # Python > 2.6 - self.assertEqual(r.keys(), expected.keys()) + self.assertEqual(r.keys(), expected.keys()) expected = r - r = get_as_dict(table, what=['name', 'id'], + r = get_as_dict( + table, what=['name', 'id'], where=['id > 1', 'id < 4', "rgb like '#b%'", "name not like 'A%'", "name not like '%t'"], scalar=True) self.assertEqual(r, expected) r = get_as_dict(table, what='name, id', limit=2, offset=1, scalar=True) self.assertEqual(r, expected) - r = get_as_dict(table, keyname=('id',), what=('name', 'id'), + r = get_as_dict( + table, keyname=('id',), what=('name', 'id'), where=('id > 1', 'id < 4'), order=('id',), scalar=True) self.assertEqual(r, expected) r = get_as_dict(table, limit=1) @@ -3040,44 +3069,43 @@ def testGetAsDict(self): self.assertEqual(len(r), 1) self.assertEqual(r[4][1], 'Desert') r = get_as_dict(table, order='id desc') - expected = OrderedDict((row[0], row[1:]) for row in reversed(colors)) + expected = {row[0]: row[1:] for row in reversed(colors)} self.assertEqual(r, expected) r = get_as_dict(table, where='id > 5') - self.assertIsInstance(r, OrderedDict) + self.assertIsInstance(r, dict) self.assertEqual(len(r), 0) # test with unordered query - expected = dict((row[0], row[1:]) for row in colors) + expected = {row[0]: row[1:] for row in colors} r = get_as_dict(table, order=False) self.assertIsInstance(r, dict) self.assertEqual(r, expected) - if dict is not OrderedDict: # Python > 2.6 - self.assertNotIsInstance(self, OrderedDict) + self.assertNotIsInstance(self, dict) # test with arbitrary from clause - from_table = '(select id, lower(name) as n2 from "%s") as t2' % table + from_table = f'(select id, lower(name) as n2 from "{table}") as t2' # primary key must be passed explicitly in this case self.assertRaises(pg.ProgrammingError, get_as_dict, from_table) r = get_as_dict(from_table, 'id') - self.assertIsInstance(r, OrderedDict) - expected = OrderedDict((row[0], (row[2].lower(),)) for row in colors) + self.assertIsInstance(r, dict) + expected = {row[0]: (row[2].lower(),) for row in colors} self.assertEqual(r, expected) # test without a primary key - query('alter table "%s" drop constraint "%s_pkey"' % (table, table)) + query(f'alter table "{table}" drop constraint "{table}_pkey"') self.assertRaises(KeyError, self.db.pkey, table, flush=True) self.assertRaises(pg.ProgrammingError, get_as_dict, table) r = get_as_dict(table, keyname='id') - expected = OrderedDict((row[0], row[1:]) for row in colors) + expected = {row[0]: row[1:] for row in colors} self.assertIsInstance(r, dict) self.assertEqual(r, expected) r = (1, '#007fff', 'Azure') - query('insert into "%s" values ($1, $2, $3)' % table, r) + query(f'insert into "{table}" values ($1, $2, $3)', r) # the last entry will win expected[1] = r[1:] r = get_as_dict(table, keyname='id') self.assertEqual(r, expected) - def testTransaction(self): + def test_transaction(self): query = self.db.query - self.createTable('test_table', 'n integer', temporary=False) + self.create_table('test_table', 'n integer', temporary=False) self.db.begin() query("insert into test_table values (1)") query("insert into test_table values (2)") @@ -3114,14 +3142,14 @@ def testTransaction(self): query, "insert into test_table values (0)") self.db.abort() - def testTransactionAliases(self): + def test_transaction_aliases(self): self.assertEqual(self.db.begin, self.db.start) self.assertEqual(self.db.commit, self.db.end) self.assertEqual(self.db.rollback, self.db.abort) - def testContextManager(self): + def test_context_manager(self): query = self.db.query - self.createTable('test_table', 'n integer check(n>0)') + self.create_table('test_table', 'n integer check(n>0)') with self.db: query("insert into test_table values (1)") query("insert into test_table values (2)") @@ -3139,16 +3167,16 @@ def testContextManager(self): query("insert into test_table values (6)") query("insert into test_table values (-1)") except pg.IntegrityError as error: - self.assertTrue('check' in str(error)) + self.assertIn('check', str(error)) with self.db: query("insert into test_table values (7)") r = [r[0] for r in query( "select * from test_table order by 1").getresult()] self.assertEqual(r, [1, 2, 5, 7]) - def testBytea(self): + def test_bytea(self): query = self.db.query - self.createTable('bytea_test', 'n smallint primary key, data bytea') + self.create_table('bytea_test', 'n smallint primary key, data bytea') s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n" r = self.db.escape_bytea(s) query('insert into bytea_test values(3, $1)', (r,)) @@ -3164,10 +3192,10 @@ def testBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, s) - def testInsertUpdateGetBytea(self): + def test_insert_update_get_bytea(self): query = self.db.query unescape = pg.unescape_bytea if pg.get_bytea_escaped() else None - self.createTable('bytea_test', 'n smallint primary key, data bytea') + self.create_table('bytea_test', 'n smallint primary key, data bytea') # insert null value r = self.db.insert('bytea_test', n=0, data=None) self.assertIsInstance(r, dict) @@ -3238,40 +3266,33 @@ def testInsertUpdateGetBytea(self): self.assertIsInstance(r, bytes) self.assertEqual(r, s) - def testUpsertBytea(self): - self.createTable('bytea_test', 'n smallint primary key, data bytea') + def test_upsert_bytea(self): + self.create_table('bytea_test', 'n smallint primary key, data bytea') s = b"It's all \\ kinds \x00 of\r nasty \xff stuff!\n" - r = dict(n=7, data=s) - try: - r = self.db.upsert('bytea_test', r) - except pg.ProgrammingError as error: - if self.db.server_version < 90500: - self.skipTest('database does not support upsert') - self.fail(str(error)) - self.assertIsInstance(r, dict) - self.assertIn('n', r) - self.assertEqual(r['n'], 7) - self.assertIn('data', r) + d = dict(n=7, data=s) + d = self.db.upsert('bytea_test', d) + self.assertIsInstance(d, dict) + self.assertIn('n', d) + self.assertEqual(d['n'], 7) + self.assertIn('data', d) + data = d['data'] if pg.get_bytea_escaped(): - self.assertNotEqual(r['data'], s) - r['data'] = pg.unescape_bytea(r['data']) - self.assertIsInstance(r['data'], bytes) - self.assertEqual(r['data'], s) - r['data'] = None - r = self.db.upsert('bytea_test', r) - self.assertIsInstance(r, dict) - self.assertIn('n', r) - self.assertEqual(r['n'], 7) - self.assertIn('data', r) - self.assertIsNone(r['data']) - - def testInsertGetJson(self): - try: - self.createTable('json_test', 'n smallint primary key, data json') - except pg.ProgrammingError as error: - if self.db.server_version < 90200: - self.skipTest('database does not support json') - self.fail(str(error)) + self.assertNotEqual(data, s) + self.assertIsInstance(data, str) + assert isinstance(data, str) # type guard + data = pg.unescape_bytea(data) + self.assertIsInstance(data, bytes) + self.assertEqual(data, s) + d['data'] = None + d = self.db.upsert('bytea_test', d) + self.assertIsInstance(d, dict) + self.assertIn('n', d) + self.assertEqual(d['n'], 7) + self.assertIn('data', d) + self.assertIsNone(d['data']) + + def test_insert_get_json(self): + self.create_table('json_test', 'n smallint primary key, data json') jsondecode = pg.get_jsondecode() # insert null value r = self.db.insert('json_test', n=0, data=None) @@ -3304,7 +3325,7 @@ def testInsertGetJson(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -3321,7 +3342,7 @@ def testInsertGetJson(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -3334,14 +3355,9 @@ def testInsertGetJson(self): self.assertIsInstance(r[0][0], str if jsondecode is None else dict) self.assertEqual(r[0][0], r[1][0]) - def testInsertGetJsonb(self): - try: - self.createTable('jsonb_test', - 'n smallint primary key, data jsonb') - except pg.ProgrammingError as error: - if self.db.server_version < 90400: - self.skipTest('database does not support jsonb') - self.fail(str(error)) + def test_insert_get_jsonb(self): + self.create_table('jsonb_test', + 'n smallint primary key, data jsonb') jsondecode = pg.get_jsondecode() # insert null value r = self.db.insert('jsonb_test', n=0, data=None) @@ -3374,7 +3390,7 @@ def testInsertGetJsonb(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) @@ -3391,15 +3407,16 @@ def testInsertGetJsonb(self): self.assertIsInstance(r, dict) self.assertEqual(r, data) self.assertIsInstance(r['id'], int) - self.assertIsInstance(r['name'], unicode) + self.assertIsInstance(r['name'], str) self.assertIsInstance(r['price'], float) self.assertIsInstance(r['new'], bool) self.assertIsInstance(r['tags'], list) self.assertIsInstance(r['stock'], dict) - def testArray(self): + def test_array(self): returns_arrays = pg.get_array() - self.createTable('arraytest', + self.create_table( + 'arraytest', 'id smallint, i2 smallint[], i4 integer[], i8 bigint[],' ' d numeric[], f4 real[], f8 double precision[], m money[],' ' b bool[], v4 varchar(4)[], c4 char(4)[], t text[]') @@ -3423,10 +3440,10 @@ def testArray(self): long_decimal = decimal('12345671234.5') odd_money = decimal('1234567123.25') t, f = (True, False) if pg.get_bool() else ('t', 'f') - data = dict(id=42, i2=[42, 1234, None, 0, -1], + data = dict( + id=42, i2=[42, 1234, None, 0, -1], i4=[42, 123456789, None, 0, 1, -1], - i8=[long(42), long(123456789123456789), None, - long(0), long(1), long(-1)], + i8=[42, 123456789123456789, None, 0, 1, -1], d=[decimal(42), long_decimal, None, decimal(0), decimal(1), decimal(-1), -long_decimal], f4=[42.0, 1234.5, None, 0.0, 1.0, -1.0, @@ -3456,10 +3473,10 @@ def testArray(self): else: self.assertEqual(r['i4'], '{42,123456789,NULL,0,1,-1}') - def testArrayLiteral(self): + def test_array_literal(self): insert = self.db.insert returns_arrays = pg.get_array() - self.createTable('arraytest', 'i int[], t text[]') + self.create_table('arraytest', 'i int[], t text[]') r = dict(i=[1, 2, 3], t=['a', 'b', 'c']) insert('arraytest', r) if returns_arrays: @@ -3476,8 +3493,8 @@ def testArrayLiteral(self): else: self.assertEqual(r['i'], '{1,2,3}') self.assertEqual(r['t'], '{a,b,c}') - L = pg.Literal - r = dict(i=L("ARRAY[1, 2, 3]"), t=L("ARRAY['a', 'b', 'c']")) + Lit = pg.Literal # noqa: N806 + r = dict(i=Lit("ARRAY[1, 2, 3]"), t=Lit("ARRAY['a', 'b', 'c']")) self.db.insert('arraytest', r) if returns_arrays: self.assertEqual(r['i'], [1, 2, 3]) @@ -3488,14 +3505,14 @@ def testArrayLiteral(self): r = dict(i="1, 2, 3", t="'a', 'b', 'c'") self.assertRaises(pg.DataError, self.db.insert, 'arraytest', r) - def testArrayOfIds(self): + def test_array_of_ids(self): array_on = pg.get_array() - self.createTable( + self.create_table( 'arraytest', 'i serial primary key, c cid[], o oid[], x xid[]') r = self.db.get_attnames('arraytest') if self.regtypes: self.assertEqual(r, dict( - i='integer', c='cid[]', o='oid[]', x='xid[]')) + i='integer', c='cid[]', o='oid[]', x='xid[]')) else: self.assertEqual(r, dict( i='int', c='int[]', o='int[]', x='int[]')) @@ -3512,9 +3529,9 @@ def testArrayOfIds(self): else: self.assertEqual(r['o'], '{21,22,23}') - def testArrayOfText(self): + def test_array_of_text(self): array_on = pg.get_array() - self.createTable('arraytest', 'id serial primary key, data text[]') + self.create_table('arraytest', 'id serial primary key, data text[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'text[]') data = ['Hello, World!', '', None, '{a,b,c}', '"Hi!"', @@ -3535,10 +3552,11 @@ def testArrayOfText(self): self.assertIsInstance(r['data'][1], str) self.assertIsNone(r['data'][2]) - def testArrayOfBytea(self): + # noinspection PyUnresolvedReferences + def test_array_of_bytea(self): array_on = pg.get_array() bytea_escaped = pg.get_bytea_escaped() - self.createTable('arraytest', 'id serial primary key, data bytea[]') + self.create_table('arraytest', 'id serial primary key, data bytea[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'bytea[]') data = [b'Hello, World!', b'', None, b'{a,b,c}', b'"Hi!"', @@ -3564,14 +3582,8 @@ def testArrayOfBytea(self): else: self.assertNotEqual(r['data'], data) - def testArrayOfJson(self): - try: - self.createTable( - 'arraytest', 'id serial primary key, data json[]') - except pg.ProgrammingError as error: - if self.db.server_version < 90200: - self.skipTest('database does not support json') - self.fail(str(error)) + def test_array_of_json(self): + self.create_table('arraytest', 'id serial primary key, data json[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'json[]') data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')] @@ -3612,14 +3624,8 @@ def testArrayOfJson(self): else: self.assertEqual(r, '{NULL,NULL}') - def testArrayOfJsonb(self): - try: - self.createTable( - 'arraytest', 'id serial primary key, data jsonb[]') - except pg.ProgrammingError as error: - if self.db.server_version < 90400: - self.skipTest('database does not support jsonb') - self.fail(str(error)) + def test_array_of_jsonb(self): + self.create_table('arraytest', 'id serial primary key, data jsonb[]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'jsonb[]' if self.regtypes else 'json[]') data = [dict(id=815, name='John Doe'), dict(id=816, name='Jane Roe')] @@ -3660,9 +3666,10 @@ def testArrayOfJsonb(self): else: self.assertEqual(r, '{NULL,NULL}') - def testDeepArray(self): + # noinspection PyUnresolvedReferences + def test_deep_array(self): array_on = pg.get_array() - self.createTable( + self.create_table( 'arraytest', 'id serial primary key, data text[][][]') r = self.db.get_attnames('arraytest') self.assertEqual(r['data'], 'text[]') @@ -3680,13 +3687,14 @@ def testDeepArray(self): else: self.assertTrue(r['data'].startswith('{{{"Hello,')) - def testInsertUpdateGetRecord(self): + # noinspection PyUnresolvedReferences + def test_insert_update_get_record(self): query = self.db.query query('create type test_person_type as' ' (name varchar, age smallint, married bool,' ' weight real, salary money)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', + self.create_table('test_person', 'id serial primary key, person test_person_type', oids=False, temporary=False) attnames = self.db.get_attnames('test_person') @@ -3699,22 +3707,25 @@ def testInsertUpdateGetRecord(self): else: self.assertEqual(person_typ, 'record') if self.regtypes: - self.assertEqual(person_typ.attnames, - dict(name='character varying', age='smallint', - married='boolean', weight='real', salary='money')) + self.assertEqual(person_typ.attnames, dict( + name='character varying', age='smallint', + married='boolean', weight='real', salary='money')) else: - self.assertEqual(person_typ.attnames, - dict(name='text', age='int', married='bool', - weight='float', salary='money')) + self.assertEqual(person_typ.attnames, dict( + name='text', age='int', married='bool', + weight='float', salary='money')) decimal = pg.get_decimal() + bool_class: type + t: bool | str + f: bool | str if pg.get_bool(): bool_class = bool t, f = True, False else: bool_class = str t, f = 't', 'f' - person = ('John Doe', 61, t, 99.5, decimal('93456.75')) - r = self.db.insert('test_person', None, person=person) + person: tuple = ('John Doe', 61, t, 99.5, decimal('93456.75')) + r: Any = self.db.insert('test_person', None, person=person) self.assertEqual(r['id'], 1) p = r['person'] self.assertIsInstance(p, tuple) @@ -3778,12 +3789,13 @@ def testInsertUpdateGetRecord(self): self.assertEqual(r['id'], 3) self.assertIsNone(r['person']) - def testRecordInsertBytea(self): + # noinspection PyUnresolvedReferences + def test_record_insert_bytea(self): query = self.db.query query('create type test_person_type as' ' (name text, picture bytea)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', 'person test_person_type', + self.create_table('test_person', 'person test_person_type', temporary=False) person_typ = self.db.get_attnames('test_person')['person'] self.assertEqual(person_typ.attnames, @@ -3798,17 +3810,11 @@ def testRecordInsertBytea(self): self.assertEqual(p.picture, person[1]) self.assertIsInstance(p.picture, bytes) - def testRecordInsertJson(self): + def test_record_insert_json(self): query = self.db.query - try: - query('create type test_person_type as' - ' (name text, data json)') - except pg.ProgrammingError as error: - if self.db.server_version < 90200: - self.skipTest('database does not support json') - self.fail(str(error)) + query('create type test_person_type as (name text, data json)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', 'person test_person_type', + self.create_table('test_person', 'person test_person_type', temporary=False) person_typ = self.db.get_attnames('test_person')['person'] self.assertEqual(person_typ.attnames, @@ -3818,6 +3824,7 @@ def testRecordInsertJson(self): p = r['person'] self.assertIsInstance(p, tuple) if pg.get_jsondecode() is None: + # noinspection PyUnresolvedReferences p = p._replace(data=json.loads(p.data)) self.assertEqual(p, person) self.assertEqual(p.name, 'John Doe') @@ -3825,12 +3832,13 @@ def testRecordInsertJson(self): self.assertEqual(p.data, person[1]) self.assertIsInstance(p.data, dict) - def testRecordLiteral(self): + # noinspection PyUnresolvedReferences + def test_record_literal(self): query = self.db.query query('create type test_person_type as' ' (name varchar, age smallint)') self.addCleanup(query, 'drop type test_person_type') - self.createTable('test_person', 'person test_person_type', + self.create_table('test_person', 'person test_person_type', temporary=False) person_typ = self.db.get_attnames('test_person')['person'] if self.regtypes: @@ -3852,9 +3860,10 @@ def testRecordLiteral(self): self.assertEqual(p.age, 61) self.assertIsInstance(p.age, int) - def testDate(self): + def test_date(self): query = self.db.query - for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', + for datestyle in ( + 'ISO', 'Postgres, MDY', 'Postgres, DMY', 'SQL, MDY', 'SQL, DMY', 'German'): self.db.set_parameter('datestyle', datestyle) d = date(2016, 3, 14) @@ -3875,7 +3884,7 @@ def testDate(self): self.assertEqual(r[0], date.max) self.assertEqual(r[1], date.min) - def testTime(self): + def test_time(self): query = self.db.query d = time(15, 9, 26) q = "select $1::time" @@ -3888,15 +3897,12 @@ def testTime(self): self.assertIsInstance(r, time) self.assertEqual(r, d) - def testTimetz(self): + def test_timetz(self): query = self.db.query timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): - tz = '%+03d00' % timezones[timezone] - try: - tzinfo = datetime.strptime(tz, '%z').tzinfo - except ValueError: # Python < 3.2 - tzinfo = pg._get_timezone(tz) + tz = f'{timezones[timezone]:+03d}00' + tzinfo = datetime.strptime(tz, '%z').tzinfo self.db.set_parameter('timezone', timezone) d = time(15, 9, 26, tzinfo=tzinfo) q = "select $1::timetz" @@ -3909,10 +3915,10 @@ def testTimetz(self): self.assertIsInstance(r, time) self.assertEqual(r, d) - def testTimestamp(self): + def test_timestamp(self): query = self.db.query for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', - 'SQL, MDY', 'SQL, DMY', 'German'): + 'SQL, MDY', 'SQL, DMY', 'German'): self.db.set_parameter('datestyle', datestyle) d = datetime(2016, 3, 14) q = "select $1::timestamp" @@ -3930,7 +3936,7 @@ def testTimestamp(self): self.assertIsInstance(r, datetime) self.assertEqual(r, d) q = ("select '10000-08-01 AD'::timestamp," - " '0099-01-08 BC'::timestamp") + " '0099-01-08 BC'::timestamp") r = query(q).getresult()[0] self.assertIsInstance(r[0], datetime) self.assertIsInstance(r[1], datetime) @@ -3943,18 +3949,15 @@ def testTimestamp(self): self.assertEqual(r[0], datetime.max) self.assertEqual(r[1], datetime.min) - def testTimestamptz(self): + def test_timestamptz(self): query = self.db.query timezones = dict(CET=1, EET=2, EST=-5, UTC=0) for timezone in sorted(timezones): - tz = '%+03d00' % timezones[timezone] - try: - tzinfo = datetime.strptime(tz, '%z').tzinfo - except ValueError: # Python < 3.2 - tzinfo = pg._get_timezone(tz) + tz = f'{timezones[timezone]:+03d}00' + tzinfo = datetime.strptime(tz, '%z').tzinfo self.db.set_parameter('timezone', timezone) for datestyle in ('ISO', 'Postgres, MDY', 'Postgres, DMY', - 'SQL, MDY', 'SQL, DMY', 'German'): + 'SQL, MDY', 'SQL, DMY', 'German'): self.db.set_parameter('datestyle', datestyle) d = datetime(2016, 3, 14, tzinfo=tzinfo) q = "select $1::timestamptz" @@ -3972,7 +3975,7 @@ def testTimestamptz(self): self.assertIsInstance(r, datetime) self.assertEqual(r, d) q = ("select '10000-08-01 AD'::timestamptz," - " '0099-01-08 BC'::timestamptz") + " '0099-01-08 BC'::timestamptz") r = query(q).getresult()[0] self.assertIsInstance(r[0], datetime) self.assertIsInstance(r[1], datetime) @@ -3985,7 +3988,7 @@ def testTimestamptz(self): self.assertEqual(r[0], datetime.max) self.assertEqual(r[1], datetime.min) - def testInterval(self): + def test_interval(self): query = self.db.query for intervalstyle in ( 'sql_standard', 'postgres', 'postgres_verbose', 'iso_8601'): @@ -4005,7 +4008,7 @@ def testInterval(self): self.assertIsInstance(r, timedelta) self.assertEqual(r, d) - def testDateAndTimeArrays(self): + def test_date_and_time_arrays(self): dt = (date(2016, 3, 14), time(15, 9, 26)) q = "select ARRAY[$1::date], ARRAY[$2::time]" r = self.db.query(q, dt).getresult()[0] @@ -4014,7 +4017,7 @@ def testDateAndTimeArrays(self): self.assertIsInstance(r[1], list) self.assertEqual(r[1][0], dt[1]) - def testHstore(self): + def test_hstore(self): try: self.db.query("select 'k=>v'::hstore") except pg.DatabaseError: @@ -4023,22 +4026,22 @@ def testHstore(self): except pg.DatabaseError: self.skipTest("hstore extension not enabled") d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever', - '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3', - '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"', - 'None': None, 'NULL': 'NULL', 'empty': ''} + '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3', + '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"', + 'None': None, 'NULL': 'NULL', 'empty': ''} q = "select $1::hstore" r = self.db.query(q, (pg.Hstore(d),)).getresult()[0][0] self.assertIsInstance(r, dict) self.assertEqual(r, d) - def testUuid(self): + def test_uuid(self): d = UUID('{12345678-1234-5678-1234-567812345678}') q = 'select $1::uuid' r = self.db.query(q, (d,)).getresult()[0][0] self.assertIsInstance(r, UUID) self.assertEqual(r, d) - def testDbTypesInfo(self): + def test_db_types_info(self): dbtypes = self.db.dbtypes self.assertIsInstance(dbtypes, dict) self.assertNotIn('numeric', dbtypes) @@ -4049,6 +4052,7 @@ def testDbTypesInfo(self): self.assertEqual(typ.pgtype, 'numeric') self.assertEqual(typ.regtype, 'numeric') self.assertEqual(typ.simple, 'num') + self.assertEqual(typ.typlen, -1) self.assertEqual(typ.typtype, 'b') self.assertEqual(typ.category, 'N') self.assertEqual(typ.delim, ',') @@ -4062,25 +4066,30 @@ def testDbTypesInfo(self): self.assertEqual(typ.pgtype, 'pg_type') self.assertEqual(typ.regtype, 'pg_type') self.assertEqual(typ.simple, 'record') + self.assertEqual(typ.typlen, -1) self.assertEqual(typ.typtype, 'c') self.assertEqual(typ.category, 'C') self.assertEqual(typ.delim, ',') self.assertNotEqual(typ.relid, 0) attnames = typ.attnames self.assertIsInstance(attnames, dict) + # noinspection PyUnresolvedReferences self.assertIs(attnames, dbtypes.get_attnames('pg_type')) self.assertIn('typname', attnames) typname = attnames['typname'] self.assertEqual(typname, 'name' if self.regtypes else 'text') + self.assertEqual(typname.typlen, 64) # base self.assertEqual(typname.typtype, 'b') # base self.assertEqual(typname.category, 'S') # string self.assertIn('typlen', attnames) typlen = attnames['typlen'] self.assertEqual(typlen, 'smallint' if self.regtypes else 'int') + self.assertEqual(typlen.typlen, 2) # base self.assertEqual(typlen.typtype, 'b') # base self.assertEqual(typlen.category, 'N') # numeric - def testDbTypesTypecast(self): + # noinspection PyUnresolvedReferences + def test_db_types_typecast(self): dbtypes = self.db.dbtypes self.assertIsInstance(dbtypes, dict) self.assertNotIn('int4', dbtypes) @@ -4095,18 +4104,19 @@ def testDbTypesTypecast(self): self.assertIs(dbtypes.get_typecast('int4'), int) self.assertNotIn('circle', dbtypes) self.assertIsNone(dbtypes.get_typecast('circle')) - squared_circle = lambda v: 'Squared Circle: %s' % v + squared_circle = lambda v: f'Squared Circle: {v}' # noqa: E731 dbtypes.set_typecast('circle', squared_circle) self.assertIs(dbtypes.get_typecast('circle'), squared_circle) r = self.db.query("select '0,0,1'::circle").getresult()[0][0] self.assertIn('circle', dbtypes) self.assertEqual(r, 'Squared Circle: <(0,0),1>') - self.assertEqual(dbtypes.typecast('Impossible', 'circle'), + self.assertEqual( + dbtypes.typecast('Impossible', 'circle'), 'Squared Circle: Impossible') dbtypes.reset_typecast('circle') self.assertIsNone(dbtypes.get_typecast('circle')) - def testGetSetTypeCast(self): + def test_get_set_type_cast(self): get_typecast = pg.get_typecast set_typecast = pg.set_typecast dbtypes = self.db.dbtypes @@ -4116,10 +4126,11 @@ def testGetSetTypeCast(self): self.assertNotIn('bool', dbtypes) self.assertIs(get_typecast('int4'), int) self.assertIs(get_typecast('float4'), float) - self.assertIs(get_typecast('bool'), pg.cast_bool) + from pg.cast import cast_bool + self.assertIs(get_typecast('bool'), cast_bool) cast_circle = get_typecast('circle') self.addCleanup(set_typecast, 'circle', cast_circle) - squared_circle = lambda v: 'Squared Circle: %s' % v + squared_circle = lambda v: f'Squared Circle: {v}' # noqa: E731 self.assertNotIn('circle', dbtypes) set_typecast('circle', squared_circle) self.assertNotIn('circle', dbtypes) @@ -4130,10 +4141,10 @@ def testGetSetTypeCast(self): set_typecast('circle', cast_circle) self.assertIs(get_typecast('circle'), cast_circle) - def testNotificationHandler(self): + def test_notification_handler(self): # the notification handler itself is tested separately f = self.db.notification_handler - callback = lambda arg_dict: None + callback = lambda arg_dict: None # noqa: E731 handler = f('test', callback) self.assertIsInstance(handler, pg.NotificationHandler) self.assertIs(handler.db, self.db) @@ -4207,13 +4218,30 @@ def testNotificationHandler(self): self.db.reopen() self.assertIsNone(handler.db) + def test_inserttable_from_query(self): + # use inserttable() to copy from one table to another + query = self.db.query + self.create_table('test_table_from', 'n integer, t timestamp') + self.create_table('test_table_to', 'n integer, t timestamp') + for i in range(1, 4): + query("insert into test_table_from values ($1, now())", i) + n = self.db.inserttable( + 'test_table_to', query("select n, t::text from test_table_from")) + data_from = query("select * from test_table_from").getresult() + data_to = query("select * from test_table_to").getresult() + self.assertEqual(n, 3) + self.assertEqual([row[0] for row in data_from], [1, 2, 3]) + self.assertEqual(data_from, data_to) + class TestDBClassNonStdOpts(TestDBClass): """Test the methods of the DB class with non-standard global options.""" + saved_options: ClassVar[dict[str, Any]] = {} + @classmethod def setUpClass(cls): - cls.saved_options = {} + cls.saved_options.clear() cls.set_option('decimal', float) not_bool = not pg.get_bool() cls.set_option('bool', not_bool) @@ -4225,11 +4253,11 @@ def setUpClass(cls): db = DB() cls.regtypes = not db.use_regtypes() db.close() - super(TestDBClassNonStdOpts, cls).setUpClass() + super().setUpClass() @classmethod def tearDownClass(cls): - super(TestDBClassNonStdOpts, cls).tearDownClass() + super().tearDownClass() cls.reset_option('jsondecode') cls.reset_option('bool') cls.reset_option('array') @@ -4238,11 +4266,13 @@ def tearDownClass(cls): @classmethod def set_option(cls, option, value): + # noinspection PyUnresolvedReferences cls.saved_options[option] = getattr(pg, 'get_' + option)() return getattr(pg, 'set_' + option)(value) @classmethod def reset_option(cls, option): + # noinspection PyUnresolvedReferences return getattr(pg, 'set_' + option)(cls.saved_options[option]) @@ -4254,12 +4284,10 @@ def setUp(self): self.adapter = self.db.adapter def tearDown(self): - try: + with suppress(pg.InternalError): self.db.close() - except pg.InternalError: - pass - def testGuessSimpleType(self): + def test_guess_simple_type(self): f = self.adapter.guess_simple_type self.assertEqual(f(pg.Bytea(b'test')), 'bytea') self.assertEqual(f('string'), 'text') @@ -4277,17 +4305,16 @@ def testGuessSimpleType(self): self.assertEqual(f([[[False]]]), 'bool[]') r = f(('string', True, 3, 2.75, [1], [False])) self.assertEqual(r, 'record') - self.assertEqual(list(r.attnames.values()), - ['text', 'bool', 'int', 'float', 'int[]', 'bool[]']) + self.assertEqual(list(r.attnames.values()), [ + 'text', 'bool', 'int', 'float', 'int[]', 'bool[]']) - def testAdaptQueryTypedList(self): + def test_adapt_query_typed_list(self): format_query = self.adapter.format_query - self.assertRaises(TypeError, format_query, - '%s,%s', (1, 2), ('int2',)) - self.assertRaises(TypeError, format_query, - '%s,%s', (1,), ('int2', 'int2')) - values = (3, 7.5, 'hello', True) - types = ('int4', 'float4', 'text', 'bool') + self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), ('int2',)) + self.assertRaises( + TypeError, format_query, '%s,%s', (1,), ('int2', 'int2')) + values: list | tuple = (3, 7.5, 'hello', True) + types: list | tuple = ('int4', 'float4', 'text', 'bool') sql, params = format_query("select %s,%s,%s,%s", values, types) self.assertEqual(sql, 'select $1,$2,$3,$4') self.assertEqual(params, [3, 7.5, 'hello', 't']) @@ -4312,21 +4339,102 @@ def testAdaptQueryTypedList(self): values = [(3, 7.5, 'hello', True, [123], ['abc'])] t = self.adapter.simple_type typ = t('record') - typ._get_attnames = lambda _self: pg.AttrDict([ - ('i', t('int')), ('f', t('float')), - ('t', t('text')), ('b', t('bool')), - ('i3', t('int[]')), ('t3', t('text[]'))]) + from pg.attrs import AttrDict + typ._get_attnames = lambda _self: AttrDict( + i=t('int'), f=t('float'), + t=t('text'), b=t('bool'), + i3=t('int[]'), t3=t('text[]')) types = [typ] sql, params = format_query('select %s', values, types) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = [(0, -3.25, '', False, [0], [''])] + sql, params = format_query('select %s', values, types) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) + + def test_adapt_query_typed_list_with_types_as_string(self): + format_query = self.adapter.format_query + self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), 'int2') + self.assertRaises( + TypeError, format_query, '%s,%s', (1,), 'int2 int2') + values = (3, 7.5, 'hello', True) + types = 'int4 float4 text bool' # pass types as string + sql, params = format_query("select %s,%s,%s,%s", values, types) + self.assertEqual(sql, 'select $1,$2,$3,$4') + self.assertEqual(params, [3, 7.5, 'hello', 't']) + + def test_adapt_query_typed_list_with_types_as_classes(self): + format_query = self.adapter.format_query + self.assertRaises(TypeError, format_query, '%s,%s', (1, 2), (int,)) + self.assertRaises( + TypeError, format_query, '%s,%s', (1,), (int, int)) + values = (3, 7.5, 'hello', True) + types = (int, float, str, bool) # pass types as classes + sql, params = format_query("select %s,%s,%s,%s", values, types) + self.assertEqual(sql, 'select $1,$2,$3,$4') + self.assertEqual(params, [3, 7.5, 'hello', 't']) + + def test_adapt_query_typed_list_with_json(self): + format_query = self.adapter.format_query + value: Any = {'test': [1, "it's fine", 3]} + sql, params = format_query("select %s", (value,), 'json') + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) + value = pg.Json({'test': [1, "it's fine", 3]}) + sql, params = format_query("select %s", (value,), 'json') + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) + value = {'test': [1, "it's fine", 3]} + sql, params = format_query("select %s", [value], [pg.Json]) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) - def testAdaptQueryTypedDict(self): + def test_adapt_query_typed_list_with_empty_json(self): + format_query = self.adapter.format_query + values: Any = [None, 0, False, '', [], {}] + types = ('json',) * 6 + sql, params = format_query("select %s,%s,%s,%s,%s,%s", values, types) + self.assertEqual(sql, 'select $1,$2,$3,$4,$5,$6') + self.assertEqual(params, [None, '0', 'false', '', '[]', '{}']) + + def test_adapt_query_typed_with_hstore(self): + format_query = self.adapter.format_query + value: Any = {'one': "it's fine", 'two': 2} + sql, params = format_query("select %s", (value,), 'hstore') + self.assertEqual(sql, "select $1") + self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) + value = pg.Hstore({'one': "it's fine", 'two': 2}) + sql, params = format_query("select %s", (value,), 'hstore') + self.assertEqual(sql, "select $1") + self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) + value = pg.Hstore({'one': "it's fine", 'two': 2}) + sql, params = format_query("select %s", [value], [pg.Hstore]) + self.assertEqual(sql, "select $1") + self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) + + def test_adapt_query_typed_with_uuid(self): format_query = self.adapter.format_query - self.assertRaises(TypeError, format_query, + value: Any = '12345678-1234-5678-1234-567812345678' + sql, params = format_query("select %s", (value,), 'uuid') + self.assertEqual(sql, "select $1") + self.assertEqual(params, ['12345678-1234-5678-1234-567812345678']) + value = UUID('{12345678-1234-5678-1234-567812345678}') + sql, params = format_query("select %s", (value,), 'uuid') + self.assertEqual(sql, "select $1") + self.assertEqual(params, ['12345678-1234-5678-1234-567812345678']) + value = UUID('{12345678-1234-5678-1234-567812345678}') + sql, params = format_query("select %s", (value,)) + self.assertEqual(sql, "select $1") + self.assertEqual(params, ['12345678-1234-5678-1234-567812345678']) + + def test_adapt_query_typed_dict(self): + format_query = self.adapter.format_query + self.assertRaises( + TypeError, format_query, '%s,%s', dict(i1=1, i2=2), dict(i1='int2')) - values = dict(i=3, f=7.5, t='hello', b=True) - types = dict(i='int4', f='float4', t='text', b='bool') + values: dict = dict(i=3, f=7.5, t='hello', b=True) + types: dict = dict(i='int4', f='float4', t='text', b='bool') sql, params = format_query( "select %(i)s,%(f)s,%(t)s,%(b)s", values, types) self.assertEqual(sql, 'select $3,$2,$4,$1') @@ -4355,18 +4463,23 @@ def testAdaptQueryTypedDict(self): values = dict(record=(3, 7.5, 'hello', True, [123], ['abc'])) t = self.adapter.simple_type typ = t('record') - typ._get_attnames = lambda _self: pg.AttrDict([ - ('i', t('int')), ('f', t('float')), - ('t', t('text')), ('b', t('bool')), - ('i3', t('int[]')), ('t3', t('text[]'))]) + from pg.attrs import AttrDict + typ._get_attnames = lambda _self: AttrDict( + i=t('int'), f=t('float'), + t=t('text'), b=t('bool'), + i3=t('int[]'), t3=t('text[]')) types = dict(record=typ) sql, params = format_query('select %(record)s', values, types) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = dict(record=(0, -3.25, '', False, [0], [''])) + sql, params = format_query('select %(record)s', values, types) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) - def testAdaptQueryUntypedList(self): + def test_adapt_query_untyped_list(self): format_query = self.adapter.format_query - values = (3, 7.5, 'hello', True) + values: list | tuple = (3, 7.5, 'hello', True) sql, params = format_query("select %s,%s,%s,%s", values) self.assertEqual(sql, 'select $1,$2,$3,$4') self.assertEqual(params, [3, 7.5, 'hello', 't']) @@ -4379,7 +4492,7 @@ def testAdaptQueryUntypedList(self): self.assertEqual(sql, "$1,$2,$3") self.assertEqual(params, ['{1,2,3}', '{a,b,c}', '{t,f,t}']) values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']], - [[True, False], [False, True]]) + [[True, False], [False, True]]) sql, params = format_query("%s,%s,%s", values) self.assertEqual(sql, "$1,$2,$3") self.assertEqual(params, [ @@ -4388,10 +4501,28 @@ def testAdaptQueryUntypedList(self): sql, params = format_query('select %s', values) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = [(0, -3.25, '', False, [0], [''])] + sql, params = format_query('select %s', values) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) - def testAdaptQueryUntypedDict(self): + def test_adapt_query_untyped_list_with_json(self): format_query = self.adapter.format_query - values = dict(i=3, f=7.5, t='hello', b=True) + value = pg.Json({'test': [1, "it's fine", 3]}) + sql, params = format_query("select %s", (value,)) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['{"test": [1, "it\'s fine", 3]}']) + + def test_adapt_query_untyped_with_hstore(self): + format_query = self.adapter.format_query + value = pg.Hstore({'one': "it's fine", 'two': 2}) + sql, params = format_query("select %s", (value,)) + self.assertEqual(sql, "select $1") + self.assertEqual(params, ['one=>"it\'s fine\",two=>2']) + + def test_adapt_query_untyped_dict(self): + format_query = self.adapter.format_query + values: dict = dict(i=3, f=7.5, t='hello', b=True) sql, params = format_query( "select %(i)s,%(f)s,%(t)s,%(b)s", values) self.assertEqual(sql, 'select $3,$2,$4,$1') @@ -4404,7 +4535,8 @@ def testAdaptQueryUntypedDict(self): sql, params = format_query("%(i)s,%(t)s,%(b)s", values) self.assertEqual(sql, "$2,$3,$1") self.assertEqual(params, ['{t,f,t}', '{1,2,3}', '{a,b,c}']) - values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']], + values = dict( + i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']], b=[[True, False], [False, True]]) sql, params = format_query("%(i)s,%(t)s,%(b)s", values) self.assertEqual(sql, "$2,$3,$1") @@ -4414,10 +4546,14 @@ def testAdaptQueryUntypedDict(self): sql, params = format_query('select %(record)s', values) self.assertEqual(sql, 'select $1') self.assertEqual(params, ['(3,7.5,hello,t,{123},{abc})']) + values = dict(record=(0, -3.25, '', False, [0], [''])) + sql, params = format_query('select %(record)s', values) + self.assertEqual(sql, 'select $1') + self.assertEqual(params, ['(0,-3.25,"",f,{0},"{\\"\\"}")']) - def testAdaptQueryInlineList(self): + def test_adapt_query_inline_list(self): format_query = self.adapter.format_query - values = (3, 7.5, 'hello', True) + values: list | tuple = (3, 7.5, 'hello', True) sql, params = format_query("select %s,%s,%s,%s", values, inline=True) self.assertEqual(sql, "select 3,7.5,'hello',true") self.assertEqual(params, []) @@ -4427,24 +4563,46 @@ def testAdaptQueryInlineList(self): self.assertEqual(params, []) values = ([1, 2, 3], ['a', 'b', 'c'], [True, False, True]) sql, params = format_query("%s,%s,%s", values, inline=True) - self.assertEqual(sql, - "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]") + self.assertEqual( + sql, "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]") self.assertEqual(params, []) values = ([[1, 2], [3, 4]], [['a', 'b'], ['c', 'd']], - [[True, False], [False, True]]) + [[True, False], [False, True]]) sql, params = format_query("%s,%s,%s", values, inline=True) - self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']]," - "ARRAY[[true,false],[false,true]]") + self.assertEqual( + sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']]," + "ARRAY[[true,false],[false,true]]") self.assertEqual(params, []) values = [(3, 7.5, 'hello', True, [123], ['abc'])] sql, params = format_query('select %s', values, inline=True) - self.assertEqual(sql, - "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") + self.assertEqual( + sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") + self.assertEqual(params, []) + values = [(0, -3.25, '', False, [0], [''])] + sql, params = format_query('select %s', values, inline=True) + self.assertEqual( + sql, "select (0,-3.25,'',false,ARRAY[0],ARRAY[''])") self.assertEqual(params, []) - def testAdaptQueryInlineDict(self): + def test_adapt_query_inline_list_with_json(self): format_query = self.adapter.format_query - values = dict(i=3, f=7.5, t='hello', b=True) + value = pg.Json({'test': [1, "it's fine", 3]}) + sql, params = format_query("select %s", (value,), inline=True) + self.assertEqual( + sql, "select '{\"test\": [1, \"it''s fine\", 3]}'::json") + self.assertEqual(params, []) + + def test_adapt_query_inline_list_with_hstore(self): + format_query = self.adapter.format_query + value = pg.Hstore({'one': "it's fine", 'two': 2}) + sql, params = format_query("select %s", (value,), inline=True) + self.assertEqual( + sql, "select 'one=>\"it''s fine\",two=>2'::hstore") + self.assertEqual(params, []) + + def test_adapt_query_inline_dict(self): + format_query = self.adapter.format_query + values: dict = dict(i=3, f=7.5, t='hello', b=True) sql, params = format_query( "select %(i)s,%(f)s,%(t)s,%(b)s", values, inline=True) self.assertEqual(sql, "select 3,7.5,'hello',true") @@ -4456,28 +4614,37 @@ def testAdaptQueryInlineDict(self): self.assertEqual(params, []) values = dict(i=[1, 2, 3], t=['a', 'b', 'c'], b=[True, False, True]) sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True) - self.assertEqual(sql, - "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]") + self.assertEqual( + sql, "ARRAY[1,2,3],ARRAY['a','b','c'],ARRAY[true,false,true]") self.assertEqual(params, []) - values = dict(i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']], + values = dict( + i=[[1, 2], [3, 4]], t=[['a', 'b'], ['c', 'd']], b=[[True, False], [False, True]]) sql, params = format_query("%(i)s,%(t)s,%(b)s", values, inline=True) - self.assertEqual(sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']]," - "ARRAY[[true,false],[false,true]]") + self.assertEqual( + sql, "ARRAY[[1,2],[3,4]],ARRAY[['a','b'],['c','d']]," + "ARRAY[[true,false],[false,true]]") self.assertEqual(params, []) values = dict(record=(3, 7.5, 'hello', True, [123], ['abc'])) sql, params = format_query('select %(record)s', values, inline=True) - self.assertEqual(sql, - "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") + self.assertEqual( + sql, "select (3,7.5,'hello',true,ARRAY[123],ARRAY['abc'])") + self.assertEqual(params, []) + values = dict(record=(0, -3.25, '', False, [0], [''])) + sql, params = format_query('select %(record)s', values, inline=True) + self.assertEqual( + sql, "select (0,-3.25,'',false,ARRAY[0],ARRAY[''])") self.assertEqual(params, []) - def testAdaptQueryWithPgRepr(self): + def test_adapt_query_with_pg_repr(self): format_query = self.adapter.format_query - self.assertRaises(TypeError, format_query, - '%s', object(), inline=True) + self.assertRaises(TypeError, format_query, '%s', object(), inline=True) + class TestObject: + # noinspection PyMethodMayBeStatic def __pg_repr__(self): return "'adapted'" + sql, params = format_query('select %s', [TestObject()], inline=True) self.assertEqual(sql, "select 'adapted'") self.assertEqual(params, []) @@ -4490,6 +4657,7 @@ class TestSchemas(unittest.TestCase): """Test correct handling of schemas (namespaces).""" cls_set_up = False + with_oids = "" @classmethod def setUpClass(cls): @@ -4498,22 +4666,23 @@ def setUpClass(cls): query = db.query for num_schema in range(5): if num_schema: - schema = "s%d" % num_schema - query("drop schema if exists %s cascade" % (schema,)) + schema = f"s{num_schema}" + query(f"drop schema if exists {schema} cascade") try: - query("create schema %s" % (schema,)) - except pg.ProgrammingError: - raise RuntimeError("The test user cannot create schemas.\n" - "Grant create on database %s to the user" - " for running these tests." % dbname) + query(f"create schema {schema}") + except pg.ProgrammingError as e: + raise RuntimeError( + "The test user cannot create schemas.\n" + f"Grant create on database {dbname} to the user" + " for running these tests.") from e else: schema = "public" - query("drop table if exists %s.t" % (schema,)) - query("drop table if exists %s.t%d" % (schema, num_schema)) - query("create table %s.t %s as select 1 as n, %d as d" - % (schema, cls.with_oids, num_schema)) - query("create table %s.t%d %s as select 1 as n, %d as d" - % (schema, num_schema, cls.with_oids, num_schema)) + query(f"drop table if exists {schema}.t") + query(f"drop table if exists {schema}.t{num_schema}") + query(f"create table {schema}.t {cls.with_oids}" + f" as select 1 as n, {num_schema} as d") + query(f"create table {schema}.t{num_schema} {cls.with_oids}" + f" as select 1 as n, {num_schema} as d") db.close() cls.cls_set_up = True @@ -4523,12 +4692,12 @@ def tearDownClass(cls): query = db.query for num_schema in range(5): if num_schema: - schema = "s%d" % num_schema - query("drop schema %s cascade" % (schema,)) + schema = f"s{num_schema}" + query(f"drop schema {schema} cascade") else: schema = "public" - query("drop table %s.t" % (schema,)) - query("drop table %s.t%d" % (schema, num_schema)) + query(f"drop table {schema}.t") + query(f"drop table {schema}.t{num_schema}") db.close() def setUp(self): @@ -4539,18 +4708,15 @@ def tearDown(self): self.doCleanups() self.db.close() - def testGetTables(self): + def test_get_tables(self): tables = self.db.get_tables() for num_schema in range(5): - if num_schema: - schema = "s" + str(num_schema) - else: - schema = "public" - for t in (schema + ".t", - schema + ".t" + str(num_schema)): + schema = 's' + str(num_schema) if num_schema else 'public' + for t in (schema + '.t', + schema + '.t' + str(num_schema)): self.assertIn(t, tables) - def testGetAttnames(self): + def test_get_attnames(self): get_attnames = self.db.get_attnames query = self.db.query result = {'d': 'int', 'n': 'int'} @@ -4562,7 +4728,7 @@ def testGetAttnames(self): self.assertEqual(r, result) query("drop table if exists s3.t3m") self.addCleanup(query, "drop table s3.t3m") - query("create table s3.t3m %s as select 1 as m" % (self.with_oids,)) + query(f"create table s3.t3m {self.with_oids} as select 1 as m") result_m = {'m': 'int'} if self.with_oids: result_m['oid'] = 'int' @@ -4574,10 +4740,10 @@ def testGetAttnames(self): r = get_attnames("t3m") self.assertEqual(r, result_m) - def testGet(self): + def test_get(self): get = self.db.get query = self.db.query - PrgError = pg.ProgrammingError + PrgError = pg.ProgrammingError # noqa: N806 self.assertEqual(get("t", 1, 'n')['d'], 0) self.assertEqual(get("t0", 1, 'n')['d'], 0) self.assertEqual(get("public.t", 1, 'n')['d'], 0) @@ -4598,7 +4764,7 @@ def testGet(self): self.assertEqual(get("t", 1, 'n')['d'], 1) self.assertEqual(get("s4.t4", 1, 'n')['d'], 4) - def testMunging(self): + def test_munging(self): get = self.db.get query = self.db.query r = get("t", 1, 'n') @@ -4619,14 +4785,24 @@ def testMunging(self): else: self.assertNotIn('oid(t)', r) + def test_query_information_schema(self): + q = "column_name" + if self.db.server_version < 110000: + q += "::text" # old version does not have sql_identifier array + q = f"select array_agg({q}) from information_schema.columns" + q += " where table_schema in ('s1', 's2', 's3', 's4')" + r = self.db.query(q).onescalar() + self.assertIsInstance(r, list) + self.assertEqual(set(r), set(['d', 'n'] * 8)) + class TestDebug(unittest.TestCase): """Test the debug attribute of the DB class.""" - + def setUp(self): self.db = DB() self.query = self.db.query - self.debug = self.db.debug + self.debug = self.db.debug # type: ignore self.output = StringIO() self.stdout, sys.stdout = sys.stdout, self.output @@ -4643,29 +4819,30 @@ def send_queries(self): self.db.query("select 1") self.db.query("select 2") - def testDebugDefault(self): + def test_debug_default(self): if debug: self.assertEqual(self.db.debug, debug) else: self.assertIsNone(self.db.debug) - def testDebugIsFalse(self): + def test_debug_is_false(self): self.db.debug = False self.send_queries() self.assertEqual(self.get_output(), "") - def testDebugIsTrue(self): + def test_debug_is_true(self): self.db.debug = True self.send_queries() self.assertEqual(self.get_output(), "select 1\nselect 2\n") - def testDebugIsString(self): + def test_debug_is_string(self): self.db.debug = "Test with string: %s." self.send_queries() - self.assertEqual(self.get_output(), + self.assertEqual( + self.get_output(), "Test with string: select 1.\nTest with string: select 2.\n") - def testDebugIsFileLike(self): + def test_debug_is_file_like(self): with tempfile.TemporaryFile('w+') as debug_file: self.db.debug = debug_file self.send_queries() @@ -4674,16 +4851,16 @@ def testDebugIsFileLike(self): self.assertEqual(output, "select 1\nselect 2\n") self.assertEqual(self.get_output(), "") - def testDebugIsCallable(self): - output = [] + def test_debug_is_callable(self): + output: list[str] = [] self.db.debug = output.append self.db.query("select 1") self.db.query("select 2") self.assertEqual(output, ["select 1", "select 2"]) self.assertEqual(self.get_output(), "") - def testDebugMultipleArgs(self): - output = [] + def test_debug_multiple_args(self): + output: list[str] = [] self.db.debug = output.append args = ['Error', 42, {1: 'a', 2: 'b'}, [3, 5, 7]] self.db._do_debug(*args) @@ -4694,9 +4871,9 @@ def testDebugMultipleArgs(self): class TestMemoryLeaks(unittest.TestCase): """Test that the DB class does not leak memory.""" - def getLeaks(self, fut): - ids = set() - objs = [] + def get_leaks(self, fut: Callable): + ids: set = set() + objs: list = [] add_ids = ids.update gc.collect() objs[:] = gc.get_objects() @@ -4705,23 +4882,22 @@ def getLeaks(self, fut): gc.collect() objs[:] = gc.get_objects() objs[:] = [obj for obj in objs if id(obj) not in ids] - if objs and sys.version_info[:3] in ((3, 5, 0), (3, 5, 1)): - # workaround for Python issue 26811 - objs[:] = [obj for obj in objs if repr(obj) != '(,)'] self.assertEqual(len(objs), 0) - def testLeaksWithClose(self): + def test_leaks_with_close(self): def fut(): db = DB() db.query("select $1::int as r", 42).dictresult() db.close() - self.getLeaks(fut) - def testLeaksWithoutClose(self): + self.get_leaks(fut) + + def test_leaks_without_close(self): def fut(): db = DB() db.query("select $1::int as r", 42).dictresult() - self.getLeaks(fut) + + self.get_leaks(fut) if __name__ == '__main__': diff --git a/tests/test_classic_functions.py b/tests/test_classic_functions.py index a7311391..d1bde01c 100755 --- a/tests/test_classic_functions.py +++ b/tests/test_classic_functions.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. @@ -10,85 +9,82 @@ These tests do not need a database to test against. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +from __future__ import annotations import json import re - -import pg # the module under test - +import unittest from datetime import timedelta +from decimal import Decimal +from typing import Any, Sequence -try: # noinspection PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int - -try: # noinspection PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str +import pg # the module under test class TestHasConnect(unittest.TestCase): """Test existence of basic pg module functions.""" - def testhasPgError(self): + def test_has_pg_error(self): self.assertTrue(issubclass(pg.Error, Exception)) - def testhasPgWarning(self): + def test_has_pg_warning(self): self.assertTrue(issubclass(pg.Warning, Exception)) - def testhasPgInterfaceError(self): + def test_has_pg_interface_error(self): self.assertTrue(issubclass(pg.InterfaceError, pg.Error)) - def testhasPgDatabaseError(self): + def test_has_pg_database_error(self): self.assertTrue(issubclass(pg.DatabaseError, pg.Error)) - def testhasPgInternalError(self): + def test_has_pg_internal_error(self): self.assertTrue(issubclass(pg.InternalError, pg.DatabaseError)) - def testhasPgOperationalError(self): + def test_has_pg_operational_error(self): self.assertTrue(issubclass(pg.OperationalError, pg.DatabaseError)) - def testhasPgProgrammingError(self): + def test_has_pg_programming_error(self): self.assertTrue(issubclass(pg.ProgrammingError, pg.DatabaseError)) - def testhasPgIntegrityError(self): + def test_has_pg_integrity_error(self): self.assertTrue(issubclass(pg.IntegrityError, pg.DatabaseError)) - def testhasPgDataError(self): + def test_has_pg_data_error(self): self.assertTrue(issubclass(pg.DataError, pg.DatabaseError)) - def testhasPgNotSupportedError(self): + def test_has_pg_not_supported_error(self): self.assertTrue(issubclass(pg.NotSupportedError, pg.DatabaseError)) - def testhasPgInvalidResultError(self): + def test_has_pg_invalid_result_error(self): self.assertTrue(issubclass(pg.InvalidResultError, pg.DataError)) - def testhasPgNoResultError(self): + def test_has_pg_no_result_error(self): self.assertTrue(issubclass(pg.NoResultError, pg.InvalidResultError)) - def testhasPgMultipleResultsError(self): + def test_has_pg_multiple_results_error(self): self.assertTrue( issubclass(pg.MultipleResultsError, pg.InvalidResultError)) - def testhasConnect(self): + def test_has_connection_type(self): + self.assertIsInstance(pg.Connection, type) + self.assertEqual(pg.Connection.__name__, 'Connection') + + def test_has_query_type(self): + self.assertIsInstance(pg.Query, type) + self.assertEqual(pg.Query.__name__, 'Query') + + def test_has_connect(self): self.assertTrue(callable(pg.connect)) - def testhasEscapeString(self): + def test_has_escape_string(self): self.assertTrue(callable(pg.escape_string)) - def testhasEscapeBytea(self): + def test_has_escape_bytea(self): self.assertTrue(callable(pg.escape_bytea)) - def testhasUnescapeBytea(self): + def test_has_unescape_bytea(self): self.assertTrue(callable(pg.unescape_bytea)) - def testDefHost(self): + def test_def_host(self): d0 = pg.get_defhost() d1 = 'pgtesthost' pg.set_defhost(d1) @@ -96,7 +92,7 @@ def testDefHost(self): pg.set_defhost(d0) self.assertEqual(pg.get_defhost(), d0) - def testDefPort(self): + def test_def_port(self): d0 = pg.get_defport() d1 = 1234 pg.set_defport(d1) @@ -108,7 +104,7 @@ def testDefPort(self): d0 = None self.assertEqual(pg.get_defport(), d0) - def testDefOpt(self): + def test_def_opt(self): d0 = pg.get_defopt() d1 = '-h pgtesthost -p 1234' pg.set_defopt(d1) @@ -116,7 +112,7 @@ def testDefOpt(self): pg.set_defopt(d0) self.assertEqual(pg.get_defopt(), d0) - def testDefBase(self): + def test_def_base(self): d0 = pg.get_defbase() d1 = 'pgtestdb' pg.set_defbase(d1) @@ -124,11 +120,18 @@ def testDefBase(self): pg.set_defbase(d0) self.assertEqual(pg.get_defbase(), d0) + def test_pqlib_version(self): + # noinspection PyUnresolvedReferences + v = pg.get_pqlib_version() + self.assertIsInstance(v, int) + self.assertGreater(v, 100000) # >= 10.0 + self.assertLess(v, 200000) # < 20.0 + class TestParseArray(unittest.TestCase): """Test the array parser.""" - test_strings = [ + test_strings: Sequence[tuple[str, type | None, Any]] = [ ('', str, ValueError), ('{}', None, []), ('{}', str, []), @@ -202,11 +205,11 @@ class TestParseArray(unittest.TestCase): ('{{{17,18,19},{14,15,16},{11,12,13}},' '{{27,28,29},{24,25,26},{21,22,23}},' '{{37,38,39},{34,35,36},{31,32,33}}}', int, - [[[17, 18, 19], [14, 15, 16], [11, 12, 13]], - [[27, 28, 29], [24, 25, 26], [21, 22, 23]], - [[37, 38, 39], [34, 35, 36], [31, 32, 33]]]), + [[[17, 18, 19], [14, 15, 16], [11, 12, 13]], + [[27, 28, 29], [24, 25, 26], [21, 22, 23]], + [[37, 38, 39], [34, 35, 36], [31, 32, 33]]]), ('{{"breakfast", "consulting"}, {"meeting", "lunch"}}', str, - [['breakfast', 'consulting'], ['meeting', 'lunch']]), + [['breakfast', 'consulting'], ['meeting', 'lunch']]), ('[1:3]={1,2,3}', int, [1, 2, 3]), ('[-1:1]={1,2,3}', int, [1, 2, 3]), ('[-1:+1]={1,2,3}', int, [1, 2, 3]), @@ -218,14 +221,14 @@ class TestParseArray(unittest.TestCase): ('[1:]={1,2,3}', int, ValueError), ('[:3]={1,2,3}', int, ValueError), ('[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}', - int, [[[1, 2, 3], [4, 5, 6]]]), + int, [[[1, 2, 3], [4, 5, 6]]]), (' [1:1] [-2:-1] [3:5] = { { { 1 , 2 , 3 }, {4 , 5 , 6 } } }', - int, [[[1, 2, 3], [4, 5, 6]]]), + int, [[[1, 2, 3], [4, 5, 6]]]), ('[1:1][3:5]={{1,2,3},{4,5,6}}', int, [[1, 2, 3], [4, 5, 6]]), ('[3:5]={{1,2,3},{4,5,6}}', int, ValueError), ('[1:1][-2:-1][3:5]={{1,2,3},{4,5,6}}', int, ValueError)] - def testParserParams(self): + def test_parser_params(self): f = pg.cast_array self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, None) @@ -244,13 +247,13 @@ def testParserParams(self): self.assertEqual(f('{}', str), []) self.assertEqual(f('{}', str, b';'), []) - def testParserSimple(self): + def test_parser_simple(self): r = pg.cast_array('{a,b,c}') self.assertIsInstance(r, list) self.assertEqual(len(r), 3) self.assertEqual(r, ['a', 'b', 'c']) - def testParserNested(self): + def test_parser_nested(self): f = pg.cast_array r = f('{{a,b,c}}') self.assertIsInstance(r, list) @@ -275,28 +278,29 @@ def testParserNested(self): self.assertEqual(len(r), 1) self.assertEqual(r[0], 'b') r = f('{{{{{{{abc}}}}}}}') - for i in range(7): + for _i in range(7): self.assertIsInstance(r, list) self.assertEqual(len(r), 1) + # noinspection PyUnresolvedReferences r = r[0] self.assertEqual(r, 'abc') - def testParserTooDeeplyNested(self): + def test_parser_too_deeply_nested(self): f = pg.cast_array for n in 3, 5, 9, 12, 16, 32, 64, 256: - r = '%sa,b,c%s' % ('{' * n, '}' * n) + s = '{' * n + 'a,b,c' + '}' * n if n > 16: # hard coded maximum depth - self.assertRaises(ValueError, f, r) + self.assertRaises(ValueError, f, s) else: - r = f(r) - for i in range(n - 1): + r = f(s) + for _i in range(n - 1): self.assertIsInstance(r, list) self.assertEqual(len(r), 1) r = r[0] self.assertEqual(len(r), 3) self.assertEqual(r, ['a', 'b', 'c']) - def testParserCast(self): + def test_parser_cast(self): f = pg.cast_array self.assertEqual(f('{1}'), ['1']) self.assertEqual(f('{1}', None), ['1']) @@ -306,10 +310,12 @@ def testParserCast(self): self.assertEqual(f('{a}', None), ['a']) self.assertRaises(ValueError, f, '{a}', int) self.assertEqual(f('{a}', str), ['a']) - cast = lambda s: '%s is ok' % s + + def cast(s): + return f'{s} is ok' self.assertEqual(f('{a}', cast), ['a is ok']) - def testParserDelim(self): + def test_parser_delim(self): f = pg.cast_array self.assertEqual(f('{1,2}'), ['1', '2']) self.assertEqual(f('{1,2}', delim=b','), ['1', '2']) @@ -317,7 +323,7 @@ def testParserDelim(self): self.assertEqual(f('{1;2}', delim=b';'), ['1', '2']) self.assertEqual(f('{1,2}', delim=b';'), ['1,2']) - def testParserWithData(self): + def test_parser_with_data(self): f = pg.cast_array for string, cast, expected in self.test_strings: if expected is ValueError: @@ -325,7 +331,7 @@ def testParserWithData(self): else: self.assertEqual(f(string, cast), expected) - def testParserWithoutCast(self): + def test_parser_without_cast(self): f = pg.cast_array for string, cast, expected in self.test_strings: @@ -336,7 +342,7 @@ def testParserWithoutCast(self): else: self.assertEqual(f(string), expected) - def testParserWithDifferentDelimiter(self): + def test_parser_with_different_delimiter(self): f = pg.cast_array def replace_comma(value): @@ -359,7 +365,7 @@ def replace_comma(value): class TestParseRecord(unittest.TestCase): """Test the record parser.""" - test_strings = [ + test_strings: Sequence[tuple[str, type | tuple[type, ...] | None, Any]] = [ ('', None, ValueError), ('', str, ValueError), ('(', None, ValueError), @@ -497,7 +503,7 @@ class TestParseRecord(unittest.TestCase): ('(fuzzy dice,"42","1.9375")', (str, int, float), ('fuzzy dice', 42, 1.9375))] - def testParserParams(self): + def test_parser_params(self): f = pg.cast_record self.assertRaises(TypeError, f) self.assertRaises(TypeError, f, None) @@ -516,27 +522,28 @@ def testParserParams(self): self.assertEqual(f('()', str), (None,)) self.assertEqual(f('()', str, b';'), (None,)) - def testParserSimple(self): + def test_parser_simple(self): r = pg.cast_record('(a,b,c)') self.assertIsInstance(r, tuple) self.assertEqual(len(r), 3) self.assertEqual(r, ('a', 'b', 'c')) - def testParserNested(self): + def test_parser_nested(self): f = pg.cast_record self.assertRaises(ValueError, f, '((a,b,c))') self.assertRaises(ValueError, f, '((a,b),(c,d))') self.assertRaises(ValueError, f, '((a),(b),(c))') self.assertRaises(ValueError, f, '(((((((abc)))))))') - def testParserManyElements(self): + def test_parser_many_elements(self): f = pg.cast_record for n in 3, 5, 9, 12, 16, 32, 64, 256: - r = '(%s)' % ','.join(map(str, range(n))) - r = f(r, int) + s = ','.join(map(str, range(n))) + s = f'({s})' + r = f(s, int) self.assertEqual(r, tuple(range(n))) - def testParserCastUniform(self): + def test_parser_cast_uniform(self): f = pg.cast_record self.assertEqual(f('(1)'), ('1',)) self.assertEqual(f('(1)', None), ('1',)) @@ -546,10 +553,12 @@ def testParserCastUniform(self): self.assertEqual(f('(a)', None), ('a',)) self.assertRaises(ValueError, f, '(a)', int) self.assertEqual(f('(a)', str), ('a',)) - cast = lambda s: '%s is ok' % s + + def cast(s): + return f'{s} is ok' self.assertEqual(f('(a)', cast), ('a is ok',)) - def testParserCastNonUniform(self): + def test_parser_cast_non_uniform(self): f = pg.cast_record self.assertEqual(f('(1)', []), ('1',)) self.assertEqual(f('(1)', [None]), ('1',)) @@ -565,22 +574,28 @@ def testParserCastNonUniform(self): self.assertRaises(ValueError, f, '(1,a)', [str, int]) self.assertEqual(f('(a,1)', [str, int]), ('a', 1)) self.assertRaises(ValueError, f, '(a,1)', [int, str]) - self.assertEqual(f('(1,a,2,b,3,c)', - [int, str, int, str, int, str]), (1, 'a', 2, 'b', 3, 'c')) - self.assertEqual(f('(1,a,2,b,3,c)', - (int, str, int, str, int, str)), (1, 'a', 2, 'b', 3, 'c')) - cast1 = lambda s: '%s is ok' % s + self.assertEqual( + f('(1,a,2,b,3,c)', [int, str, int, str, int, str]), + (1, 'a', 2, 'b', 3, 'c')) + self.assertEqual( + f('(1,a,2,b,3,c)', (int, str, int, str, int, str)), + (1, 'a', 2, 'b', 3, 'c')) + + def cast1(s): + return f'{s} is ok' self.assertEqual(f('(a)', [cast1]), ('a is ok',)) - cast2 = lambda s: 'and %s is ok, too' % s - self.assertEqual(f('(a,b)', [cast1, cast2]), - ('a is ok', 'and b is ok, too')) + + def cast2(s): + return f'and {s} is ok, too' + self.assertEqual( + f('(a,b)', [cast1, cast2]), ('a is ok', 'and b is ok, too')) self.assertRaises(ValueError, f, '(a)', [cast1, cast2]) self.assertRaises(ValueError, f, '(a,b,c)', [cast1, cast2]) - self.assertEqual(f('(1,2,3,4,5,6)', - [int, float, str, None, cast1, cast2]), + self.assertEqual( + f('(1,2,3,4,5,6)', [int, float, str, None, cast1, cast2]), (1, 2.0, '3', '4', '5 is ok', 'and 6 is ok, too')) - def testParserDelim(self): + def test_parser_delim(self): f = pg.cast_record self.assertEqual(f('(1,2)'), ('1', '2')) self.assertEqual(f('(1,2)', delim=b','), ('1', '2')) @@ -588,7 +603,7 @@ def testParserDelim(self): self.assertEqual(f('(1;2)', delim=b';'), ('1', '2')) self.assertEqual(f('(1,2)', delim=b';'), ('1,2',)) - def testParserWithData(self): + def test_parser_with_data(self): f = pg.cast_record for string, cast, expected in self.test_strings: if expected is ValueError: @@ -596,7 +611,7 @@ def testParserWithData(self): else: self.assertEqual(f(string, cast), expected) - def testParserWithoutCast(self): + def test_parser_without_cast(self): f = pg.cast_record for string, cast, expected in self.test_strings: @@ -607,7 +622,7 @@ def testParserWithoutCast(self): else: self.assertEqual(f(string), expected) - def testParserWithDifferentDelimiter(self): + def test_parser_with_different_delimiter(self): f = pg.cast_record def replace_comma(value): @@ -631,7 +646,7 @@ def replace_comma(value): class TestParseHStore(unittest.TestCase): """Test the hstore parser.""" - test_strings = [ + test_strings: Sequence[tuple[str, Any]] = [ ('', {}), ('=>', ValueError), ('""=>', ValueError), @@ -653,17 +668,16 @@ class TestParseHStore(unittest.TestCase): ('"k=>v', ValueError), ('k=>"v', ValueError), ('"1-a" => "anything at all"', {'1-a': 'anything at all'}), - ('k => v, foo => bar, baz => whatever,' - ' "1-a" => "anything at all"', - {'k': 'v', 'foo': 'bar', 'baz': 'whatever', - '1-a': 'anything at all'}), + ('k => v, foo => bar, baz => whatever, "1-a" => "anything at all"', + {'k': 'v', 'foo': 'bar', 'baz': 'whatever', + '1-a': 'anything at all'}), ('"Hello, World!"=>"Hi!"', {'Hello, World!': 'Hi!'}), ('"Hi!"=>"Hello, World!"', {'Hi!': 'Hello, World!'}), (r'"k=>v"=>k\=\>v', {'k=>v': 'k=>v'}), (r'k\=\>v=>"k=>v"', {'k=>v': 'k=>v'}), ('a\\,b=>a,b=>a', {'a,b': 'a', 'b': 'a'})] - def testParser(self): + def test_parser(self): f = pg.cast_hstore self.assertRaises(TypeError, f) @@ -681,17 +695,17 @@ def testParser(self): class TestCastInterval(unittest.TestCase): """Test the interval typecast function.""" - intervals = [ + intervals: Sequence[tuple[tuple[int, ...], tuple[str, ...]]] = [ ((0, 0, 0, 1, 0, 0, 0), ('1:00:00', '01:00:00', '@ 1 hour', 'PT1H')), ((0, 0, 0, -1, 0, 0, 0), ('-1:00:00', '-01:00:00', '@ -1 hour', 'PT-1H')), ((0, 0, 0, 1, 0, 0, 0), ('0-0 0 1:00:00', '0 years 0 mons 0 days 01:00:00', - '@ 0 years 0 mons 0 days 1 hour', 'P0Y0M0DT1H')), + '@ 0 years 0 mons 0 days 1 hour', 'P0Y0M0DT1H')), ((0, 0, 0, -1, 0, 0, 0), ('-0-0 -1:00:00', '0 years 0 mons 0 days -01:00:00', - '@ 0 years 0 mons 0 days -1 hour', 'P0Y0M0DT-1H')), + '@ 0 years 0 mons 0 days -1 hour', 'P0Y0M0DT-1H')), ((0, 0, 1, 0, 0, 0, 0), ('1 0:00:00', '1 day', '@ 1 day', 'P1D')), ((0, 0, -1, 0, 0, 0, 0), @@ -840,15 +854,16 @@ class TestCastInterval(unittest.TestCase): '@ 10 mons 3 days -3 hours -55 mins -5.999993 secs ago', 'P-10M-3DT3H55M5.999993S'))] - def testCastInterval(self): + def test_cast_interval(self): + from pg.cast import cast_interval for result, values in self.intervals: - f = pg.cast_interval years, mons, days, hours, mins, secs, usecs = result days += 365 * years + 30 * mons - interval = timedelta(days=days, hours=hours, minutes=mins, + interval = timedelta( + days=days, hours=hours, minutes=mins, seconds=secs, microseconds=usecs) for value in values: - self.assertEqual(f(value), interval) + self.assertEqual(cast_interval(value), interval) class TestEscapeFunctions(unittest.TestCase): @@ -861,47 +876,47 @@ class TestEscapeFunctions(unittest.TestCase): """ - def testEscapeString(self): + def test_escape_string(self): f = pg.escape_string - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f(u'plain') - self.assertIsInstance(r, unicode) - self.assertEqual(r, u'plain') - r = f("that's cheese") - self.assertIsInstance(r, str) - self.assertEqual(r, "that''s cheese") - - def testEscapeBytea(self): + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + s = f("that's cheese") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheese") + + def test_escape_bytea(self): f = pg.escape_bytea - r = f(b'plain') - self.assertIsInstance(r, bytes) - self.assertEqual(r, b'plain') - r = f(u'plain') - self.assertIsInstance(r, unicode) - self.assertEqual(r, u'plain') - r = f("that's cheese") - self.assertIsInstance(r, str) - self.assertEqual(r, "that''s cheese") - - def testUnescapeBytea(self): + b = f(b'plain') + self.assertIsInstance(b, bytes) + self.assertEqual(b, b'plain') + s = f('plain') + self.assertIsInstance(s, str) + self.assertEqual(s, 'plain') + s = f("that's cheese") + self.assertIsInstance(s, str) + self.assertEqual(s, "that''s cheese") + + def test_unescape_bytea(self): f = pg.unescape_bytea r = f(b'plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') - r = f(u'plain') + r = f('plain') self.assertIsInstance(r, bytes) self.assertEqual(r, b'plain') r = f(b"das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf-8')) - r = f(u"das is' k\\303\\244se") + self.assertEqual(r, "das is' käse".encode()) + r = f("das is' k\\303\\244se") self.assertIsInstance(r, bytes) - self.assertEqual(r, u"das is' käse".encode('utf-8')) + self.assertEqual(r, "das is' käse".encode()) r = f(b'O\\000ps\\377!') self.assertEqual(r, b'O\x00ps\xff!') - r = f(u'O\\000ps\\377!') + r = f('O\\000ps\\377!') self.assertEqual(r, b'O\x00ps\xff!') @@ -913,10 +928,10 @@ class TestConfigFunctions(unittest.TestCase): """ - def testGetDatestyle(self): + def test_get_datestyle(self): self.assertIsNone(pg.get_datestyle()) - def testGetDatestyle(self): + def test_set_datestyle(self): datestyle = pg.get_datestyle() try: pg.set_datestyle('ISO, YMD') @@ -936,12 +951,12 @@ def testGetDatestyle(self): finally: pg.set_datestyle(datestyle) - def testGetDecimalPoint(self): + def test_get_decimal_point(self): r = pg.get_decimal_point() self.assertIsInstance(r, str) self.assertEqual(r, '.') - def testSetDecimalPoint(self): + def test_set_decimal_point(self): point = pg.get_decimal_point() try: pg.set_decimal_point('*') @@ -954,12 +969,12 @@ def testSetDecimalPoint(self): self.assertIsInstance(r, str) self.assertEqual(r, point) - def testGetDecimal(self): + def test_get_decimal(self): r = pg.get_decimal() - self.assertIs(r, pg.Decimal) + self.assertIs(r, Decimal) - def testSetDecimal(self): - decimal_class = pg.Decimal + def test_set_decimal(self): + decimal_class = Decimal try: pg.set_decimal(int) r = pg.get_decimal() @@ -969,12 +984,12 @@ def testSetDecimal(self): r = pg.get_decimal() self.assertIs(r, decimal_class) - def testGetBool(self): + def test_get_bool(self): r = pg.get_bool() self.assertIsInstance(r, bool) self.assertIs(r, True) - def testSetBool(self): + def test_set_bool(self): use_bool = pg.get_bool() try: pg.set_bool(False) @@ -992,12 +1007,12 @@ def testSetBool(self): self.assertIsInstance(r, bool) self.assertIs(r, use_bool) - def testGetByteaEscaped(self): + def test_get_bytea_escaped(self): r = pg.get_bytea_escaped() self.assertIsInstance(r, bool) self.assertIs(r, False) - def testSetByteaEscaped(self): + def test_set_bytea_escaped(self): bytea_escaped = pg.get_bytea_escaped() try: pg.set_bytea_escaped(True) @@ -1015,12 +1030,12 @@ def testSetByteaEscaped(self): self.assertIsInstance(r, bool) self.assertIs(r, bytea_escaped) - def testGetJsondecode(self): + def test_get_jsondecode(self): r = pg.get_jsondecode() self.assertTrue(callable(r)) self.assertIs(r, json.loads) - def testSetJsondecode(self): + def test_set_jsondecode(self): jsondecode = pg.get_jsondecode() try: pg.set_jsondecode(None) @@ -1039,7 +1054,7 @@ def testSetJsondecode(self): class TestModuleConstants(unittest.TestCase): """Test the existence of the documented module constants.""" - def testVersion(self): + def test_version(self): v = pg.version self.assertIsInstance(v, str) # make sure the version conforms to PEP440 diff --git a/tests/test_classic_largeobj.py b/tests/test_classic_largeobj.py index 38dc5e1f..7c53053d 100755 --- a/tests/test_classic_largeobj.py +++ b/tests/test_classic_largeobj.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the classic PyGreSQL interface. @@ -10,35 +9,24 @@ These tests need a database to test against. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest -import tempfile import os +import tempfile +import unittest +from contextlib import suppress +from typing import Any import pg # the module under test -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -dbname = 'unittest' -dbhost = None -dbport = 5432 - -try: - from .LOCAL_PyGreSQL import * -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * - except ImportError: - pass +from .config import dbhost, dbname, dbpasswd, dbport, dbuser windows = os.name == 'nt' +# noinspection PyArgumentList def connect(): """Create a basic pg connection to the test database.""" - connection = pg.connect(dbname, dbhost, dbport) + connection = pg.connect(dbname, dbhost, dbport, + user=dbuser, passwd=dbpasswd) connection.query("set client_min_messages=warning") return connection @@ -46,13 +34,13 @@ def connect(): class TestModuleConstants(unittest.TestCase): """Test the existence of the documented module constants.""" - def testLargeObjectIntConstants(self): + def test_large_object_int_constants(self): names = 'INV_READ INV_WRITE SEEK_SET SEEK_CUR SEEK_END'.split() for name in names: try: value = getattr(pg, name) except AttributeError: - self.fail('Module constant %s is missing' % name) + self.fail(f'Module constant {name} is missing') self.assertIsInstance(value, int) @@ -67,7 +55,7 @@ def tearDown(self): self.c.query('rollback') self.c.close() - def assertIsLargeObject(self, obj): + def assertIsLargeObject(self, obj): # noqa: N802 self.assertIsNotNone(obj) self.assertTrue(hasattr(obj, 'open')) self.assertTrue(hasattr(obj, 'close')) @@ -80,14 +68,14 @@ def assertIsLargeObject(self, obj): self.assertIsInstance(obj.error, str) self.assertFalse(obj.error) - def testLoCreate(self): + def test_lo_create(self): large_object = self.c.locreate(pg.INV_READ | pg.INV_WRITE) try: self.assertIsLargeObject(large_object) finally: del large_object - def testGetLo(self): + def test_get_lo(self): large_object = self.c.locreate(pg.INV_READ | pg.INV_WRITE) try: self.assertIsLargeObject(large_object) @@ -117,19 +105,20 @@ def testGetLo(self): self.assertIsInstance(r, bytes) self.assertEqual(r, data) - def testLoImport(self): + def test_lo_import(self): + f : Any if windows: # NamedTemporaryFiles don't work well here fname = 'temp_test_pg_largeobj_import.txt' - f = open(fname, 'wb') + f = open(fname, 'wb') # noqa: SIM115 else: - f = tempfile.NamedTemporaryFile() + f = tempfile.NamedTemporaryFile() # noqa: SIM115 fname = f.name data = b'some data to be imported' f.write(data) if windows: f.close() - f = open(fname, 'rb') + f = open(fname, 'rb') # noqa: SIM115 else: f.flush() f.seek(0) @@ -163,54 +152,49 @@ def setUp(self): def tearDown(self): if self.obj.oid: - try: + with suppress(SystemError, OSError): self.obj.close() - except (SystemError, IOError): - pass - try: + with suppress(SystemError, OSError): self.obj.unlink() - except (SystemError, IOError): - pass del self.obj - try: + with suppress(SystemError): self.pgcnx.query('rollback') - except SystemError: - pass self.pgcnx.close() - def testClassName(self): + def test_class_name(self): self.assertEqual(self.obj.__class__.__name__, 'LargeObject') - def testModuleName(self): + def test_module_name(self): self.assertEqual(self.obj.__class__.__module__, 'pg') - def testOid(self): + def test_oid(self): self.assertIsInstance(self.obj.oid, int) self.assertNotEqual(self.obj.oid, 0) - def testPgcn(self): + def test_pgcn(self): self.assertIs(self.obj.pgcnx, self.pgcnx) - def testError(self): + def test_error(self): self.assertIsInstance(self.obj.error, str) self.assertEqual(self.obj.error, '') - def testStr(self): + def test_str(self): self.obj.open(pg.INV_WRITE) data = b'some object to be printed' self.obj.write(data) oid = self.obj.oid r = str(self.obj) - self.assertEqual(r, 'Opened large object, oid %d' % oid) + self.assertEqual(r, f'Opened large object, oid {oid}') self.obj.close() r = str(self.obj) - self.assertEqual(r, 'Closed large object, oid %d' % oid) + self.assertEqual(r, f'Closed large object, oid {oid}') - def testRepr(self): + def test_repr(self): r = repr(self.obj) self.assertTrue(r.startswith('= len(self.sent): @@ -321,94 +274,94 @@ def receive(self, stop=False): self.sent = [] self.assertEqual(self.handler.listening, not self.stopped) - def testNotifyHandlerEmpty(self): + def test_notify_handler_empty(self): self.start_handler() self.notify_handler(stop=True) self.assertEqual(len(self.sent), 1) self.receive() - def testNotifyQueryEmpty(self): + def test_notify_query_empty(self): self.start_handler() self.notify_query(stop=True) self.assertEqual(len(self.sent), 1) self.receive() - def testNotifyHandlerOnce(self): + def test_notify_handler_once(self): self.start_handler() self.notify_handler() self.assertEqual(len(self.sent), 1) self.receive() self.receive(stop=True) - def testNotifyQueryOnce(self): + def test_notify_query_once(self): self.start_handler() self.notify_query() self.receive() self.notify_query(stop=True) self.receive() - def testNotifyWithArgs(self): + def test_notify_with_args(self): arg_dict = {'test': 42, 'more': 43, 'less': 41} self.start_handler('test_args', arg_dict) self.notify_query() self.receive(stop=True) - def testNotifySeveralTimes(self): + def test_notify_several_times(self): arg_dict = {'test': 1} self.start_handler(arg_dict=arg_dict) - for count in range(3): + for _n in range(3): self.notify_query() self.receive() arg_dict['test'] += 1 - for count in range(2): + for _n in range(2): self.notify_handler() self.receive() arg_dict['test'] += 1 - for count in range(3): + for _n in range(3): self.notify_query() self.receive(stop=True) - def testNotifyOnceWithPayload(self): + def test_notify_once_with_payload(self): self.start_handler() self.notify_query(payload='test_payload') self.receive(stop=True) - def testNotifyWithArgsAndPayload(self): + def test_notify_with_args_and_payload(self): self.start_handler(arg_dict={'foo': 'bar'}) self.notify_query(payload='baz') self.receive(stop=True) - def testNotifyQuotedNames(self): + def test_notify_quoted_names(self): self.start_handler('Hello, World!') self.notify_query(payload='How do you do?') self.receive(stop=True) - def testNotifyWithFivePayloads(self): + def test_notify_with_five_payloads(self): self.start_handler('gimme_5', {'test': 'Gimme 5'}) - for count in range(5): - self.notify_query(payload="Round %d" % count) + for n in range(5): + self.notify_query(payload=f"Round {n}") self.assertEqual(len(self.sent), 5) self.receive(stop=True) - def testReceiveImmediately(self): + def test_receive_immediately(self): self.start_handler('immediate', {'test': 'immediate'}) - for count in range(3): - self.notify_query(payload="Round %d" % count) + for n in range(3): + self.notify_query(payload=f"Round {n}") self.receive() self.receive(stop=True) - def testNotifyDistinctInTransaction(self): + def test_notify_distinct_in_transaction(self): self.start_handler('test_transaction', {'transaction': True}) self.db.begin() - for count in range(3): - self.notify_query(payload='Round %d' % count) + for n in range(3): + self.notify_query(payload=f'Round {n}') self.db.commit() self.receive(stop=True) - def testNotifySameInTransaction(self): + def test_notify_same_in_transaction(self): self.start_handler('test_transaction', {'transaction': True}) self.db.begin() - for count in range(3): + for _n in range(3): self.notify_query() self.db.commit() # these same notifications may be delivered as one, @@ -416,7 +369,8 @@ def testNotifySameInTransaction(self): self.sent = self.sent[:1] self.receive(stop=True) - def testNotifyNoTimeout(self): + def test_notify_no_timeout(self): + # noinspection PyTypeChecker self.start_handler(timeout=None) self.assertIsNone(self.handler.timeout) self.assertTrue(self.handler.listening) @@ -424,20 +378,21 @@ def testNotifyNoTimeout(self): self.assertFalse(self.timeout) self.receive(stop=True) - def testNotifyZeroTimeout(self): + def test_notify_zero_timeout(self): self.start_handler(timeout=0) self.assertEqual(self.handler.timeout, 0) self.assertTrue(self.handler.listening) self.assertFalse(self.timeout) - def testNotifyWithoutTimeout(self): + def test_notify_without_timeout(self): self.start_handler(timeout=1) self.assertEqual(self.handler.timeout, 1) sleep(0.02) self.assertFalse(self.timeout) self.receive(stop=True) - def testNotifyWithTimeout(self): + def test_notify_with_timeout(self): + # noinspection PyTypeChecker self.start_handler(timeout=0.01) sleep(0.02) self.assertTrue(self.timeout) diff --git a/tests/test_dbapi20.py b/tests/test_dbapi20.py index 81f5c73e..0e70e073 100755 --- a/tests/test_dbapi20.py +++ b/tests/test_dbapi20.py @@ -1,47 +1,18 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest - -import pgdb - -try: - from . import dbapi20 -except (ImportError, ValueError, SystemError): - import dbapi20 - -# We need a database to test against. -# If LOCAL_PyGreSQL.py exists we will get our information from that. -# Otherwise we use the defaults. -dbname = 'dbapi20_test' -dbhost = '' -dbport = 5432 -try: - from .LOCAL_PyGreSQL import * -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * - except ImportError: - pass +from __future__ import annotations import gc -import sys +import unittest +from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal +from typing import Any, ClassVar +from uuid import UUID as Uuid # noqa: N811 -from datetime import date, time, datetime, timedelta -from uuid import UUID as Uuid - -try: # noinspection PyUnresolvedReferences - long -except NameError: # Python >= 3.0 - long = int +import pgdb -try: - from collections import OrderedDict -except ImportError: # Python 2.6 or 3.0 - OrderedDict = None +from . import dbapi20 +from .config import dbhost, dbname, dbpasswd, dbport, dbuser class PgBitString: @@ -51,21 +22,21 @@ def __init__(self, value): self.value = value def __pg_repr__(self): - return "B'{0:b}'".format(self.value) + return f"B'{self.value:b}'" -class test_PyGreSQL(dbapi20.DatabaseAPI20Test): +class TestPgDb(dbapi20.DatabaseAPI20Test): driver = pgdb connect_args = () - connect_kw_args = {'database': dbname, - 'host': '%s:%d' % (dbhost or '', dbport or -1)} + connect_kw_args: ClassVar[dict[str, Any]] = { + 'database': dbname, 'host': f"{dbhost or ''}:{dbport or -1}", + 'user': dbuser, 'password': dbpasswd} lower_func = 'lower' # For stored procedure test def setUp(self): - # Call superclass setUp in case this does something in the future - dbapi20.DatabaseAPI20Test.setUp(self) + super().setUp() try: con = self._connect() con.close() @@ -73,13 +44,13 @@ def setUp(self): import pg try: # first try to log in as superuser db = pg.DB('postgres', dbhost or None, dbport or -1, - user='postgres') + user='postgres') except Exception: # then try to log in as current user db = pg.DB('postgres', dbhost or None, dbport or -1) db.query('create database ' + dbname) def tearDown(self): - dbapi20.DatabaseAPI20Test.tearDown(self) + super().tearDown() def test_version(self): v = pgdb.version @@ -93,9 +64,18 @@ def test_connect_kwargs(self): con = self._connect() cur = con.cursor() cur.execute("select application_name from pg_stat_activity" - " where application_name = %s", (application_name,)) + " where application_name = %s", (application_name,)) self.assertEqual(cur.fetchone(), (application_name,)) + def test_connect_kwargs_with_special_chars(self): + special_name = 'Single \' and double " quote and \\ backslash!' + self.connect_kw_args['application_name'] = special_name + con = self._connect() + cur = con.cursor() + cur.execute("select application_name from pg_stat_activity" + " where application_name = %s", (special_name,)) + self.assertEqual(cur.fetchone(), (special_name,)) + def test_percent_sign(self): con = self._connect() cur = con.cursor() @@ -106,6 +86,19 @@ def test_percent_sign(self): cur.execute("select 'a %% sign'") self.assertEqual(cur.fetchone(), ('a % sign',)) + def test_paramstyles(self): + self.assertEqual(pgdb.paramstyle, 'pyformat') + con = self._connect() + cur = con.cursor() + # parameters can be passed as tuple + cur.execute("select %s, %s, %s", (123, 'abc', True)) + self.assertEqual(cur.fetchone(), (123, 'abc', True)) + # parameters can be passed as dict + cur.execute("select %(one)s, %(two)s, %(one)s, %(three)s", { + "one": 123, "two": "abc", "three": True + }) + self.assertEqual(cur.fetchone(), (123, 'abc', 123, True)) + def test_callproc_no_params(self): con = self._connect() cur = con.cursor() @@ -144,7 +137,9 @@ def test_callproc_two_params(self): def test_cursor_type(self): class TestCursor(pgdb.Cursor): - pass + @staticmethod + def row_factory(row): + return row # not used con = self._connect() self.assertIs(con.cursor_type, pgdb.Cursor) @@ -166,9 +161,11 @@ def test_row_factory(self): class TestCursor(pgdb.Cursor): - def row_factory(self, row): - return dict(('column %s' % desc[0], value) - for desc, value in zip(self.description, row)) + def row_factory(self, row): # type: ignore[override] + description = self.description + assert isinstance(description, list) + return {f'column {desc[0]}': value + for desc, value in zip(description, row)} con = self._connect() con.cursor_type = TestCursor @@ -190,12 +187,15 @@ def row_factory(self, row): def test_build_row_factory(self): + # noinspection PyAbstractClass class TestCursor(pgdb.Cursor): def build_row_factory(self): - keys = [desc[0] for desc in self.description] - return lambda row: dict((key, value) - for key, value in zip(keys, row)) + description = self.description + assert isinstance(description, list) + keys = [desc[0] for desc in description] + return lambda row: { + key: value for key, value in zip(keys, row)} con = self._connect() con.cursor_type = TestCursor @@ -214,6 +214,7 @@ def build_row_factory(self): self.assertIsInstance(res[1], dict) self.assertEqual(res[1], {'a': 3, 'b': 4}) + # noinspection PyUnresolvedReferences def test_cursor_with_named_columns(self): con = self._connect() cur = con.cursor() @@ -237,6 +238,7 @@ def test_cursor_with_named_columns(self): self.assertEqual(res[1], (3, 4)) self.assertEqual(res[1]._fields, ('one', 'two')) + # noinspection PyUnresolvedReferences def test_cursor_with_unnamed_columns(self): con = self._connect() cur = con.cursor() @@ -244,21 +246,14 @@ def test_cursor_with_unnamed_columns(self): res = cur.fetchone() self.assertIsInstance(res, tuple) self.assertEqual(res, (1, 2, 3)) - old_py = OrderedDict is None # Python 2.6 or 3.0 - # old Python versions cannot rename tuple fields with underscore - if old_py: - self.assertEqual(res._fields, ('column_0', 'column_1', 'column_2')) - else: - self.assertEqual(res._fields, ('_0', '_1', '_2')) + self.assertEqual(res._fields, ('_0', '_1', '_2')) cur.execute("select 1 as one, 2, 3 as three") res = cur.fetchone() self.assertIsInstance(res, tuple) self.assertEqual(res, (1, 2, 3)) - if old_py: # cannot auto rename with underscore - self.assertEqual(res._fields, ('one', 'column_1', 'three')) - else: - self.assertEqual(res._fields, ('one', '_1', 'three')) + self.assertEqual(res._fields, ('one', '_1', 'three')) + # noinspection PyUnresolvedReferences def test_cursor_with_badly_named_columns(self): con = self._connect() cur = con.cursor() @@ -266,21 +261,15 @@ def test_cursor_with_badly_named_columns(self): res = cur.fetchone() self.assertIsInstance(res, tuple) self.assertEqual(res, (1, 2)) - old_py = OrderedDict is None # Python 2.6 or 3.0 - if old_py: - self.assertEqual(res._fields, ('abc', 'column_1')) - else: - self.assertEqual(res._fields, ('abc', '_1')) - cur.execute('select 1 as snake_case, 2 as "CamelCase",' + self.assertEqual(res._fields, ('abc', '_1')) + cur.execute( + 'select 1 as snake_case, 2 as "CamelCase",' ' 3 as "kebap-case", 4 as "_bad", 5 as "0bad", 6 as "bad$"') res = cur.fetchone() self.assertIsInstance(res, tuple) self.assertEqual(res, (1, 2, 3, 4, 5, 6)) - # old Python versions cannot rename tuple fields with underscore self.assertEqual(res._fields[:2], ('snake_case', 'CamelCase')) fields = ('_2', '_3', '_4', '_5') - if old_py: - fields = tuple('column' + field for field in fields) self.assertEqual(res._fields[2:], fields) def test_colnames(self): @@ -303,12 +292,13 @@ def test_coltypes(self): self.assertIsInstance(types, list) self.assertEqual(types, ['int2', 'int4', 'int8']) + # noinspection PyUnresolvedReferences def test_description_fields(self): con = self._connect() cur = con.cursor() cur.execute("select 123456789::int8 col0," - " 123456.789::numeric(41, 13) as col1," - " 'foobar'::char(39) as col2") + " 123456.789::numeric(41, 13) as col1," + " 'foobar'::char(39) as col2") desc = cur.description self.assertIsInstance(desc, list) self.assertEqual(len(desc), 3) @@ -318,7 +308,7 @@ def test_description_fields(self): self.assertIsInstance(d, tuple) self.assertEqual(len(d), 7) self.assertIsInstance(d.name, str) - self.assertEqual(d.name, 'col%d' % i) + self.assertEqual(d.name, f'col{i}') self.assertIsInstance(d.type_code, str) self.assertEqual(d.type_code, c[0]) self.assertIsNone(d.display_size) @@ -394,7 +384,7 @@ def test_type_cache_typecast(self): cur = con.cursor() type_cache = con.type_cache self.assertIs(type_cache.get_typecast('int4'), int) - cast_int = lambda v: 'int(%s)' % v + cast_int = lambda v: f'int({v})' # noqa: E731 type_cache.set_typecast('int4', cast_int) query = 'select 2::int2, 4::int4, 8::int8' cur.execute(query) @@ -454,18 +444,18 @@ def test_cursor_invalidation(self): self.assertRaises(pgdb.OperationalError, cur.fetchone) def test_fetch_2_rows(self): - Decimal = pgdb.decimal_type() values = ('test', pgdb.Binary(b'\xff\x52\xb2'), - True, 5, 6, 5.7, Decimal('234.234234'), Decimal('75.45'), - pgdb.Date(2011, 7, 17), pgdb.Time(15, 47, 42), - pgdb.Timestamp(2008, 10, 20, 15, 25, 35), - pgdb.Interval(15, 31, 5), 7897234) + True, 5, 6, 5.7, Decimal('234.234234'), Decimal('75.45'), + pgdb.Date(2011, 7, 17), pgdb.Time(15, 47, 42), + pgdb.Timestamp(2008, 10, 20, 15, 25, 35), + pgdb.Interval(15, 31, 5), 7897234) table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() cur.execute("set datestyle to iso") - cur.execute("create table %s (" + cur.execute( + f"create table {table} (" "stringtest varchar," "binarytest bytea," "booltest bool," @@ -478,15 +468,16 @@ def test_fetch_2_rows(self): "timetest time," "datetimetest timestamp," "intervaltest interval," - "rowidtest oid)" % table) + "rowidtest oid)") cur.execute("set standard_conforming_strings to on") for s in ('numeric', 'monetary', 'time'): - cur.execute("set lc_%s to 'C'" % s) + cur.execute(f"set lc_{s} to 'C'") for _i in range(2): - cur.execute("insert into %s values (" - "%%s,%%s,%%s,%%s,%%s,%%s,%%s," - "'%%s'::money,%%s,%%s,%%s,%%s,%%s)" % table, values) - cur.execute("select * from %s" % table) + cur.execute( + f"insert into {table} values (" + "%s,%s,%s,%s,%s,%s,%s," + "'%s'::money,%s,%s,%s,%s,%s)", values) + cur.execute(f"select * from {table}") rows = cur.fetchall() self.assertEqual(len(rows), 2) row0 = rows[0] @@ -496,7 +487,7 @@ def test_fetch_2_rows(self): self.assertIsInstance(row0[1], bytes) self.assertIsInstance(row0[2], bool) self.assertIsInstance(row0[3], int) - self.assertIsInstance(row0[4], long) + self.assertIsInstance(row0[4], int) self.assertIsInstance(row0[5], float) self.assertIsInstance(row0[6], Decimal) self.assertIsInstance(row0[7], Decimal) @@ -513,11 +504,12 @@ def test_integrity_error(self): try: cur = con.cursor() cur.execute("set client_min_messages = warning") - cur.execute("create table %s (i int primary key)" % table) - cur.execute("insert into %s values (1)" % table) - cur.execute("insert into %s values (2)" % table) - self.assertRaises(pgdb.IntegrityError, cur.execute, - "insert into %s values (1)" % table) + cur.execute(f"create table {table} (i int primary key)") + cur.execute(f"insert into {table} values (1)") + cur.execute(f"insert into {table} values (2)") + self.assertRaises( + pgdb.IntegrityError, cur.execute, + f"insert into {table} values (1)") finally: con.close() @@ -526,11 +518,11 @@ def test_update_rowcount(self): con = self._connect() try: cur = con.cursor() - cur.execute("create table %s (i int)" % table) - cur.execute("insert into %s values (1)" % table) - cur.execute("update %s set i=2 where i=2 returning i" % table) + cur.execute(f"create table {table} (i int)") + cur.execute(f"insert into {table} values (1)") + cur.execute(f"update {table} set i=2 where i=2 returning i") self.assertEqual(cur.rowcount, 0) - cur.execute("update %s set i=2 where i=1 returning i" % table) + cur.execute(f"update {table} set i=2 where i=1 returning i") self.assertEqual(cur.rowcount, 1) cur.close() # keep rowcount even if cursor is closed (needed by SQLAlchemy) @@ -544,26 +536,27 @@ def test_sqlstate(self): try: cur.execute("select 1/0") except pgdb.DatabaseError as error: - self.assertTrue(isinstance(error, pgdb.DataError)) + self.assertIsInstance(error, pgdb.DataError) # the SQLSTATE error code for division by zero is 22012 + # noinspection PyUnresolvedReferences self.assertEqual(error.sqlstate, '22012') def test_float(self): nan, inf = float('nan'), float('inf') - from math import isnan, isinf + from math import isinf, isnan self.assertTrue(isnan(nan) and not isinf(nan)) self.assertTrue(isinf(inf) and not isnan(inf)) values = [0, 1, 0.03125, -42.53125, nan, inf, -inf, - 'nan', 'inf', '-inf', 'NaN', 'Infinity', '-Infinity'] + 'nan', 'inf', '-inf', 'NaN', 'Infinity', '-Infinity'] table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() cur.execute( - "create table %s (n smallint, floattest float)" % table) + f"create table {table} (n smallint, floattest float)") params = enumerate(values) - cur.executemany("insert into %s values (%%d,%%s)" % table, params) - cur.execute("select floattest from %s order by n" % table) + cur.executemany(f"insert into {table} values (%d,%s)", params) + cur.execute(f"select floattest from {table} order by n") rows = cur.fetchall() self.assertEqual(cur.description[0].type_code, pgdb.FLOAT) self.assertNotEqual(cur.description[0].type_code, pgdb.ARRAY) @@ -579,59 +572,68 @@ def test_float(self): inval = -inf elif inval in ('nan', 'NaN'): inval = nan - if isinf(inval): + if isinf(inval): # type: ignore self.assertTrue(isinf(outval)) - if inval < 0: - self.assertTrue(outval < 0) + if inval < 0: # type: ignore + self.assertLess(outval, 0) else: - self.assertTrue(outval > 0) - elif isnan(inval): + self.assertGreater(outval, 0) + elif isnan(inval): # type: ignore self.assertTrue(isnan(outval)) else: self.assertEqual(inval, outval) def test_datetime(self): dt = datetime(2011, 7, 17, 15, 47, 42, 317509) + values = [dt.date(), dt.time(), dt, dt.time(), dt] + self.assertIsInstance(values[3], time) + assert isinstance(values[3], time) # type guard + values[3] = values[3].replace(tzinfo=timezone.utc) + self.assertIsInstance(values[4], datetime) + assert isinstance(values[4], datetime) # type guard + values[4] = values[4].replace(tzinfo=timezone.utc) + da = (dt.year, dt.month, dt.day) + ti = (dt.hour, dt.minute, dt.second, dt.microsecond) + tz = (timezone.utc,) + inputs = [ + # input as objects + values, + # input as text + [v.isoformat() for v in values], # type: ignore + # # input using type helpers + [pgdb.Date(*da), pgdb.Time(*ti), + pgdb.Timestamp(*(da + ti)), pgdb.Time(*(ti + tz)), + pgdb.Timestamp(*(da + ti + tz))] + ] table = self.table_prefix + 'booze' - con = self._connect() + con: pgdb.Connection = self._connect() try: cur = con.cursor() cur.execute("set timezone = UTC") - cur.execute("create table %s (" - "d date, t time, ts timestamp," - "tz timetz, tsz timestamptz)" % table) - for n in range(3): - values = [dt.date(), dt.time(), dt, - dt.time(), dt] - values[3] = values[3].replace(tzinfo=pgdb.timezone.utc) - values[4] = values[4].replace(tzinfo=pgdb.timezone.utc) - if n == 0: # input as objects - params = values - if n == 1: # input as text - params = [v.isoformat() for v in values] # as text - elif n == 2: # input using type helpers - d = (dt.year, dt.month, dt.day) - t = (dt.hour, dt.minute, dt.second, dt.microsecond) - z = (pgdb.timezone.utc,) - params = [pgdb.Date(*d), pgdb.Time(*t), - pgdb.Timestamp(*(d + t)), pgdb.Time(*(t + z)), - pgdb.Timestamp(*(d + t + z))] + cur.execute(f"create table {table} (" + "d date, t time, ts timestamp," + "tz timetz, tsz timestamptz)") + for params in inputs: for datestyle in ('iso', 'postgres, mdy', 'postgres, dmy', - 'sql, mdy', 'sql, dmy', 'german'): - cur.execute("set datestyle to %s" % datestyle) - if n != 1: + 'sql, mdy', 'sql, dmy', 'german'): + cur.execute(f"set datestyle to {datestyle}") + if not isinstance(params[0], str): cur.execute("select %s,%s,%s,%s,%s", params) row = cur.fetchone() self.assertEqual(row, tuple(values)) - cur.execute("insert into %s" - " values (%%s,%%s,%%s,%%s,%%s)" % table, params) - cur.execute("select * from %s" % table) + cur.execute( + f"insert into {table}" + " values (%s,%s,%s,%s,%s)", params) + cur.execute(f"select * from {table}") d = cur.description + self.assertIsInstance(d, list) + assert d is not None # type guard for i in range(5): - self.assertEqual(d[i].type_code, pgdb.DATETIME) - self.assertNotEqual(d[i].type_code, pgdb.STRING) - self.assertNotEqual(d[i].type_code, pgdb.ARRAY) - self.assertNotEqual(d[i].type_code, pgdb.RECORD) + tc = d[i].type_code + self.assertEqual(tc, pgdb.DATETIME) + self.assertNotEqual(tc, pgdb.STRING) + self.assertNotEqual(tc, pgdb.ARRAY) + self.assertNotEqual(tc, pgdb.RECORD) self.assertEqual(d[0].type_code, pgdb.DATE) self.assertEqual(d[1].type_code, pgdb.TIME) self.assertEqual(d[2].type_code, pgdb.TIMESTAMP) @@ -639,32 +641,32 @@ def test_datetime(self): self.assertEqual(d[4].type_code, pgdb.TIMESTAMP) row = cur.fetchone() self.assertEqual(row, tuple(values)) - cur.execute("delete from %s" % table) + cur.execute(f"truncate table {table}") finally: con.close() def test_interval(self): td = datetime(2011, 7, 17, 15, 47, 42, 317509) - datetime(1970, 1, 1) + inputs = [ + # input as objects + td, + # input as text + f'{td.days} days {td.seconds} seconds' + f' {td.microseconds} microseconds', + # input using type helpers + pgdb.Interval(td.days, 0, 0, td.seconds, td.microseconds)] table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() - cur.execute("create table %s (i interval)" % table) - for n in range(3): - if n == 0: # input as objects - param = td - if n == 1: # input as text - param = '%d days %d seconds %d microseconds ' % ( - td.days, td.seconds, td.microseconds) - elif n == 2: # input using type helpers - param = pgdb.Interval( - td.days, 0, 0, td.seconds, td.microseconds) + cur.execute(f"create table {table} (i interval)") + for param in inputs: for intervalstyle in ('sql_standard ', 'postgres', - 'postgres_verbose', 'iso_8601'): - cur.execute("set intervalstyle to %s" % intervalstyle) - cur.execute("insert into %s" - " values (%%s)" % table, [param]) - cur.execute("select * from %s" % table) + 'postgres_verbose', 'iso_8601'): + cur.execute(f"set intervalstyle to {intervalstyle}") + # noinspection PyUnboundLocalVariable + cur.execute(f"insert into {table} values (%s)", [param]) + cur.execute(f"select * from {table}") tc = cur.description[0].type_code self.assertEqual(tc, pgdb.DATETIME) self.assertNotEqual(tc, pgdb.STRING) @@ -673,14 +675,14 @@ def test_interval(self): self.assertEqual(tc, pgdb.INTERVAL) row = cur.fetchone() self.assertEqual(row, (td,)) - cur.execute("delete from %s" % table) + cur.execute(f"truncate table {table}") finally: con.close() def test_hstore(self): con = self._connect() + cur = con.cursor() try: - cur = con.cursor() cur.execute("select 'k=>v'::hstore") except pgdb.DatabaseError: try: @@ -690,9 +692,9 @@ def test_hstore(self): finally: con.close() d = {'k': 'v', 'foo': 'bar', 'baz': 'whatever', 'back\\': '\\slash', - '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3', - '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"', - 'None': None, 'NULL': 'NULL', 'empty': ''} + '1a': 'anything at all', '2=b': 'value = 2', '3>c': 'value > 3', + '4"c': 'value " 4', "5'c": "value ' 5", 'hello, world': '"hi!"', + 'None': None, 'NULL': 'NULL', 'empty': ''} con = self._connect() try: cur = con.cursor() @@ -717,23 +719,25 @@ def test_uuid(self): self.assertEqual(result, d) def test_insert_array(self): - values = [(None, None), ([], []), ([None], [[None], ['null']]), + values: list[tuple[Any, Any]] = [ + (None, None), ([], []), ([None], [[None], ['null']]), ([1, 2, 3], [['a', 'b'], ['c', 'd']]), ([20000, 25000, 25000, 30000], - [['breakfast', 'consulting'], ['meeting', 'lunch']]), + [['breakfast', 'consulting'], ['meeting', 'lunch']]), ([0, 1, -1], [['Hello, World!', '"Hi!"'], ['{x,y}', ' x y ']])] table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() - cur.execute("create table %s" - " (n smallint, i int[], t text[][])" % table) + cur.execute( + f"create table {table} (n smallint, i int[], t text[][])") params = [(n, v[0], v[1]) for n, v in enumerate(values)] # Note that we must explicit casts because we are inserting # empty arrays. Otherwise this is not necessary. - cur.executemany("insert into %s values" - " (%%d,%%s::int[],%%s::text[][])" % table, params) - cur.execute("select i, t from %s order by n" % table) + cur.executemany( + f"insert into {table} values" + " (%d,%s::int[],%s::text[][])", params) + cur.execute(f"select i, t from {table} order by n") d = cur.description self.assertEqual(d[0].type_code, pgdb.ARRAY) self.assertNotEqual(d[0].type_code, pgdb.RECORD) @@ -759,7 +763,7 @@ def test_select_array(self): self.assertEqual(row, values) def test_unicode_list_and_tuple(self): - value = (u'Käse', u'Würstchen') + value = ('Käse', 'Würstchen') con = self._connect() try: cur = con.cursor() @@ -782,13 +786,13 @@ def test_insert_record(self): table = self.table_prefix + 'booze' record = self.table_prefix + 'munch' con = self._connect() + cur = con.cursor() try: - cur = con.cursor() - cur.execute("create type %s as (name varchar, age int)" % record) - cur.execute("create table %s (n smallint, r %s)" % (table, record)) + cur.execute(f"create type {record} as (name varchar, age int)") + cur.execute(f"create table {table} (n smallint, r {record})") params = enumerate(values) - cur.executemany("insert into %s values (%%d,%%s)" % table, params) - cur.execute("select r from %s order by n" % table) + cur.executemany(f"insert into {table} values (%d,%s)", params) + cur.execute(f"select r from {table} order by n") type_code = cur.description[0].type_code self.assertEqual(type_code, record) self.assertEqual(type_code, pgdb.RECORD) @@ -800,8 +804,8 @@ def test_insert_record(self): self.assertEqual(con.type_cache[columns[1].type], 'int4') rows = cur.fetchall() finally: - cur.execute('drop table %s' % table) - cur.execute('drop type %s' % record) + cur.execute(f'drop table {table}') + cur.execute(f'drop type {record}') con.close() self.assertEqual(len(rows), len(values)) rows = [row[0] for row in rows] @@ -811,7 +815,7 @@ def test_insert_record(self): def test_select_record(self): value = (1, 25000, 2.5, 'hello', 'Hello World!', 'Hello, World!', - '(test)', '(x,y)', ' x y ', 'null', None) + '(test)', '(x,y)', ' x y ', 'null', None) con = self._connect() try: cur = con.cursor() @@ -829,16 +833,16 @@ def test_select_record(self): def test_custom_type(self): values = [3, 5, 65] - values = list(map(PgBitString, values)) + values = list(map(PgBitString, values)) # type: ignore table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() - params = enumerate(values) # params have __pg_repr__ method + seq_params = enumerate(values) # params have __pg_repr__ method cur.execute( - 'create table "%s" (n smallint, b bit varying(7))' % table) - cur.executemany("insert into %s values (%%s,%%s)" % table, params) - cur.execute("select * from %s" % table) + f'create table "{table}" (n smallint, b bit varying(7))') + cur.executemany(f"insert into {table} values (%s,%s)", seq_params) + cur.execute(f"select * from {table}") rows = cur.fetchall() finally: con.close() @@ -847,50 +851,60 @@ def test_custom_type(self): try: cur = con.cursor() params = (1, object()) # an object that cannot be handled - self.assertRaises(pgdb.InterfaceError, cur.execute, - "insert into %s values (%%s,%%s)" % table, params) + self.assertRaises( + pgdb.InterfaceError, cur.execute, + f"insert into {table} values (%s,%s)", params) finally: con.close() def test_set_decimal_type(self): - decimal_type = pgdb.decimal_type() - self.assertTrue(decimal_type is not None and callable(decimal_type)) + from pgdb.cast import decimal_type + self.assertIs(decimal_type(), Decimal) con = self._connect() try: cur = con.cursor() # change decimal type globally to int - int_type = lambda v: int(float(v)) - self.assertTrue(pgdb.decimal_type(int_type) is int_type) + + class CustomDecimal(str): + + def __init__(self, value: Any) -> None: + self.value = value + + def __str__(self) -> str: + return str(self.value).replace('.', ',') + + self.assertIs(decimal_type(CustomDecimal), CustomDecimal) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] - self.assertTrue(isinstance(value, int)) - self.assertEqual(value, 4) + self.assertIsInstance(value, CustomDecimal) + self.assertEqual(str(value), '4,25') # change decimal type again to float - self.assertTrue(pgdb.decimal_type(float) is float) + self.assertIs(decimal_type(float), float) cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] # the connection still uses the old setting - self.assertTrue(isinstance(value, int)) + self.assertIsInstance(value, str) + self.assertEqual(str(value), '4,25') # bust the cache for type functions for the connection con.type_cache.reset_typecast() cur.execute('select 4.25') self.assertEqual(cur.description[0].type_code, pgdb.NUMBER) value = cur.fetchone()[0] # now the connection uses the new setting - self.assertTrue(isinstance(value, float)) + self.assertIsInstance(value, float) self.assertEqual(value, 4.25) finally: con.close() - pgdb.decimal_type(decimal_type) - self.assertTrue(pgdb.decimal_type() is decimal_type) + decimal_type(Decimal) + self.assertIs(decimal_type(), Decimal) def test_global_typecast(self): try: query = 'select 2::int2, 4::int4, 8::int8' self.assertIs(pgdb.get_typecast('int4'), int) - cast_int = lambda v: 'int(%s)' % v + cast_int = lambda v: f'int({v})' # noqa: E731 pgdb.set_typecast('int4', cast_int) con = self._connect() try: @@ -977,32 +991,30 @@ def test_set_typecast_for_arrays(self): def test_unicode_with_utf8(self): table = self.table_prefix + 'booze' - input = u"He wes Leovenaðes sone — liðe him be Drihten" + s = "He wes Leovenaðes sone — liðe him be Drihten" con = self._connect() + cur = con.cursor() try: - cur = con.cursor() - cur.execute("create table %s (t text)" % table) + cur.execute(f"create table {table} (t text)") try: cur.execute("set client_encoding=utf8") - cur.execute(u"select '%s'" % input) + cur.execute(f"select '{s}'") except Exception: self.skipTest("database does not support utf8") output1 = cur.fetchone()[0] - cur.execute("insert into %s values (%%s)" % table, (input,)) - cur.execute("select * from %s" % table) + cur.execute(f"insert into {table} values (%s)", (s,)) + cur.execute(f"select * from {table}") output2 = cur.fetchone()[0] - cur.execute("select t = '%s' from %s" % (input, table)) + cur.execute(f"select t = '{s}' from {table}") output3 = cur.fetchone()[0] - cur.execute("select t = %%s from %s" % table, (input,)) + cur.execute(f"select t = %s from {table}", (s,)) output4 = cur.fetchone()[0] finally: con.close() - if str is bytes: # Python < 3.0 - input = input.encode('utf8') self.assertIsInstance(output1, str) - self.assertEqual(output1, input) + self.assertEqual(output1, s) self.assertIsInstance(output2, str) - self.assertEqual(output2, input) + self.assertEqual(output2, s) self.assertIsInstance(output3, bool) self.assertTrue(output3) self.assertIsInstance(output4, bool) @@ -1010,32 +1022,30 @@ def test_unicode_with_utf8(self): def test_unicode_with_latin1(self): table = self.table_prefix + 'booze' - input = u"Ehrt den König seine Würde, ehret uns der Hände Fleiß." + s = "Ehrt den König seine Würde, ehret uns der Hände Fleiß." con = self._connect() try: cur = con.cursor() - cur.execute("create table %s (t text)" % table) + cur.execute(f"create table {table} (t text)") try: cur.execute("set client_encoding=latin1") - cur.execute(u"select '%s'" % input) + cur.execute(f"select '{s}'") except Exception: self.skipTest("database does not support latin1") output1 = cur.fetchone()[0] - cur.execute("insert into %s values (%%s)" % table, (input,)) - cur.execute("select * from %s" % table) + cur.execute(f"insert into {table} values (%s)", (s,)) + cur.execute(f"select * from {table}") output2 = cur.fetchone()[0] - cur.execute("select t = '%s' from %s" % (input, table)) + cur.execute(f"select t = '{s}' from {table}") output3 = cur.fetchone()[0] - cur.execute("select t = %%s from %s" % table, (input,)) + cur.execute(f"select t = %s from {table}", (s,)) output4 = cur.fetchone()[0] finally: con.close() - if str is bytes: # Python < 3.0 - input = input.encode('latin1') self.assertIsInstance(output1, str) - self.assertEqual(output1, input) + self.assertEqual(output1, s) self.assertIsInstance(output2, str) - self.assertEqual(output2, input) + self.assertEqual(output2, s) self.assertIsInstance(output3, bool) self.assertTrue(output3) self.assertIsInstance(output4, bool) @@ -1047,11 +1057,10 @@ def test_bool(self): con = self._connect() try: cur = con.cursor() - cur.execute( - "create table %s (n smallint, booltest bool)" % table) + cur.execute(f"create table {table} (n smallint, booltest bool)") params = enumerate(values) - cur.executemany("insert into %s values (%%s,%%s)" % table, params) - cur.execute("select booltest from %s order by n" % table) + cur.executemany(f"insert into {table} values (%s,%s)", params) + cur.execute(f"select booltest from {table} order by n") rows = cur.fetchall() self.assertEqual(cur.description[0].type_code, pgdb.BOOL) finally: @@ -1073,19 +1082,19 @@ def test_literal(self): self.assertEqual(row, (value, 'hello')) def test_json(self): - inval = {"employees": - [{"firstName": "John", "lastName": "Doe", "age": 61}]} + inval = {"employees": [ + {"firstName": "John", "lastName": "Doe", "age": 61}]} table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() try: - cur.execute("create table %s (jsontest json)" % table) + cur.execute(f"create table {table} (jsontest json)") except pgdb.ProgrammingError: self.skipTest('database does not support json') params = (pgdb.Json(inval),) - cur.execute("insert into %s values (%%s)" % table, params) - cur.execute("select jsontest from %s" % table) + cur.execute(f"insert into {table} values (%s)", params) + cur.execute(f"select jsontest from {table}") outval = cur.fetchone()[0] self.assertEqual(cur.description[0].type_code, pgdb.JSON) finally: @@ -1093,19 +1102,19 @@ def test_json(self): self.assertEqual(inval, outval) def test_jsonb(self): - inval = {"employees": - [{"firstName": "John", "lastName": "Doe", "age": 61}]} + inval = {"employees": [ + {"firstName": "John", "lastName": "Doe", "age": 61}]} table = self.table_prefix + 'booze' con = self._connect() try: cur = con.cursor() try: - cur.execute("create table %s (jsonbtest jsonb)" % table) + cur.execute(f"create table {table} (jsonbtest jsonb)") except pgdb.ProgrammingError: self.skipTest('database does not support jsonb') params = (pgdb.Json(inval),) - cur.execute("insert into %s values (%%s)" % table, params) - cur.execute("select jsonbtest from %s" % table) + cur.execute(f"insert into {table} values (%s)", params) + cur.execute(f"select jsonbtest from {table}") outval = cur.fetchone()[0] self.assertEqual(cur.description[0].type_code, pgdb.JSON) finally: @@ -1135,6 +1144,30 @@ def test_execute_edge_cases(self): sql = 'select 1' # cannot be executed after connection is closed self.assertRaises(pgdb.OperationalError, cur.execute, sql) + def test_fetchall_with_various_sizes(self): + # we test this because there are optimizations based on result size + con = self._connect() + try: + for n in (1, 3, 5, 7, 10, 100, 1000): + cur = con.cursor() + try: + cur.execute('select n, n::text as s, n % 2 = 1 as b' + f' from generate_series(1, {n}) as s(n)') + res = cur.fetchall() + self.assertEqual(len(res), n, res) + self.assertEqual(len(res[0]), 3) + self.assertEqual(res[0].n, 1) + self.assertEqual(res[0].s, '1') + self.assertIs(res[0].b, True) + self.assertEqual(len(res[-1]), 3) + self.assertEqual(res[-1].n, n) + self.assertEqual(res[-1].s, str(n)) + self.assertIs(res[-1].b, n % 2 == 1) + finally: + cur.close() + finally: + con.close() + def test_fetchmany_with_keep(self): con = self._connect() try: @@ -1160,6 +1193,12 @@ def test_fetchmany_with_keep(self): finally: con.close() + def help_nextset_setup(self, _cur): + pass # helper not needed + + def help_nextset_teardown(self, _cur): + pass # helper not needed + def test_nextset(self): con = self._connect() cur = con.cursor() @@ -1185,17 +1224,17 @@ def test_transaction(self): table = self.table_prefix + 'booze' con1 = self._connect() cur1 = con1.cursor() - self.executeDDL1(cur1) + self.execute_ddl1(cur1) con1.commit() con2 = self._connect() cur2 = con2.cursor() - cur2.execute("select name from %s" % table) + cur2.execute(f"select name from {table}") self.assertIsNone(cur2.fetchone()) - cur1.execute("insert into %s values('Schlafly')" % table) - cur2.execute("select name from %s" % table) + cur1.execute(f"insert into {table} values('Schlafly')") + cur2.execute(f"select name from {table}") self.assertIsNone(cur2.fetchone()) con1.commit() - cur2.execute("select name from %s" % table) + cur2.execute(f"select name from {table}") self.assertEqual(cur2.fetchone(), ('Schlafly',)) con2.close() con1.close() @@ -1205,13 +1244,13 @@ def test_autocommit(self): con1 = self._connect() con1.autocommit = True cur1 = con1.cursor() - self.executeDDL1(cur1) + self.execute_ddl1(cur1) con2 = self._connect() cur2 = con2.cursor() - cur2.execute("select name from %s" % table) + cur2.execute(f"select name from {table}") self.assertIsNone(cur2.fetchone()) - cur1.execute("insert into %s values('Shmaltz Pastrami')" % table) - cur2.execute("select name from %s" % table) + cur1.execute(f"insert into {table} values('Shmaltz Pastrami')") + cur2.execute(f"select name from {table}") self.assertEqual(cur2.fetchone(), ('Shmaltz Pastrami',)) con2.close() con1.close() @@ -1224,32 +1263,32 @@ def test_connection_as_contextmanager(self): try: cur = con.cursor() if autocommit: - cur.execute("truncate %s" % table) + cur.execute(f"truncate table {table}") else: cur.execute( - "create table %s (n smallint check(n!=4))" % table) + f"create table {table} (n smallint check(n!=4))") with con: - cur.execute("insert into %s values (1)" % table) - cur.execute("insert into %s values (2)" % table) + cur.execute(f"insert into {table} values (1)") + cur.execute(f"insert into {table} values (2)") try: with con: - cur.execute("insert into %s values (3)" % table) - cur.execute("insert into %s values (4)" % table) + cur.execute(f"insert into {table} values (3)") + cur.execute(f"insert into {table} values (4)") except con.IntegrityError as error: - self.assertTrue('check' in str(error).lower()) + self.assertIn('check', str(error).lower()) with con: - cur.execute("insert into %s values (5)" % table) - cur.execute("insert into %s values (6)" % table) + cur.execute(f"insert into {table} values (5)") + cur.execute(f"insert into {table} values (6)") try: with con: - cur.execute("insert into %s values (7)" % table) - cur.execute("insert into %s values (8)" % table) + cur.execute(f"insert into {table} values (7)") + cur.execute(f"insert into {table} values (8)") raise ValueError('transaction should rollback') except ValueError as error: self.assertEqual(str(error), 'transaction should rollback') with con: - cur.execute("insert into %s values (9)" % table) - cur.execute("select * from %s order by 1" % table) + cur.execute(f"insert into {table} values (9)") + cur.execute(f"select * from {table} order by 1") rows = cur.fetchall() rows = [row[0] for row in rows] finally: @@ -1289,11 +1328,11 @@ def test_pgdb_type(self): self.assertEqual('int8', pgdb.INTEGER) self.assertNotEqual('int4', pgdb.LONG) self.assertEqual('int8', pgdb.LONG) - self.assertTrue('char' in pgdb.STRING) - self.assertTrue(pgdb.NUMERIC <= pgdb.NUMBER) - self.assertTrue(pgdb.NUMBER >= pgdb.INTEGER) - self.assertTrue(pgdb.TIME <= pgdb.DATETIME) - self.assertTrue(pgdb.DATETIME >= pgdb.DATE) + self.assertIn('char', pgdb.STRING) + self.assertLess(pgdb.NUMERIC, pgdb.NUMBER) + self.assertGreaterEqual(pgdb.NUMBER, pgdb.INTEGER) + self.assertLessEqual(pgdb.TIME, pgdb.DATETIME) + self.assertGreaterEqual(pgdb.DATETIME, pgdb.DATE) self.assertEqual(pgdb.ARRAY, pgdb.ARRAY) self.assertNotEqual(pgdb.ARRAY, pgdb.STRING) self.assertEqual('_char', pgdb.ARRAY) @@ -1313,17 +1352,14 @@ def test_no_close(self): row = cur.fetchone() self.assertEqual(row, data) - def test_set_row_factory_size(self): - try: - from functools import lru_cache - except ImportError: # Python < 3.2 - lru_cache = None + def test_change_row_factory_cache_size(self): + from pg import RowCache queries = ['select 1 as a, 2 as b, 3 as c', 'select 123 as abc'] con = self._connect() cur = con.cursor() for maxsize in (None, 0, 1, 2, 3, 10, 1024): - pgdb.set_row_factory_size(maxsize) - for i in range(3): + RowCache.change_size(maxsize) + for _i in range(3): for q in queries: cur.execute(q) r = cur.fetchone() @@ -1333,16 +1369,15 @@ def test_set_row_factory_size(self): else: self.assertEqual(r, (1, 2, 3)) self.assertEqual(r._fields, ('a', 'b', 'c')) - if lru_cache: - info = pgdb._row_factory.cache_info() - self.assertEqual(info.maxsize, maxsize) - self.assertEqual(info.hits + info.misses, 6) - self.assertEqual(info.hits, - 0 if maxsize is not None and maxsize < 2 else 4) + info = RowCache.row_factory.cache_info() + self.assertEqual(info.maxsize, maxsize) + self.assertEqual(info.hits + info.misses, 6) + self.assertEqual(info.hits, + 0 if maxsize is not None and maxsize < 2 else 4) def test_memory_leaks(self): - ids = set() - objs = [] + ids: set = set() + objs: list = [] add_ids = ids.update gc.collect() objs[:] = gc.get_objects() @@ -1351,9 +1386,6 @@ def test_memory_leaks(self): gc.collect() objs[:] = gc.get_objects() objs[:] = [obj for obj in objs if id(obj) not in ids] - if objs and sys.version_info[:3] in ((3, 5, 0), (3, 5, 1)): - # workaround for Python issue 26811 - objs[:] = [obj for obj in objs if repr(obj) != '(,)'] self.assertEqual(len(objs), 0) def test_cve_2018_1058(self): diff --git a/tests/test_dbapi20_copy.py b/tests/test_dbapi20_copy.py index 7fdca2c0..02810ba6 100644 --- a/tests/test_dbapi20_copy.py +++ b/tests/test_dbapi20_copy.py @@ -1,5 +1,4 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- """Test the modern PyGreSQL interface. @@ -10,51 +9,28 @@ These tests need a database to test against. """ -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +from __future__ import annotations # -try: - from collections.abc import Iterable -except ImportError: # Python < 3.3 - from collections import Iterable +import unittest +from collections.abc import Iterable +from contextlib import suppress +from typing import ClassVar import pgdb # the module under test -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -# The current user must have create schema privilege on the database. -dbname = 'unittest' -dbhost = None -dbport = 5432 - -try: - from .LOCAL_PyGreSQL import * -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * - except ImportError: - pass - -try: # noinspection PyUnresolvedReferences - unicode -except NameError: # Python >= 3.0 - unicode = str +from .config import dbhost, dbname, dbpasswd, dbport, dbuser class InputStream: def __init__(self, data): - if isinstance(data, unicode): - data = data.encode('utf-8') + if isinstance(data, str): + data = data.encode() self.data = data or b'' self.sizes = [] def __str__(self): - data = self.data - if str is unicode: # Python >= 3.0 - data = data.decode('utf-8') + data = self.data.decode() return data def __len__(self): @@ -77,17 +53,15 @@ def __init__(self): self.sizes = [] def __str__(self): - data = self.data - if str is unicode: # Python >= 3.0 - data = data.decode('utf-8') + data = self.data.decode() return data def __len__(self): return len(self.data) def write(self, data): - if isinstance(data, unicode): - data = data.encode('utf-8') + if isinstance(data, str): + data = data.encode() self.data += data self.sizes.append(len(data)) @@ -127,10 +101,16 @@ class TestCopy(unittest.TestCase): cls_set_up = False + data: ClassVar[list[tuple[int, str]]] = [ + (1935, 'Luciano Pavarotti'), + (1941, 'Plácido Domingo'), + (1946, 'José Carreras')] + @staticmethod def connect(): - return pgdb.connect(database=dbname, - host='%s:%d' % (dbhost or '', dbport or -1)) + host = f"{dbhost or ''}:{dbport or -1}" + return pgdb.connect(database=dbname, host=host, + user=dbuser, password=dbpasswd) @classmethod def setUpClass(cls): @@ -139,7 +119,7 @@ def setUpClass(cls): cur.execute("set client_min_messages=warning") cur.execute("drop table if exists copytest cascade") cur.execute("create table copytest (" - "id smallint primary key, name varchar(64))") + "id smallint primary key, name varchar(64))") cur.close() con.commit() cur = con.cursor() @@ -147,8 +127,9 @@ def setUpClass(cls): cur.execute("set client_encoding=utf8") cur.execute("select 'Plácido and José'").fetchone() except (pgdb.DataError, pgdb.NotSupportedError): - cls.data[1] = (1941, 'Plaacido Domingo') - cls.data[2] = (1946, 'Josee Carreras') + cls.data[1:3] = [ + (1941, 'Plaacido Domingo'), + (1946, 'Josee Carreras')] cls.can_encode = False cur.close() con.close() @@ -170,32 +151,22 @@ def setUp(self): self.cursor.execute("set client_encoding=utf8") def tearDown(self): - try: + with suppress(Exception): self.cursor.close() - except Exception: - pass - try: + with suppress(Exception): self.con.rollback() - except Exception: - pass - try: + with suppress(Exception): self.con.close() - except Exception: - pass - - data = [(1935, 'Luciano Pavarotti'), - (1941, 'Plácido Domingo'), - (1946, 'José Carreras')] can_encode = True @property def data_text(self): - return ''.join('%d\t%s\n' % row for row in self.data) + return ''.join('{}\t{}\n'.format(*row) for row in self.data) @property def data_csv(self): - return ''.join('%d,%s\n' % row for row in self.data) + return ''.join('{},{}\n'.format(*row) for row in self.data) def truncate_table(self): self.cursor.execute("truncate table copytest") @@ -208,7 +179,7 @@ def table_data(self): def check_table(self): self.assertEqual(self.table_data, self.data) - def check_rowcount(self, number=len(data)): + def check_rowcount(self, number=len(data)): # noqa: B008 self.assertEqual(self.cursor.rowcount, number) @@ -216,10 +187,10 @@ class TestCopyFrom(TestCopy): """Test the copy_from method.""" def tearDown(self): - super(TestCopyFrom, self).tearDown() + super().tearDown() self.setUp() self.truncate_table() - super(TestCopyFrom, self).tearDown() + super().tearDown() def copy_from(self, stream, **options): return self.cursor.copy_from(stream, 'copytest', **options) @@ -230,7 +201,7 @@ def data_file(self): def test_bad_params(self): call = self.cursor.copy_from - call('0\t', 'copytest'), self.cursor + call('0\t', 'copytest') call('1\t', 'copytest', format='text', sep='\t', null='', columns=['id', 'name']) self.assertRaises(TypeError, call) @@ -247,8 +218,8 @@ def test_bad_params(self): self.assertRaises(TypeError, call, '0\t', 'copytest', null=42) self.assertRaises(ValueError, call, '0\t', 'copytest', size='bad') self.assertRaises(TypeError, call, '0\t', 'copytest', columns=42) - self.assertRaises(ValueError, call, b'', 'copytest', - format='binary', sep=',') + self.assertRaises( + ValueError, call, b'', 'copytest', format='binary', sep=',') def test_input_string(self): ret = self.copy_from('42\tHello, world!') @@ -256,6 +227,10 @@ def test_input_string(self): self.assertEqual(self.table_data, [(42, 'Hello, world!')]) self.check_rowcount(1) + def test_input_string_with_schema_name(self): + self.cursor.copy_from('42\tHello, world!', 'public.copytest') + self.assertEqual(self.table_data, [(42, 'Hello, world!')]) + def test_input_string_with_newline(self): self.copy_from('42\tHello, world!\n') self.assertEqual(self.table_data, [(42, 'Hello, world!')]) @@ -267,25 +242,12 @@ def test_input_string_multiple_rows(self): self.check_table() self.check_rowcount() - if str is unicode: # Python >= 3.0 - - def test_input_bytes(self): - self.copy_from(b'42\tHello, world!') - self.assertEqual(self.table_data, [(42, 'Hello, world!')]) - self.truncate_table() - self.copy_from(self.data_text.encode('utf-8')) - self.check_table() - - else: # Python < 3.0 - - def test_input_unicode(self): - if not self.can_encode: - self.skipTest('database does not support utf8') - self.copy_from(u'43\tWürstel, Käse!') - self.assertEqual(self.table_data, [(43, 'Würstel, Käse!')]) - self.truncate_table() - self.copy_from(self.data_text.decode('utf-8')) - self.check_table() + def test_input_bytes(self): + self.copy_from(b'42\tHello, world!') + self.assertEqual(self.table_data, [(42, 'Hello, world!')]) + self.truncate_table() + self.copy_from(self.data_text.encode()) + self.check_table() def test_input_iterable(self): self.copy_from(self.data_text.splitlines()) @@ -296,18 +258,16 @@ def test_input_iterable_invalid(self): self.assertRaises(IOError, self.copy_from, [None]) def test_input_iterable_with_newlines(self): - self.copy_from('%s\n' % row for row in self.data_text.splitlines()) + self.copy_from(f'{row}\n' for row in self.data_text.splitlines()) self.check_table() - if str is unicode: # Python >= 3.0 - - def test_input_iterable_bytes(self): - self.copy_from(row.encode('utf-8') - for row in self.data_text.splitlines()) - self.check_table() + def test_input_iterable_bytes(self): + self.copy_from(row.encode() + for row in self.data_text.splitlines()) + self.check_table() def test_sep(self): - stream = ('%d-%s' % row for row in self.data) + stream = ('{}-{}'.format(*row) for row in self.data) self.copy_from(stream, sep='-') self.check_table() @@ -342,7 +302,7 @@ def test_columns(self): (1, None), (2, None), (3, 'Three'), (4, 'Four'), (5, 'Five')]) self.check_rowcount(5) self.assertRaises(pgdb.ProgrammingError, self.copy_from, - '6\t42', columns=['id', 'age']) + '6\t42', columns=['id', 'age']) self.check_rowcount(-1) def test_csv(self): @@ -350,22 +310,22 @@ def test_csv(self): self.check_table() def test_csv_with_sep(self): - stream = ('%d;"%s"\n' % row for row in self.data) + stream = ('{};"{}"\n'.format(*row) for row in self.data) self.copy_from(stream, format='csv', sep=';') self.check_table() self.check_rowcount() def test_binary(self): - self.assertRaises(IOError, self.copy_from, - b'NOPGCOPY\n', format='binary') + self.assertRaises( + IOError, self.copy_from, b'NOPGCOPY\n', format='binary') self.check_rowcount(-1) def test_binary_with_sep(self): - self.assertRaises(ValueError, self.copy_from, - '', format='binary', sep='\t') + self.assertRaises( + ValueError, self.copy_from, '', format='binary', sep='\t') def test_binary_with_unicode(self): - self.assertRaises(ValueError, self.copy_from, u'', format='binary') + self.assertRaises(ValueError, self.copy_from, '', format='binary') def test_query(self): self.assertRaises(ValueError, self.cursor.copy_from, '', "select null") @@ -398,8 +358,8 @@ def test_size_negative(self): self.check_rowcount() def test_size_invalid(self): - self.assertRaises(TypeError, - self.copy_from, self.data_file, size='invalid') + self.assertRaises( + TypeError, self.copy_from, self.data_file, size='invalid') class TestCopyTo(TestCopy): @@ -407,7 +367,7 @@ class TestCopyTo(TestCopy): @classmethod def setUpClass(cls): - super(TestCopyTo, cls).setUpClass() + super().setUpClass() con = cls.connect() cur = con.cursor() cur.execute("set client_encoding=utf8") @@ -446,67 +406,58 @@ def test_generator(self): self.assertIsInstance(ret, Iterable) rows = list(ret) self.assertEqual(len(rows), 3) - rows = ''.join(rows) - self.assertIsInstance(rows, str) - self.assertEqual(rows, self.data_text) + text = ''.join(rows) + self.assertIsInstance(text, str) + self.assertEqual(text, self.data_text) self.check_rowcount() - if str is unicode: # Python >= 3.0 + def test_generator_with_schema_name(self): + ret = self.cursor.copy_to(None, 'public.copytest') + self.assertEqual(''.join(ret), self.data_text) - def test_generator_bytes(self): - ret = self.copy_to(decode=False) - self.assertIsInstance(ret, Iterable) - rows = list(ret) - self.assertEqual(len(rows), 3) - rows = b''.join(rows) - self.assertIsInstance(rows, bytes) - self.assertEqual(rows, self.data_text.encode('utf-8')) - - else: # Python < 3.0 - - def test_generator_unicode(self): - ret = self.copy_to(decode=True) - self.assertIsInstance(ret, Iterable) - rows = list(ret) - self.assertEqual(len(rows), 3) - rows = ''.join(rows) - self.assertIsInstance(rows, unicode) - self.assertEqual(rows, self.data_text.decode('utf-8')) + def test_generator_bytes(self): + ret = self.copy_to(decode=False) + self.assertIsInstance(ret, Iterable) + rows = list(ret) + self.assertEqual(len(rows), 3) + byte_text = b''.join(rows) + self.assertIsInstance(byte_text, bytes) + self.assertEqual(byte_text, self.data_text.encode()) def test_rowcount_increment(self): ret = self.copy_to() self.assertIsInstance(ret, Iterable) - for n, row in enumerate(ret): + for n, _row in enumerate(ret): self.check_rowcount(n + 1) def test_decode(self): ret_raw = b''.join(self.copy_to(decode=False)) ret_decoded = ''.join(self.copy_to(decode=True)) self.assertIsInstance(ret_raw, bytes) - self.assertIsInstance(ret_decoded, unicode) - self.assertEqual(ret_decoded, ret_raw.decode('utf-8')) + self.assertIsInstance(ret_decoded, str) + self.assertEqual(ret_decoded, ret_raw.decode()) self.check_rowcount() def test_sep(self): ret = list(self.copy_to(sep='-')) - self.assertEqual(ret, ['%d-%s\n' % row for row in self.data]) + self.assertEqual(ret, ['{}-{}\n'.format(*row) for row in self.data]) def test_null(self): - data = ['%d\t%s\n' % row for row in self.data] + data = ['{}\t{}\n'.format(*row) for row in self.data] self.cursor.execute('insert into copytest values(4, null)') try: ret = list(self.copy_to()) - self.assertEqual(ret, data + ['4\t\\N\n']) + self.assertEqual(ret, [*data, '4\t\\N\n']) ret = list(self.copy_to(null='Nix')) - self.assertEqual(ret, data + ['4\tNix\n']) + self.assertEqual(ret, [*data, '4\tNix\n']) ret = list(self.copy_to(null='')) - self.assertEqual(ret, data + ['4\t\n']) + self.assertEqual(ret, [*data, '4\t\n']) finally: self.cursor.execute('delete from copytest where id=4') def test_columns(self): - data_id = ''.join('%d\n' % row[0] for row in self.data) - data_name = ''.join('%s\n' % row[1] for row in self.data) + data_id = ''.join(f'{row[0]}\n' for row in self.data) + data_name = ''.join(f'{row[1]}\n' for row in self.data) ret = ''.join(self.copy_to(columns='id')) self.assertEqual(ret, data_id) ret = ''.join(self.copy_to(columns=['id'])) @@ -519,17 +470,17 @@ def test_columns(self): self.assertEqual(ret, self.data_text) ret = ''.join(self.copy_to(columns=['id', 'name'])) self.assertEqual(ret, self.data_text) - self.assertRaises(pgdb.ProgrammingError, self.copy_to, - columns=['id', 'age']) + self.assertRaises( + pgdb.ProgrammingError, self.copy_to, columns=['id', 'age']) def test_csv(self): ret = self.copy_to(format='csv') self.assertIsInstance(ret, Iterable) rows = list(ret) self.assertEqual(len(rows), 3) - rows = ''.join(rows) - self.assertIsInstance(rows, str) - self.assertEqual(rows, self.data_csv) + csv = ''.join(rows) + self.assertIsInstance(csv, str) + self.assertEqual(csv, self.data_csv) self.check_rowcount(3) def test_csv_with_sep(self): @@ -548,19 +499,20 @@ def test_binary_with_sep(self): self.assertRaises(ValueError, self.copy_to, format='binary', sep='\t') def test_binary_with_unicode(self): - self.assertRaises(ValueError, self.copy_to, - format='binary', decode=True) + self.assertRaises( + ValueError, self.copy_to, format='binary', decode=True) def test_query(self): - self.assertRaises(ValueError, self.cursor.copy_to, None, + self.assertRaises( + ValueError, self.cursor.copy_to, None, "select name from copytest", columns='noname') - ret = self.cursor.copy_to(None, - "select name||'!' from copytest where id=1941") + ret = self.cursor.copy_to( + None, "select name||'!' from copytest where id=1941") self.assertIsInstance(ret, Iterable) rows = list(ret) self.assertEqual(len(rows), 1) self.assertIsInstance(rows[0], str) - self.assertEqual(rows[0], '%s!\n' % self.data[1][1]) + self.assertEqual(rows[0], f'{self.data[1][1]}!\n') self.check_rowcount(1) def test_file(self): @@ -568,9 +520,7 @@ def test_file(self): ret = self.copy_to(stream) self.assertIs(ret, self.cursor) self.assertEqual(str(stream), self.data_text) - data = self.data_text - if str is unicode: # Python >= 3.0 - data = data.encode('utf-8') + data = self.data_text.encode() sizes = [len(row) + 1 for row in data.splitlines()] self.assertEqual(stream.sizes, sizes) self.check_rowcount() diff --git a/tests/test_tutorial.py b/tests/test_tutorial.py index 10943359..c09d13b8 100644 --- a/tests/test_tutorial.py +++ b/tests/test_tutorial.py @@ -1,37 +1,20 @@ #!/usr/bin/python -# -*- coding: utf-8 -*- -from __future__ import print_function - -try: - import unittest2 as unittest # for Python < 2.7 -except ImportError: - import unittest +import unittest +from typing import Any from pg import DB from pgdb import connect -# We need a database to test against. If LOCAL_PyGreSQL.py exists we will -# get our information from that. Otherwise we use the defaults. -dbname = 'unittest' -dbhost = None -dbport = 5432 - -try: - from .LOCAL_PyGreSQL import * -except (ImportError, ValueError): - try: - from LOCAL_PyGreSQL import * - except ImportError: - pass +from .config import dbhost, dbname, dbpasswd, dbport, dbuser class TestClassicTutorial(unittest.TestCase): """Test the First Steps Tutorial for the classic interface.""" def setUp(self): - """Setup test tables or empty them if they already exist.""" - db = DB(dbname=dbname, host=dbhost, port=dbport) + """Set up test tables or empty them if they already exist.""" + db = DB(dbname, dbhost, dbport, user=dbuser, passwd=dbpasswd) db.query("set datestyle to 'iso'") db.query("set default_with_oids=false") db.query("set standard_conforming_strings=false") @@ -47,7 +30,7 @@ def tearDown(self): def test_all_steps(self): db = self.db - r = db.get_tables() + r: Any = db.get_tables() self.assertIsInstance(r, list) self.assertIn('public.fruits', r) r = db.get_attnames('fruits') @@ -63,7 +46,8 @@ def test_all_steps(self): self.assertEqual(r, {'name': 'banana', 'id': 2}) more_fruits = 'cherimaya durian eggfruit fig grapefruit'.split() data = list(enumerate(more_fruits, start=3)) - db.inserttable('fruits', data) + n = db.inserttable('fruits', data) + self.assertEqual(n, 5) q = db.query('select * from fruits') r = str(q).splitlines() self.assertEqual(r[0], 'id| name ') @@ -124,10 +108,10 @@ class TestDbApi20Tutorial(unittest.TestCase): """Test the First Steps Tutorial for the DB-API 2.0 interface.""" def setUp(self): - """Setup test tables or empty them if they already exist.""" - database = dbname - host = '%s:%d' % (dbhost or '', dbport or -1) - con = connect(database=database, host=host) + """Set up test tables or empty them if they already exist.""" + host = f"{dbhost or ''}:{dbport or -1}" + con = connect(database=dbname, host=host, + user=dbuser, password=dbpasswd) cur = con.cursor() cur.execute("set datestyle to 'iso'") cur.execute("set default_with_oids=false") @@ -155,7 +139,7 @@ def test_all_steps(self): cursor.executemany("insert into fruits (name) values (%s)", parameters) con.commit() cursor.execute('select * from fruits where id=1') - r = cursor.fetchone() + r: Any = cursor.fetchone() self.assertIsInstance(r, tuple) self.assertEqual(len(r), 2) r = str(r) diff --git a/tox.ini b/tox.ini index c8be87b9..2359c8df 100644 --- a/tox.ini +++ b/tox.ini @@ -1,12 +1,60 @@ # config file for tox [tox] -envlist = py{26,27,33,34,35,36,37,38} +envlist = py3{7,8,9,10,11,12,13},ruff,mypy,cformat,docs +requires = # this is needed for compatibility with Python 3.7 + pip<24.1 + virtualenv<20.27 + +[testenv:ruff] +basepython = python3.13 +deps = ruff>=0.8,<0.9 +commands = + ruff check setup.py pg pgdb tests + +[testenv:mypy] +basepython = python3.13 +deps = mypy>=1.13,<1.14 +commands = + mypy pg pgdb tests + +[testenv:cformat] +basepython = python3.13 +allowlist_externals = + sh +commands = + sh -c "! (clang-format --style=file -n ext/*.c 2>&1 | tee /dev/tty | grep format-violations)" + +[testenv:docs] +basepython = python3.13 +deps = + sphinx>=8,<9 +commands = + sphinx-build -b html -nEW docs docs/_build/html + +[testenv:build] +basepython = python3.13 +deps = + setuptools>=68 + wheel>=0.42,<1 + build>=1,<2 +commands = + python -m build -s -n -C strict -C memory-size + +[testenv:coverage] +basepython = python3.13 +deps = + coverage>=7,<8 +commands = + coverage run -m unittest discover -v + coverage html [testenv] +passenv = + PG* + PYGRESQL_* deps = - py26: unittest2 + setuptools>=68 commands = - python setup.py clean --all build_ext --force --inplace --strict - py26: unit2 discover {posargs} - py{27,33,34,35,36,37,38}: python -m unittest discover {posargs} + python setup.py clean --all build_ext --force --inplace --strict --memory-size + python -m unittest {posargs:discover -v}