diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml new file mode 100644 index 0000000..ad06a1f --- /dev/null +++ b/.github/workflows/build-and-test.yml @@ -0,0 +1,117 @@ +name: Build and Test +on: + push: + branches: + - testing + - main + release: + types: + - created + workflow_dispatch: +jobs: + build_wheels: + name: Build wheels on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + include: + - os: macos-12 + arch: x86_64 + deployment-target: '10.9' + - os: macos-latest + arch: arm64 + deployment-target: '11.0' + - os: ubuntu-latest + arch: x86_64 + deployment-target: '' + - os: windows-2022 + arch: AMD64 + deployment-target: '' + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + if: matrix.os != 'ubuntu-latest' + with: + python-version: '3.10' + - uses: actions/setup-python@v4 + if: matrix.os == 'ubuntu-latest' + # for testing due to docker env issues + with: + python-version: '3.9' + - name: Install cibuildwheel + run: | + python -m pip install --upgrade pip + python -m pip install --upgrade cibuildwheel + - name: Restore postgres build from cache + if: ${{ matrix.os != 'ubuntu-latest' }} + id: restore-postgres + uses: actions/cache/restore@v3 + env: + cache-name: cache-postgres + with: + path: | + pgbuild + src/pgserver/pginstall + key: ${{ runner.os }}-${{ runner.arch }}-build-${{ env.cache-name }}-${{ + hashFiles('Makefile', 'pgbuild/Makefile', '.github/workflows/build-and-test.yml') }} + - name: Build postgres and pgvector + if: ${{ matrix.os != 'ubuntu-latest' && ! steps.restore-postgres.outputs.cache-hit }} + env: + MACOSX_DEPLOYMENT_TARGET: ${{ matrix.deployment-target }} + # this step is implied by Build wheels, but we do it here for caching before python tests run + # on ubuntu, cibuildwheel will run this step within a docker container, so it cannot use the cache this way + run: make + - name: Save postgres build + if: ${{ matrix.os != 'ubuntu-latest' && ! steps.restore-postgres.outputs.cache-hit }} + id: cache-postgres + uses: actions/cache/save@v3 + env: + cache-name: cache-postgres + with: + path: | + pgbuild + src/pgserver/pginstall + key: ${{ runner.os }}-${{ runner.arch }}-build-${{ env.cache-name }}-${{ + hashFiles('Makefile', 'pgbuild/Makefile', '.github/workflows/build-and-test.yml') }} + - name: Build wheels + env: + CIBW_ARCHS: ${{ matrix.arch }} + CIBW_SKIP: pp* cp38-* *-musllinux* + MACOSX_DEPLOYMENT_TARGET: ${{ matrix.deployment-target }} + run: python -m cibuildwheel --output-dir wheelhouse + - name: Save postgres build + if: ${{ matrix.os == 'ubuntu-latest' && ! steps.restore-postgres.outputs.cache-hit }} + id: cache-postgres2 + uses: actions/cache/save@v3 + env: + cache-name: cache-postgres + with: + path: | + pgbuild + src/pgserver/pginstall + key: ${{ runner.os }}-${{ runner.arch }}-build-${{ env.cache-name }}-${{ + hashFiles('Makefile', 'pgbuild/Makefile', '.github/workflows/build-and-test.yml') }} + - uses: actions/upload-artifact@v3 + with: + path: wheelhouse/*.whl + name: python-package-distributions + publish-to-pypi: + if: ${{ startsWith(github.ref, 'refs/tags/') }} + name: Publish Python dist to PyPI + needs: + - build_wheels + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/pgserver # Replace with your PyPI project name + permissions: + id-token: write # IMPORTANT: mandatory for trusted publishing + steps: + - name: Download all the dists + uses: actions/download-artifact@v3 + with: + name: python-package-distributions + path: dist/ + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml deleted file mode 100644 index 99323ab..0000000 --- a/.github/workflows/wheels.yml +++ /dev/null @@ -1,81 +0,0 @@ -name: Build Wheels -on: - release: - types: - - created -jobs: - build_wheels: - name: Build wheels on ${{ matrix.os }} - runs-on: ${{ matrix.os }} - strategy: - matrix: - include: - # using macos-latest-xlarge for arm64 apple build: - # macos-latest can cross compile for arm64, even though it is a x86_64 system - # (see pyav build workflow) - # however the build code for pg extensions themselves depends on some generated binaries - # (like pg_config), which would have been cross-compiled and cannot run on the host - # to keep things simple just build on the -xlarge instance, which is arm64.exclude: - # ubuntu arm64 may have similar issues (but useful for newer aws machines) - - os: macos-latest-xlarge - arch: arm64 - - os: macos-latest - arch: x86_64 - - os: ubuntu-20.04 # for testing - arch: x86_64 - # - os: windows-latest - # arch: AMD64 - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v4 - if: matrix.os != 'ubuntu-20.04' - with: - python-version: '3.10' - - uses: actions/setup-python@v4 - if: matrix.os == 'ubuntu-20.04' - # for testing due to docker env issues - with: - python-version: '3.9' - - name: Install cibuildwheel - run: | - python -m pip install --upgrade pip - python -m pip install --upgrade cibuildwheel - - name: Build wheels - env: - CIBW_ARCHS: ${{ matrix.arch }} - CIBW_SKIP: pp* cp38-* cp312-* *-musllinux* - run: python -m cibuildwheel --output-dir wheelhouse - - uses: actions/upload-artifact@v3 - with: - path: wheelhouse/*.whl - name: python-package-distributions - - uses: actions/checkout@v4 - - uses: actions/download-artifact@v3 - with: - name: python-package-distributions - path: dist - - name: Install and test ubuntu wheel - if: matrix.os == 'ubuntu-20.04' - # test ubuntu outside of cibuildwheel due to postgres root user issues - run: | - python -m pip install --upgrade pip pytest - python -m pip install --force-reinstall dist/*cp39*manylinux*x86*.whl - pytest ./tests - publish-to-pypi: - name: Publish Python dist to PyPI - needs: - - build_wheels - runs-on: ubuntu-latest - environment: - name: pypi - url: https://pypi.org/p/pgserver # Replace with your PyPI project name - permissions: - id-token: write # IMPORTANT: mandatory for trusted publishing - steps: - - name: Download all the dists - uses: actions/download-artifact@v3 - with: - name: python-package-distributions - path: dist/ - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index a268561..7967327 100644 --- a/.gitignore +++ b/.gitignore @@ -129,6 +129,9 @@ dmypy.json # Pyre type checker .pyre/ +# Pycharm +.idea + build/ pgbuild/ src/pgserver/pginstall diff --git a/Makefile b/Makefile index 5afd552..0ec57ca 100644 --- a/Makefile +++ b/Makefile @@ -2,11 +2,11 @@ .PHONY: build wheel install-wheel install-dev clean test build: - $(MAKE) -C pgbuild all + $(MAKE) -d -C pgbuild all wheel: build python setup.py bdist_wheel - + install-wheel: wheel python -m pip install --force-reinstall dist/*.whl @@ -14,8 +14,8 @@ install-dev: build python -m pip install --force-reinstall -e . clean: - rm -rf build/ wheelhouse/ dist/ + rm -rf build/ wheelhouse/ dist/ .eggs/ $(MAKE) -C pgbuild clean test: - python -m pytest tests/ \ No newline at end of file + python -m pytest tests/ diff --git a/README.md b/README.md index 6bd9bd4..bd2b921 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,50 @@ +![Python Version](https://img.shields.io/badge/python-3.9%2C%203.10%2C%203.11%2C%203.12-blue) +![Postgres Version](https://img.shields.io/badge/PostgreSQL-16.2-blue) + +![Linux Support](https://img.shields.io/badge/Linux%20Support-manylinux-green) +![macOS Apple Silicon Support >=11](https://img.shields.io/badge/macOS%20Apple%20Silicon%20Support-%E2%89%A511(BigSur)-green) +![macOS Intel Support => 10.0](https://img.shields.io/badge/macOS%20Intel%20Support-%E2%89%A510.9-green) +![Windows Support >= 2022](https://img.shields.io/badge/Windows%20AMD64%20Support-%E2%89%A52022-green) + +[![License](https://img.shields.io/badge/License-Apache%202.0-darkblue.svg)](https://opensource.org/licenses/Apache-2.0) +[![PyPI Package](https://img.shields.io/pypi/v/pgserver?color=darkorange)](https://pypi.org/project/pgserver) +![PyPI - Downloads](https://img.shields.io/pypi/dm/pgserver) + +

- +

-# `pgserver`: pip-installable postgres + pgvector for your python app +# pgserver: pip-installable, embedded postgres server + pgvector extension for your python app -`pip install pgserver` +`pgserver` lets you build Postgres-backed python apps with the same convenience afforded by an embedded database (ie, alternatives such as sqlite). +If you build your app with pgserver, your app remains wholly pip-installable, saving your users from needing to understand how to setup a postgres server (they simply pip install your app, and postgres is brought in through dependencies), and letting you get started developing quickly: just `pip install pgserver` and `pgserver.get_server(...)`, as shown in this notebook: Open In Colab -`pgserver` lets you initialize (if needed) and run a Postgres server associated with a data dir within your Python app, with server binaries included. -Wheels are built for multiple platforms. +To achieve this, you need two things which `pgserver` provides + * python binary wheels for multiple-plaforms with postgres binaries + * convenience python methods that handle db initialization and server process management, that deals with things that would normally prevent you from running your python app seamlessly on environments like docker containers, a machine you have no root access in, machines with other running postgres servers, google colab, etc. One main goal of the project is robustness around this. -### Example use cases: -* The main motivation is to enable building Postgres-backed python apps that remain pip-installable, saving you and, more importantly, your users any need to install and setup postgres. -* The second advantage is avoid remembering how to set-up a local postgres server, instead you immediately get a sqlalchemy or psql usable URI string and get to work. -* Also possible: developing and testing apps that depend on some external Postgres (as a dev dependency) +Additionally, this package includes the [pgvector](https://github.com/pgvector/pgvector) postgres extension, useful for storing associated vector data and for vector similarity queries. -### Basic summary: -* _Pip installable binaries_: tested on Ubuntu and MacOS (apple silicon + x86), including pgvector extension. -* _No sudo needed_: Does not require `root` or `sudo`. -* _Simpler initialization_: `pgserver.get_server(MY_DATA_DIR)` method to initialize data and server if needed, so you don't need to understand `initdb`, `pg_ctl`, port conflicts, and skip debugging why you still cannot connect to the server, just do `server.get_uri()` to connect. Uses unix domain sockets to avoid port conflicts. +## Basic summary: +* _Pip installable binaries_: built and tested on Manylinux, MacOS and Windows. +* _No sudo or admin rights needed_: Does not require `root` privileges or `sudo`. +* but... _can handle root_: in some environments your python app runs as root, eg docker, google colab, `pgserver` handles this case. +* _Simpler initialization_: `pgserver.get_server(MY_DATA_DIR)` method to initialize data and server if needed, so you don't need to understand `initdb`, `pg_ctl`, port conflicts. * _Convenient cleanup_: server process cleanup is done for you: when the process using pgserver ends, the server is shutdown, including when multiple independent processes call -`pgserver.get_server(MY_DATA_DIR)` on the same dir (wait for last one) - * includes context manager protocol to explicitly control cleanup timing in testing scenarios. +`pgserver.get_server(MY_DATA_DIR)` on the same dir (wait for last one). You can blow away your PGDATA dir and start again. * For lower-level control, wrappers to all binaries, such as `initdb`, `pg_ctl`, `psql`, `pg_config`. Includes header files in case you wish to build some other extension and use it against these binaries. ```py # Example 1: postgres backed application import pgserver -pgdata = f'{MY_APP_DIR}/pgdata' -db = pgserver.get_server(pgdata) +db = pgserver.get_server(MYPGDATA) # server ready for connection. print(db.psql('create extension vector')) db_uri = db.get_uri() -# use uri with sqlalchemy / psycopg, etc +# use uri with sqlalchemy / psycopg, etc, see colab. # if no other process is using this server, it will be shutdown at exit, # if other process use same pgadata, server process will be shutdown when all stop. @@ -46,15 +57,16 @@ import pytest @pytest.fixture def tmp_postgres(): tmp_pg_data = tempfile.mkdtemp() - with pgserver.get_server(tmp_pg_data, cleanup_mode='delete') as pg: - yield pg + pg = pgserver.get_server(tmp_pg_data, cleanup_mode='stop') + yield pg + pg.cleanup() ``` Postgres binaries in the package can be found in the directory pointed -to by the `pgserver.pg_bin` global variable. +to by the `pgserver.POSTGRES_BIN_PATH` to be used directly. -Based on https://github.com/michelp/postgresql-wheel, but with the following differences: -1. Wheels for multiple platforms (ubuntu x86, +MacOS x86, +MacOS apple silicon), pull requests taken for ubuntu arm. -2. pgvector extension included -3. postgres Server management: cleanup via shared count when multiple processes use the same server. -4. no postGIS (need to build cross platform, pull requests taken) +This project was originally based on [](https://github.com/michelp/postgresql-wheel), which provides a linux wheel. +But adds the following differences: +1. binary wheels for multiple platforms (ubuntu x86, MacOS apple silicon, MacOS x86, Windows) +2. postgres python management: cross-platfurm startup and cleanup including many edge cases, runs on colab etc. +3. includes `pgvector` extension but currently excludes `postGIS` diff --git a/cibuildwheel_test.bash b/cibuildwheel_test.bash index 3dc148a..e318f1d 100644 --- a/cibuildwheel_test.bash +++ b/cibuildwheel_test.bash @@ -4,10 +4,10 @@ PROJECT=$1 echo "Running on OSTYPE=$OSTYPE with UID=$UID" case "$OSTYPE" in - linux*) - echo "Tests disabled on the manylinux docker container: still debugging test failures only in this environment" - ;; + # linux *) + # echo "Tests disabled on the manylinux docker container for now" + # ;; *) - pytest -v $PROJECT/tests + pytest -s -v --log-cli-level=INFO $PROJECT/tests ;; esac diff --git a/pgbuild/Makefile b/pgbuild/Makefile index 08e1662..d6f8b13 100644 --- a/pgbuild/Makefile +++ b/pgbuild/Makefile @@ -1,64 +1,72 @@ SHELL := /bin/bash -PREFIX := $(shell pwd)/../src/pgserver/pginstall/ +INSTALL_PREFIX := $(shell pwd)/../src/pgserver/pginstall/ BUILD := $(shell pwd)/pgbuild/ .PHONY: all all: pgvector postgres ### postgres -POSTGRES_VERSION := 15.5 +POSTGRES_VERSION := 16.2 POSTGRES_URL := https://ftp.postgresql.org/pub/source/v$(POSTGRES_VERSION)/postgresql-$(POSTGRES_VERSION).tar.gz -POSTGRES_DIR := postgresql-$(POSTGRES_VERSION) +POSTGRES_SRC := postgresql-$(POSTGRES_VERSION) +POSTGRES_BLD := $(POSTGRES_SRC) -$(POSTGRES_DIR).tar.gz: +$(POSTGRES_SRC).tar.gz: curl -L -O $(POSTGRES_URL) -$(POSTGRES_DIR): $(POSTGRES_DIR).tar.gz - tar xzf $(POSTGRES_DIR).tar.gz - touch $(POSTGRES_DIR) +## extract +$(POSTGRES_SRC)/configure: $(POSTGRES_SRC).tar.gz + tar xzf $(POSTGRES_SRC).tar.gz + touch $(POSTGRES_SRC)/configure -$(PREFIX): - mkdir -p $(PREFIX) +## configure +$(POSTGRES_BLD)/config.status: $(POSTGRES_SRC)/configure + mkdir -p $(POSTGRES_BLD) + cd $(POSTGRES_BLD) && ../$(POSTGRES_SRC)/configure --prefix=$(INSTALL_PREFIX) --without-readline --without-icu -#https://stackoverflow.com/questions/68379786/ -#for explanation of unsetting make env variables prior to calling postgres' own make -$(PREFIX)/bin/postgres: $(POSTGRES_DIR) $(PREFIX) - unset MAKELEVEL && unset MAKEFLAGS && unset MFLAGS && cd $(POSTGRES_DIR) \ - && ./configure --prefix=$(PREFIX) --without-readline \ - && $(MAKE) -j \ - && $(MAKE) install +## build +# https://stackoverflow.com/questions/68379786/ +# for explanation of unsetting make env variables prior to calling postgres' own make +$(POSTGRES_BLD)/src/bin/initdb/initdb: $(POSTGRES_BLD)/config.status + unset MAKELEVEL && unset MAKEFLAGS && unset MFLAGS && $(MAKE) -C $(POSTGRES_BLD) -j + +## install to INSTALL_PREFIX +$(INSTALL_PREFIX)/bin/postgres: $(POSTGRES_BLD)/config.status + mkdir -p $(INSTALL_PREFIX) + unset MAKELEVEL && unset MAKEFLAGS && unset MFLAGS && $(MAKE) -C $(POSTGRES_BLD) install .PHONY: postgres -postgres: $(PREFIX)/bin/postgres +postgres: $(INSTALL_PREFIX)/bin/postgres ### pgvector -PGVECTOR_TAG := v0.5.1 +PGVECTOR_TAG := v0.6.2 PGVECTOR_URL := https://github.com/pgvector/pgvector/archive/refs/tags/$(PGVECTOR_TAG).tar.gz PGVECTOR_DIR := pgvector-$(PGVECTOR_TAG) $(PGVECTOR_DIR).tar.gz: curl -L -o $(PGVECTOR_DIR).tar.gz $(PGVECTOR_URL) -$(PGVECTOR_DIR): $(PGVECTOR_DIR).tar.gz +$(PGVECTOR_DIR)/Makefile: $(PGVECTOR_DIR).tar.gz # tar extract into pgvector-$(PGVECTOR_TAG) mkdir -p $(PGVECTOR_DIR) tar xzf $(PGVECTOR_DIR).tar.gz -C $(PGVECTOR_DIR) --strip-components=1 - touch $(PGVECTOR_DIR) + touch $(PGVECTOR_DIR)/Makefile -$(PREFIX)/lib/vector.so: $(PGVECTOR_DIR) $(PREFIX)/bin/postgres - unset MAKELEVEL && unset MAKEFLAGS && unset MFLAGS && cd $(PGVECTOR_DIR) \ - && export PG_CONFIG=$(PREFIX)/bin/pg_config \ - && $(MAKE) -j \ - && $(MAKE) install +$(INSTALL_PREFIX)/lib/vector.so: $(PGVECTOR_DIR)/Makefile $(INSTALL_PREFIX)/bin/postgres + unset MAKELEVEL && unset MAKEFLAGS && unset MFLAGS \ + && export PG_CONFIG=$(INSTALL_PREFIX)/bin/pg_config \ + && $(MAKE) -C $(PGVECTOR_DIR) -j \ + && $(MAKE) -C $(PGVECTOR_DIR) install .PHONY: pgvector -pgvector: postgres $(PREFIX)/lib/vector.so +pgvector: postgres $(INSTALL_PREFIX)/lib/vector.so ### other .PHONY: clean clean-all clean: - rm -rf $(PREFIX) - rm -rf $(POSTGRES_DIR) + rm -rf $(INSTALL_PREFIX) + rm -rf $(POSTGRES_SRC) + rm -rf $(POSTGRES_BLD) rm -rf $(PGVECTOR_DIR) clean-all: clean diff --git a/pgserver-example.ipynb b/pgserver-example.ipynb new file mode 100644 index 0000000..e40faac --- /dev/null +++ b/pgserver-example.ipynb @@ -0,0 +1,134 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install pgserver sqlalchemy psycopg2-binary sqlalchemy_utils" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import pgserver\n", + "srv = pgserver.get_server('./mypgdata')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " res \n", + "-----\n", + " 2\n", + "(1 row)\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(srv.psql('SELECT 1+1 as res;'))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from sqlalchemy_utils import create_database, database_exists\n", + "from sqlalchemy import create_engine\n", + "import sqlalchemy as sql" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'postgresql://postgres:@/mydb?host=/Users/orm/repos/pgserver/mypgdata'" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "dburi = srv.get_uri(database='mydb')\n", + "display(dburi)\n", + "if not database_exists(dburi):\n", + " create_database(dburi)\n", + "engine = create_engine(dburi)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "table_name = 'mytable'\n", + "with engine.connect() as conn:\n", + " conn.execute(sql.text(f\"create table {table_name} (id int);\"))\n", + " conn.execute(sql.text(f\"insert into {table_name} values (1);\"))\n", + " cur = conn.execute(sql.text(f\"select * from {table_name};\"))\n", + " result = cur.fetchone()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1,)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pixeltable_39", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pgserver_small.png b/pgserver_small.png deleted file mode 100644 index 3027179..0000000 Binary files a/pgserver_small.png and /dev/null differ diff --git a/pgserver_square_small.png b/pgserver_square_small.png index 9e0679a..9ae32b9 100644 Binary files a/pgserver_square_small.png and b/pgserver_square_small.png differ diff --git a/pyproject.toml b/pyproject.toml index 7a40493..4c1fe7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,28 @@ [project] name = "pgserver" # Required -version = "0.0.7" # Required +version = "0.1.4" # Required description = "Self-contained postgres server for your python applications" # Required readme = "README.md" # Optional requires-python = ">=3.9" license = {file = "LICENSE.txt"} urls={repository="https://github.com/orm011/pgserver"} authors=[{ name="Oscar Moll", email="orm@csail.mit.edu" }] -keywords=["postgresql", "pgvector", "pgserver"] +keywords=["postgresql", "pgvector", "pgserver", "rag"] dependencies = [ "fasteners>=0.19", "platformdirs>=4.0.0", + "psutil>=5.9.0", +] + +[project.optional-dependencies] +dev = [ + "sysv_ipc", +] +test = [ + "pytest", + "psycopg2-binary", + "sqlalchemy>=2", + "sqlalchemy-utils" ] [tool.setuptools.packages.find] @@ -22,11 +34,9 @@ testpaths = ["tests"] [tool.cibuildwheel] before-all = "make" -test-requires = "pytest" -# cibw-skip = "pp* *-musllinux* pp** cp39-* cp310-* cp311-* cp312-*" +test-extras = "test" test-command = "bash -x {project}/cibuildwheel_test.bash {project}" - [build-system] # These are the assumed default build requirements from pip: # https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support diff --git a/src/pgserver/__init__.py b/src/pgserver/__init__.py index fe76267..e7e95b4 100644 --- a/src/pgserver/__init__.py +++ b/src/pgserver/__init__.py @@ -1,3 +1,2 @@ from ._commands import * -from ._utils import * -from ._utils import PostgresServer +from .postgres_server import PostgresServer, get_server diff --git a/src/pgserver/_commands.py b/src/pgserver/_commands.py index f4fcd88..4a2dfc3 100644 --- a/src/pgserver/_commands.py +++ b/src/pgserver/_commands.py @@ -2,55 +2,15 @@ import sys import subprocess from typing import Optional, List, Callable -import os -import pwd -import pathlib -import stat import logging +import tempfile POSTGRES_BIN_PATH = Path(__file__).parent / "pginstall" / "bin" -def ensure_prefix_permissions(path: pathlib.Path): - """ Ensure target user can traverse prefix to path - Permissions for everyone will be increased to ensure traversal. - """ - - # ensure path exists and user exists - assert path.exists() - - prefix = path.parent - - # chmod g+rx,o+rx: enable other users to traverse prefix folders - g_rx_o_rx = stat.S_IRGRP | stat.S_IROTH | stat.S_IXGRP | stat.S_IXOTH - while True: - curr_permissions = prefix.stat().st_mode - ensure_permissions = curr_permissions | g_rx_o_rx - # TODO: are symlinks handled ok here? - prefix.chmod(ensure_permissions) - - if prefix == prefix.parent: - # reached file system root - break - prefix = prefix.parent - -def ensure_user_exists(username : str) -> pwd.struct_passwd: - """ Ensure system user `username` exists. - Returns their pwentry if user exists, otherwise it creates a user through `useradd`. - Assume permissions to add users, eg run as root. - """ - try: - entry = pwd.getpwnam(username) - except KeyError as e: - entry = None - - if entry is None: - subprocess.run(["useradd", "-s", "/bin/bash", username], check=True, capture_output=True, text=True) - entry = pwd.getpwnam(username) - - return entry +_logger = logging.getLogger('pgserver') def create_command_function(pg_exe_name : str) -> Callable: - def command(args : List[str], pgdata : Optional[Path] = None, user : Optional[str] = None) -> str: + def command(args : List[str], pgdata : Optional[Path] = None, **kwargs) -> str: """ Run a command with the given command line arguments. Args: @@ -58,7 +18,7 @@ def command(args : List[str], pgdata : Optional[Path] = None, user : Optional[st a list of options as would be passed to `subprocess.run` pgdata: The path to the data directory to use for the command. If the command does not need a data directory, this should be None. - user: The user to run the command as. If None, the current user is used. + kwargs: Additional keyword arguments to pass to `subprocess.run`, eg user, timeout. Returns: The stdout of the command as a string. @@ -71,17 +31,30 @@ def command(args : List[str], pgdata : Optional[Path] = None, user : Optional[st full_command_line = [str(POSTGRES_BIN_PATH / pg_exe_name)] + args - try: - result = subprocess.run(full_command_line, check=True, capture_output=True, text=True, - user=user) - logging.info("Successful postgres command %s as user `%s`\nstdout:\n%s\n---\nstderr:\n%s\n---\n", - result.args, user, result.stdout, result.stderr) - except subprocess.CalledProcessError as err: - logging.error("Failed postgres command %s as user `%s`:\nerror:\n%s\nstdout:\n%s\n---\nstderr:\n%s\n---\n", - err.args, user, str(err), err.stdout, err.stderr) - raise err - - return result.stdout + with tempfile.TemporaryFile('w+') as stdout, tempfile.TemporaryFile('w+') as stderr: + try: + _logger.info("Running commandline:\n%s\nwith kwargs: `%s`", full_command_line, kwargs) + # NB: capture_output=True, as well as using stdout=subprocess.PIPE and stderr=subprocess.PIPE + # can cause this call to hang, even with a time-out depending on the command, (pg_ctl) + # so we use two temporary files instead + result = subprocess.run(full_command_line, check=True, stdout=stdout, stderr=stderr, text=True, + **kwargs) + stdout.seek(0) + stderr.seek(0) + output = stdout.read() + error = stderr.read() + _logger.info("Successful postgres command %s with kwargs: `%s`\nstdout:\n%s\n---\nstderr:\n%s\n---\n", + result.args, kwargs, output, error) + except subprocess.CalledProcessError as err: + stdout.seek(0) + stderr.seek(0) + output = stdout.read() + error = stderr.read() + _logger.error("Failed postgres command %s with kwargs: `%s`:\nerror:\n%s\nstdout:\n%s\n---\nstderr:\n%s\n---\n", + err.args, kwargs, str(err), output, error) + raise err + + return output return command @@ -95,5 +68,4 @@ def _init(): setattr(sys.modules[__name__], function_name, prog) __all__.append(function_name) - _init() \ No newline at end of file diff --git a/src/pgserver/_utils.py b/src/pgserver/_utils.py deleted file mode 100644 index 08c91de..0000000 --- a/src/pgserver/_utils.py +++ /dev/null @@ -1,296 +0,0 @@ -from pathlib import Path -from typing import Optional, Dict, Union, List -import shutil -import atexit -import subprocess -import json -import os -import logging -import hashlib -import socket -import pwd - -from ._commands import POSTGRES_BIN_PATH, initdb, pg_ctl, ensure_prefix_permissions, ensure_user_exists -from .shared import PostmasterInfo, _process_is_running - - -__all__ = ['get_server'] - -class _DiskList: - """ A list of integers stored in a file on disk. - """ - def __init__(self, path : Path): - self.path = path - - def get_and_add(self, value : int) -> List[int]: - old_values = self.get() - values = old_values.copy() - if value not in values: - values.append(value) - self.put(values) - return old_values - - def get_and_remove(self, value : int) -> List[int]: - old_values = self.get() - values = old_values.copy() - if value in values: - values.remove(value) - self.put(values) - return old_values - - def get(self) -> List[int]: - if not self.path.exists(): - return [] - return json.loads(self.path.read_text()) - - def put(self, values : List[int]) -> None: - self.path.write_text(json.dumps(values)) - - -def socket_name_length_ok(socket_name : Path): - ''' checks whether a socket path is too long for domain sockets - on this system. Returns True if the socket path is ok, False if it is too long. - ''' - if socket_name.exists(): - return socket_name.is_socket() - - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - try: - sock.bind(str(socket_name)) - return True - except OSError as err: - if 'AF_UNIX path too long' in str(err): - return False - raise err - finally: - sock.close() - socket_name.unlink(missing_ok=True) - - -class PostgresServer: - """ Provides a common interface for interacting with a server. - """ - import platformdirs - import fasteners - - _instances : Dict[Path, 'PostgresServer'] = {} - - # lockfile for whole class - # home dir does not always support locking (eg some clusters) - runtime_path : Path = platformdirs.user_runtime_path('python_PostgresServer') - lock_path = platformdirs.user_runtime_path('python_PostgresServer') / '.lockfile' - _lock = fasteners.InterProcessLock(lock_path) - - def __init__(self, pgdata : Path, *, cleanup_mode : Optional[str] = 'stop'): - """ Initializes the postgresql server instance. - Constructor is intended to be called directly, use get_server() instead. - """ - assert cleanup_mode in [None, 'stop', 'delete'] - - self.pgdata = pgdata - self.log = self.pgdata / 'log' - - # postgres user name, NB not the same as system user name - self.system_user = None - if os.geteuid() == 0: - # running as root - # need a different system user to run as - self.system_user = 'pgserver' - ensure_user_exists(self.system_user) - - self.postgres_user = "postgres" - list_path = self.pgdata / '.handle_pids.json' - self.global_process_id_list = _DiskList(list_path) - self.cleanup_mode = cleanup_mode - self._postmaster_info : Optional[PostmasterInfo] = None - self._count = 0 - - atexit.register(self._cleanup) - self._init_server() - - def _find_suitable_socket_dir(self) -> Path: - """ Assumes server is not running. Returns a suitable directory for used with pg_ctl. - Usually, this is the same directory as the pgdata directory. - However, if the pgdata directory exceeds the maximum length for domain sockets on this system, - a different directory will be used. - """ - # find a suitable directory for the domain socket - # 1. pgdata. simplest approach, but can be too long for unix socket depending on the path - # 2. runtime_path. This is a directory that is intended for storing runtime data. - - # for shared folders, use a hash of the path to avoid collisions of different folders - # use a hash of the pgdata path combined with inode number to avoid collisions - string_identifier = f'{self.pgdata}-{self.pgdata.stat().st_ino}' - path_hash = hashlib.sha256(string_identifier.encode()).hexdigest()[:10] - - candidate_socket_dir = [ - self.pgdata, - self.runtime_path / path_hash, - ] - - ok_path = None - for path in candidate_socket_dir: - path.mkdir(parents=True, exist_ok=True) - # name used by postgresql for domain socket is .s.PGSQL.5432 - if socket_name_length_ok(path / '.s.PGSQL.5432'): - ok_path = path - logging.info(f"Using socket path: {path}") - break - else: - logging.info(f"Socket path too long: {path}. Will try a different directory for socket.") - - if ok_path is None: - raise RuntimeError("Could not find a suitable socket path") - - return ok_path - - def get_postmaster_info(self) -> PostmasterInfo: - assert self._postmaster_info is not None - return self._postmaster_info - - def get_pid(self) -> int: - """ Returns the pid of the postgresql server process. - (First line of postmaster.pid file). - If the server is not running, returns None. - """ - return self.get_postmaster_info().pid - - def get_socket_dir(self) -> Path: - """ Returns the directory of the domain socket used by the server. - """ - return self.get_postmaster_info().socket_dir - - def get_uri(self, database : Optional[str] = None) -> str: - """ Returns a connection string for the postgresql server. - """ - if database is None: - database = self.postgres_user - - return f"postgresql://{self.postgres_user}:@/{database}?host={self.get_socket_dir()}" - - def _init_server(self) -> None: - """ Starts the postgresql server and registers the shutdown handler. - Effect: self._postmaster_info is set. - """ - with self._lock: - self._instances[self.pgdata] = self - - if self.system_user is not None: - ensure_prefix_permissions(self.pgdata) - os.chown(self.pgdata, pwd.getpwnam(self.system_user).pw_uid, - pwd.getpwnam(self.system_user).pw_gid) - - if not (self.pgdata / 'PG_VERSION').exists(): - initdb(['--auth=trust', '--auth-local=trust', '-U', self.postgres_user], pgdata=self.pgdata, - user=self.system_user) - - self._postmaster_info = PostmasterInfo.read_from_pgdata(self.pgdata) - if self._postmaster_info is None: - socket_dir = self._find_suitable_socket_dir() - if self.system_user is not None and socket_dir != self.pgdata: - ensure_prefix_permissions(socket_dir) - socket_dir.chmod(0o777) - - try: - # -o to pg_ctl are options to be passed directly to the postgres executable, be wary of quotes (man pg_ctl) - pg_ctl(['-w', # wait for server to start - '-o', f'-k {socket_dir}', # socket option (forwarded to postgres exec) see man postgres for -k - '-o', '-h ""', # no listening on any IP addresses (forwarded to postgres exec) see man postgres for -hj - '-l', str(self.log), # log location: set to pgdata dir also - 'start' # action - ], - pgdata=self.pgdata, user=self.system_user) - except subprocess.CalledProcessError as err: - logging.error(f"Failed to start server.\nShowing contents of postgres server log ({self.log.absolute()}) below:\n{self.log.read_text()}") - raise err - - self._postmaster_info = PostmasterInfo.read_from_pgdata(self.pgdata) - assert self._postmaster_info is not None - assert self._postmaster_info.pid is not None - assert self._postmaster_info.socket_dir is not None - - self.global_process_id_list.get_and_add(os.getpid()) - - def _cleanup(self) -> None: - with self._lock: - pids = self.global_process_id_list.get_and_remove(os.getpid()) - - if pids != [os.getpid()]: # includes case where already cleaned up - return - # last handle is being removed - del self._instances[self.pgdata] - if self.cleanup_mode is None: # done - return - - assert self.cleanup_mode in ['stop', 'delete'] - if _process_is_running(self._postmaster_info.pid): - try: - pg_ctl(['-w', 'stop'], pgdata=self.pgdata, user=self.system_user) - except subprocess.CalledProcessError: - pass # somehow the server is already stopped. - - if self.cleanup_mode == 'stop': - return - - assert self.cleanup_mode == 'delete' - shutil.rmtree(str(self.pgdata)) - atexit.unregister(self._cleanup) - - def psql(self, command : str) -> str: - """ Runs a psql command on this server. The command is passed to psql via stdin. - """ - executable = POSTGRES_BIN_PATH / 'psql' - stdout = subprocess.check_output(f'{executable} {self.get_uri()}', - input=command.encode(), shell=True) - return stdout.decode("utf-8") - - def __enter__(self): - self._count += 1 - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._count -= 1 - if self._count <= 0: - self._cleanup() - - def cleanup(self) -> None: - """ Stops the postgresql server and removes the pgdata directory. - """ - self._cleanup() - - -def get_server(pgdata : Union[Path,str] , cleanup_mode : Optional[str] = 'stop' ) -> PostgresServer: - """ Returns handle to postgresql server instance for the given pgdata directory. - Args: - pgdata: pddata directory. If the pgdata directory does not exist, it will be created, but its - prefix must be a valid directory. - cleanup_mode: If 'stop', the server will be stopped when the last handle is closed (default) - If 'delete', the server will be stopped and the pgdata directory will be deleted. - If None, the server will not be stopped or deleted. - - To create a temporary server, use mkdtemp() to create a temporary directory and pass it as pg_data, - and set cleanup_mode to 'delete'. - """ - if isinstance(pgdata, str): - pgdata = Path(pgdata) - pgdata = pgdata.expanduser().resolve() - - if not pgdata.parent.exists(): - raise FileNotFoundError(f"Parent directory of pgdata does not exist: {pgdata.parent}") - - if not pgdata.exists(): - pgdata.mkdir(parents=False, exist_ok=False) - - if pgdata in PostgresServer._instances: - return PostgresServer._instances[pgdata] - - return PostgresServer(pgdata, cleanup_mode=cleanup_mode) - - - - - - - - - diff --git a/src/pgserver/postgres_server.py b/src/pgserver/postgres_server.py new file mode 100644 index 0000000..d6517d1 --- /dev/null +++ b/src/pgserver/postgres_server.py @@ -0,0 +1,291 @@ +from pathlib import Path +from typing import Optional, Dict, Union +import shutil +import atexit +import subprocess +import os +import logging +import platform +import psutil +import time + +from ._commands import POSTGRES_BIN_PATH, initdb, pg_ctl +from .utils import find_suitable_port, find_suitable_socket_dir, DiskList, PostmasterInfo, process_is_running + +if platform.system() != 'Windows': + from .utils import ensure_user_exists, ensure_prefix_permissions, ensure_folder_permissions + +_logger = logging.getLogger('pgserver') + +class PostgresServer: + """ Provides a common interface for interacting with a server. + """ + import platformdirs + import fasteners + + _instances : Dict[Path, 'PostgresServer'] = {} + + # NB home does not always support locking, eg NFS or LUSTRE (eg some clusters) + # so, use user_runtime_path instead, which seems to be in a local filesystem + runtime_path : Path = platformdirs.user_runtime_path('python_PostgresServer') + lock_path = platformdirs.user_runtime_path('python_PostgresServer') / '.lockfile' + _lock = fasteners.InterProcessLock(lock_path) + + def __init__(self, pgdata : Path, *, cleanup_mode : Optional[str] = 'stop'): + """ Initializes the postgresql server instance. + Constructor is intended to be called directly, use get_server() instead. + """ + assert cleanup_mode in [None, 'stop', 'delete'] + + self.pgdata = pgdata + self.log = self.pgdata / 'log' + + # postgres user name, NB not the same as system user name + self.system_user = None + + # note os.geteuid() is not available on windows, so must go after + if platform.system() != 'Windows' and os.geteuid() == 0: + # running as root + # need a different system user to run as + self.system_user = 'pgserver' + ensure_user_exists(self.system_user) + + self.postgres_user = "postgres" + list_path = self.pgdata / '.handle_pids.json' + self.global_process_id_list = DiskList(list_path) + self.cleanup_mode = cleanup_mode + self._postmaster_info : Optional[PostmasterInfo] = None + self._count = 0 + + atexit.register(self._cleanup) + with self._lock: + self._instances[self.pgdata] = self + self.ensure_pgdata_inited() + self.ensure_postgres_running() + self.global_process_id_list.get_and_add(os.getpid()) + + def get_postmaster_info(self) -> PostmasterInfo: + assert self._postmaster_info is not None + return self._postmaster_info + + def get_pid(self) -> Optional[int]: + """ Returns the pid of the postgresql server process. + (First line of postmaster.pid file). + If the server is not running, returns None. + """ + return self.get_postmaster_info().pid + + def get_uri(self, database : Optional[str] = None) -> str: + """ Returns a connection string for the postgresql server. + """ + return self.get_postmaster_info().get_uri(database=database) + + def ensure_pgdata_inited(self) -> None: + """ Initializes the pgdata directory if it is not already initialized. + """ + if platform.system() != 'Windows' and os.geteuid() == 0: + import pwd + import stat + assert self.system_user is not None + ensure_prefix_permissions(self.pgdata) + ensure_prefix_permissions(POSTGRES_BIN_PATH) + + read_perm = stat.S_IRGRP | stat.S_IROTH + execute_perm = stat.S_IXGRP | stat.S_IXOTH + # for envs like cibuildwheel docker, where the user is has no permission otherwise + ensure_folder_permissions(POSTGRES_BIN_PATH, execute_perm | read_perm) + ensure_folder_permissions(POSTGRES_BIN_PATH.parent / 'lib', read_perm) + + + os.chown(self.pgdata, pwd.getpwnam(self.system_user).pw_uid, + pwd.getpwnam(self.system_user).pw_gid) + + if not (self.pgdata / 'PG_VERSION').exists(): # making a new PGDATA + # First ensure there are no left-over servers on a previous version of the same pgdata path, + # which does happen on Mac/Linux if the previous pgdata was deleted without stopping the server process + # (the old server continues running for some time, sometimes indefinitely) + # + # It is likely the old server could also corrupt the data beyond the socket file, so it is best to kill it. + # This must be done before initdb to ensure no race conditions with the old server. + # + # Since we do not know PID information of the old server, we stop all servers with the same pgdata path. + # way to test this: python -c 'import pixeltable as pxt; pxt.Client()'; rm -rf ~/.pixeltable/; python -c 'import pixeltable as pxt; pxt.Client()' + _logger.info(f'no PG_VERSION file found within {self.pgdata}. Initializing pgdata') + for proc in psutil.process_iter(attrs=['name', 'cmdline']): + if proc.info['name'] == 'postgres': + if proc.info['cmdline'] is not None and str(self.pgdata) in proc.info['cmdline']: + _logger.info(f"Found a running postgres server with same pgdata: {proc.as_dict(attrs=['name', 'pid', 'cmdline'])=}.\ + Assuming it is a leftover from a previous run on a different version of the same pgdata path, killing it.") + proc.terminate() + try: + proc.wait(2) # wait at most a second + except psutil.TimeoutExpired: + pass + if proc.is_running(): + proc.kill() + assert not proc.is_running() + + initdb(['--auth=trust', '--auth-local=trust', '--encoding=utf8', '-U', self.postgres_user], pgdata=self.pgdata, + user=self.system_user) + else: + _logger.info('PG_VERSION file found, skipping initdb') + + def ensure_postgres_running(self) -> None: + """ pre condition: pgdata is initialized, being run with lock. + post condition: self._postmaster_info is set. + """ + + postmaster_info = PostmasterInfo.read_from_pgdata(self.pgdata) + if postmaster_info is not None and postmaster_info.is_running(): + _logger.info(f"a postgres server is already running: {postmaster_info=} {postmaster_info.process=}") + self._postmaster_info = postmaster_info + else: + if postmaster_info is not None and not postmaster_info.is_running(): + _logger.info(f"found a postmaster.pid file, but the server is not running: {postmaster_info=}") + if postmaster_info is None: + _logger.info(f"no postmaster.pid file found in {self.pgdata}") + + if platform.system() != 'Windows': + # use sockets to avoid any future conflict with port numbers + socket_dir = find_suitable_socket_dir(self.pgdata, self.runtime_path) + + if self.system_user is not None and socket_dir != self.pgdata: + ensure_prefix_permissions(socket_dir) + socket_dir.chmod(0o777) + + pg_ctl_args = ['-w', # wait for server to start + '-o', '-h ""', # no listening on any IP addresses (forwarded to postgres exec) see man postgres for -hj + '-o', f'-k {socket_dir}', # socket option (forwarded to postgres exec) see man postgres for -k + '-l', str(self.log), # log location: set to pgdata dir also + 'start' # action + ] + else: # Windows, + socket_dir = None + # socket.AF_UNIX is undefined when running on Windows, so default to a port + host = "127.0.0.1" + port = find_suitable_port(host) + pg_ctl_args = ['-w', # wait for server to start + '-o', f'-h "{host}"', + '-o', f'-p {port}', + '-l', str(self.log), # log location: set to pgdata dir also + 'start' # action + ] + + try: + _logger.info(f"running pg_ctl... {pg_ctl_args=}") + pg_ctl(pg_ctl_args,pgdata=self.pgdata, user=self.system_user, timeout=10) + except subprocess.CalledProcessError as err: + _logger.error(f"Failed to start server.\nShowing contents of postgres server log ({self.log.absolute()}) below:\n{self.log.read_text()}") + raise err + except subprocess.TimeoutExpired as err: + _logger.error(f"Timeout starting server.\nShowing contents of postgres server log ({self.log.absolute()}) below:\n{self.log.read_text()}") + raise err + + while True: + # in Windows, when there is a postmaster.pid, init_ctl seems to return + # but the file is not immediately updated, here we wait until the file shows + # a new running server. see test_stale_postmaster + _logger.info(f'waiting for postmaster info to show a running process') + pinfo = PostmasterInfo.read_from_pgdata(self.pgdata) + _logger.info(f'running... checking if ready {pinfo=}') + if pinfo is not None and pinfo.is_running() and pinfo.status == 'ready': + self._postmaster_info = pinfo + break + + _logger.info(f'not ready yet... waiting a bit more...') + time.sleep(1.) + + _logger.info(f"Now asserting server is running {self._postmaster_info=}") + assert self._postmaster_info is not None + assert self._postmaster_info.is_running() + assert self._postmaster_info.status == 'ready' + + def _cleanup(self) -> None: + with self._lock: + pids = self.global_process_id_list.get_and_remove(os.getpid()) + _logger.info(f"exiting {os.getpid()} remaining {pids=}") + if pids != [os.getpid()]: # includes case where already cleaned up + return + + _logger.info(f"cleaning last handle for server: {self.pgdata}") + # last handle is being removed + del self._instances[self.pgdata] + if self.cleanup_mode is None: # done + return + + assert self.cleanup_mode in ['stop', 'delete'] + if self._postmaster_info is not None: + if self._postmaster_info.process.is_running(): + try: + pg_ctl(['-w', 'stop'], pgdata=self.pgdata, user=self.system_user) + stopped = True + except subprocess.CalledProcessError: + stopped = False + pass # somehow the server is already stopped. + + if not stopped: + _logger.warning(f"Failed to stop server, killing it instead.") + self._postmaster_info.process.terminate() + try: + self._postmaster_info.process.wait(2) + except psutil.TimeoutExpired: + pass + if self._postmaster_info.process.is_running(): + self._postmaster_info.process.kill() + + if self.cleanup_mode == 'stop': + return + + assert self.cleanup_mode == 'delete' + shutil.rmtree(str(self.pgdata)) + atexit.unregister(self._cleanup) + + def psql(self, command : str) -> str: + """ Runs a psql command on this server. The command is passed to psql via stdin. + """ + executable = POSTGRES_BIN_PATH / 'psql' + stdout = subprocess.check_output(f'{executable} {self.get_uri()}', + input=command.encode(), shell=True) + return stdout.decode("utf-8") + + def __enter__(self): + self._count += 1 + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._count -= 1 + if self._count <= 0: + self._cleanup() + + def cleanup(self) -> None: + """ Stops the postgresql server and removes the pgdata directory. + """ + self._cleanup() + + +def get_server(pgdata : Union[Path,str] , cleanup_mode : Optional[str] = 'stop' ) -> PostgresServer: + """ Returns handle to postgresql server instance for the given pgdata directory. + Args: + pgdata: pddata directory. If the pgdata directory does not exist, it will be created, but its + parent must exists and be a valid directory. + cleanup_mode: If 'stop', the server will be stopped when the last handle is closed (default) + If 'delete', the server will be stopped and the pgdata directory will be deleted. + If None, the server will not be stopped or deleted. + + To create a temporary server, use mkdtemp() to create a temporary directory and pass it as pg_data, + and set cleanup_mode to 'delete'. + """ + if isinstance(pgdata, str): + pgdata = Path(pgdata) + pgdata = pgdata.expanduser().resolve() + + if not pgdata.parent.exists(): + raise FileNotFoundError(f"Parent directory of pgdata does not exist: {pgdata.parent}") + + if not pgdata.exists(): + pgdata.mkdir(parents=False, exist_ok=False) + + if pgdata in PostgresServer._instances: + return PostgresServer._instances[pgdata] + + return PostgresServer(pgdata, cleanup_mode=cleanup_mode) diff --git a/src/pgserver/shared.py b/src/pgserver/shared.py deleted file mode 100644 index a7782c3..0000000 --- a/src/pgserver/shared.py +++ /dev/null @@ -1,34 +0,0 @@ -from pathlib import Path -import subprocess -from typing import Optional - -class PostmasterInfo: - def __init__(self, pgdata : Path, pid : int, socket_dir : Path): - self.pgdata = pgdata - self.pid = pid - self.socket_dir = socket_dir - - @classmethod - def read_from_pgdata(cls, pgdata : Path) -> Optional['PostmasterInfo']: - postmaster_pid = pgdata / 'postmaster.pid' - if not postmaster_pid.exists(): - return None - - lines = postmaster_pid.read_text().splitlines() - pid = int(lines[0]) - socket_dir = Path(lines[4]) - socket_path = socket_dir / '.s.PGSQL.5432' - assert socket_dir.exists() - assert socket_path.exists() - assert socket_path.is_socket() - - return cls(postmaster_pid.parent, pid, socket_dir) - -def _process_is_running(pid : int) -> bool: - assert pid is not None - try: - subprocess.run(["kill", "-0", str(pid)], check=True) - return True - except subprocess.CalledProcessError: - pass - return False diff --git a/src/pgserver/utils.py b/src/pgserver/utils.py new file mode 100644 index 0000000..0b89002 --- /dev/null +++ b/src/pgserver/utils.py @@ -0,0 +1,278 @@ +from pathlib import Path +import typing +from typing import Optional, List, Dict +import subprocess +import json +import logging +import hashlib +import socket +import platform +import stat +import psutil +import datetime +import shutil + +_logger = logging.getLogger('pgserver') + +class PostmasterInfo: + """Struct with contents of the PGDATA/postmaster.pid file, contains information about the running server. + Example of file contents: (comments added for clarity) + cat /Users/orm/Library/Application Support/Postgres/var-15/postmaster.pid + ``` + 3072 # pid + /Users/orm/Library/Application Support/Postgres/var-15 # pgdata + 1712346200 # start_time + 5432 # port + /tmp # socker_dir, where .s.PGSQL.5432 is located + localhost # listening on this hostname + 8826964 65536 # shared mem size?, shmget id (can deallocate with sysv_ipc.remove_shared_memory(shmget_id)) + ready # server status + ``` + """ + + def __init__(self, lines : List[str]): + _lines = ['pid', 'pgdata', 'start_time', 'port', 'socket_dir', 'hostname', 'shared_memory_info', 'status'] + assert len(lines) == len(_lines), f"_lines: {_lines=} lines: {lines=}" + clean_lines = [ line.strip() for line in lines ] + + raw : Dict[str,str] = dict(zip(_lines, clean_lines)) + + self.pid = int(raw['pid']) + self.pgdata = Path(raw['pgdata']) + self.start_time = datetime.datetime.fromtimestamp(int(raw['start_time'])) + + if raw['socket_dir']: + self.socket_dir = Path(raw['socket_dir']) + else: + self.socket_dir = None + + if raw['hostname']: + self.hostname = raw['hostname'] + else: + self.hostname = None + + if raw['port']: + self.port = int(raw['port']) + else: + self.port = None + + # not sure what this is in windows + self.shmem_info = raw['shared_memory_info'] + self.status = raw['status'] + + self.process = None # will be not None if process is running + self._init_process_meta() + + def _init_process_meta(self) -> Optional[psutil.Process]: + if self.pid is None: + return + try: + process = psutil.Process(self.pid) + except psutil.NoSuchProcess: + return + + self.process = process + # exact_create_time = datetime.datetime.fromtimestamp(process.create_time()) + # if abs(self.start_time - exact_create_time) <= datetime.timedelta(seconds=1): + + def is_running(self) -> bool: + return self.process is not None and self.process.is_running() + + @classmethod + def read_from_pgdata(cls, pgdata : Path) -> Optional['PostmasterInfo']: + postmaster_file = pgdata / 'postmaster.pid' + if not postmaster_file.exists(): + return None + + lines = postmaster_file.read_text().splitlines() + return cls(lines) + + def get_uri(self, user : str = 'postgres', database : Optional[str] = None) -> str: + """ Returns a connection uri string for the postgresql server using the information in postmaster.pid""" + if database is None: + database = user + + if self.socket_dir is not None: + return f"postgresql://{user}:@/{database}?host={self.socket_dir}" + elif self.port is not None: + assert self.hostname is not None + return f"postgresql://{user}:@{self.hostname}:{self.port}/{database}" + else: + raise RuntimeError("postmaster.pid does not contain port or socket information") + + @property + def shmget_id(self) -> Optional[int]: + if platform.system() == 'Windows': + return None + + if not self.shmem_info: + return None + raw_id = self.shmem_info.split()[-1] + return int(raw_id) + + @property + def socket_path(self) -> Optional[Path]: + if self.socket_dir is not None: + # TODO: is the port always 5432 for the socket? or does it depend on the port in postmaster.pid? + return self.socket_dir / f'.s.PGSQL.{self.port}' + return None + + def __repr__(self) -> str: + return f"PostmasterInfo(pid={self.pid}, pgdata={self.pgdata}, start_time={self.start_time}, hostname={self.hostname} port={self.port}, socket_dir={self.socket_dir} status={self.status}, process={self.process})" + + def __str__(self) -> str: + return self.__repr__() + +def process_is_running(pid : int) -> bool: + assert pid is not None + return psutil.pid_exists(pid) + +if platform.system() != 'Windows': + def ensure_user_exists(username : str) -> Optional['pwd.struct_passwd']: + """ Ensure system user `username` exists. + Returns their pwentry if user exists, otherwise it creates a user through `useradd`. + Assume permissions to add users, eg run as root. + """ + import pwd + + try: + entry = pwd.getpwnam(username) + except KeyError: + entry = None + + if entry is None: + subprocess.run(["useradd", "-s", "/bin/bash", username], check=True, capture_output=True, text=True) + entry = pwd.getpwnam(username) + + return entry + + def ensure_prefix_permissions(path: Path): + """ Ensure target user can traverse prefix to path + Permissions for everyone will be increased to ensure traversal. + """ + # ensure path exists and user exists + assert path.exists() + prefix = path.parent + # chmod g+rx,o+rx: enable other users to traverse prefix folders + g_rx_o_rx = stat.S_IRGRP | stat.S_IROTH | stat.S_IXGRP | stat.S_IXOTH + while True: + curr_permissions = prefix.stat().st_mode + ensure_permissions = curr_permissions | g_rx_o_rx + # TODO: are symlinks handled ok here? + prefix.chmod(ensure_permissions) + if prefix == prefix.parent: # reached file system root + break + prefix = prefix.parent + + def ensure_folder_permissions(path: Path, flag : int): + """ Ensure target user can read, and execute the folder. + Permissions for everyone will be increased to ensure traversal. + """ + # read and traverse folder + g_rx_o_rx = stat.S_IRGRP | stat.S_IROTH | stat.S_IXGRP | stat.S_IXOTH + + def _helper(path: Path): + if path.is_dir(): + path.chmod(path.stat().st_mode | g_rx_o_rx ) + for child in path.iterdir(): + _helper(child) + else: + path.chmod(path.stat().st_mode | flag) + + _helper(path) + +class DiskList: + """ A list of integers stored in a file on disk. + """ + def __init__(self, path : Path): + self.path = path + + def get_and_add(self, value : int) -> List[int]: + old_values = self.get() + values = old_values.copy() + if value not in values: + values.append(value) + self.put(values) + return old_values + + def get_and_remove(self, value : int) -> List[int]: + old_values = self.get() + values = old_values.copy() + if value in values: + values.remove(value) + self.put(values) + return old_values + + def get(self) -> List[int]: + if not self.path.exists(): + return [] + return json.loads(self.path.read_text()) + + def put(self, values : List[int]) -> None: + self.path.write_text(json.dumps(values)) + + +def socket_name_length_ok(socket_name : Path): + ''' checks whether a socket path is too long for domain sockets + on this system. Returns True if the socket path is ok, False if it is too long. + ''' + if socket_name.exists(): + return socket_name.is_socket() + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.bind(str(socket_name)) + return True + except OSError as err: + if 'AF_UNIX path too long' in str(err): + return False + raise err + finally: + sock.close() + socket_name.unlink(missing_ok=True) + +def find_suitable_socket_dir(pgdata, runtime_path) -> Path: + """ Assumes server is not running. Returns a suitable directory for used as pg_ctl -o '-k ' option. + Usually, this is the same directory as the pgdata directory. + However, if the pgdata directory exceeds the maximum length for domain sockets on this system, + a different directory will be used. + """ + # find a suitable directory for the domain socket + # 1. pgdata. simplest approach, but can be too long for unix socket depending on the path + # 2. runtime_path. This is a directory that is intended for storing runtime data. + + # for shared folders, use a hash of the path to avoid collisions of different folders + # use a hash of the pgdata path combined with inode number to avoid collisions + string_identifier = f'{pgdata}-{pgdata.stat().st_ino}' + path_hash = hashlib.sha256(string_identifier.encode()).hexdigest()[:10] + + candidate_socket_dir = [ + pgdata, + runtime_path / path_hash, + ] + + ok_path = None + for path in candidate_socket_dir: + path.mkdir(parents=True, exist_ok=True) + # name used by postgresql for domain socket is .s.PGSQL.5432 + if socket_name_length_ok(path / '.s.PGSQL.5432'): + ok_path = path + _logger.info(f"Using socket path: {path}") + break + else: + _logger.info(f"Socket path too long: {path}. Will try a different directory for socket.") + + if ok_path is None: + raise RuntimeError("Could not find a suitable socket path") + + return ok_path + +def find_suitable_port(address : Optional[str] = None) -> int: + """Find an available TCP port.""" + if address is None: + address = '127.0.0.1' + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind((address, 0)) + port = sock.getsockname()[1] + sock.close() + return port diff --git a/tests/test_pgserver.py b/tests/test_pgserver.py index 8bb4c69..a862386 100644 --- a/tests/test_pgserver.py +++ b/tests/test_pgserver.py @@ -2,26 +2,105 @@ import pgserver import subprocess import tempfile -from typing import Optional +from typing import Optional, Union import multiprocessing as mp import shutil -import time from pathlib import Path -from pgserver.shared import _process_is_running - +import pgserver.utils +import socket +from pgserver.utils import find_suitable_port, process_is_running +import psutil +import platform +import sqlalchemy as sa +import datetime +from sqlalchemy_utils import database_exists, create_database +import logging +import os + +def _check_sqlalchemy_works(srv : pgserver.PostgresServer): + database_name = 'testdb' + uri = srv.get_uri(database_name) + + if not database_exists(uri): + create_database(uri) + + engine = sa.create_engine(uri) + conn = engine.connect() -def _check_server_works(pg : pgserver.PostgresServer) -> int: + table_name = 'table_foo' + with conn.begin(): + # if table exists already, drop it + if engine.dialect.has_table(conn, table_name): + conn.execute(sa.text(f"drop table {table_name};")) + conn.execute(sa.text(f"create table {table_name} (id int);")) + conn.execute(sa.text(f"insert into {table_name} values (1);")) + cur = conn.execute(sa.text(f"select * from {table_name};")) + result = cur.fetchone() + assert result + assert result[0] == 1 + +def _check_postmaster_info(pgdata : Path, postmaster_info : pgserver.utils.PostmasterInfo): + assert postmaster_info is not None + assert postmaster_info.pgdata is not None + assert postmaster_info.pgdata == pgdata + + assert postmaster_info.is_running() + + if postmaster_info.socket_dir is not None: + assert postmaster_info.socket_dir.exists() + assert postmaster_info.socket_path is not None + assert postmaster_info.socket_path.exists() + assert postmaster_info.socket_path.is_socket() + + +def _check_server(pg : pgserver.PostgresServer) -> int: assert pg.pgdata.exists() - pid = pg.get_pid() - assert pid is not None + postmaster_info = pgserver.utils.PostmasterInfo.read_from_pgdata(pg.pgdata) + assert postmaster_info is not None + assert postmaster_info.pid is not None + _check_postmaster_info(pg.pgdata, postmaster_info) + ret = pg.psql("show data_directory;") - assert str(pg.pgdata) in ret - return pid + # parse second row (first two are headers) + ret_path = Path(ret.splitlines()[2].strip()) + assert pg.pgdata == ret_path + _check_sqlalchemy_works(pg) + return postmaster_info.pid -def _kill_server(pid : Optional[int]) -> None: +def _kill_server(pid : Union[int,psutil.Process,None]) -> None: if pid is None: return - subprocess.run(["kill", "-9", str(pid)]) + elif isinstance(pid, psutil.Process): + proc = pid + else: + try: + proc = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + if proc.is_running(): + proc.terminate() # attempt cleaner shutdown + try: + proc.wait(3) # wait at most a few seconds + except psutil.TimeoutExpired: + pass + + if proc.is_running(): + proc.kill() + +def test_get_port(): + address = '127.0.0.1' + port = find_suitable_port(address) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + try: + sock.bind((address, port)) + except OSError as err: + if 'Address already in use' in str(err): + raise RuntimeError(f"Port {port} is already in use.") + raise err + finally: + sock.close() def test_get_server(): with tempfile.TemporaryDirectory() as tmpdir: @@ -29,80 +108,50 @@ def test_get_server(): try: # check case when initializing the pgdata dir with pgserver.get_server(tmpdir) as pg: - pid = _check_server_works(pg) + pid = _check_server(pg) - assert not _process_is_running(pid) + assert not process_is_running(pid) assert pg.pgdata.exists() # check case when pgdata dir is already initialized with pgserver.get_server(tmpdir) as pg: - pid = _check_server_works(pg) + pid = _check_server(pg) - assert not _process_is_running(pid) + assert not process_is_running(pid) assert pg.pgdata.exists() finally: _kill_server(pid) - def test_reentrant(): with tempfile.TemporaryDirectory() as tmpdir: pid = None try: with pgserver.get_server(tmpdir) as pg: - pid = _check_server_works(pg) + pid = _check_server(pg) with pgserver.get_server(tmpdir) as pg2: assert pg2 is pg - _check_server_works(pg) + _check_server(pg) - _check_server_works(pg) + _check_server(pg) - assert not _process_is_running(pid) + assert not process_is_running(pid) assert pg.pgdata.exists() finally: _kill_server(pid) -def _start_and_wait(tmpdir, queue_in, queue_out): - with pgserver.get_server(tmpdir) as pg: - pid = _check_server_works(pg) +def _start_server_in_separate_process(pgdata, queue_in : Optional[mp.Queue], queue_out : mp.Queue, cleanup_mode : Optional[str]): + with pgserver.get_server(pgdata, cleanup_mode=cleanup_mode) as pg: + pid = _check_server(pg) queue_out.put(pid) - # now wait for parent to tell us to exit - _ = queue_in.get() + if queue_in is not None: + _ = queue_in.get() # wait for signal + return -def test_multiprocess_shared(): - """ Test that multiple processes can share the same server. +def test_unix_domain_socket(): + if platform.system() == 'Windows': + pytest.skip("This test is for unix domain sockets, which are not available on Windows.") - 1. get server in a child process, - 2. then, get server in the parent process - 3. then, exiting the child process - 4. checking the parent can still use the server. - """ - pid = None - try: - with tempfile.TemporaryDirectory() as tmpdir: - queue_to_child = mp.Queue() - queue_from_child = mp.Queue() - child = mp.Process(target=_start_and_wait, args=(tmpdir,queue_to_child,queue_from_child)) - child.start() - # wait for child to start server - server_pid_child = queue_from_child.get() - - with pgserver.get_server(tmpdir) as pg: - server_pid_parent = _check_server_works(pg) - assert server_pid_child == server_pid_parent - - # tell child to continue - queue_to_child.put(None) - child.join() - - # check server still works - _check_server_works(pg) - - assert not _process_is_running(server_pid_parent) - finally: - _kill_server(pid) - -def test_dir_length(): long_prefix = '_'.join(['long'] + ['1234567890']*12) assert len(long_prefix) > 120 prefixes = ['short', long_prefix] @@ -112,9 +161,9 @@ def test_dir_length(): pid = None try: with pgserver.get_server(tmpdir) as pg: - pid = _check_server_works(pg) + pid = _check_server(pg) - assert not _process_is_running(pid) + assert not process_is_running(pid) assert pg.pgdata.exists() if len(prefix) > 120: assert str(tmpdir) not in pg.get_uri() @@ -123,14 +172,64 @@ def test_dir_length(): finally: _kill_server(pid) +def test_pg_ctl(): + if platform.system() != 'Windows' and os.geteuid() == 0: + # on Linux root, this test would fail. + # we'd need to create a user etc to run the command, which is not worth it + # pgserver does this internally, but not worth it for this test + pytest.skip("This test is not run as root on Linux.") + + with tempfile.TemporaryDirectory() as tmpdir: + pid = None + try: + with pgserver.get_server(tmpdir) as pg: + output = pgserver.pg_ctl(['status'], str(pg.pgdata)) + assert 'server is running' in output.splitlines()[0] + + finally: + _kill_server(pid) + +def test_stale_postmaster(): + """ To simulate a stale postmaster.pid file, we create a postmaster.pid file by starting a server, + back the file up, then restore the backup to the original location after killing the server. + ( our method to kill the server is graceful to avoid running out of shmem, but this seems to also + remove the postmaster.pid file, so we need to go to these lengths to simulate a stale postmaster.pid file ) + """ + if platform.system() != 'Windows' and os.geteuid() == 0: + # on Linux as root, this test fails bc of permissions for the postmaster.pid file + # we simply skip it in this case, as in practice, the permissions issue would not occur + pytest.skip("This test is not run as root on Linux.") + + with tempfile.TemporaryDirectory() as tmpdir: + pid = None + pid2 = None + + try: + with pgserver.get_server(tmpdir, cleanup_mode='stop') as pg: + pid = _check_server(pg) + pgdata = pg.pgdata + postmaster_pid = pgdata / 'postmaster.pid' + + ## make a backup of the postmaster.pid file + shutil.copy2(str(postmaster_pid), str(postmaster_pid) + '.bak') + + # restore the backup to gurantee a stale postmaster.pid file + shutil.copy2(str(postmaster_pid) + '.bak', str(postmaster_pid)) + with pgserver.get_server(tmpdir) as pg: + pid2 = _check_server(pg) + finally: + _kill_server(pid) + _kill_server(pid2) + + def test_cleanup_delete(): with tempfile.TemporaryDirectory() as tmpdir: pid = None try: with pgserver.get_server(tmpdir, cleanup_mode='delete') as pg: - pid = _check_server_works(pg) + pid = _check_server(pg) - assert not _process_is_running(pid) + assert not process_is_running(pid) assert not pg.pgdata.exists() finally: _kill_server(pid) @@ -140,9 +239,9 @@ def test_cleanup_none(): pid = None try: with pgserver.get_server(tmpdir, cleanup_mode=None) as pg: - pid = _check_server_works(pg) + pid = _check_server(pg) - assert _process_is_running(pid) + assert process_is_running(pid) assert pg.pgdata.exists() finally: _kill_server(pid) @@ -155,7 +254,7 @@ def tmp_postgres(): def test_pgvector(tmp_postgres): ret = tmp_postgres.psql("CREATE EXTENSION vector;") - assert ret == "CREATE EXTENSION\n" + assert ret.strip() == "CREATE EXTENSION" def test_start_failure_log(caplog): """ Test server log contents are shown in python log when failures @@ -174,41 +273,80 @@ def test_start_failure_log(caplog): assert 'postgres: could not access the server configuration file' in caplog.text -def _reuse_deleted_datadir(prefix): - """ test that new server starts normally on same datadir after datadir is deleted + +def test_no_conflict(): + """ test we can start pgservers on two different datadirs with no conflict (eg port conflict) """ - tmpdir = tempfile.mkdtemp(prefix=prefix) - orig_pid = None - new_pid = None + pid1 = None + pid2 = None try: - pgdata = Path(tmpdir) / 'pgdata' - with pgserver.get_server(pgdata, cleanup_mode=None) as pg: - orig_pid = _check_server_works(pg) - - shutil.rmtree(pgdata) - assert not pgdata.exists() - # # TODO: why does the test fail in some environments if I dont kill the old server here? - # # if the directory is new, why does it somehow conflict with the old server - # _kill_server(orig_pid) - - # starting the server on same dir should work - with pgserver.get_server(pgdata, cleanup_mode=None) as pg: - new_pid = _check_server_works(pg) - assert orig_pid != new_pid + with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2: + with pgserver.get_server(tmpdir1) as pg1, pgserver.get_server(tmpdir2) as pg2: + pid1 = _check_server(pg1) + pid2 = _check_server(pg2) finally: - _kill_server(orig_pid) - _kill_server(new_pid) + _kill_server(pid1) + _kill_server(pid2) - shutil.rmtree(tmpdir) -def test_no_conflict(): - """ test we can start pgservers on two different datadirs with no conflict (eg port conflict) - """ - with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2: - with pgserver.get_server(tmpdir1) as pg1, pgserver.get_server(tmpdir2) as pg2: - pid1 = _check_server_works(pg1) - pid2 = _check_server_works(pg2) +def _reuse_deleted_datadir(prefix: str): + """ test common scenario where we repeatedly delete the datadir and start a new server on it """ + """ NB: currently this test is not reproducing the problem """ + # one can reproduce the problem by running the following in a loop: + # python -c 'import pixeltable as pxt; pxt.Client()'; rm -rf ~/.pixeltable/; python -c 'import pixeltable as pxt; pxt.Client()' + # which creates a database with more contents etc + tmpdir = tempfile.mkdtemp(prefix=prefix) + pgdata = Path(tmpdir) / 'pgdata' + server_processes = [] + shmem_ids = [] + + num_tries = 3 + try: + for _ in range(num_tries): + assert not pgdata.exists() + + queue_from_child = mp.Queue() + child = mp.Process(target=_start_server_in_separate_process, args=(pgdata, None, queue_from_child, None)) + child.start() + # wait for child to start server + curr_pid = queue_from_child.get() + child.join() + server_proc = psutil.Process(curr_pid) + assert server_proc.is_running() + server_processes.append(server_proc) + postmaster = pgserver.utils.PostmasterInfo.read_from_pgdata(pgdata) + + if postmaster.shmget_id is not None: + shmem_ids.append(postmaster.shmget_id) + if platform.system() == 'Windows': + # windows will not allow deletion of the directory while the server is running + _kill_server(server_proc) + + shutil.rmtree(pgdata) + finally: + if platform.system() != 'Windows': + # if sysv_ipc is installed (eg locally), remove the shared memory segment + # done this way because of CI/CD issues with sysv_ipc + # this avoids having to restart the machine to clear the shared memory + try: + import sysv_ipc + do_shmem_cleanup = True + except ImportError: + do_shmem_cleanup = False + logging.warning("sysv_ipc not installed, skipping shared memory cleanup...") + + if do_shmem_cleanup: + for shmid in shmem_ids: + try: + sysv_ipc.remove_shared_memory(shmid) + except sysv_ipc.ExistentialError as e: + logging.info(f"shared memory already removed: {e}") + + for proc in server_processes: + _kill_server(proc) + + shutil.rmtree(tmpdir) def test_reuse_deleted_datadir_short(): """ test that new server starts normally on same datadir after datadir is deleted @@ -222,33 +360,35 @@ def test_reuse_deleted_datadir_long(): assert len(long_prefix) > 120 _reuse_deleted_datadir(long_prefix) -@pytest.mark.skip(reason="run locally only (needs dep)") -def test_uri_string(tmp_postgres): - import sqlalchemy as sa - engine = sa.create_engine(tmp_postgres.get_uri('mydb')) - conn = engine.connect() - with conn.begin(): - conn.execute(sa.text("create table foo (id int);")) - conn.execute(sa.text("insert into foo values (1);")) - cur = conn.execute(sa.text("select * from foo;")) - assert cur.fetchone()[0] == 1 - -@pytest.mark.skip(reason="not implemented") -def test_delete_pgdata_cleanup(tmp_postgres): - """ Test server process is stopped when pgdata is deleted. +def test_multiprocess_shared(): + """ Test that multiple processes can share the same server. + + 1. get server in a child process, + 2. then, get server in the parent process + 3. then, exiting the child process + 4. checking the parent can still use the server. """ - assert tmp_postgres.pgdata.exists() - pid = tmp_postgres.get_pid() - assert pid is not None - assert _process_is_running(pid) + pid = None + try: + with tempfile.TemporaryDirectory() as tmpdir: + queue_to_child = mp.Queue() + queue_from_child = mp.Queue() + child = mp.Process(target=_start_server_in_separate_process, args=(tmpdir,queue_to_child,queue_from_child, 'stop')) + child.start() + # wait for child to start server + server_pid_child = queue_from_child.get() - # external deletion of pgdata should stop server - shutil.rmtree(tmp_postgres.pgdata) + with pgserver.get_server(tmpdir) as pg: + server_pid_parent = _check_server(pg) + assert server_pid_child == server_pid_parent - # wait for server to stop - for _ in range(20): # wait at most 3 seconds. - time.sleep(.2) - if not _process_is_running(pid): - break + # tell child to continue + queue_to_child.put(None) + child.join() - assert not _process_is_running(pid) + # check server still works + _check_server(pg) + + assert not process_is_running(server_pid_parent) + finally: + _kill_server(pid) \ No newline at end of file