diff --git a/.dockerignore b/.dockerignore index 4698557..b1e9d5f 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,3 +1,2 @@ # preventing .git files from entering context -.git .github diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..cebc11a --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,12 @@ +# These are supported funding model platforms + +github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] +patreon: # Replace with a single Patreon username +open_collective: # Replace with a single Open Collective username +ko_fi: # Replace with a single Ko-fi username +tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel +community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry +liberapay: # Replace with a single Liberapay username +issuehunt: # Replace with a single IssueHunt username +otechie: # Replace with a single Otechie username +custom: ['https://paypal.me/huelse99'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..2d66731 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,30 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + +[//]: # Before submitting this issue, make sure you have already searched and still have problems. +**Describe the bug** +A clear and concise description of what the bug is. + +**Device\Environment** +System: +Python: + +**To Reproduce** +Steps to reproduce the behavior: +1. +2. + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..11fc491 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: enhancement +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/ISSUE_TEMPLATE/other-issue.md b/.github/ISSUE_TEMPLATE/other-issue.md new file mode 100644 index 0000000..e608631 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/other-issue.md @@ -0,0 +1,14 @@ +--- +name: Other issue +about: Other issue template +title: '' +labels: question +assignees: '' + +--- + +**Type** +illustrate your original intention. + +**Descripe** +write what you want. diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 0000000..43628a9 --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,41 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Python package + +on: + push: + branches: [ 3.3.2 ] + pull_request: + branches: [ 3.3.2 ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.5, 3.6, 3.7, 3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + sudo apt-get install python3 python3-pip build-essential cmake + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: build seal + run: | + cd SEAL/native/src + cmake . + make + - name: build seal-python + run: | + python3 setup.py build_ext -i + python3 setup.py install + + diff --git a/.gitignore b/.gitignore index 727c18d..f4a04da 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,5 @@ build *.pyd *.pyc temp +.idea +*.bin diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..c8a8ee7 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,7 @@ +[submodule "SEAL"] + path = SEAL + url = https://github.com/microsoft/SEAL.git +[submodule "pybind11"] + path = pybind11 + url = https://github.com/pybind/pybind11.git + branch = stable diff --git a/Dockerfile b/Dockerfile index c184931..456038d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,44 +1,36 @@ -FROM ubuntu:19.10 - -# Install binary dependencies -RUN apt-get update && \ - apt-get install -qqy \ - g++ \ - git \ - make \ - cmake \ - python3 \ - python3-dev \ - python3-pip \ - sudo \ - libdpkg-perl \ - --no-install-recommends - -# Copy all files to container -COPY ./ /app - -# Build SEAL -WORKDIR /app/SEAL/native/src -RUN cmake . && \ - make && \ - make install - -# Install requirements -WORKDIR /app/src -RUN pip3 install -r requirements.txt - -# Build pybind11 -WORKDIR /app/pybind11 -RUN mkdir build -WORKDIR /app/pybind11/build -RUN cmake .. && \ - make check -j 4 && \ - make install - -# Build wrapper -WORKDIR /app/src -RUN python3 setup.py build_ext -i && \ - python3 setup.py install - -# Clean-up -RUN apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* +FROM ubuntu:21.04 + +# define the folder where our src should exist/ be deposited +ARG SRC=/python-seal + +# prevents update and install asking for tz +ENV DEBIAN_FRONTEND=noninteractive + +# install dependencies +RUN apt update && \ + apt install -y git build-essential cmake python3 python3-dev python3-pip && \ + mkdir -p ${SRC} + +# copy into container requirements and install them before rest of code +COPY ./requirements.txt ${SRC}/. +RUN pip3 install -r ${SRC}/requirements.txt + +# copy everything into container now that requirements stage is complete +COPY . ${SRC} + +# setting our default directory to the one specified above +WORKDIR ${SRC} + +# update submodules +RUN cd ${SRC} && \ + git submodule update --init --recursive + # git submodule update --remote + +# build and install seal + bindings +RUN cd ${SRC}/SEAL && \ + cmake -S . -B build -DSEAL_USE_MSGSL=OFF -DSEAL_USE_ZLIB=OFF -DSEAL_USE_ZSTD=OFF && \ + cmake --build build && \ + cd ${SRC} && \ + python3 setup.py build_ext -i + +CMD ["/usr/bin/python3"] diff --git a/README.md b/README.md index 79b5adb..80d2383 100644 --- a/README.md +++ b/README.md @@ -10,104 +10,149 @@ This is a python binding for the Microsoft SEAL library. ## Contents -* [Build](https://github.com/Huelse/SEAL-Python#build) -* [Tests](https://github.com/Huelse/SEAL-Python#tests) -* [About](https://github.com/Huelse/SEAL-Python#about) -* [Contributing](https://github.com/Huelse/SEAL-Python#contributing) +* [Build](#build) +* [Note](#note) + * [Serialize](#serialize) + * [Other](#other) +* [FAQ](#faq) ## Build -### Linux -CMake (>= 3.10), GNU G++ (>= 6.0) or Clang++ (>= 5.0), Python (>=3.6.8) -`sudo apt-get update && sudo apt-get install g++ make cmake git python3 python3-dev python3.6-pip` +* ### Linux -`git clone https://github.com/Huelse/SEAL-Python.git` + Recommend: Clang++ (>= 10.0) or GNU G++ (>= 9.4), CMake (>= 3.16) -```shell -cd SEAL/native/src -cmake . -make + ```shell + # Optional + sudo apt-get install git build-essential cmake python3 python3-dev python3-pip -cd src -pip3 install -r requirements.txt + # Get the repository or download from the releases + git clone https://github.com/Huelse/SEAL-Python.git + cd SEAL-Python -# Check the path at first -# Setuptools (Recommend) -cd src -python3 setup.py build_ext -i -# or install -python3 setup.py install + # Install dependencies + pip3 install numpy pybind11 -# CMake (Optional) -mkdir build -cd build -cmake .. -make -``` + # Init the SEAL and pybind11 + git submodule update --init --recursive + # Get the newest repositories (dev only) + # git submodule update --remote -[setuptools docs](https://docs.python.org/3/distutils/configfile.html) + # Build the SEAL lib + cd SEAL + cmake -S . -B build -DSEAL_USE_MSGSL=OFF -DSEAL_USE_ZLIB=OFF + cmake --build build + cd .. -[pybind11 docs](https://pybind11.readthedocs.io/en/master/index.html) + # Run the setup.py + python3 setup.py build_ext -i -### Windows + # Test + cp seal.*.so examples + cd examples + python3 4_bgv_basics.py + ``` -Visual Studio 2017 version 15.3 or newer is required to build Microsoft SEAL. + Build examples (after `cmake -S . -B`): `-DSEAL_BUILD_EXAMPLES=ON` -Open the `SEAL/SEAL.sln` in VS, config in `x64, Release, WinSDK(17763, etc)` mode and generate it. + Zstandard compression off: `-DSEAL_USE_ZSTD=OFF` -```shell -cd src -python3 setup.py build_ext -i -# or install -python3 setup.py install -``` + [More cmake options](https://github.com/microsoft/SEAL#basic-cmake-options) -Microsoft official video [SEAL in windows](https://www.microsoft.com/en-us/research/video/installing-microsoft-seal-on-windows/). +* ### Windows + Visual Studio 2019 or newer is required. x64 support only! And use the **x64 Native Tools Command Prompt for VS** command prompt to configure and build the Microsoft SEAL library. It's usually can be found in your Start Menu. + ```shell + # Run in "x64 Native Tools Command Prompt for VS" command prompt + cmake -S . -B build -G Ninja -DSEAL_USE_MSGSL=OFF -DSEAL_USE_ZLIB=OFF + cmake --build build -## Tests + # Build + pip install numpy pybind11 + python setup.py build_ext -i -`cd tests` + # Test + cp seal.*.pyd examples + cd examples + python 4_bgv_basics.py + ``` -`python3 [example_name].py` + Microsoft SEAL official [docs](https://github.com/microsoft/SEAL#building-microsoft-seal-manually). -* The `.so` file must be in the same folder, or you had `install` it already. +* ### Docker + requires: [Docker](https://www.docker.com/) -## Getting Started + To build source code into a docker image (from this directory): + ```shell + docker build -t huelse/seal -f Dockerfile . + ``` -| C++ | Python | Description | Progress | -| ----------------- | ---------------- | ------------------------------------------------------------ | -------- | -| 1_bfv_basics.cpp | 1_bfv_basics.py | Encrypted modular arithmetic using the BFV scheme | Finished | -| 2_encoders.cpp | 2_encoders.py | Encoding more complex data into Microsoft SEAL plaintext objects | Finished | -| 3_levels.cpp | 3_levels.py | Introduces the concept of levels; prerequisite for using the CKKS scheme | Finished | -| 4_ckks_basics.cpp | 4_ckks_basics.py | Encrypted real number arithmetic using the CKKS scheme | Finished | -| 5_rotation.cpp | 5_rotation.py | Performing cyclic rotations on encrypted vectors in the BFV and CKKS schemes | Finished | -| 6_performance.cpp | 6_performance.py | Performance tests for Microsoft SEAL | Finished | + To use the image by running it as an interactive container: + ```shell + docker run -it huelse/seal + ``` -## Future +## Note -* SEAL 3.4 or higher support +* ### Serialize + See more in `examples/7_serialization.py`, here is a simple example: + ```python + cipher.save('cipher') + load_cipher = Ciphertext() + load_cipher.load(context, 'cipher') # work if the context is valid. + ``` -## About + Supported classes: `EncryptionParameters, Ciphertext, Plaintext, SecretKey, PublicKey, RelinKeys, GaloisKeys` -This project is still testing now, if any problems(bugs), [Issue](https://github.com/Huelse/SEAL-Python/issues) please. -Email: [huelse@oini.top](mailto:huelse@oini.top?subject=Github-SEAL-Python-Issues&cc=5956877@qq.com) +* ### Other + + There are a lot of changes in the latest SEAL lib, we try to make the API in python can be used easier, but it may remain some problems unknown, if any problems or bugs, report [issues](https://github.com/Huelse/SEAL-Python/issues). + + Email: [topmaxz@protonmail.com](mailto:topmaxz@protonmail.com?subject=Github-SEAL-Python-Issues) + + + +## FAQ + +1. ImportError: undefined symbol + + Build a shared SEAL library `cmake . -DBUILD_SHARED_LIBS=ON`, and get the `libseal.so`, + + then change the path in `setup.py`, and rebuild. + +2. ImportError: libseal.so... cannot find + + a. `sudo ln -s /path/to/libseal.so /usr/lib` + + b. add `/usr/local/lib` or the `SEAL/native/lib` to `/etc/ld.so.conf` and refresh it `sudo ldconfig` + + c. build in cmake. + +3. BuildError: C++17 at least + +4. ModuleNotFoundError: No module named 'seal' + + The `.so` or `.pyd` file must be in the current directory, or you have `install` it already. + +5. Windows Error LNK2001, RuntimeLibrary and MT_StaticRelease mismatch + + Only `x64` is supported, Choose `x64 Native Tools Command Prompt for VS`. ## Contributing + * Professor: [Dr. Chen](https://zhigang-chen.github.io/) * [Contributors](https://github.com/Huelse/SEAL-Python/graphs/contributors) - diff --git a/SEAL b/SEAL new file mode 160000 index 0000000..82b07db --- /dev/null +++ b/SEAL @@ -0,0 +1 @@ +Subproject commit 82b07db635132e297282649e2ab5908999089ad2 diff --git a/SEAL/.gitignore b/SEAL/.gitignore deleted file mode 100644 index fdcf33e..0000000 --- a/SEAL/.gitignore +++ /dev/null @@ -1,329 +0,0 @@ -# Other stuff -native/src/cmake/SEALConfig.cmake -native/src/cmake/SEALConfigVersion.cmake -native/src/cmake/SEALTargets.cmake -native/src/seal/util/config.h -**/CMakeCache.txt -**/CMakeFiles -**/Makefile -**/.config -**/autom4te.cache/* -**/cmake_install.cmake -**/install_manifest.txt -.ycm_extra_conf.py -.vimrc -.lvimrc -.local_vimrc -**/*.code-workspace -**/.vscode -*/build -**/*.build -**/compile_commands.json -**/.DS_Store - -## Ignore Visual Studio temporary files, build results, and -## files generated by popular Visual Studio add-ons. -## -## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore - -# User-specific files -*.suo -*.user -*.userosscache -*.sln.docstates - -# User-specific files (MonoDevelop/Xamarin Studio) -*.userprefs - -# Build results -[Dd]ebug/ -[Dd]ebugPublic/ -[Rr]elease/ -[Rr]eleases/ -x64/ -x86/ -bld/ -[Bb]in/ -[Oo]bj/ -[Ll]og/ -[Ll]ib/ - -# Visual Studio 2015 cache/options directory -.vs/ -# Uncomment if you have tasks that create the project's static files in wwwroot -#wwwroot/ - -# MSTest test Results -[Tt]est[Rr]esult*/ -[Bb]uild[Ll]og.* - -# NUNIT -*.VisualState.xml -TestResult.xml - -# Build Results of an ATL Project -[Dd]ebugPS/ -[Rr]eleasePS/ -dlldata.c - -# Benchmark Results -BenchmarkDotNet.Artifacts/ - -# .NET Core -project.lock.json -project.fragment.lock.json -artifacts/ -**/Properties/launchSettings.json - -*_i.c -*_p.c -*_i.h -*.ilk -*.meta -*.obj -*.pch -*.pdb -*.pgc -*.pgd -*.rsp -*.sbr -*.tlb -*.tli -*.tlh -*.tmp -*.tmp_proj -*.log -*.vspscc -*.vssscc -.builds -*.pidb -*.svclog -*.scc - -# Chutzpah Test files -_Chutzpah* - -# Visual C++ cache files -ipch/ -*.aps -*.ncb -*.opendb -*.opensdf -*.sdf -*.cachefile -*.VC.db -*.VC.VC.opendb - -# Visual Studio profiler -*.psess -*.vsp -*.vspx -*.sap - -# Visual Studio Trace Files -*.e2e - -# TFS 2012 Local Workspace -$tf/ - -# Guidance Automation Toolkit -*.gpState - -# ReSharper is a .NET coding add-in -_ReSharper*/ -*.[Rr]e[Ss]harper -*.DotSettings.user - -# JustCode is a .NET coding add-in -.JustCode - -# TeamCity is a build add-in -_TeamCity* - -# DotCover is a Code Coverage Tool -*.dotCover - -# AxoCover is a Code Coverage Tool -.axoCover/* -!.axoCover/settings.json - -# Visual Studio code coverage results -*.coverage -*.coveragexml - -# NCrunch -_NCrunch_* -.*crunch*.local.xml -nCrunchTemp_* - -# MightyMoose -*.mm.* -AutoTest.Net/ - -# Web workbench (sass) -.sass-cache/ - -# Installshield output folder -[Ee]xpress/ - -# DocProject is a documentation generator add-in -DocProject/buildhelp/ -DocProject/Help/*.HxT -DocProject/Help/*.HxC -DocProject/Help/*.hhc -DocProject/Help/*.hhk -DocProject/Help/*.hhp -DocProject/Help/Html2 -DocProject/Help/html - -# Click-Once directory -publish/ - -# Publish Web Output -*.[Pp]ublish.xml -*.azurePubxml -# Note: Comment the next line if you want to checkin your web deploy settings, -# but database connection strings (with potential passwords) will be unencrypted -*.pubxml -*.publishproj - -# Microsoft Azure Web App publish settings. Comment the next line if you want to -# checkin your Azure Web App publish settings, but sensitive information contained -# in these scripts will be unencrypted -PublishScripts/ - -# NuGet Packages -*.nupkg -# The packages folder can be ignored because of Package Restore -**/[Pp]ackages/* -# except build/, which is used as an MSBuild target. -!**/[Pp]ackages/build/ -# Uncomment if necessary however generally it will be regenerated when needed -#!**/[Pp]ackages/repositories.config -# NuGet v3's project.json files produces more ignorable files -*.nuget.props -*.nuget.targets - -# Microsoft Azure Build Output -csx/ -*.build.csdef - -# Microsoft Azure Emulator -ecf/ -rcf/ - -# Windows Store app package directories and files -AppPackages/ -BundleArtifacts/ -Package.StoreAssociation.xml -_pkginfo.txt -*.appx - -# Visual Studio cache files -# files ending in .cache can be ignored -*.[Cc]ache -# but keep track of directories ending in .cache -!*.[Cc]ache/ - -# Others -ClientBin/ -~$* -*~ -*.dbmdl -*.dbproj.schemaview -*.jfm -*.pfx -*.publishsettings -orleans.codegen.cs - -# Since there are multiple workflows, uncomment next line to ignore bower_components -# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) -#bower_components/ - -# RIA/Silverlight projects -Generated_Code/ - -# Backup & report files from converting an old project file -# to a newer Visual Studio version. Backup files are not needed, -# because we have git ;-) -_UpgradeReport_Files/ -Backup*/ -UpgradeLog*.XML -UpgradeLog*.htm - -# SQL Server files -*.mdf -*.ldf -*.ndf - -# Business Intelligence projects -*.rdl.data -*.bim.layout -*.bim_*.settings - -# Microsoft Fakes -FakesAssemblies/ - -# GhostDoc plugin setting file -*.GhostDoc.xml - -# Node.js Tools for Visual Studio -.ntvs_analysis.dat -node_modules/ - -# Typescript v1 declaration files -typings/ - -# Visual Studio 6 build log -*.plg - -# Visual Studio 6 workspace options file -*.opt - -# Visual Studio 6 auto-generated workspace file (contains which files were open etc.) -*.vbw - -# Visual Studio LightSwitch build output -**/*.HTMLClient/GeneratedArtifacts -**/*.DesktopClient/GeneratedArtifacts -**/*.DesktopClient/ModelManifest.xml -**/*.Server/GeneratedArtifacts -**/*.Server/ModelManifest.xml -_Pvt_Extensions - -# Paket dependency manager -.paket/paket.exe -paket-files/ - -# FAKE - F# Make -.fake/ - -# JetBrains Rider -.idea/ -*.sln.iml - -# CodeRush -.cr/ - -# Python Tools for Visual Studio (PTVS) -__pycache__/ -*.pyc - -# Cake - Uncomment if you are using it -# tools/** -# !tools/packages.config - -# Tabs Studio -*.tss - -# Telerik's JustMock configuration file -*.jmconfig - -# BizTalk build output -*.btp.cs -*.btm.cs -*.odx.cs -*.xsd.cs - -# OpenCover UI analysis results -OpenCover/ diff --git a/SEAL/.gitmodules b/SEAL/.gitmodules deleted file mode 100644 index 48f1a39..0000000 --- a/SEAL/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "native/tests/thirdparty/googletest"] - path = native/tests/thirdparty/googletest - url = https://github.com/google/googletest diff --git a/SEAL/LICENSE b/SEAL/LICENSE deleted file mode 100644 index 2107107..0000000 --- a/SEAL/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ - MIT License - - Copyright (c) Microsoft Corporation. All rights reserved. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE diff --git a/SEAL/README.md b/SEAL/README.md deleted file mode 100644 index d970137..0000000 --- a/SEAL/README.md +++ /dev/null @@ -1,425 +0,0 @@ -# Microsoft SEAL - -Microsoft SEAL is an easy-to-use open-source ([MIT licensed](LICENSE)) homomorphic encryption library developed by the Cryptography Research group at Microsoft. Microsoft SEAL is written in modern standard C++ and has no external dependencies, making it easy to compile and run in many different environments. For more information about the Microsoft SEAL project, see [sealcrypto.org](https://www.microsoft.com/en-us/research/project/microsoft-seal). - -This document pertains to Microsoft SEAL version 3.3. Users of previous versions of the library should look at the [list of changes](Changes.md). - -# Contents - -- [Introduction](#introduction) - - [Core Concepts](#core-concepts) - - [Homomorphic Encryption](#homomorphic-encryption) - - [Microsoft SEAL](#microsoft-seal-1) -- [Installing Microsoft SEAL](#installing-microsoft-seal) - - [Windows](#windows) - - [Linux and macOS](#linux-and-macos) -- [Installing Microsoft SEAL for .NET](#installing-microsoft-seal-for-net) - - [Windows](#windows-1) - - [Linux and macOS](#linux-and-macos-1) -- [Getting Started](#getting-started) -- [Contributing](#contributing) -- [Citing Microsoft SEAL](#citing-microsoft-seal) - -# Introduction - -## Core Concepts - -Most encryption schemes consist of three functionalities: key generation, encryption, and decryption. Symmetric-key encryption schemes use the same secret key for both encryption and decryption; public-key encryption schemes use separately a public key for encryption and a secret key for decryption. Therefore, public-key encryption schemes allow anyone who knows the public key to encrypt data, but only those who know the secret key can decrypt and read the data. Symmetric-key encryption can be used for efficiently encrypting very large amounts of data, and enables secure outsourced cloud storage. Public-key encryption is a fundamental concept that enables secure online communication today, but is typically much less efficient than symmetric-key encryption. - -While traditional symmetric- and public-key encryption can be used for secure storage and communication, any outsourced computation will necessarily require such encryption layers to be removed before computation can take place. Therefore, cloud services providing outsourced computation capabilities must have access to the secret keys, and implement access policies to prevent unauthorized employees from getting access to these keys. - -## Homomorphic Encryption - -Homomorphic encryption refers to encryption schemes that allow the cloud to compute directly on the encrypted data, without requiring the data to be decrypted first. The results of such encrypted computations remain encrypted, and can be only decrypted with the secret key (by the data owner). Multiple homomorphic encryption schemes with different capabilities and trade-offs have been invented over the past decade; most of these are public-key encryption schemes, although the public-key functionality may not always be needed. - -Homomorphic encryption is not a generic technology: only some computations on encrypted data are possible. It also comes with a substantial performance overhead, so computations that are already very costly to perform on unencrypted data are likely to be infeasible on encrypted data. Moreover, data encrypted with homomorphic encryption is many times larger than unencrypted data, so it may not make sense to encrypt, e.g., entire large databases, with this technology. Instead, meaningful use-cases are in scenarios where strict privacy requirements prohibit unencrypted cloud computation altogether, but the computations themselves are fairly lightweight. - -Typically, homomorphic encryption schemes have a single secret key which is held by the data owner. For scenarios where multiple different private data owners wish to engage in collaborative computation, homomorphic encryption is probably not a reasonable solution. - -Homomorphic encryption cannot be used to enable data scientist to circumvent GDPR. For example, there is no way for a cloud service to use homomorphic encryption to draw insights from encrypted customer data. Instead, results of encrypted computations remain encrypted and can only be decrypted by the owner of the data, e.g., a cloud service customer. - -## Microsoft SEAL - -Microsoft SEAL is a homomorphic encryption library that allows additions and multiplications to be performed on encrypted integers or real numbers. Other operations, such as encrypted comparison, sorting, or regular expressions, are in most cases not feasible to evaluate on encrypted data using this technology. Therefore, only specific privacy-critical cloud computation parts of programs should be implemented with Microsoft SEAL. - -It is not always easy or straightfoward to translate an unencrypted computation into a computation on encrypted data, for example, it is not possible to branch on encrypted data. Microsoft SEAL itself has a steep learning curve and requires the user to understand many homomorphic encryption specific concepts, even though in the end the API is not too complicated. Even if a user is able to program and run a specific computation using Microsoft SEAL, the difference between efficient and inefficient implementations can be several orders of magnitude, and it can be hard for new users to know how to improve the performance of their computation. - -Microsoft SEAL comes with two different homomorphic encryption schemes with very different properties. The BFV scheme allows modular arithmetic to be performed on encrypted integers. The CKKS scheme allows additions and multiplications on encrypted real or complex numbers, but yields only approximate results. In applications such a summing up encrypted real numbers, evaluating machine learning models on encrypted data, or computing distances of encrypted locations CKKS is going to be by far the best choice. For applications where exact values are necessary, the BFV scheme is the only choice. - -# Installing Microsoft SEAL - -## Windows - -Microsoft SEAL comes with a Microsoft Visual Studio 2017 solution file `SEAL.sln` that can be -used to conveniently build the library, examples, and unit tests. - -#### Debug and Release builds - -You can easily switch from Visual Studio build configuration menu whether Microsoft SEAL should be -built in `Debug` mode (no optimizations) or in `Release` mode. Please note that `Debug` -mode should not be used except for debugging SEAL itself, as the performance will be -orders of magnitude worse than in `Release` mode. - -#### Library - -Build the SEAL project `native\src\SEAL.vcxproj` from `SEAL.sln`. This results -in the static library `seal.lib` to be created in `native\lib\$(Platform)\$(Configuration)`. When -linking with applications, you need to add `native\src\` (full path) as an include directory -for SEAL header files. - -#### Examples - -Build the SEALExamples project `native\examples\SEALExamples.vcxproj` from `SEAL.sln`. -This results in an executable `sealexamples.exe` to be created in `native\bin\$(Platform)\$(Configuration)`. - -#### Unit tests - -The unit tests require the Google Test framework to be installed. The appropriate -NuGet package is already listed in `native\tests\packages.config`, so once you attempt to build -the SEALTest project `native\tests\SEALTest.vcxproj` from `SEAL.sln` Visual Studio will -automatically download and install it for you. - -## Linux and macOS - -Microsoft SEAL is very easy to configure and build in Linux and macOS using CMake (>= 3.10). -A modern version of GNU G++ (>= 6.0) or Clang++ (>= 5.0) is needed. In macOS the -Xcode toolchain (>= 9.3) will work. - -In macOS you will need CMake with command line tools. For this, you can either -1. install the cmake package with [Homebrew](https://brew.sh), or -2. download CMake directly from [https://cmake.org/download](https://cmake.org/download) and [enable command line tools](https://stackoverflow.com/questions/30668601/installing-cmake-command-line-tools-on-a-mac). - -Below we give instructions for how to configure, build, and install SEAL either -system-wide (global install), or for a single user (local install). A system-wide -install requires elevated (root) privileges. - -#### Debug and Release builds - -You can easily switch from CMake configuration options whether Microsoft SEAL should be built in -`Debug` mode (no optimizations) or in `Release` mode. Please note that `Debug` mode should not -be used except for debugging Microsoft SEAL itself, as the performance will be orders of magnitude -worse than in `Release` mode. - -### Global install - -#### Library - -If you have root access to the system you can install Microsoft SEAL system-wide as follows: -```` -cd native/src -cmake . -make -sudo make install -cd ../.. -```` -#### Examples - -To build the examples do: -```` -cd native/examples -cmake . -make -cd ../.. -```` - -After completing the above steps the `sealexamples` executable can be found in `native/bin/`. -See `native/examples/CMakeLists.txt` for how to link Microsoft SEAL with your own project using CMake. - -#### Unit tests - -To build the unit tests you will need the [GoogleTest](https://github.com/google/googletest) framework, which is included in Microsoft SEAL as a git submodule. To download the GoogleTest source files, do: -```` -git submodule update --init -```` -This needs to be executed only once, and can be skipped if Microsoft SEAL was cloned with `git --recurse-submodules`. To build the tests, do: -```` -cd native/tests -cmake . -make -cd ../.. -```` - -After completing these steps the `sealtest` executable can be found in `native/bin/`. All unit -tests should pass successfully. - -### Local install - -#### Library - -To install Microsoft SEAL locally, e.g., to `~/mylibs/`, do the following: -```` -cd native/src -cmake -DCMAKE_INSTALL_PREFIX=~/mylibs . -make -make install -cd ../.. -```` - -#### Examples - -To build the examples do: -```` -cd native/examples -cmake -DCMAKE_PREFIX_PATH=~/mylibs . -make -cd ../.. -```` - -After completing the above steps the `sealexamples` executable can be found in `native/bin/`. -See `native/examples/CMakeLists.txt` for how to link Microsoft SEAL with your own project using CMake. - -#### Unit tests - -To build the unit tests you will need the [GoogleTest](https://github.com/google/googletest) framework, which is included in Microsoft SEAL as a git submodule. To download the GoogleTest source files, do: -```` -git submodule update --init -```` -This needs to be executed only one, and can be skipped if Microsoft SEAL was cloned with `git --recurse-submodules`. Then do: -```` -cd native/tests -cmake -DCMAKE_PREFIX_PATH=~/mylibs . -make -cd ../.. -```` - -After completing these steps the `sealtest` executable can be found in `native/bin/`. All unit -tests should pass successfully. - -# Installing Microsoft SEAL for .NET - -Microsoft SEAL provides a .NET Standard library that wraps the functionality in Microsoft SEAL -for use in .NET development. - -## Windows - -The Microsoft Visual Studio 2017 solution file `SEAL.sln` contains the projects necessary -to build the .NET assembly, a backing native shared library, .NET examples, and unit tests. - -#### Native library - -Microsoft SEAL for .NET requires a native library that is invoked by the managed .NET library. -Build the SEALNetNative project `dotnet\native\SEALNetNative.vcxproj` from `SEAL.sln`. -Building SEALNetNative results in the dynamic library `sealnetnative.dll` to be created -in `dotnet\lib\$(Platform)\$(Configuration)`. This library is meant to be used only by the -.NET library, not by end users, and needs to be present in the same directory as your -executable when developing a .NET application. - -#### .NET library - -Once you have built the shared native library (see above), build the SEALNet project -`dotnet\src\SEALNet.csproj` from `SEAL.sln`. Building SEALNet results in the assembly -`SEALNet.dll` to be created in `dotnet\lib\$(Configuration)\netstandard2.0`. This -is the assembly you can reference in your application. - -#### .NET examples - -Build the SEALNetExamples project `dotnet\examples\SEALNetExamples.csproj` from `SEAL.sln`. -This results in the assembly `SEALNetExamples.dll` to be created in -`dotnet\bin\$(Configuration)\netcoreapp2.1`. The project takes care of copying the -native SEALNetNative library to the output directory. - -#### .NET unit tests - -Build the SEALNet Test project `dotnet\tests\SEALNetTest.csproj` from `SEAL.sln`. This results -in the `SEALNetTest.dll` assembly to be created in `dotnet\lib\$(Configuration)\netcoreapp2.1`. -The project takes care of copying the native SEALNetNative library to the output directory. - -### Using Microsoft SEAL for .NET in your own application - -To use Microsoft SEAL for .NET in your own application you need to: -1. add a reference in your project to `SEALNet.dll`; -2. ensure `sealnetnative.dll` is available for your application when run. The easiest way to ensure - this is to copy `sealnetnative.dll` to the same directory where your application's executable - is located. - -Alternatively, you can build and use a NuGet package; see instructions in [NUGET.md](dotnet/nuget/NUGET.md). - -## Linux and macOS - -Microsoft SEAL for .NET relies on a native shared library that can be easily configured and built -using CMake (>= 3.10) and a modern version of GNU G++ (>= 6.0) or Clang++ (>= 5.0). In macOS -the Xcode toolchain (>= 9.3) will work. - -For compiling .NET code you will need to install a .NET Core SDK (>= 2.1). You can follow -these [instructions for installing in Linux](https://dotnet.microsoft.com/download?initial-os=linux), -or for [installing in macOS](https://dotnet.microsoft.com/download?initial-os=macos). - -### Local use of shared native library - -If you only intend to run the examples and unit tests provided with Microsoft SEAL, -you do not need to install the native shared library, you only need to compile it. -The SEALNetExamples and SEALNetTest projects take care of copying the native shared -library to the appropriate assembly output directory. - -To compile the native shared library you will need to: -1. Compile Microsoft SEAL as a static or shared library with Position-Independent Code (PIC); -2. Compile native shared library. - -The instructions for compiling Microsoft SEAL are similar to the instructions described -[above](#linux-and-macos) for a global or local install. Make sure the CMake configuration -option `SEAL_LIB_BUILD_TYPE` is set to either `Static_PIC` (default) or `Shared`. Assuming -Microsoft SEAL was built and installed globally using the default CMake configuration -options, we can immediately use it to compile the shared native library required for .NET: -```` -cd dotnet/native -cmake . -make -cd ../.. -```` -If Microsoft SEAL was installed locally instead, use: -```` -cd dotnet/native -cmake -DCMAKE_PREFIX_PATH=~/mylibs . -make -cd ../.. -```` - -#### .NET library - -To build the .NET Standard library, do the following: -```` -cd dotnet/src -dotnet build -cd ../.. -```` -You can use the `dotnet` parameter `--configuration ` to build either -a `Debug` or `Release` version of the assembly. This will result in a `SEALNet.dll` -assembly to be created in `dotnet/lib/$(Configuration)/netstandard2.0`. This assembly -is the one you will want to reference in your own projects. - -#### Examples - -To build and run the .NET examples, do: -```` -cd dotnet/examples -dotnet run -cd ../.. -```` -As mentioned before, the .NET project will copy the shared native library to the assembly -output directory. You can use the `dotnet` parameter `--configuration ` to -run either `Debug` or `Release` versions of the examples. - -#### Unit tests - -To build and run the .NET unit tests, do: -```` -cd dotnet/tests -dotnet test -cd ../.. -```` -All unit tests should pass. You can use the `dotnet` parameter `--configuration ` -to run `Debug` or `Relase` unit tests, and you can use `--verbosity detailed` to print the list -of unit tests that are being run. - -### Using Microsoft SEAL for .NET in your own application - -To use Microsoft SEAL for .NET in your own application you need to: -1. add a reference in your project to `SEALNet.dll`; -2. ensure the native shared library is available for your application when run. The easiest way to ensure this is to copy `libsealnetnative.so` to the same directory where your application's executable is located. - -In Linux or macOS, if you have root access to the system, you have the option to install the -native shared library globally. Then your application will always be able to find and load it. - -Assuming Microsoft SEAL is build and installed globally, you can install the shared native -library globally as follows: -```` -cd dotnet/native -cmake . -make -sudo make install -cd ../.. -```` - -# Getting Started -Using Microsoft SEAL will require the user to invest some time in learning fundamental -concepts in homomorphic encryption. The code comes with heavily commented examples that -are designed to gradually teach such concepts as well as to demonstrate much of the API. -The code examples are available (and identical) in C++ and C#, and are divided into -several source files in `native/examples/` (C++) and `dotnet/examples/` (C#), as follows: - -|C++ |C# |Description | -|-------------------|------------------|----------------------------------------------------------------------------| -|`examples.cpp` |`Examples.cs` |The example runner application | -|`1_bfv_basics.cpp` |`1_BFV_Basics.cs` |Encrypted modular arithmetic using the BFV scheme | -|`2_encoders.cpp` |`2_Encoders.cs` |Encoding more complex data into Microsoft SEAL plaintext objects | -|`3_levels.cpp` |`3_Levels.cs` |Introduces the concept of levels; prerequisite for using the CKKS scheme | -|`4_ckks_basics.cpp`|`4_CKKS_Basics.cs`|Encrypted real number arithmetic using the CKKS scheme | -|`5_rotation.cpp` |`5_Rotation.cs` |Performing cyclic rotations on encrypted vectors in the BFV and CKKS schemes| -|`6_performance.cpp`|`6_Performance.cs`|Performance tests for Microsoft SEAL | - -It is recommeded to read the comments and the code snippets along with command line printout -from running an example. For easier navigation, command line printout provides the line number -in the associated source file where the associated code snippets start. - -**WARNING: It is impossible to use Microsoft SEAL correctly without reading all examples -or by simply re-using the code from examples. Any developer attempting to do so -will inevitably produce code that is *vulnerable*, *malfunctioning*, or *extremely slow*.** - -# Contributing - -This project welcomes contributions and suggestions. Most contributions require you -to agree to a Contributor License Agreement (CLA) declaring that you have the right to, -and actually do, grant us the rights to use your contribution. For details, visit -https://cla.microsoft.com. - -When you submit a pull request, a CLA-bot will automatically determine whether you need -to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow -the instructions provided by the bot. You will only need to do this once across all -repos using our CLA. - -Pull requests must be submitted to the branch called `contrib`. - -This project has adopted the -[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). -For more information see the -[Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) -or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional -questions or comments. - -# Citing Microsoft SEAL - -To cite Microsoft SEAL in academic papers, please use the following BibTeX entries. - -### Version 3.3 - - @misc{sealcrypto, - title = {{M}icrosoft {SEAL} (release 3.3)}, - howpublished = {\url{https://github.com/Microsoft/SEAL}}, - month = june, - year = 2019, - note = {Microsoft Research, Redmond, WA.}, - key = {SEAL} - } - -### Version 3.2 - - @misc{sealcrypto, - title = {{M}icrosoft {SEAL} (release 3.2)}, - howpublished = {\url{https://github.com/Microsoft/SEAL}}, - month = feb, - year = 2019, - note = {Microsoft Research, Redmond, WA.}, - key = {SEAL} - } - -### Version 3.1 - - @misc{sealcrypto, - title = {{M}icrosoft {SEAL} (release 3.1)}, - howpublished = {\url{https://github.com/Microsoft/SEAL}}, - month = dec, - year = 2018, - note = {Microsoft Research, Redmond, WA.}, - key = {SEAL} - } - -### Version 3.0 - - @misc{sealcrypto, - title = {{M}icrosoft {SEAL} (release 3.0)}, - howpublished = {\url{http://sealcrypto.org}}, - month = oct, - year = 2018, - note = {Microsoft Research, Redmond, WA.}, - key = {SEAL} - } diff --git a/SEAL/SEAL.sln b/SEAL/SEAL.sln deleted file mode 100644 index 1de9a01..0000000 --- a/SEAL/SEAL.sln +++ /dev/null @@ -1,45 +0,0 @@ -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 15 -VisualStudioVersion = 15.0.26430.16 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SEAL", "native\src\SEAL.vcxproj", "{7EA96C25-FC0D-485A-BB71-32B6DA55652A}" -EndProject -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "SEALExamples", "native\examples\SEALExamples.vcxproj", "{2B57D847-26DC-45FF-B9AF-EE33910B5093}" - ProjectSection(ProjectDependencies) = postProject - {7EA96C25-FC0D-485A-BB71-32B6DA55652A} = {7EA96C25-FC0D-485A-BB71-32B6DA55652A} - EndProjectSection -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "native", "native", "{A5BADDF0-1F03-48FE-AAC0-3355614C9A8D}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|x64 = Debug|x64 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {7EA96C25-FC0D-485A-BB71-32B6DA55652A}.Debug|x64.ActiveCfg = Debug|x64 - {7EA96C25-FC0D-485A-BB71-32B6DA55652A}.Debug|x64.Build.0 = Debug|x64 - {7EA96C25-FC0D-485A-BB71-32B6DA55652A}.Release|x64.ActiveCfg = Release|x64 - {7EA96C25-FC0D-485A-BB71-32B6DA55652A}.Release|x64.Build.0 = Release|x64 - {2B57D847-26DC-45FF-B9AF-EE33910B5093}.Debug|x64.ActiveCfg = Debug|x64 - {2B57D847-26DC-45FF-B9AF-EE33910B5093}.Debug|x64.Build.0 = Debug|x64 - {2B57D847-26DC-45FF-B9AF-EE33910B5093}.Release|x64.ActiveCfg = Release|x64 - {2B57D847-26DC-45FF-B9AF-EE33910B5093}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection - GlobalSection(NestedProjects) = preSolution - {7EA96C25-FC0D-485A-BB71-32B6DA55652A} = {A5BADDF0-1F03-48FE-AAC0-3355614C9A8D} - {2B57D847-26DC-45FF-B9AF-EE33910B5093} = {A5BADDF0-1F03-48FE-AAC0-3355614C9A8D} - EndGlobalSection - GlobalSection(ExtensibilityGlobals) = postSolution - SolutionGuid = {15A17F22-F747-4B82-BF5F-E0224AF4B3ED} - EndGlobalSection - GlobalSection(Performance) = preSolution - HasPerformanceSessions = true - EndGlobalSection - GlobalSection(Performance) = preSolution - HasPerformanceSessions = true - EndGlobalSection -EndGlobal diff --git a/SEAL/native/examples/1_bfv_basics.cpp b/SEAL/native/examples/1_bfv_basics.cpp deleted file mode 100644 index 8545c4d..0000000 --- a/SEAL/native/examples/1_bfv_basics.cpp +++ /dev/null @@ -1,409 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "examples.h" - -using namespace std; -using namespace seal; - -void example_bfv_basics() -{ - print_example_banner("Example: BFV Basics"); - - /* - In this example, we demonstrate performing simple computations (a polynomial - evaluation) on encrypted integers using the BFV encryption scheme. - - The first task is to set up an instance of the EncryptionParameters class. - It is critical to understand how the different parameters behave, how they - affect the encryption scheme, performance, and the security level. There are - three encryption parameters that are necessary to set: - - - poly_modulus_degree (degree of polynomial modulus); - - coeff_modulus ([ciphertext] coefficient modulus); - - plain_modulus (plaintext modulus; only for the BFV scheme). - - The BFV scheme cannot perform arbitrary computations on encrypted data. - Instead, each ciphertext has a specific quantity called the `invariant noise - budget' -- or `noise budget' for short -- measured in bits. The noise budget - in a freshly encrypted ciphertext (initial noise budget) is determined by - the encryption parameters. Homomorphic operations consume the noise budget - at a rate also determined by the encryption parameters. In BFV the two basic - operations allowed on encrypted data are additions and multiplications, of - which additions can generally be thought of as being nearly free in terms of - noise budget consumption compared to multiplications. Since noise budget - consumption compounds in sequential multiplications, the most significant - factor in choosing appropriate encryption parameters is the multiplicative - depth of the arithmetic circuit that the user wants to evaluate on encrypted - data. Once the noise budget of a ciphertext reaches zero it becomes too - corrupted to be decrypted. Thus, it is essential to choose the parameters to - be large enough to support the desired computation; otherwise the result is - impossible to make sense of even with the secret key. - */ - EncryptionParameters parms(scheme_type::BFV); - - /* - The first parameter we set is the degree of the `polynomial modulus'. This - must be a positive power of 2, representing the degree of a power-of-two - cyclotomic polynomial; it is not necessary to understand what this means. - - Larger poly_modulus_degree makes ciphertext sizes larger and all operations - slower, but enables more complicated encrypted computations. Recommended - values are 1024, 2048, 4096, 8192, 16384, 32768, but it is also possible - to go beyond this range. - - In this example we use a relatively small polynomial modulus. Anything - smaller than this will enable only very restricted encrypted computations. - */ - size_t poly_modulus_degree = 4096; - parms.set_poly_modulus_degree(poly_modulus_degree); - - /* - Next we set the [ciphertext] `coefficient modulus' (coeff_modulus). This - parameter is a large integer, which is a product of distinct prime numbers, - each up to 60 bits in size. It is represented as a vector of these prime - numbers, each represented by an instance of the SmallModulus class. The - bit-length of coeff_modulus means the sum of the bit-lengths of its prime - factors. - - A larger coeff_modulus implies a larger noise budget, hence more encrypted - computation capabilities. However, an upper bound for the total bit-length - of the coeff_modulus is determined by the poly_modulus_degree, as follows: - - +----------------------------------------------------+ - | poly_modulus_degree | max coeff_modulus bit-length | - +---------------------+------------------------------+ - | 1024 | 27 | - | 2048 | 54 | - | 4096 | 109 | - | 8192 | 218 | - | 16384 | 438 | - | 32768 | 881 | - +---------------------+------------------------------+ - - These numbers can also be found in native/src/seal/util/hestdparms.h encoded - in the function SEAL_HE_STD_PARMS_128_TC, and can also be obtained from the - function - - CoeffModulus::MaxBitCount(poly_modulus_degree). - - For example, if poly_modulus_degree is 4096, the coeff_modulus could consist - of three 36-bit primes (108 bits). - - Microsoft SEAL comes with helper functions for selecting the coeff_modulus. - For new users the easiest way is to simply use - - CoeffModulus::BFVDefault(poly_modulus_degree), - - which returns std::vector consisting of a generally good choice - for the given poly_modulus_degree. - */ - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - - /* - The plaintext modulus can be any positive integer, even though here we take - it to be a power of two. In fact, in many cases one might instead want it - to be a prime number; we will see this in later examples. The plaintext - modulus determines the size of the plaintext data type and the consumption - of noise budget in multiplications. Thus, it is essential to try to keep the - plaintext data type as small as possible for best performance. The noise - budget in a freshly encrypted ciphertext is - - ~ log2(coeff_modulus/plain_modulus) (bits) - - and the noise budget consumption in a homomorphic multiplication is of the - form log2(plain_modulus) + (other terms). - - The plaintext modulus is specific to the BFV scheme, and cannot be set when - using the CKKS scheme. - */ - parms.set_plain_modulus(256); - - /* - Now that all parameters are set, we are ready to construct a SEALContext - object. This is a heavy class that checks the validity and properties of the - parameters we just set. - */ - auto context = SEALContext::Create(parms); - - /* - Print the parameters that we have chosen. - */ - print_line(__LINE__); - cout << "Set encryption parameters and print" << endl; - print_parameters(context); - - cout << endl; - cout << "~~~~~~ A naive way to calculate 2(x^2+1)(x+1)^2. ~~~~~~" << endl; - - /* - The encryption schemes in Microsoft SEAL are public key encryption schemes. - For users unfamiliar with this terminology, a public key encryption scheme - has a separate public key for encrypting data, and a separate secret key for - decrypting data. This way multiple parties can encrypt data using the same - shared public key, but only the proper recipient of the data can decrypt it - with the secret key. - - We are now ready to generate the secret and public keys. For this purpose - we need an instance of the KeyGenerator class. Constructing a KeyGenerator - automatically generates the public and secret key, which can immediately be - read to local variables. - */ - KeyGenerator keygen(context); - PublicKey public_key = keygen.public_key(); - SecretKey secret_key = keygen.secret_key(); - - /* - To be able to encrypt we need to construct an instance of Encryptor. Note - that the Encryptor only requires the public key, as expected. - */ - Encryptor encryptor(context, public_key); - - /* - Computations on the ciphertexts are performed with the Evaluator class. In - a real use-case the Evaluator would not be constructed by the same party - that holds the secret key. - */ - Evaluator evaluator(context); - - /* - We will of course want to decrypt our results to verify that everything worked, - so we need to also construct an instance of Decryptor. Note that the Decryptor - requires the secret key. - */ - Decryptor decryptor(context, secret_key); - - /* - As an example, we evaluate the degree 4 polynomial - - 2x^4 + 4x^3 + 4x^2 + 4x + 2 - - over an encrypted x = 6. The coefficients of the polynomial can be considered - as plaintext inputs, as we will see below. The computation is done modulo the - plain_modulus 256. - - While this examples is simple and easy to understand, it does not have much - practical value. In later examples we will demonstrate how to compute more - efficiently on encrypted integers and real or complex numbers. - - Plaintexts in the BFV scheme are polynomials of degree less than the degree - of the polynomial modulus, and coefficients integers modulo the plaintext - modulus. For readers with background in ring theory, the plaintext space is - the polynomial quotient ring Z_T[X]/(X^N + 1), where N is poly_modulus_degree - and T is plain_modulus. - - To get started, we create a plaintext containing the constant 6. For the - plaintext element we use a constructor that takes the desired polynomial as - a string with coefficients represented as hexadecimal numbers. - */ - print_line(__LINE__); - int x = 6; - Plaintext x_plain(to_string(x)); - cout << "Express x = " + to_string(x) + - " as a plaintext polynomial 0x" + x_plain.to_string() + "." << endl; - - /* - We then encrypt the plaintext, producing a ciphertext. - */ - print_line(__LINE__); - Ciphertext x_encrypted; - cout << "Encrypt x_plain to x_encrypted." << endl; - encryptor.encrypt(x_plain, x_encrypted); - - /* - In Microsoft SEAL, a valid ciphertext consists of two or more polynomials - whose coefficients are integers modulo the product of the primes in the - coeff_modulus. The number of polynomials in a ciphertext is called its `size' - and is given by Ciphertext::size(). A freshly encrypted ciphertext always - has size 2. - */ - cout << " + size of freshly encrypted x: " << x_encrypted.size() << endl; - - /* - There is plenty of noise budget left in this freshly encrypted ciphertext. - */ - cout << " + noise budget in freshly encrypted x: " - << decryptor.invariant_noise_budget(x_encrypted) << " bits" << endl; - - /* - We decrypt the ciphertext and print the resulting plaintext in order to - demonstrate correctness of the encryption. - */ - Plaintext x_decrypted; - cout << " + decryption of x_encrypted: "; - decryptor.decrypt(x_encrypted, x_decrypted); - cout << "0x" << x_decrypted.to_string() << " ...... Correct." << endl; - - /* - When using Microsoft SEAL, it is typically advantageous to compute in a way - that minimizes the longest chain of sequential multiplications. In other - words, encrypted computations are best evaluated in a way that minimizes - the multiplicative depth of the computation, because the total noise budget - consumption is proportional to the multiplicative depth. For example, for - our example computation it is advantageous to factorize the polynomial as - - 2x^4 + 4x^3 + 4x^2 + 4x + 2 = 2(x + 1)^2 * (x^2 + 1) - - to obtain a simple depth 2 representation. Thus, we compute (x + 1)^2 and - (x^2 + 1) separately, before multiplying them, and multiplying by 2. - - First, we compute x^2 and add a plaintext "1". We can clearly see from the - print-out that multiplication has consumed a lot of noise budget. The user - can vary the plain_modulus parameter to see its effect on the rate of noise - budget consumption. - */ - print_line(__LINE__); - cout << "Compute x_sq_plus_one (x^2+1)." << endl; - Ciphertext x_sq_plus_one; - evaluator.square(x_encrypted, x_sq_plus_one); - Plaintext plain_one("1"); - evaluator.add_plain_inplace(x_sq_plus_one, plain_one); - - /* - Encrypted multiplication results in the output ciphertext growing in size. - More precisely, if the input ciphertexts have size M and N, then the output - ciphertext after homomorphic multiplication will have size M+N-1. In this - case we perform a squaring, and observe both size growth and noise budget - consumption. - */ - cout << " + size of x_sq_plus_one: " << x_sq_plus_one.size() << endl; - cout << " + noise budget in x_sq_plus_one: " - << decryptor.invariant_noise_budget(x_sq_plus_one) << " bits" << endl; - - /* - Even though the size has grown, decryption works as usual as long as noise - budget has not reached 0. - */ - Plaintext decrypted_result; - cout << " + decryption of x_sq_plus_one: "; - decryptor.decrypt(x_sq_plus_one, decrypted_result); - cout << "0x" << decrypted_result.to_string() << " ...... Correct." << endl; - - /* - Next, we compute (x + 1)^2. - */ - print_line(__LINE__); - cout << "Compute x_plus_one_sq ((x+1)^2)." << endl; - Ciphertext x_plus_one_sq; - evaluator.add_plain(x_encrypted, plain_one, x_plus_one_sq); - evaluator.square_inplace(x_plus_one_sq); - cout << " + size of x_plus_one_sq: " << x_plus_one_sq.size() << endl; - cout << " + noise budget in x_plus_one_sq: " - << decryptor.invariant_noise_budget(x_plus_one_sq) - << " bits" << endl; - cout << " + decryption of x_plus_one_sq: "; - decryptor.decrypt(x_plus_one_sq, decrypted_result); - cout << "0x" << decrypted_result.to_string() << " ...... Correct." << endl; - - /* - Finally, we multiply (x^2 + 1) * (x + 1)^2 * 2. - */ - print_line(__LINE__); - cout << "Compute encrypted_result (2(x^2+1)(x+1)^2)." << endl; - Ciphertext encrypted_result; - Plaintext plain_two("2"); - evaluator.multiply_plain_inplace(x_sq_plus_one, plain_two); - evaluator.multiply(x_sq_plus_one, x_plus_one_sq, encrypted_result); - cout << " + size of encrypted_result: " << encrypted_result.size() << endl; - cout << " + noise budget in encrypted_result: " - << decryptor.invariant_noise_budget(encrypted_result) << " bits" << endl; - cout << "NOTE: Decryption can be incorrect if noise budget is zero." << endl; - - cout << endl; - cout << "~~~~~~ A better way to calculate 2(x^2+1)(x+1)^2. ~~~~~~" << endl; - - /* - Noise budget has reached 0, which means that decryption cannot be expected - to give the correct result. This is because both ciphertexts x_sq_plus_one - and x_plus_one_sq consist of 3 polynomials due to the previous squaring - operations, and homomorphic operations on large ciphertexts consume much more - noise budget than computations on small ciphertexts. Computing on smaller - ciphertexts is also computationally significantly cheaper. - - `Relinearization' is an operation that reduces the size of a ciphertext after - multiplication back to the initial size, 2. Thus, relinearizing one or both - input ciphertexts before the next multiplication can have a huge positive - impact on both noise growth and performance, even though relinearization has - a significant computational cost itself. It is only possible to relinearize - size 3 ciphertexts down to size 2, so often the user would want to relinearize - after each multiplication to keep the ciphertext sizes at 2. - - Relinearization requires special `relinearization keys', which can be thought - of as a kind of public key. Relinearization keys can easily be created with - the KeyGenerator. - - Relinearization is used similarly in both the BFV and the CKKS schemes, but - in this example we continue using BFV. We repeat our computation from before, - but this time relinearize after every multiplication. - - We use KeyGenerator::relin_keys() to create relinearization keys. - */ - print_line(__LINE__); - cout << "Generate relinearization keys." << endl; - auto relin_keys = keygen.relin_keys(); - - /* - We now repeat the computation relinearizing after each multiplication. - */ - print_line(__LINE__); - cout << "Compute and relinearize x_squared (x^2)," << endl; - cout << string(13, ' ') << "then compute x_sq_plus_one (x^2+1)" << endl; - Ciphertext x_squared; - evaluator.square(x_encrypted, x_squared); - cout << " + size of x_squared: " << x_squared.size() << endl; - evaluator.relinearize_inplace(x_squared, relin_keys); - cout << " + size of x_squared (after relinearization): " - << x_squared.size() << endl; - evaluator.add_plain(x_squared, plain_one, x_sq_plus_one); - cout << " + noise budget in x_sq_plus_one: " - << decryptor.invariant_noise_budget(x_sq_plus_one) << " bits" << endl; - cout << " + decryption of x_sq_plus_one: "; - decryptor.decrypt(x_sq_plus_one, decrypted_result); - cout << "0x" << decrypted_result.to_string() << " ...... Correct." << endl; - - print_line(__LINE__); - Ciphertext x_plus_one; - cout << "Compute x_plus_one (x+1)," << endl; - cout << string(13, ' ') - << "then compute and relinearize x_plus_one_sq ((x+1)^2)." << endl; - evaluator.add_plain(x_encrypted, plain_one, x_plus_one); - evaluator.square(x_plus_one, x_plus_one_sq); - cout << " + size of x_plus_one_sq: " << x_plus_one_sq.size() << endl; - evaluator.relinearize_inplace(x_plus_one_sq, relin_keys); - cout << " + noise budget in x_plus_one_sq: " - << decryptor.invariant_noise_budget(x_plus_one_sq) << " bits" << endl; - cout << " + decryption of x_plus_one_sq: "; - decryptor.decrypt(x_plus_one_sq, decrypted_result); - cout << "0x" << decrypted_result.to_string() << " ...... Correct." << endl; - - print_line(__LINE__); - cout << "Compute and relinearize encrypted_result (2(x^2+1)(x+1)^2)." << endl; - evaluator.multiply_plain_inplace(x_sq_plus_one, plain_two); - evaluator.multiply(x_sq_plus_one, x_plus_one_sq, encrypted_result); - cout << " + size of encrypted_result: " << encrypted_result.size() << endl; - evaluator.relinearize_inplace(encrypted_result, relin_keys); - cout << " + size of encrypted_result (after relinearization): " - << encrypted_result.size() << endl; - cout << " + noise budget in encrypted_result: " - << decryptor.invariant_noise_budget(encrypted_result) << " bits" << endl; - - cout << endl; - cout << "NOTE: Notice the increase in remaining noise budget." << endl; - - /* - Relinearization clearly improved our noise consumption. We have still plenty - of noise budget left, so we can expect the correct answer when decrypting. - */ - print_line(__LINE__); - cout << "Decrypt encrypted_result (2(x^2+1)(x+1)^2)." << endl; - decryptor.decrypt(encrypted_result, decrypted_result); - cout << " + decryption of 2(x^2+1)(x+1)^2 = 0x" - << decrypted_result.to_string() << " ...... Correct." << endl; - cout << endl; - - /* - For x=6, 2(x^2+1)(x+1)^2 = 3626. Since the plaintext modulus is set to 256, - this result is computed in integers modulo 256. Therefore the expected output - should be 3626 % 256 == 42, or 0x2A in hexadecimal. - */ -} \ No newline at end of file diff --git a/SEAL/native/examples/2_encoders.cpp b/SEAL/native/examples/2_encoders.cpp deleted file mode 100644 index 9a5cb16..0000000 --- a/SEAL/native/examples/2_encoders.cpp +++ /dev/null @@ -1,481 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "examples.h" - -using namespace std; -using namespace seal; - -/* -In `1_bfv_basics.cpp' we showed how to perform a very simple computation using the -BFV scheme. The computation was performed modulo the plain_modulus parameter, and -utilized only one coefficient from a BFV plaintext polynomial. This approach has -two notable problems: - - (1) Practical applications typically use integer or real number arithmetic, - not modular arithmetic; - (2) We used only one coefficient of the plaintext polynomial. This is really - wasteful, as the plaintext polynomial is large and will in any case be - encrypted in its entirety. - -For (1), one may ask why not just increase the plain_modulus parameter until no -overflow occurs, and the computations behave as in integer arithmetic. The problem -is that increasing plain_modulus increases noise budget consumption, and decreases -the initial noise budget too. - -In these examples we will discuss other ways of laying out data into plaintext -elements (encoding) that allow more computations without data type overflow, and -can allow the full plaintext polynomial to be utilized. -*/ -void example_integer_encoder() -{ - print_example_banner("Example: Encoders / Integer Encoder"); - - /* - [IntegerEncoder] (For BFV scheme only) - - The IntegerEncoder encodes integers to BFV plaintext polynomials as follows. - First, a binary expansion of the integer is computed. Next, a polynomial is - created with the bits as coefficients. For example, the integer - - 26 = 2^4 + 2^3 + 2^1 - - is encoded as the polynomial 1x^4 + 1x^3 + 1x^1. Conversely, plaintext - polynomials are decoded by evaluating them at x=2. For negative numbers the - IntegerEncoder simply stores all coefficients as either 0 or -1, where -1 is - represented by the unsigned integer plain_modulus - 1 in memory. - - Since encrypted computations operate on the polynomials rather than on the - encoded integers themselves, the polynomial coefficients will grow in the - course of such computations. For example, computing the sum of the encrypted - encoded integer 26 with itself will result in an encrypted polynomial with - larger coefficients: 2x^4 + 2x^3 + 2x^1. Squaring the encrypted encoded - integer 26 results also in increased coefficients due to cross-terms, namely, - - (1x^4 + 1x^3 + 1x^1)^2 = 1x^8 + 2x^7 + 1x^6 + 2x^5 + 2x^4 + 1x^2; - - further computations will quickly increase the coefficients much more. - Decoding will still work correctly in this case (evaluating the polynomial - at x=2), but since the coefficients of plaintext polynomials are really - integers modulo plain_modulus, implicit reduction modulo plain_modulus may - yield unexpected results. For example, adding 1x^4 + 1x^3 + 1x^1 to itself - plain_modulus many times will result in the constant polynomial 0, which is - clearly not equal to 26 * plain_modulus. It can be difficult to predict when - such overflow will take place especially when computing several sequential - multiplications. - - The IntegerEncoder is easy to understand and use for simple computations, - and can be a good tool to experiment with for users new to Microsoft SEAL. - However, advanced users will probably prefer more efficient approaches, - such as the BatchEncoder or the CKKSEncoder. - */ - EncryptionParameters parms(scheme_type::BFV); - size_t poly_modulus_degree = 4096; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - - /* - There is no hidden logic behind our choice of the plain_modulus. The only - thing that matters is that the plaintext polynomial coefficients will not - exceed this value at any point during our computation; otherwise the result - will be incorrect. - */ - parms.set_plain_modulus(512); - auto context = SEALContext::Create(parms); - print_parameters(context); - cout << endl; - - KeyGenerator keygen(context); - PublicKey public_key = keygen.public_key(); - SecretKey secret_key = keygen.secret_key(); - Encryptor encryptor(context, public_key); - Evaluator evaluator(context); - Decryptor decryptor(context, secret_key); - - /* - We create an IntegerEncoder. - */ - IntegerEncoder encoder(context); - - /* - First, we encode two integers as plaintext polynomials. Note that encoding - is not encryption: at this point nothing is encrypted. - */ - int value1 = 5; - Plaintext plain1 = encoder.encode(value1); - print_line(__LINE__); - cout << "Encode " << value1 << " as polynomial " << plain1.to_string() - << " (plain1)," << endl; - - int value2 = -7; - Plaintext plain2 = encoder.encode(value2); - cout << string(13, ' ') << "encode " << value2 << " as polynomial " << plain2.to_string() - << " (plain2)." << endl; - - /* - Now we can encrypt the plaintext polynomials. - */ - Ciphertext encrypted1, encrypted2; - print_line(__LINE__); - cout << "Encrypt plain1 to encrypted1 and plain2 to encrypted2." << endl; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - cout << " + Noise budget in encrypted1: " - << decryptor.invariant_noise_budget(encrypted1) << " bits" << endl; - cout << " + Noise budget in encrypted2: " - << decryptor.invariant_noise_budget(encrypted2) << " bits" << endl; - - /* - As a simple example, we compute (-encrypted1 + encrypted2) * encrypted2. - */ - encryptor.encrypt(plain2, encrypted2); - Ciphertext encrypted_result; - print_line(__LINE__); - cout << "Compute encrypted_result = (-encrypted1 + encrypted2) * encrypted2." << endl; - evaluator.negate(encrypted1, encrypted_result); - evaluator.add_inplace(encrypted_result, encrypted2); - evaluator.multiply_inplace(encrypted_result, encrypted2); - cout << " + Noise budget in encrypted_result: " - << decryptor.invariant_noise_budget(encrypted_result) << " bits" << endl; - Plaintext plain_result; - print_line(__LINE__); - cout << "Decrypt encrypted_result to plain_result." << endl; - decryptor.decrypt(encrypted_result, plain_result); - - /* - Print the result plaintext polynomial. The coefficients are not even close - to exceeding our plain_modulus, 512. - */ - cout << " + Plaintext polynomial: " << plain_result.to_string() << endl; - - /* - Decode to obtain an integer result. - */ - print_line(__LINE__); - cout << "Decode plain_result." << endl; - cout << " + Decoded integer: " << encoder.decode_int32(plain_result); - cout << "...... Correct." << endl; -} - -void example_batch_encoder() -{ - print_example_banner("Example: Encoders / Batch Encoder"); - - /* - [BatchEncoder] (For BFV scheme only) - - Let N denote the poly_modulus_degree and T denote the plain_modulus. Batching - allows the BFV plaintext polynomials to be viewed as 2-by-(N/2) matrices, with - each element an integer modulo T. In the matrix view, encrypted operations act - element-wise on encrypted matrices, allowing the user to obtain speeds-ups of - several orders of magnitude in fully vectorizable computations. Thus, in all - but the simplest computations, batching should be the preferred method to use - with BFV, and when used properly will result in implementations outperforming - anything done with the IntegerEncoder. - */ - EncryptionParameters parms(scheme_type::BFV); - size_t poly_modulus_degree = 8192; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - - /* - To enable batching, we need to set the plain_modulus to be a prime number - congruent to 1 modulo 2*poly_modulus_degree. Microsoft SEAL provides a helper - method for finding such a prime. In this example we create a 20-bit prime - that supports batching. - */ - parms.set_plain_modulus(PlainModulus::Batching(poly_modulus_degree, 20)); - - auto context = SEALContext::Create(parms); - print_parameters(context); - cout << endl; - - /* - We can verify that batching is indeed enabled by looking at the encryption - parameter qualifiers created by SEALContext. - */ - auto qualifiers = context->first_context_data()->qualifiers(); - cout << "Batching enabled: " << boolalpha << qualifiers.using_batching << endl; - - KeyGenerator keygen(context); - PublicKey public_key = keygen.public_key(); - SecretKey secret_key = keygen.secret_key(); - RelinKeys relin_keys = keygen.relin_keys(); - Encryptor encryptor(context, public_key); - Evaluator evaluator(context); - Decryptor decryptor(context, secret_key); - - /* - Batching is done through an instance of the BatchEncoder class. - */ - BatchEncoder batch_encoder(context); - - /* - The total number of batching `slots' equals the poly_modulus_degree, N, and - these slots are organized into 2-by-(N/2) matrices that can be encrypted and - computed on. Each slot contains an integer modulo plain_modulus. - */ - size_t slot_count = batch_encoder.slot_count(); - size_t row_size = slot_count / 2; - cout << "Plaintext matrix row size: " << row_size << endl; - - /* - The matrix plaintext is simply given to BatchEncoder as a flattened vector - of numbers. The first `row_size' many numbers form the first row, and the - rest form the second row. Here we create the following matrix: - - [ 0, 1, 2, 3, 0, 0, ..., 0 ] - [ 4, 5, 6, 7, 0, 0, ..., 0 ] - */ - vector pod_matrix(slot_count, 0ULL); - pod_matrix[0] = 0ULL; - pod_matrix[1] = 1ULL; - pod_matrix[2] = 2ULL; - pod_matrix[3] = 3ULL; - pod_matrix[row_size] = 4ULL; - pod_matrix[row_size + 1] = 5ULL; - pod_matrix[row_size + 2] = 6ULL; - pod_matrix[row_size + 3] = 7ULL; - - cout << "Input plaintext matrix:" << endl; - print_matrix(pod_matrix, row_size); - - /* - First we use BatchEncoder to encode the matrix into a plaintext polynomial. - */ - Plaintext plain_matrix; - print_line(__LINE__); - cout << "Encode plaintext matrix:" << endl; - batch_encoder.encode(pod_matrix, plain_matrix); - - /* - We can instantly decode to verify correctness of the encoding. Note that no - encryption or decryption has yet taken place. - */ - vector pod_result; - cout << " + Decode plaintext matrix ...... Correct." << endl; - batch_encoder.decode(plain_matrix, pod_result); - print_matrix(pod_result, row_size); - - /* - Next we encrypt the encoded plaintext. - */ - Ciphertext encrypted_matrix; - print_line(__LINE__); - cout << "Encrypt plain_matrix to encrypted_matrix." << endl; - encryptor.encrypt(plain_matrix, encrypted_matrix); - cout << " + Noise budget in encrypted_matrix: " - << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; - - /* - Operating on the ciphertext results in homomorphic operations being performed - simultaneously in all 8192 slots (matrix elements). To illustrate this, we - form another plaintext matrix - - [ 1, 2, 1, 2, 1, 2, ..., 2 ] - [ 1, 2, 1, 2, 1, 2, ..., 2 ] - - and encode it into a plaintext. - */ - vector pod_matrix2; - for (size_t i = 0; i < slot_count; i++) - { - pod_matrix2.push_back((i % 2) + 1); - } - Plaintext plain_matrix2; - batch_encoder.encode(pod_matrix2, plain_matrix2); - cout << endl; - cout << "Second input plaintext matrix:" << endl; - print_matrix(pod_matrix2, row_size); - - /* - We now add the second (plaintext) matrix to the encrypted matrix, and square - the sum. - */ - print_line(__LINE__); - cout << "Sum, square, and relinearize." << endl; - evaluator.add_plain_inplace(encrypted_matrix, plain_matrix2); - evaluator.square_inplace(encrypted_matrix); - evaluator.relinearize_inplace(encrypted_matrix, relin_keys); - - /* - How much noise budget do we have left? - */ - cout << " + Noise budget in result: " - << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; - - /* - We decrypt and decompose the plaintext to recover the result as a matrix. - */ - Plaintext plain_result; - print_line(__LINE__); - cout << "Decrypt and decode result." << endl; - decryptor.decrypt(encrypted_matrix, plain_result); - batch_encoder.decode(plain_result, pod_result); - cout << " + Result plaintext matrix ...... Correct." << endl; - print_matrix(pod_result, row_size); - - /* - Batching allows us to efficiently use the full plaintext polynomial when the - desired encrypted computation is highly parallelizable. However, it has not - solved the other problem mentioned in the beginning of this file: each slot - holds only an integer modulo plain_modulus, and unless plain_modulus is very - large, we can quickly encounter data type overflow and get unexpected results - when integer computations are desired. Note that overflow cannot be detected - in encrypted form. The CKKS scheme (and the CKKSEncoder) addresses the data - type overflow issue, but at the cost of yielding only approximate results. - */ -} - -void example_ckks_encoder() -{ - print_example_banner("Example: Encoders / CKKS Encoder"); - - /* - [CKKSEncoder] (For CKKS scheme only) - - In this example we demonstrate the Cheon-Kim-Kim-Song (CKKS) scheme for - computing on encrypted real or complex numbers. We start by creating - encryption parameters for the CKKS scheme. There are two important - differences compared to the BFV scheme: - - (1) CKKS does not use the plain_modulus encryption parameter; - (2) Selecting the coeff_modulus in a specific way can be very important - when using the CKKS scheme. We will explain this further in the file - `ckks_basics.cpp'. In this example we use CoeffModulus::Create to - generate 5 40-bit prime numbers. - */ - EncryptionParameters parms(scheme_type::CKKS); - - size_t poly_modulus_degree = 8192; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::Create( - poly_modulus_degree, { 40, 40, 40, 40, 40 })); - - /* - We create the SEALContext as usual and print the parameters. - */ - auto context = SEALContext::Create(parms); - print_parameters(context); - cout << endl; - - /* - Keys are created the same way as for the BFV scheme. - */ - KeyGenerator keygen(context); - auto public_key = keygen.public_key(); - auto secret_key = keygen.secret_key(); - auto relin_keys = keygen.relin_keys(); - - /* - We also set up an Encryptor, Evaluator, and Decryptor as usual. - */ - Encryptor encryptor(context, public_key); - Evaluator evaluator(context); - Decryptor decryptor(context, secret_key); - - /* - To create CKKS plaintexts we need a special encoder: there is no other way - to create them. The IntegerEncoder and BatchEncoder cannot be used with the - CKKS scheme. The CKKSEncoder encodes vectors of real or complex numbers into - Plaintext objects, which can subsequently be encrypted. At a high level this - looks a lot like what BatchEncoder does for the BFV scheme, but the theory - behind it is completely different. - */ - CKKSEncoder encoder(context); - - /* - In CKKS the number of slots is poly_modulus_degree / 2 and each slot encodes - one real or complex number. This should be contrasted with BatchEncoder in - the BFV scheme, where the number of slots is equal to poly_modulus_degree - and they are arranged into a matrix with two rows. - */ - size_t slot_count = encoder.slot_count(); - cout << "Number of slots: " << slot_count << endl; - - /* - We create a small vector to encode; the CKKSEncoder will implicitly pad it - with zeros to full size (poly_modulus_degree / 2) when encoding. - */ - vector input{ 0.0, 1.1, 2.2, 3.3 }; - cout << "Input vector: " << endl; - print_vector(input); - - /* - Now we encode it with CKKSEncoder. The floating-point coefficients of `input' - will be scaled up by the parameter `scale'. This is necessary since even in - the CKKS scheme the plaintext elements are fundamentally polynomials with - integer coefficients. It is instructive to think of the scale as determining - the bit-precision of the encoding; naturally it will affect the precision of - the result. - - In CKKS the message is stored modulo coeff_modulus (in BFV it is stored modulo - plain_modulus), so the scaled message must not get too close to the total size - of coeff_modulus. In this case our coeff_modulus is quite large (218 bits) so - we have little to worry about in this regard. For this simple example a 30-bit - scale is more than enough. - */ - Plaintext plain; - double scale = pow(2.0, 30); - print_line(__LINE__); - cout << "Encode input vector." << endl; - encoder.encode(input, scale, plain); - - /* - We can instantly decode to check the correctness of encoding. - */ - vector output; - cout << " + Decode input vector ...... Correct." << endl; - encoder.decode(plain, output); - print_vector(output); - - /* - The vector is encrypted the same was as in BFV. - */ - Ciphertext encrypted; - print_line(__LINE__); - cout << "Encrypt input vector, square, and relinearize." << endl; - encryptor.encrypt(plain, encrypted); - - /* - Basic operations on the ciphertexts are still easy to do. Here we square the - ciphertext, decrypt, decode, and print the result. We note also that decoding - returns a vector of full size (poly_modulus_degree / 2); this is because of - the implicit zero-padding mentioned above. - */ - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, relin_keys); - - /* - We notice that the scale in the result has increased. In fact, it is now the - square of the original scale: 2^60. - */ - cout << " + Scale in squared input: " << encrypted.scale() - << " (" << log2(encrypted.scale()) << " bits)" << endl; - - print_line(__LINE__); - cout << "Decrypt and decode." << endl; - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - cout << " + Result vector ...... Correct." << endl; - print_vector(output); - - /* - The CKKS scheme allows the scale to be reduced between encrypted computations. - This is a fundamental and critical feature that makes CKKS very powerful and - flexible. We will discuss it in great detail in `3_levels.cpp' and later in - `4_ckks_basics.cpp'. - */ -} - -void example_encoders() -{ - print_example_banner("Example: Encoders"); - - /* - Run all encoder examples. - */ - example_integer_encoder(); - example_batch_encoder(); - example_ckks_encoder(); -} diff --git a/SEAL/native/examples/3_levels.cpp b/SEAL/native/examples/3_levels.cpp deleted file mode 100644 index 0da9b21..0000000 --- a/SEAL/native/examples/3_levels.cpp +++ /dev/null @@ -1,331 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "examples.h" - -using namespace std; -using namespace seal; - -void example_levels() -{ - print_example_banner("Example: Levels"); - - /* - In this examples we describe the concept of `levels' in BFV and CKKS and the - related objects that represent them in Microsoft SEAL. - - In Microsoft SEAL a set of encryption parameters (excluding the random number - generator) is identified uniquely by a SHA-3 hash of the parameters. This - hash is called the `parms_id' and can be easily accessed and printed at any - time. The hash will change as soon as any of the parameters is changed. - - When a SEALContext is created from a given EncryptionParameters instance, - Microsoft SEAL automatically creates a so-called `modulus switching chain', - which is a chain of other encryption parameters derived from the original set. - The parameters in the modulus switching chain are the same as the original - parameters with the exception that size of the coefficient modulus is - decreasing going down the chain. More precisely, each parameter set in the - chain attempts to remove the last coefficient modulus prime from the - previous set; this continues until the parameter set is no longer valid - (e.g., plain_modulus is larger than the remaining coeff_modulus). It is easy - to walk through the chain and access all the parameter sets. Additionally, - each parameter set in the chain has a `chain index' that indicates its - position in the chain so that the last set has index 0. We say that a set - of encryption parameters, or an object carrying those encryption parameters, - is at a higher level in the chain than another set of parameters if its the - chain index is bigger, i.e., it is earlier in the chain. - - Each set of parameters in the chain involves unique pre-computations performed - when the SEALContext is created, and stored in a SEALContext::ContextData - object. The chain is basically a linked list of SEALContext::ContextData - objects, and can easily be accessed through the SEALContext at any time. Each - node can be identified by the parms_id of its specific encryption parameters - (poly_modulus_degree remains the same but coeff_modulus varies). - */ - EncryptionParameters parms(scheme_type::BFV); - - size_t poly_modulus_degree = 8192; - parms.set_poly_modulus_degree(poly_modulus_degree); - - /* - In this example we use a custom coeff_modulus, consisting of 5 primes of - sizes 50, 30, 30, 50, and 50 bits. Note that this is still OK according to - the explanation in `1_bfv_basics.cpp'. Indeed, - - CoeffModulus::MaxBitCount(poly_modulus_degree) - - returns 218 (less than 50+30+30+50+50=210). - - Due to the modulus switching chain, the order of the 5 primes is significant. - The last prime has a special meaning and we call it the `special prime'. Thus, - the first parameter set in the modulus switching chain is the only one that - involves the special prime. All key objects, such as SecretKey, are created - at this highest level. All data objects, such as Ciphertext, can be only at - lower levels. The special modulus should be as large as the largest of the - other primes in the coeff_modulus, although this is not a strict requirement. - - special prime +---------+ - | - v - coeff_modulus: { 50, 30, 30, 50, 50 } +---+ Level 4 (all keys; `key level') - | - | - coeff_modulus: { 50, 30, 30, 50 } +---+ Level 3 (highest `data level') - | - | - coeff_modulus: { 50, 30, 30 } +---+ Level 2 - | - | - coeff_modulus: { 50, 30 } +---+ Level 1 - | - | - coeff_modulus: { 50 } +---+ Level 0 (lowest level) - */ - parms.set_coeff_modulus(CoeffModulus::Create( - poly_modulus_degree, { 50, 30, 30, 50, 50 })); - - /* - In this example the plain_modulus does not play much of a role; we choose - some reasonable value. - */ - parms.set_plain_modulus(1 << 20); - - auto context = SEALContext::Create(parms); - print_parameters(context); - cout << endl; - - /* - There are convenience method for accessing the SEALContext::ContextData for - some of the most important levels: - - SEALContext::key_context_data(): access to key level ContextData - SEALContext::first_context_data(): access to highest data level ContextData - SEALContext::last_context_data(): access to lowest level ContextData - - We iterate over the chain and print the parms_id for each set of parameters. - */ - print_line(__LINE__); - cout << "Print the modulus switching chain." << endl; - - /* - First print the key level parameter information. - */ - auto context_data = context->key_context_data(); - cout << "----> Level (chain index): " << context_data->chain_index(); - cout << " ...... key_context_data()" << endl; - cout << " parms_id: " << context_data->parms_id() << endl; - cout << " coeff_modulus primes: "; - cout << hex; - for(const auto &prime : context_data->parms().coeff_modulus()) - { - cout << prime.value() << " "; - } - cout << dec << endl; - cout << "\\" << endl; - cout << " \\-->"; - - /* - Next iterate over the remaining (data) levels. - */ - context_data = context->first_context_data(); - while (context_data) - { - cout << " Level (chain index): " << context_data->chain_index(); - if (context_data->parms_id() == context->first_parms_id()) - { - cout << " ...... first_context_data()" << endl; - } - else if (context_data->parms_id() == context->last_parms_id()) - { - cout << " ...... last_context_data()" << endl; - } - else - { - cout << endl; - } - cout << " parms_id: " << context_data->parms_id() << endl; - cout << " coeff_modulus primes: "; - cout << hex; - for(const auto &prime : context_data->parms().coeff_modulus()) - { - cout << prime.value() << " "; - } - cout << dec << endl; - cout << "\\" << endl; - cout << " \\-->"; - - /* - Step forward in the chain. - */ - context_data = context_data->next_context_data(); - } - cout << " End of chain reached" << endl << endl; - - /* - We create some keys and check that indeed they appear at the highest level. - */ - KeyGenerator keygen(context); - auto public_key = keygen.public_key(); - auto secret_key = keygen.secret_key(); - auto relin_keys = keygen.relin_keys(); - auto galois_keys = keygen.galois_keys(); - print_line(__LINE__); - cout << "Print the parameter IDs of generated elements." << endl; - cout << " + public_key: " << public_key.parms_id() << endl; - cout << " + secret_key: " << secret_key.parms_id() << endl; - cout << " + relin_keys: " << relin_keys.parms_id() << endl; - cout << " + galois_keys: " << galois_keys.parms_id() << endl; - - Encryptor encryptor(context, public_key); - Evaluator evaluator(context); - Decryptor decryptor(context, secret_key); - - /* - In the BFV scheme plaintexts do not carry a parms_id, but ciphertexts do. Note - how the freshly encrypted ciphertext is at the highest data level. - */ - Plaintext plain("1x^3 + 2x^2 + 3x^1 + 4"); - Ciphertext encrypted; - encryptor.encrypt(plain, encrypted); - cout << " + plain: " << plain.parms_id() << " (not set in BFV)" << endl; - cout << " + encrypted: " << encrypted.parms_id() << endl << endl; - - /* - `Modulus switching' is a technique of changing the ciphertext parameters down - in the chain. The function Evaluator::mod_switch_to_next always switches to - the next level down the chain, whereas Evaluator::mod_switch_to switches to - a parameter set down the chain corresponding to a given parms_id. However, it - is impossible to switch up in the chain. - */ - print_line(__LINE__); - cout << "Perform modulus switching on encrypted and print." << endl; - context_data = context->first_context_data(); - cout << "---->"; - while(context_data->next_context_data()) - { - cout << " Level (chain index): " << context_data->chain_index() << endl; - cout << " parms_id of encrypted: " << encrypted.parms_id() << endl; - cout << " Noise budget at this level: " - << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; - cout << "\\" << endl; - cout << " \\-->"; - evaluator.mod_switch_to_next_inplace(encrypted); - context_data = context_data->next_context_data(); - } - cout << " Level (chain index): " << context_data->chain_index() << endl; - cout << " parms_id of encrypted: " << encrypted.parms_id() << endl; - cout << " Noise budget at this level: " - << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; - cout << "\\" << endl; - cout << " \\-->"; - cout << " End of chain reached" << endl << endl; - - /* - At this point it is hard to see any benefit in doing this: we lost a huge - amount of noise budget (i.e., computational power) at each switch and seemed - to get nothing in return. Decryption still works. - */ - print_line(__LINE__); - cout << "Decrypt still works after modulus switching." << endl; - decryptor.decrypt(encrypted, plain); - cout << " + Decryption of encrypted: " << plain.to_string(); - cout << " ...... Correct." << endl << endl; - - /* - However, there is a hidden benefit: the size of the ciphertext depends - linearly on the number of primes in the coefficient modulus. Thus, if there - is no need or intention to perform any further computations on a given - ciphertext, we might as well switch it down to the smallest (last) set of - parameters in the chain before sending it back to the secret key holder for - decryption. - - Also the lost noise budget is actually not as issue at all, if we do things - right, as we will see below. - - First we recreate the original ciphertext and perform some computations. - */ - cout << "Computation is more efficient with modulus switching." << endl; - print_line(__LINE__); - cout << "Compute the fourth power." << endl; - encryptor.encrypt(plain, encrypted); - cout << " + Noise budget before squaring: " - << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, relin_keys); - cout << " + Noise budget after squaring: " - << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; - - /* - Surprisingly, in this case modulus switching has no effect at all on the - noise budget. - */ - evaluator.mod_switch_to_next_inplace(encrypted); - cout << " + Noise budget after modulus switching: " - << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; - - /* - This means that there is no harm at all in dropping some of the coefficient - modulus after doing enough computations. In some cases one might want to - switch to a lower level slightly earlier, actually sacrificing some of the - noise budget in the process, to gain computational performance from having - smaller parameters. We see from the print-out that the next modulus switch - should be done ideally when the noise budget is down to around 81 bits. - */ - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, relin_keys); - cout << " + Noise budget after squaring: " - << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; - evaluator.mod_switch_to_next_inplace(encrypted); - cout << " + Noise budget after modulus switching: " - << decryptor.invariant_noise_budget(encrypted) << " bits" << endl; - - /* - At this point the ciphertext still decrypts correctly, has very small size, - and the computation was as efficient as possible. Note that the decryptor - can be used to decrypt a ciphertext at any level in the modulus switching - chain. - */ - decryptor.decrypt(encrypted, plain); - cout << " + Decryption of fourth power (hexadecimal) ...... Correct." << endl; - cout << " " << plain.to_string() << endl << endl; - - /* - In BFV modulus switching is not necessary and in some cases the user might - not want to create the modulus switching chain, except for the highest two - levels. This can be done by passing a bool `false' to SEALContext::Create. - */ - context = SEALContext::Create(parms, false); - - /* - We can check that indeed the modulus switching chain has been created only - for the highest two levels (key level and highest data level). The following - loop should execute only once. - */ - cout << "Optionally disable modulus switching chain expansion." << endl; - print_line(__LINE__); - cout << "Print the modulus switching chain." << endl; - cout << "---->"; - for (context_data = context->key_context_data(); context_data; - context_data = context_data->next_context_data()) - { - cout << " Level (chain index): " << context_data->chain_index() << endl; - cout << " parms_id: " << context_data->parms_id() << endl; - cout << " coeff_modulus primes: "; - cout << hex; - for (const auto &prime : context_data->parms().coeff_modulus()) - { - cout << prime.value() << " "; - } - cout << dec << endl; - cout << "\\" << endl; - cout << " \\-->"; - } - cout << " End of chain reached" << endl << endl; - - /* - It is very important to understand how this example works since in the CKKS - scheme modulus switching has a much more fundamental purpose and the next - examples will be difficult to understand unless these basic properties are - totally clear. - */ -} diff --git a/SEAL/native/examples/4_ckks_basics.cpp b/SEAL/native/examples/4_ckks_basics.cpp deleted file mode 100644 index 17d8eee..0000000 --- a/SEAL/native/examples/4_ckks_basics.cpp +++ /dev/null @@ -1,321 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "examples.h" - -using namespace std; -using namespace seal; - -void example_ckks_basics() -{ - print_example_banner("Example: CKKS Basics"); - - /* - In this example we demonstrate evaluating a polynomial function - - PI*x^3 + 0.4*x + 1 - - on encrypted floating-point input data x for a set of 4096 equidistant points - in the interval [0, 1]. This example demonstrates many of the main features - of the CKKS scheme, but also the challenges in using it. - - We start by setting up the CKKS scheme. - */ - EncryptionParameters parms(scheme_type::CKKS); - - /* - We saw in `2_encoders.cpp' that multiplication in CKKS causes scales - in ciphertexts to grow. The scale of any ciphertext must not get too close - to the total size of coeff_modulus, or else the ciphertext simply runs out of - room to store the scaled-up plaintext. The CKKS scheme provides a `rescale' - functionality that can reduce the scale, and stabilize the scale expansion. - - Rescaling is a kind of modulus switch operation (recall `3_levels.cpp'). - As modulus switching, it removes the last of the primes from coeff_modulus, - but as a side-effect it scales down the ciphertext by the removed prime. - Usually we want to have perfect control over how the scales are changed, - which is why for the CKKS scheme it is more common to use carefully selected - primes for the coeff_modulus. - - More precisely, suppose that the scale in a CKKS ciphertext is S, and the - last prime in the current coeff_modulus (for the ciphertext) is P. Rescaling - to the next level changes the scale to S/P, and removes the prime P from the - coeff_modulus, as usual in modulus switching. The number of primes limits - how many rescalings can be done, and thus limits the multiplicative depth of - the computation. - - It is possible to choose the initial scale freely. One good strategy can be - to is to set the initial scale S and primes P_i in the coeff_modulus to be - very close to each other. If ciphertexts have scale S before multiplication, - they have scale S^2 after multiplication, and S^2/P_i after rescaling. If all - P_i are close to S, then S^2/P_i is close to S again. This way we stabilize the - scales to be close to S throughout the computation. Generally, for a circuit - of depth D, we need to rescale D times, i.e., we need to be able to remove D - primes from the coefficient modulus. Once we have only one prime left in the - coeff_modulus, the remaining prime must be larger than S by a few bits to - preserve the pre-decimal-point value of the plaintext. - - Therefore, a generally good strategy is to choose parameters for the CKKS - scheme as follows: - - (1) Choose a 60-bit prime as the first prime in coeff_modulus. This will - give the highest precision when decrypting; - (2) Choose another 60-bit prime as the last element of coeff_modulus, as - this will be used as the special prime and should be as large as the - largest of the other primes; - (3) Choose the intermediate primes to be close to each other. - - We use CoeffModulus::Create to generate primes of the appropriate size. Note - that our coeff_modulus is 200 bits total, which is below the bound for our - poly_modulus_degree: CoeffModulus::MaxBitCount(8192) returns 218. - */ - size_t poly_modulus_degree = 8192; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::Create( - poly_modulus_degree, { 60, 40, 40, 60 })); - - /* - We choose the initial scale to be 2^40. At the last level, this leaves us - 60-40=20 bits of precision before the decimal point, and enough (roughly - 10-20 bits) of precision after the decimal point. Since our intermediate - primes are 40 bits (in fact, they are very close to 2^40), we can achieve - scale stabilization as described above. - */ - double scale = pow(2.0, 40); - - auto context = SEALContext::Create(parms); - print_parameters(context); - cout << endl; - - KeyGenerator keygen(context); - auto public_key = keygen.public_key(); - auto secret_key = keygen.secret_key(); - auto relin_keys = keygen.relin_keys(); - Encryptor encryptor(context, public_key); - Evaluator evaluator(context); - Decryptor decryptor(context, secret_key); - - CKKSEncoder encoder(context); - size_t slot_count = encoder.slot_count(); - cout << "Number of slots: " << slot_count << endl; - - vector input; - input.reserve(slot_count); - double curr_point = 0; - double step_size = 1.0 / (static_cast(slot_count) - 1); - for (size_t i = 0; i < slot_count; i++, curr_point += step_size) - { - input.push_back(curr_point); - } - cout << "Input vector: " << endl; - print_vector(input, 3, 7); - - cout << "Evaluating polynomial PI*x^3 + 0.4x + 1 ..." << endl; - - /* - We create plaintexts for PI, 0.4, and 1 using an overload of CKKSEncoder::encode - that encodes the given floating-point value to every slot in the vector. - */ - Plaintext plain_coeff3, plain_coeff1, plain_coeff0; - encoder.encode(3.14159265, scale, plain_coeff3); - encoder.encode(0.4, scale, plain_coeff1); - encoder.encode(1.0, scale, plain_coeff0); - - Plaintext x_plain; - print_line(__LINE__); - cout << "Encode input vectors." << endl; - encoder.encode(input, scale, x_plain); - Ciphertext x1_encrypted; - encryptor.encrypt(x_plain, x1_encrypted); - - /* - To compute x^3 we first compute x^2 and relinearize. However, the scale has - now grown to 2^80. - */ - Ciphertext x3_encrypted; - print_line(__LINE__); - cout << "Compute x^2 and relinearize:" << endl; - evaluator.square(x1_encrypted, x3_encrypted); - evaluator.relinearize_inplace(x3_encrypted, relin_keys); - cout << " + Scale of x^2 before rescale: " << log2(x3_encrypted.scale()) - << " bits" << endl; - - /* - Now rescale; in addition to a modulus switch, the scale is reduced down by - a factor equal to the prime that was switched away (40-bit prime). Hence, the - new scale should be close to 2^40. Note, however, that the scale is not equal - to 2^40: this is because the 40-bit prime is only close to 2^40. - */ - print_line(__LINE__); - cout << "Rescale x^2." << endl; - evaluator.rescale_to_next_inplace(x3_encrypted); - cout << " + Scale of x^2 after rescale: " << log2(x3_encrypted.scale()) - << " bits" << endl; - - /* - Now x3_encrypted is at a different level than x1_encrypted, which prevents us - from multiplying them to compute x^3. We could simply switch x1_encrypted to - the next parameters in the modulus switching chain. However, since we still - need to multiply the x^3 term with PI (plain_coeff3), we instead compute PI*x - first and multiply that with x^2 to obtain PI*x^3. To this end, we compute - PI*x and rescale it back from scale 2^80 to something close to 2^40. - */ - print_line(__LINE__); - cout << "Compute and rescale PI*x." << endl; - Ciphertext x1_encrypted_coeff3; - evaluator.multiply_plain(x1_encrypted, plain_coeff3, x1_encrypted_coeff3); - cout << " + Scale of PI*x before rescale: " << log2(x1_encrypted_coeff3.scale()) - << " bits" << endl; - evaluator.rescale_to_next_inplace(x1_encrypted_coeff3); - cout << " + Scale of PI*x after rescale: " << log2(x1_encrypted_coeff3.scale()) - << " bits" << endl; - - /* - Since x3_encrypted and x1_encrypted_coeff3 have the same exact scale and use - the same encryption parameters, we can multiply them together. We write the - result to x3_encrypted, relinearize, and rescale. Note that again the scale - is something close to 2^40, but not exactly 2^40 due to yet another scaling - by a prime. We are down to the last level in the modulus switching chain. - */ - print_line(__LINE__); - cout << "Compute, relinearize, and rescale (PI*x)*x^2." << endl; - evaluator.multiply_inplace(x3_encrypted, x1_encrypted_coeff3); - evaluator.relinearize_inplace(x3_encrypted, relin_keys); - cout << " + Scale of PI*x^3 before rescale: " << log2(x3_encrypted.scale()) - << " bits" << endl; - evaluator.rescale_to_next_inplace(x3_encrypted); - cout << " + Scale of PI*x^3 after rescale: " << log2(x3_encrypted.scale()) - << " bits" << endl; - - /* - Next we compute the degree one term. All this requires is one multiply_plain - with plain_coeff1. We overwrite x1_encrypted with the result. - */ - print_line(__LINE__); - cout << "Compute and rescale 0.4*x." << endl; - evaluator.multiply_plain_inplace(x1_encrypted, plain_coeff1); - cout << " + Scale of 0.4*x before rescale: " << log2(x1_encrypted.scale()) - << " bits" << endl; - evaluator.rescale_to_next_inplace(x1_encrypted); - cout << " + Scale of 0.4*x after rescale: " << log2(x1_encrypted.scale()) - << " bits" << endl; - - /* - Now we would hope to compute the sum of all three terms. However, there is - a serious problem: the encryption parameters used by all three terms are - different due to modulus switching from rescaling. - - Encrypted addition and subtraction require that the scales of the inputs are - the same, and also that the encryption parameters (parms_id) match. If there - is a mismatch, Evaluator will throw an exception. - */ - cout << endl; - print_line(__LINE__); - cout << "Parameters used by all three terms are different." << endl; - cout << " + Modulus chain index for x3_encrypted: " - << context->get_context_data(x3_encrypted.parms_id())->chain_index() << endl; - cout << " + Modulus chain index for x1_encrypted: " - << context->get_context_data(x1_encrypted.parms_id())->chain_index() << endl; - cout << " + Modulus chain index for plain_coeff0: " - << context->get_context_data(plain_coeff0.parms_id())->chain_index() << endl; - cout << endl; - - /* - Let us carefully consider what the scales are at this point. We denote the - primes in coeff_modulus as P_0, P_1, P_2, P_3, in this order. P_3 is used as - the special modulus and is not involved in rescalings. After the computations - above the scales in ciphertexts are: - - - Product x^2 has scale 2^80 and is at level 2; - - Product PI*x has scale 2^80 and is at level 2; - - We rescaled both down to scale 2^80/P_2 and level 1; - - Product PI*x^3 has scale (2^80/P_2)^2; - - We rescaled it down to scale (2^80/P_2)^2/P_1 and level 0; - - Product 0.4*x has scale 2^80; - - We rescaled it down to scale 2^80/P_2 and level 1; - - The contant term 1 has scale 2^40 and is at level 2. - - Although the scales of all three terms are approximately 2^40, their exact - values are different, hence they cannot be added together. - */ - print_line(__LINE__); - cout << "The exact scales of all three terms are different:" << endl; - ios old_fmt(nullptr); - old_fmt.copyfmt(cout); - cout << fixed << setprecision(10); - cout << " + Exact scale in PI*x^3: " << x3_encrypted.scale() << endl; - cout << " + Exact scale in 0.4*x: " << x1_encrypted.scale() << endl; - cout << " + Exact scale in 1: " << plain_coeff0.scale() << endl; - cout << endl; - cout.copyfmt(old_fmt); - - /* - There are many ways to fix this problem. Since P_2 and P_1 are really close - to 2^40, we can simply "lie" to Microsoft SEAL and set the scales to be the - same. For example, changing the scale of PI*x^3 to 2^40 simply means that we - scale the value of PI*x^3 by 2^120/(P_2^2*P_1), which is very close to 1. - This should not result in any noticeable error. - - Another option would be to encode 1 with scale 2^80/P_2, do a multiply_plain - with 0.4*x, and finally rescale. In this case we would need to additionally - make sure to encode 1 with appropriate encryption parameters (parms_id). - - In this example we will use the first (simplest) approach and simply change - the scale of PI*x^3 and 0.4*x to 2^40. - */ - print_line(__LINE__); - cout << "Normalize scales to 2^40." << endl; - x3_encrypted.scale() = pow(2.0, 40); - x1_encrypted.scale() = pow(2.0, 40); - - /* - We still have a problem with mismatching encryption parameters. This is easy - to fix by using traditional modulus switching (no rescaling). CKKS supports - modulus switching just like the BFV scheme, allowing us to switch away parts - of the coefficient modulus when it is simply not needed. - */ - print_line(__LINE__); - cout << "Normalize encryption parameters to the lowest level." << endl; - parms_id_type last_parms_id = x3_encrypted.parms_id(); - evaluator.mod_switch_to_inplace(x1_encrypted, last_parms_id); - evaluator.mod_switch_to_inplace(plain_coeff0, last_parms_id); - - /* - All three ciphertexts are now compatible and can be added. - */ - print_line(__LINE__); - cout << "Compute PI*x^3 + 0.4*x + 1." << endl; - Ciphertext encrypted_result; - evaluator.add(x3_encrypted, x1_encrypted, encrypted_result); - evaluator.add_plain_inplace(encrypted_result, plain_coeff0); - - /* - First print the true result. - */ - Plaintext plain_result; - print_line(__LINE__); - cout << "Decrypt and decode PI*x^3 + 0.4x + 1." << endl; - cout << " + Expected result:" << endl; - vector true_result; - for (size_t i = 0; i < input.size(); i++) - { - double x = input[i]; - true_result.push_back((3.14159265 * x * x + 0.4)* x + 1); - } - print_vector(true_result, 3, 7); - - /* - Decrypt, decode, and print the result. - */ - decryptor.decrypt(encrypted_result, plain_result); - vector result; - encoder.decode(plain_result, result); - cout << " + Computed result ...... Correct." << endl; - print_vector(result, 3, 7); - - /* - While we did not show any computations on complex numbers in these examples, - the CKKSEncoder would allow us to have done that just as easily. Additions - and multiplications of complex numbers behave just as one would expect. - */ -} \ No newline at end of file diff --git a/SEAL/native/examples/5_rotation.cpp b/SEAL/native/examples/5_rotation.cpp deleted file mode 100644 index 487f156..0000000 --- a/SEAL/native/examples/5_rotation.cpp +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "examples.h" - -using namespace std; -using namespace seal; - -/* -Both the BFV scheme (with BatchEncoder) as well as the CKKS scheme support native -vectorized computations on encrypted numbers. In addition to computing slot-wise, -it is possible to rotate the encrypted vectors cyclically. -*/ -void example_rotation_bfv() -{ - print_example_banner("Example: Rotation / Rotation in BFV"); - - EncryptionParameters parms(scheme_type::BFV); - - size_t poly_modulus_degree = 8192; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - parms.set_plain_modulus(PlainModulus::Batching(poly_modulus_degree, 20)); - - auto context = SEALContext::Create(parms); - print_parameters(context); - cout << endl; - - KeyGenerator keygen(context); - PublicKey public_key = keygen.public_key(); - SecretKey secret_key = keygen.secret_key(); - RelinKeys relin_keys = keygen.relin_keys(); - Encryptor encryptor(context, public_key); - Evaluator evaluator(context); - Decryptor decryptor(context, secret_key); - - BatchEncoder batch_encoder(context); - size_t slot_count = batch_encoder.slot_count(); - size_t row_size = slot_count / 2; - cout << "Plaintext matrix row size: " << row_size << endl; - - vector pod_matrix(slot_count, 0ULL); - pod_matrix[0] = 0ULL; - pod_matrix[1] = 1ULL; - pod_matrix[2] = 2ULL; - pod_matrix[3] = 3ULL; - pod_matrix[row_size] = 4ULL; - pod_matrix[row_size + 1] = 5ULL; - pod_matrix[row_size + 2] = 6ULL; - pod_matrix[row_size + 3] = 7ULL; - - cout << "Input plaintext matrix:" << endl; - print_matrix(pod_matrix, row_size); - - /* - First we use BatchEncoder to encode the matrix into a plaintext. We encrypt - the plaintext as usual. - */ - Plaintext plain_matrix; - print_line(__LINE__); - cout << "Encode and encrypt." << endl; - batch_encoder.encode(pod_matrix, plain_matrix); - Ciphertext encrypted_matrix; - encryptor.encrypt(plain_matrix, encrypted_matrix); - cout << " + Noise budget in fresh encryption: " - << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; - cout << endl; - - /* - Rotations require yet another type of special key called `Galois keys'. These - are easily obtained from the KeyGenerator. - */ - GaloisKeys gal_keys = keygen.galois_keys(); - - /* - Now rotate both matrix rows 3 steps to the left, decrypt, decode, and print. - */ - print_line(__LINE__); - cout << "Rotate rows 3 steps left." << endl; - evaluator.rotate_rows_inplace(encrypted_matrix, 3, gal_keys); - Plaintext plain_result; - cout << " + Noise budget after rotation: " - << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; - cout << " + Decrypt and decode ...... Correct." << endl; - decryptor.decrypt(encrypted_matrix, plain_result); - batch_encoder.decode(plain_result, pod_matrix); - print_matrix(pod_matrix, row_size); - - /* - We can also rotate the columns, i.e., swap the rows. - */ - print_line(__LINE__); - cout << "Rotate columns." << endl; - evaluator.rotate_columns_inplace(encrypted_matrix, gal_keys); - cout << " + Noise budget after rotation: " - << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; - cout << " + Decrypt and decode ...... Correct." << endl; - decryptor.decrypt(encrypted_matrix, plain_result); - batch_encoder.decode(plain_result, pod_matrix); - print_matrix(pod_matrix, row_size); - - /* - Finally, we rotate the rows 4 steps to the right, decrypt, decode, and print. - */ - print_line(__LINE__); - cout << "Rotate rows 4 steps right." << endl; - evaluator.rotate_rows_inplace(encrypted_matrix, -4, gal_keys); - cout << " + Noise budget after rotation: " - << decryptor.invariant_noise_budget(encrypted_matrix) << " bits" << endl; - cout << " + Decrypt and decode ...... Correct." << endl; - decryptor.decrypt(encrypted_matrix, plain_result); - batch_encoder.decode(plain_result, pod_matrix); - print_matrix(pod_matrix, row_size); - - /* - Note that rotations do not consume any noise budget. However, this is only - the case when the special prime is at least as large as the other primes. The - same holds for relinearization. Microsoft SEAL does not require that the - special prime is of any particular size, so ensuring this is the case is left - for the user to do. - */ -} - -void example_rotation_ckks() -{ - print_example_banner("Example: Rotation / Rotation in CKKS"); - - /* - Rotations in the CKKS scheme work very similarly to rotations in BFV. - */ - EncryptionParameters parms(scheme_type::CKKS); - - size_t poly_modulus_degree = 8192; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::Create( - poly_modulus_degree, { 40, 40, 40, 40, 40 })); - - auto context = SEALContext::Create(parms); - print_parameters(context); - cout << endl; - - KeyGenerator keygen(context); - PublicKey public_key = keygen.public_key(); - SecretKey secret_key = keygen.secret_key(); - RelinKeys relin_keys = keygen.relin_keys(); - GaloisKeys gal_keys = keygen.galois_keys(); - Encryptor encryptor(context, public_key); - Evaluator evaluator(context); - Decryptor decryptor(context, secret_key); - - CKKSEncoder ckks_encoder(context); - - size_t slot_count = ckks_encoder.slot_count(); - cout << "Number of slots: " << slot_count << endl; - vector input; - input.reserve(slot_count); - double curr_point = 0; - double step_size = 1.0 / (static_cast(slot_count) - 1); - for (size_t i = 0; i < slot_count; i++, curr_point += step_size) - { - input.push_back(curr_point); - } - cout << "Input vector:" << endl; - print_vector(input, 3, 7); - - auto scale = pow(2.0, 50); - - print_line(__LINE__); - cout << "Encode and encrypt." << endl; - Plaintext plain; - ckks_encoder.encode(input, scale, plain); - Ciphertext encrypted; - encryptor.encrypt(plain, encrypted); - - Ciphertext rotated; - print_line(__LINE__); - cout << "Rotate 2 steps left." << endl; - evaluator.rotate_vector(encrypted, 2, gal_keys, rotated); - cout << " + Decrypt and decode ...... Correct." << endl; - decryptor.decrypt(rotated, plain); - vector result; - ckks_encoder.decode(plain, result); - print_vector(result, 3, 7); - - /* - With the CKKS scheme it is also possible to evaluate a complex conjugation on - a vector of encrypted complex numbers, using Evaluator::complex_conjugate. - This is in fact a kind of rotation, and requires also Galois keys. - */ -} - -void example_rotation() -{ - print_example_banner("Example: Rotation"); - - /* - Run all rotation examples. - */ - example_rotation_bfv(); - example_rotation_ckks(); -} \ No newline at end of file diff --git a/SEAL/native/examples/6_performance.cpp b/SEAL/native/examples/6_performance.cpp deleted file mode 100644 index 4ba3a66..0000000 --- a/SEAL/native/examples/6_performance.cpp +++ /dev/null @@ -1,758 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "examples.h" - -using namespace std; -using namespace seal; - -void bfv_performance_test(shared_ptr context) -{ - chrono::high_resolution_clock::time_point time_start, time_end; - - print_parameters(context); - cout << endl; - - auto &parms = context->first_context_data()->parms(); - auto &plain_modulus = parms.plain_modulus(); - size_t poly_modulus_degree = parms.poly_modulus_degree(); - - cout << "Generating secret/public keys: "; - KeyGenerator keygen(context); - cout << "Done" << endl; - - auto secret_key = keygen.secret_key(); - auto public_key = keygen.public_key(); - - RelinKeys relin_keys; - GaloisKeys gal_keys; - chrono::microseconds time_diff; - if (context->using_keyswitching()) - { - /* - Generate relinearization keys. - */ - cout << "Generating relinearization keys: "; - time_start = chrono::high_resolution_clock::now(); - relin_keys = keygen.relin_keys(); - time_end = chrono::high_resolution_clock::now(); - time_diff = chrono::duration_cast(time_end - time_start); - cout << "Done [" << time_diff.count() << " microseconds]" << endl; - - if (!context->key_context_data()->qualifiers().using_batching) - { - cout << "Given encryption parameters do not support batching." << endl; - return; - } - - /* - Generate Galois keys. In larger examples the Galois keys can use a lot of - memory, which can be a problem in constrained systems. The user should - try some of the larger runs of the test and observe their effect on the - memory pool allocation size. The key generation can also take a long time, - as can be observed from the print-out. - */ - cout << "Generating Galois keys: "; - time_start = chrono::high_resolution_clock::now(); - gal_keys = keygen.galois_keys(); - time_end = chrono::high_resolution_clock::now(); - time_diff = chrono::duration_cast(time_end - time_start); - cout << "Done [" << time_diff.count() << " microseconds]" << endl; - } - - Encryptor encryptor(context, public_key); - Decryptor decryptor(context, secret_key); - Evaluator evaluator(context); - BatchEncoder batch_encoder(context); - IntegerEncoder encoder(context); - - /* - These will hold the total times used by each operation. - */ - chrono::microseconds time_batch_sum(0); - chrono::microseconds time_unbatch_sum(0); - chrono::microseconds time_encrypt_sum(0); - chrono::microseconds time_decrypt_sum(0); - chrono::microseconds time_add_sum(0); - chrono::microseconds time_multiply_sum(0); - chrono::microseconds time_multiply_plain_sum(0); - chrono::microseconds time_square_sum(0); - chrono::microseconds time_relinearize_sum(0); - chrono::microseconds time_rotate_rows_one_step_sum(0); - chrono::microseconds time_rotate_rows_random_sum(0); - chrono::microseconds time_rotate_columns_sum(0); - - /* - How many times to run the test? - */ - int count = 10; - - /* - Populate a vector of values to batch. - */ - size_t slot_count = batch_encoder.slot_count(); - vector pod_vector; - random_device rd; - for (size_t i = 0; i < slot_count; i++) - { - pod_vector.push_back(rd() % plain_modulus.value()); - } - - cout << "Running tests "; - for (int i = 0; i < count; i++) - { - /* - [Batching] - There is nothing unusual here. We batch our random plaintext matrix - into the polynomial. Note how the plaintext we create is of the exactly - right size so unnecessary reallocations are avoided. - */ - Plaintext plain(parms.poly_modulus_degree(), 0); - time_start = chrono::high_resolution_clock::now(); - batch_encoder.encode(pod_vector, plain); - time_end = chrono::high_resolution_clock::now(); - time_batch_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Unbatching] - We unbatch what we just batched. - */ - vector pod_vector2(slot_count); - time_start = chrono::high_resolution_clock::now(); - batch_encoder.decode(plain, pod_vector2); - time_end = chrono::high_resolution_clock::now(); - time_unbatch_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - if (pod_vector2 != pod_vector) - { - throw runtime_error("Batch/unbatch failed. Something is wrong."); - } - - /* - [Encryption] - We make sure our ciphertext is already allocated and large enough - to hold the encryption with these encryption parameters. We encrypt - our random batched matrix here. - */ - Ciphertext encrypted(context); - time_start = chrono::high_resolution_clock::now(); - encryptor.encrypt(plain, encrypted); - time_end = chrono::high_resolution_clock::now(); - time_encrypt_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Decryption] - We decrypt what we just encrypted. - */ - Plaintext plain2(poly_modulus_degree, 0); - time_start = chrono::high_resolution_clock::now(); - decryptor.decrypt(encrypted, plain2); - time_end = chrono::high_resolution_clock::now(); - time_decrypt_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - if (plain2 != plain) - { - throw runtime_error("Encrypt/decrypt failed. Something is wrong."); - } - - /* - [Add] - We create two ciphertexts and perform a few additions with them. - */ - Ciphertext encrypted1(context); - encryptor.encrypt(encoder.encode(i), encrypted1); - Ciphertext encrypted2(context); - encryptor.encrypt(encoder.encode(i + 1), encrypted2); - time_start = chrono::high_resolution_clock::now(); - evaluator.add_inplace(encrypted1, encrypted1); - evaluator.add_inplace(encrypted2, encrypted2); - evaluator.add_inplace(encrypted1, encrypted2); - time_end = chrono::high_resolution_clock::now(); - time_add_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Multiply] - We multiply two ciphertexts. Since the size of the result will be 3, - and will overwrite the first argument, we reserve first enough memory - to avoid reallocating during multiplication. - */ - encrypted1.reserve(3); - time_start = chrono::high_resolution_clock::now(); - evaluator.multiply_inplace(encrypted1, encrypted2); - time_end = chrono::high_resolution_clock::now(); - time_multiply_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Multiply Plain] - We multiply a ciphertext with a random plaintext. Recall that - multiply_plain does not change the size of the ciphertext so we use - encrypted2 here. - */ - time_start = chrono::high_resolution_clock::now(); - evaluator.multiply_plain_inplace(encrypted2, plain); - time_end = chrono::high_resolution_clock::now(); - time_multiply_plain_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Square] - We continue to use encrypted2. Now we square it; this should be - faster than generic homomorphic multiplication. - */ - time_start = chrono::high_resolution_clock::now(); - evaluator.square_inplace(encrypted2); - time_end = chrono::high_resolution_clock::now(); - time_square_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - if (context->using_keyswitching()) - { - /* - [Relinearize] - Time to get back to encrypted1. We now relinearize it back - to size 2. Since the allocation is currently big enough to - contain a ciphertext of size 3, no costly reallocations are - needed in the process. - */ - time_start = chrono::high_resolution_clock::now(); - evaluator.relinearize_inplace(encrypted1, relin_keys); - time_end = chrono::high_resolution_clock::now(); - time_relinearize_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Rotate Rows One Step] - We rotate matrix rows by one step left and measure the time. - */ - time_start = chrono::high_resolution_clock::now(); - evaluator.rotate_rows_inplace(encrypted, 1, gal_keys); - evaluator.rotate_rows_inplace(encrypted, -1, gal_keys); - time_end = chrono::high_resolution_clock::now(); - time_rotate_rows_one_step_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start);; - - /* - [Rotate Rows Random] - We rotate matrix rows by a random number of steps. This is much more - expensive than rotating by just one step. - */ - size_t row_size = batch_encoder.slot_count() / 2; - int random_rotation = static_cast(rd() % row_size); - time_start = chrono::high_resolution_clock::now(); - evaluator.rotate_rows_inplace(encrypted, random_rotation, gal_keys); - time_end = chrono::high_resolution_clock::now(); - time_rotate_rows_random_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Rotate Columns] - Nothing surprising here. - */ - time_start = chrono::high_resolution_clock::now(); - evaluator.rotate_columns_inplace(encrypted, gal_keys); - time_end = chrono::high_resolution_clock::now(); - time_rotate_columns_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - } - - /* - Print a dot to indicate progress. - */ - cout << "."; - cout.flush(); - } - - cout << " Done" << endl << endl; - cout.flush(); - - auto avg_batch = time_batch_sum.count() / count; - auto avg_unbatch = time_unbatch_sum.count() / count; - auto avg_encrypt = time_encrypt_sum.count() / count; - auto avg_decrypt = time_decrypt_sum.count() / count; - auto avg_add = time_add_sum.count() / (3 * count); - auto avg_multiply = time_multiply_sum.count() / count; - auto avg_multiply_plain = time_multiply_plain_sum.count() / count; - auto avg_square = time_square_sum.count() / count; - auto avg_relinearize = time_relinearize_sum.count() / count; - auto avg_rotate_rows_one_step = time_rotate_rows_one_step_sum.count() / (2 * count); - auto avg_rotate_rows_random = time_rotate_rows_random_sum.count() / count; - auto avg_rotate_columns = time_rotate_columns_sum.count() / count; - - cout << "Average batch: " << avg_batch << " microseconds" << endl; - cout << "Average unbatch: " << avg_unbatch << " microseconds" << endl; - cout << "Average encrypt: " << avg_encrypt << " microseconds" << endl; - cout << "Average decrypt: " << avg_decrypt << " microseconds" << endl; - cout << "Average add: " << avg_add << " microseconds" << endl; - cout << "Average multiply: " << avg_multiply << " microseconds" << endl; - cout << "Average multiply plain: " << avg_multiply_plain << " microseconds" << endl; - cout << "Average square: " << avg_square << " microseconds" << endl; - if (context->using_keyswitching()) - { - cout << "Average relinearize: " << avg_relinearize << " microseconds" << endl; - cout << "Average rotate rows one step: " << avg_rotate_rows_one_step << - " microseconds" << endl; - cout << "Average rotate rows random: " << avg_rotate_rows_random << - " microseconds" << endl; - cout << "Average rotate columns: " << avg_rotate_columns << - " microseconds" << endl; - } - cout.flush(); -} - -void ckks_performance_test(shared_ptr context) -{ - chrono::high_resolution_clock::time_point time_start, time_end; - - print_parameters(context); - cout << endl; - - auto &parms = context->first_context_data()->parms(); - size_t poly_modulus_degree = parms.poly_modulus_degree(); - - cout << "Generating secret/public keys: "; - KeyGenerator keygen(context); - cout << "Done" << endl; - - auto secret_key = keygen.secret_key(); - auto public_key = keygen.public_key(); - - RelinKeys relin_keys; - GaloisKeys gal_keys; - chrono::microseconds time_diff; - if (context->using_keyswitching()) - { - cout << "Generating relinearization keys: "; - time_start = chrono::high_resolution_clock::now(); - relin_keys = keygen.relin_keys(); - time_end = chrono::high_resolution_clock::now(); - time_diff = chrono::duration_cast(time_end - time_start); - cout << "Done [" << time_diff.count() << " microseconds]" << endl; - - if (!context->first_context_data()->qualifiers().using_batching) - { - cout << "Given encryption parameters do not support batching." << endl; - return; - } - - cout << "Generating Galois keys: "; - time_start = chrono::high_resolution_clock::now(); - gal_keys = keygen.galois_keys(); - time_end = chrono::high_resolution_clock::now(); - time_diff = chrono::duration_cast(time_end - time_start); - cout << "Done [" << time_diff.count() << " microseconds]" << endl; - } - - Encryptor encryptor(context, public_key); - Decryptor decryptor(context, secret_key); - Evaluator evaluator(context); - CKKSEncoder ckks_encoder(context); - - chrono::microseconds time_encode_sum(0); - chrono::microseconds time_decode_sum(0); - chrono::microseconds time_encrypt_sum(0); - chrono::microseconds time_decrypt_sum(0); - chrono::microseconds time_add_sum(0); - chrono::microseconds time_multiply_sum(0); - chrono::microseconds time_multiply_plain_sum(0); - chrono::microseconds time_square_sum(0); - chrono::microseconds time_relinearize_sum(0); - chrono::microseconds time_rescale_sum(0); - chrono::microseconds time_rotate_one_step_sum(0); - chrono::microseconds time_rotate_random_sum(0); - chrono::microseconds time_conjugate_sum(0); - - /* - How many times to run the test? - */ - int count = 10; - - /* - Populate a vector of floating-point values to batch. - */ - vector pod_vector; - random_device rd; - for (size_t i = 0; i < ckks_encoder.slot_count(); i++) - { - pod_vector.push_back(1.001 * static_cast(i)); - } - - cout << "Running tests "; - for (int i = 0; i < count; i++) - { - /* - [Encoding] - For scale we use the square root of the last coeff_modulus prime - from parms. - */ - Plaintext plain(parms.poly_modulus_degree() * - parms.coeff_modulus().size(), 0); - /* - - */ - double scale = sqrt(static_cast( - parms.coeff_modulus().back().value())); - time_start = chrono::high_resolution_clock::now(); - ckks_encoder.encode(pod_vector, scale, plain); - time_end = chrono::high_resolution_clock::now(); - time_encode_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Decoding] - */ - vector pod_vector2(ckks_encoder.slot_count()); - time_start = chrono::high_resolution_clock::now(); - ckks_encoder.decode(plain, pod_vector2); - time_end = chrono::high_resolution_clock::now(); - time_decode_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Encryption] - */ - Ciphertext encrypted(context); - time_start = chrono::high_resolution_clock::now(); - encryptor.encrypt(plain, encrypted); - time_end = chrono::high_resolution_clock::now(); - time_encrypt_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Decryption] - */ - Plaintext plain2(poly_modulus_degree, 0); - time_start = chrono::high_resolution_clock::now(); - decryptor.decrypt(encrypted, plain2); - time_end = chrono::high_resolution_clock::now(); - time_decrypt_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Add] - */ - Ciphertext encrypted1(context); - ckks_encoder.encode(i + 1, plain); - encryptor.encrypt(plain, encrypted1); - Ciphertext encrypted2(context); - ckks_encoder.encode(i + 1, plain2); - encryptor.encrypt(plain2, encrypted2); - time_start = chrono::high_resolution_clock::now(); - evaluator.add_inplace(encrypted1, encrypted1); - evaluator.add_inplace(encrypted2, encrypted2); - evaluator.add_inplace(encrypted1, encrypted2); - time_end = chrono::high_resolution_clock::now(); - time_add_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Multiply] - */ - encrypted1.reserve(3); - time_start = chrono::high_resolution_clock::now(); - evaluator.multiply_inplace(encrypted1, encrypted2); - time_end = chrono::high_resolution_clock::now(); - time_multiply_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Multiply Plain] - */ - time_start = chrono::high_resolution_clock::now(); - evaluator.multiply_plain_inplace(encrypted2, plain); - time_end = chrono::high_resolution_clock::now(); - time_multiply_plain_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Square] - */ - time_start = chrono::high_resolution_clock::now(); - evaluator.square_inplace(encrypted2); - time_end = chrono::high_resolution_clock::now(); - time_square_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - if (context->using_keyswitching()) - { - /* - [Relinearize] - */ - time_start = chrono::high_resolution_clock::now(); - evaluator.relinearize_inplace(encrypted1, relin_keys); - time_end = chrono::high_resolution_clock::now(); - time_relinearize_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Rescale] - */ - time_start = chrono::high_resolution_clock::now(); - evaluator.rescale_to_next_inplace(encrypted1); - time_end = chrono::high_resolution_clock::now(); - time_rescale_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Rotate Vector] - */ - time_start = chrono::high_resolution_clock::now(); - evaluator.rotate_vector_inplace(encrypted, 1, gal_keys); - evaluator.rotate_vector_inplace(encrypted, -1, gal_keys); - time_end = chrono::high_resolution_clock::now(); - time_rotate_one_step_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Rotate Vector Random] - */ - int random_rotation = static_cast(rd() % ckks_encoder.slot_count()); - time_start = chrono::high_resolution_clock::now(); - evaluator.rotate_vector_inplace(encrypted, random_rotation, gal_keys); - time_end = chrono::high_resolution_clock::now(); - time_rotate_random_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - - /* - [Complex Conjugate] - */ - time_start = chrono::high_resolution_clock::now(); - evaluator.complex_conjugate_inplace(encrypted, gal_keys); - time_end = chrono::high_resolution_clock::now(); - time_conjugate_sum += chrono::duration_cast< - chrono::microseconds>(time_end - time_start); - } - - /* - Print a dot to indicate progress. - */ - cout << "."; - cout.flush(); - } - - cout << " Done" << endl << endl; - cout.flush(); - - auto avg_encode = time_encode_sum.count() / count; - auto avg_decode = time_decode_sum.count() / count; - auto avg_encrypt = time_encrypt_sum.count() / count; - auto avg_decrypt = time_decrypt_sum.count() / count; - auto avg_add = time_add_sum.count() / (3 * count); - auto avg_multiply = time_multiply_sum.count() / count; - auto avg_multiply_plain = time_multiply_plain_sum.count() / count; - auto avg_square = time_square_sum.count() / count; - auto avg_relinearize = time_relinearize_sum.count() / count; - auto avg_rescale = time_rescale_sum.count() / count; - auto avg_rotate_one_step = time_rotate_one_step_sum.count() / (2 * count); - auto avg_rotate_random = time_rotate_random_sum.count() / count; - auto avg_conjugate = time_conjugate_sum.count() / count; - - cout << "Average encode: " << avg_encode << " microseconds" << endl; - cout << "Average decode: " << avg_decode << " microseconds" << endl; - cout << "Average encrypt: " << avg_encrypt << " microseconds" << endl; - cout << "Average decrypt: " << avg_decrypt << " microseconds" << endl; - cout << "Average add: " << avg_add << " microseconds" << endl; - cout << "Average multiply: " << avg_multiply << " microseconds" << endl; - cout << "Average multiply plain: " << avg_multiply_plain << " microseconds" << endl; - cout << "Average square: " << avg_square << " microseconds" << endl; - if (context->using_keyswitching()) - { - cout << "Average relinearize: " << avg_relinearize << " microseconds" << endl; - cout << "Average rescale: " << avg_rescale << " microseconds" << endl; - cout << "Average rotate vector one step: " << avg_rotate_one_step << - " microseconds" << endl; - cout << "Average rotate vector random: " << avg_rotate_random << " microseconds" << endl; - cout << "Average complex conjugate: " << avg_conjugate << " microseconds" << endl; - } - cout.flush(); -} - -void example_bfv_performance_default() -{ - print_example_banner("BFV Performance Test with Degrees: 4096, 8192, and 16384"); - - EncryptionParameters parms(scheme_type::BFV); - size_t poly_modulus_degree = 4096; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - parms.set_plain_modulus(786433); - bfv_performance_test(SEALContext::Create(parms)); - - cout << endl; - poly_modulus_degree = 8192; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - parms.set_plain_modulus(786433); - bfv_performance_test(SEALContext::Create(parms)); - - cout << endl; - poly_modulus_degree = 16384; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - parms.set_plain_modulus(786433); - bfv_performance_test(SEALContext::Create(parms)); - - /* - Comment out the following to run the biggest example. - */ - // cout << endl; - // poly_modulus_degree = 32768; - // parms.set_poly_modulus_degree(poly_modulus_degree); - // parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - // parms.set_plain_modulus(786433); - // bfv_performance_test(SEALContext::Create(parms)); -} - -void example_bfv_performance_custom() -{ - size_t poly_modulus_degree = 0; - cout << endl << "Set poly_modulus_degree (1024, 2048, 4096, 8192, 16384, or 32768): "; - if (!(cin >> poly_modulus_degree)) - { - cout << "Invalid option." << endl; - cin.clear(); - cin.ignore(numeric_limits::max(), '\n'); - return; - } - if (poly_modulus_degree < 1024 || poly_modulus_degree > 32768 || - (poly_modulus_degree & (poly_modulus_degree - 1)) != 0) - { - cout << "Invalid option." << endl; - return; - } - - string banner = "BFV Performance Test with Degree: "; - print_example_banner(banner + to_string(poly_modulus_degree)); - - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - if (poly_modulus_degree == 1024) - { - parms.set_plain_modulus(12289); - } - else - { - parms.set_plain_modulus(786433); - } - bfv_performance_test(SEALContext::Create(parms)); -} - -void example_ckks_performance_default() -{ - print_example_banner("CKKS Performance Test with Degrees: 4096, 8192, and 16384"); - - // It is not recommended to use BFVDefault primes in CKKS. However, for performance - // test, BFVDefault primes are good enough. - EncryptionParameters parms(scheme_type::CKKS); - size_t poly_modulus_degree = 4096; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - ckks_performance_test(SEALContext::Create(parms)); - - cout << endl; - poly_modulus_degree = 8192; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - ckks_performance_test(SEALContext::Create(parms)); - - cout << endl; - poly_modulus_degree = 16384; - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - ckks_performance_test(SEALContext::Create(parms)); - - /* - Comment out the following to run the biggest example. - */ - // cout << endl; - // poly_modulus_degree = 32768; - // parms.set_poly_modulus_degree(poly_modulus_degree); - // parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - // ckks_performance_test(SEALContext::Create(parms)); -} - -void example_ckks_performance_custom() -{ - size_t poly_modulus_degree = 0; - cout << endl << "Set poly_modulus_degree (1024, 2048, 4096, 8192, 16384, or 32768): "; - if (!(cin >> poly_modulus_degree)) - { - cout << "Invalid option." << endl; - cin.clear(); - cin.ignore(numeric_limits::max(), '\n'); - return; - } - if (poly_modulus_degree < 1024 || poly_modulus_degree > 32768 || - (poly_modulus_degree & (poly_modulus_degree - 1)) != 0) - { - cout << "Invalid option." << endl; - return; - } - - string banner = "CKKS Performance Test with Degree: "; - print_example_banner(banner + to_string(poly_modulus_degree)); - - EncryptionParameters parms(scheme_type::CKKS); - parms.set_poly_modulus_degree(poly_modulus_degree); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(poly_modulus_degree)); - ckks_performance_test(SEALContext::Create(parms)); -} - -/* -Prints a sub-menu to select the performance test. -*/ -void example_performance_test() -{ - print_example_banner("Example: Performance Test"); - - while (true) - { - cout << endl; - cout << "Select a scheme (and optionally poly_modulus_degree):" << endl; - cout << " 1. BFV with default degrees" << endl; - cout << " 2. BFV with a custom degree" << endl; - cout << " 3. CKKS with default degrees" << endl; - cout << " 4. CKKS with a custom degree" << endl; - cout << " 0. Back to main menu" << endl; - - int selection = 0; - cout << endl << "> Run performance test (1 ~ 4) or go back (0): "; - if (!(cin >> selection)) - { - cout << "Invalid option." << endl; - cin.clear(); - cin.ignore(numeric_limits::max(), '\n'); - continue; - } - - switch (selection) - { - case 1: - example_bfv_performance_default(); - break; - - case 2: - example_bfv_performance_custom(); - break; - - case 3: - example_ckks_performance_default(); - break; - - case 4: - example_ckks_performance_custom(); - break; - - case 0: - cout << endl; - return; - - default: - cout << "Invalid option." << endl; - } - } -} \ No newline at end of file diff --git a/SEAL/native/examples/CMakeLists.txt b/SEAL/native/examples/CMakeLists.txt deleted file mode 100644 index 4e8509e..0000000 --- a/SEAL/native/examples/CMakeLists.txt +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -cmake_minimum_required(VERSION 3.10) - -project(SEALExamples VERSION 3.3.2 LANGUAGES CXX) - -# Executable will be in ../bin -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/../bin) - -add_executable(sealexamples examples.cpp) -target_sources(sealexamples - PRIVATE - 1_bfv_basics.cpp - 2_encoders.cpp - 3_levels.cpp - 4_ckks_basics.cpp - 5_rotation.cpp - 6_performance.cpp -) - -# Import Microsoft SEAL -find_package(SEAL 3.3.2 EXACT REQUIRED) - -# Link Microsoft SEAL -target_link_libraries(sealexamples SEAL::seal) diff --git a/SEAL/native/examples/SEALExamples.vcxproj b/SEAL/native/examples/SEALExamples.vcxproj deleted file mode 100644 index b676cb8..0000000 --- a/SEAL/native/examples/SEALExamples.vcxproj +++ /dev/null @@ -1,123 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {2B57D847-26DC-45FF-B9AF-EE33910B5093} - Win32Proj - SEALExamples - 10.0.16299.0 - - - - Application - true - v141 - Unicode - - - Application - false - v141 - true - Unicode - - - - - - - - - - - - - - - true - $(ProjectDir)..\bin\$(Platform)\$(Configuration)\ - $(ProjectDir)obj\$(Platform)\$(Configuration)\ - sealexamples - - - false - $(ProjectDir)..\bin\$(Platform)\$(Configuration)\ - $(ProjectDir)obj\$(Platform)\$(Configuration)\ - sealexamples - - - - Level3 - NotUsing - Disabled - - - true - $(SolutionDir)native\src - stdcpp17 - /Zc:__cplusplus %(AdditionalOptions) - Guard - ProgramDatabase - true - - - Console - true - $(ProjectDir)..\lib\$(Platform)\$(Configuration);%(AdditionalLibraryDirectories) - seal.lib;%(AdditionalDependencies) - - - - - Level3 - NotUsing - MaxSpeed - true - true - - - true - $(SolutionDir)native\src - stdcpp17 - /Zc:__cplusplus %(AdditionalOptions) - Guard - true - - - Console - true - true - true - $(ProjectDir)..\lib\$(Platform)\$(Configuration);%(AdditionalLibraryDirectories) - seal.lib;%(AdditionalDependencies) - true - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/SEAL/native/examples/SEALExamples.vcxproj.filters b/SEAL/native/examples/SEALExamples.vcxproj.filters deleted file mode 100644 index 78180cd..0000000 --- a/SEAL/native/examples/SEALExamples.vcxproj.filters +++ /dev/null @@ -1,49 +0,0 @@ - - - - - {4FC737F1-C7A5-4376-A066-2A32D752A2FF} - cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx - - - {93995380-89BD-4b04-88EB-625FBE52EBFB} - h;hh;hpp;hxx;hm;inl;inc;xsd - - - {abd2e216-316f-4dad-a2a4-a72ffccfd92b} - - - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - - - Other - - - - - Header Files - - - \ No newline at end of file diff --git a/SEAL/native/examples/examples.cpp b/SEAL/native/examples/examples.cpp deleted file mode 100644 index 4e9ead7..0000000 --- a/SEAL/native/examples/examples.cpp +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "examples.h" - -using namespace std; -using namespace seal; - -int main() -{ -#ifdef SEAL_VERSION - cout << "Microsoft SEAL version: " << SEAL_VERSION << endl; -#endif - while (true) - { - cout << "+---------------------------------------------------------+" << endl; - cout << "| The following examples should be executed while reading |" << endl; - cout << "| comments in associated files in native/examples/. |" << endl; - cout << "+---------------------------------------------------------+" << endl; - cout << "| Examples | Source Files |" << endl; - cout << "+----------------------------+----------------------------+" << endl; - cout << "| 1. BFV Basics | 1_bfv_basics.cpp |" << endl; - cout << "| 2. Encoders | 2_encoders.cpp |" << endl; - cout << "| 3. Levels | 3_levels.cpp |" << endl; - cout << "| 4. CKKS Basics | 4_ckks_basics.cpp |" << endl; - cout << "| 5. Rotation | 5_rotation.cpp |" << endl; - cout << "| 6. Performance Test | 6_performance.cpp |" << endl; - cout << "+----------------------------+----------------------------+" << endl; - - /* - Print how much memory we have allocated from the current memory pool. - By default the memory pool will be a static global pool and the - MemoryManager class can be used to change it. Most users should have - little or no reason to touch the memory allocation system. - */ - size_t megabytes = MemoryManager::GetPool().alloc_byte_count() >> 20; - cout << "[" << setw(7) << right << megabytes << " MB] " - << "Total allocation from the memory pool" << endl; - - int selection = 0; - bool invalid = true; - do - { - cout << endl << "> Run example (1 ~ 6) or exit (0): "; - if (!(cin >> selection)) - { - invalid = false; - } - else if (selection < 0 || selection > 6) - { - invalid = false; - } - else - { - invalid = true; - } - if (!invalid) - { - cout << " [Beep~~] Invalid option: type 0 ~ 6" << endl; - cin.clear(); - cin.ignore(numeric_limits::max(), '\n'); - } - } while (!invalid); - - switch (selection) - { - case 1: - example_bfv_basics(); - break; - - case 2: - example_encoders(); - break; - - case 3: - example_levels(); - break; - - case 4: - example_ckks_basics(); - break; - - case 5: - example_rotation(); - break; - - case 6: - example_performance_test(); - break; - - case 0: - return 0; - } - } - - return 0; -} \ No newline at end of file diff --git a/SEAL/native/examples/examples.h b/SEAL/native/examples/examples.h deleted file mode 100644 index e7aeb89..0000000 --- a/SEAL/native/examples/examples.h +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "seal/seal.h" - -void example_bfv_basics(); - -void example_encoders(); - -void example_levels(); - -void example_ckks_basics(); - -void example_rotation(); - -void example_performance_test(); - -/* -Helper function: Prints the name of the example in a fancy banner. -*/ -inline void print_example_banner(std::string title) -{ - if (!title.empty()) - { - std::size_t title_length = title.length(); - std::size_t banner_length = title_length + 2 * 10; - std::string banner_top = "+" + std::string(banner_length - 2, '-') + "+"; - std::string banner_middle = - "|" + std::string(9, ' ') + title + std::string(9, ' ') + "|"; - - std::cout << std::endl - << banner_top << std::endl - << banner_middle << std::endl - << banner_top << std::endl; - } -} - -/* -Helper function: Prints the parameters in a SEALContext. -*/ -inline void print_parameters(std::shared_ptr context) -{ - // Verify parameters - if (!context) - { - throw std::invalid_argument("context is not set"); - } - auto &context_data = *context->key_context_data(); - - /* - Which scheme are we using? - */ - std::string scheme_name; - switch (context_data.parms().scheme()) - { - case seal::scheme_type::BFV: - scheme_name = "BFV"; - break; - case seal::scheme_type::CKKS: - scheme_name = "CKKS"; - break; - default: - throw std::invalid_argument("unsupported scheme"); - } - std::cout << "/" << std::endl; - std::cout << "| Encryption parameters :" << std::endl; - std::cout << "| scheme: " << scheme_name << std::endl; - std::cout << "| poly_modulus_degree: " << - context_data.parms().poly_modulus_degree() << std::endl; - - /* - Print the size of the true (product) coefficient modulus. - */ - std::cout << "| coeff_modulus size: "; - std::cout << context_data.total_coeff_modulus_bit_count() << " ("; - auto coeff_modulus = context_data.parms().coeff_modulus(); - std::size_t coeff_mod_count = coeff_modulus.size(); - for (std::size_t i = 0; i < coeff_mod_count - 1; i++) - { - std::cout << coeff_modulus[i].bit_count() << " + "; - } - std::cout << coeff_modulus.back().bit_count(); - std::cout << ") bits" << std::endl; - - /* - For the BFV scheme print the plain_modulus parameter. - */ - if (context_data.parms().scheme() == seal::scheme_type::BFV) - { - std::cout << "| plain_modulus: " << context_data. - parms().plain_modulus().value() << std::endl; - } - - std::cout << "\\" << std::endl; -} - -/* -Helper function: Prints the `parms_id' to std::ostream. -*/ -inline std::ostream &operator <<(std::ostream &stream, seal::parms_id_type parms_id) -{ - /* - Save the formatting information for std::cout. - */ - std::ios old_fmt(nullptr); - old_fmt.copyfmt(std::cout); - - stream << std::hex << std::setfill('0') - << std::setw(16) << parms_id[0] << " " - << std::setw(16) << parms_id[1] << " " - << std::setw(16) << parms_id[2] << " " - << std::setw(16) << parms_id[3] << " "; - - /* - Restore the old std::cout formatting. - */ - std::cout.copyfmt(old_fmt); - - return stream; -} - -/* -Helper function: Prints a vector of floating-point values. -*/ -template -inline void print_vector(std::vector vec, std::size_t print_size = 4, int prec = 3) -{ - /* - Save the formatting information for std::cout. - */ - std::ios old_fmt(nullptr); - old_fmt.copyfmt(std::cout); - - std::size_t slot_count = vec.size(); - - std::cout << std::fixed << std::setprecision(prec); - std::cout << std::endl; - if(slot_count <= 2 * print_size) - { - std::cout << " ["; - for (std::size_t i = 0; i < slot_count; i++) - { - std::cout << " " << vec[i] << ((i != slot_count - 1) ? "," : " ]\n"); - } - } - else - { - vec.resize(std::max(vec.size(), 2 * print_size)); - std::cout << " ["; - for (std::size_t i = 0; i < print_size; i++) - { - std::cout << " " << vec[i] << ","; - } - if(vec.size() > 2 * print_size) - { - std::cout << " ...,"; - } - for (std::size_t i = slot_count - print_size; i < slot_count; i++) - { - std::cout << " " << vec[i] << ((i != slot_count - 1) ? "," : " ]\n"); - } - } - std::cout << std::endl; - - /* - Restore the old std::cout formatting. - */ - std::cout.copyfmt(old_fmt); -} - - -/* -Helper function: Prints a matrix of values. -*/ -template -inline void print_matrix(std::vector matrix, std::size_t row_size) -{ - /* - We're not going to print every column of the matrix (there are 2048). Instead - print this many slots from beginning and end of the matrix. - */ - std::size_t print_size = 5; - - std::cout << std::endl; - std::cout << " ["; - for (std::size_t i = 0; i < print_size; i++) - { - std::cout << std::setw(3) << std::right << matrix[i] << ","; - } - std::cout << std::setw(3) << " ...,"; - for (std::size_t i = row_size - print_size; i < row_size; i++) - { - std::cout << std::setw(3) << matrix[i] - << ((i != row_size - 1) ? "," : " ]\n"); - } - std::cout << " ["; - for (std::size_t i = row_size; i < row_size + print_size; i++) - { - std::cout << std::setw(3) << matrix[i] << ","; - } - std::cout << std::setw(3) << " ...,"; - for (std::size_t i = 2 * row_size - print_size; i < 2 * row_size; i++) - { - std::cout << std::setw(3) << matrix[i] - << ((i != 2 * row_size - 1) ? "," : " ]\n"); - } - std::cout << std::endl; -}; - -/* -Helper function: Print line number. -*/ -inline void print_line(int line_number) -{ - std::cout << "Line " << std::setw(3) << line_number << " --> "; -} \ No newline at end of file diff --git a/SEAL/native/src/CMakeConfig.cmd b/SEAL/native/src/CMakeConfig.cmd deleted file mode 100644 index b3e1b64..0000000 --- a/SEAL/native/src/CMakeConfig.cmd +++ /dev/null @@ -1,68 +0,0 @@ -@echo off - -rem Copyright (c) Microsoft Corporation. All rights reserved. -rem Licensed under the MIT license. - -setlocal - -rem The purpose of this script is to have CMake generate config.h for use by Microsoft SEAL. -rem We assume that CMake was installed with Visual Studio, which should be the default -rem when the user installs the "Desktop Development with C++" workload. - -set VSVERSION=%~1 -set PROJECTCONFIGURATION=%~2 -set VSDEVENVDIR=%~3 -set INCLUDEPATH=%~4 - -echo Configuring Microsoft SEAL through CMake - -if not exist "%VSDEVENVDIR%" ( - rem We may be running in the CI server. Try a standard VS path. - echo Did not find VS at provided location: "%VSDEVENVDIR%". - echo Trying standard location. - set VSDEVENVDIR="C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\Common7\IDE" -) - -set VSDEVENVDIR=%VSDEVENVDIR:"=% -set CMAKEPATH=%VSDEVENVDIR%\CommonExtensions\Microsoft\CMake\CMake\bin\cmake.exe - -if not exist "%CMAKEPATH%" ( - echo ****************************************************************************************************************** - echo ** Did not find CMake at "%CMAKEPATH%" - echo ** Please make sure "Visual C++ Tools for CMake" are enabled in the "Desktop development with C++" workload. - echo ****************************************************************************************************************** - exit 1 -) - -echo Found CMake at %CMAKEPATH% - -rem Identify Visual Studio version and set CMake generator accordingly. -set CMAKEGEN="" -if "%VSVERSION%"=="15.0" ( - set CMAKEGEN="Visual Studio 15 2017" -) else if "%VSVERSION%"=="16.0" ( - set CMAKEGEN="Visual Studio 16 2019" -) else ( - echo *************************************************** - echo ** Unsupported Visual Studio version "%VSVERSION%" - echo *************************************************** - exit 1 -) - -set CONFIGDIR=".config\%VSVERSION%" -cd %~dp0 -if not exist %CONFIGDIR% ( - mkdir %CONFIGDIR% -) -cd %CONFIGDIR% -echo Running CMake configuration in %cd% - -rem Call CMake. -"%CMAKEPATH%" ..\.. ^ - -G %CMAKEGEN% ^ - -A x64 ^ - -DALLOW_COMMAND_LINE_BUILD=1 ^ - -DCMAKE_BUILD_TYPE="%PROJECTCONFIGURATION%" ^ - -DSEAL_LIB_BUILD_TYPE="Static_PIC" ^ - -DMSGSL_INCLUDE_DIR="%INCLUDEPATH%" ^ - --no-warn-unused-cli diff --git a/SEAL/native/src/CMakeLists.txt b/SEAL/native/src/CMakeLists.txt deleted file mode 100644 index 7052163..0000000 --- a/SEAL/native/src/CMakeLists.txt +++ /dev/null @@ -1,421 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -cmake_minimum_required(VERSION 3.10) - -project(SEAL VERSION 3.3.2 LANGUAGES CXX C) - -if(DEFINED MSVC) - if(DEFINED ALLOW_COMMAND_LINE_BUILD) - message(STATUS "Configuring for Visual Studio") - else() - message(FATAL_ERROR "Please build Microsoft SEAL using the attached Visual Studio solution/project files") - endif() -endif() - -# Build in Release mode by default; otherwise use selected option -set(SEAL_DEFAULT_BUILD_TYPE "Release") -if(NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE ${SEAL_DEFAULT_BUILD_TYPE} CACHE - STRING "Build type" FORCE) - set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS - "Release" "Debug" "MinSizeRel" "RelWithDebInfo") -endif() -message(STATUS "Build type (CMAKE_BUILD_TYPE): ${CMAKE_BUILD_TYPE}") - -# In Debug mode enable also SEAL_DEBUG by default -if(CMAKE_BUILD_TYPE STREQUAL "Debug") - set(SEAL_DEBUG_DEFAULT ON) -else() - set(SEAL_DEBUG_DEFAULT OFF) -endif() - -# Required files and directories -set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${SEAL_SOURCE_DIR}/../lib) -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${SEAL_SOURCE_DIR}/../lib) -set(CMAKE_LIBRARY_RUNTIME_DIRECTORY ${SEAL_SOURCE_DIR}/../bin) -set(SEAL_INCLUDES_INSTALL_DIR include) -set(SEAL_CONFIG_IN_FILENAME ${SEAL_SOURCE_DIR}/cmake/SEALConfig.cmake.in) -set(SEAL_CONFIG_FILENAME ${SEAL_SOURCE_DIR}/cmake/SEALConfig.cmake) -set(SEAL_TARGETS_FILENAME ${SEAL_SOURCE_DIR}/cmake/SEALTargets.cmake) -set(SEAL_CONFIG_VERSION_FILENAME ${SEAL_SOURCE_DIR}/cmake/SEALConfigVersion.cmake) -set(SEAL_CONFIG_INSTALL_DIR lib/cmake/SEAL) - -# For extra modules we might have -list(APPEND CMAKE_MODULE_PATH ${SEAL_SOURCE_DIR}/cmake) - -include(CMakePushCheckState) -include(CMakeDependentOption) -include(CheckIncludeFiles) -include(CheckCXXSourceRuns) -include(CheckTypeSize) - -# For easier adding of CXX compiler flags -include(CheckCXXCompilerFlag) - -function(enable_cxx_compiler_flag_if_supported flag) - string(FIND "${CMAKE_CXX_FLAGS}" "${flag}" flag_already_set) - if(flag_already_set EQUAL -1) - message(STATUS "Adding CXX compiler flag: ${flag} ...") - check_cxx_compiler_flag("${flag}" flag_supported) - if(flag_supported) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${flag}" PARENT_SCOPE) - endif() - unset(flag_supported CACHE) - endif() -endfunction() - -# enable_cxx_compiler_flag_if_supported("-Wall") -# enable_cxx_compiler_flag_if_supported("-Wextra") -# enable_cxx_compiler_flag_if_supported("-Wconversion") -# enable_cxx_compiler_flag_if_supported("-Wshadow") -# enable_cxx_compiler_flag_if_supported("-pedantic") - -# Are we using SEAL_DEBUG? -set(SEAL_DEBUG ${SEAL_DEBUG_DEFAULT}) -message(STATUS "Microsoft SEAL debug mode: ${SEAL_DEBUG}") - -# Should we use C++14 or C++17? -set(SEAL_USE_CXX17_OPTION_STR "Use C++17") -option(SEAL_USE_CXX17 ${SEAL_USE_CXX17_OPTION_STR} ON) - -# Conditionally enable features from C++17 -set(SEAL_USE_STD_BYTE OFF) -set(SEAL_USE_SHARED_MUTEX OFF) -set(SEAL_USE_IF_CONSTEXPR OFF) -set(SEAL_USE_MAYBE_UNUSED OFF) -set(SEAL_USE_NODISCARD OFF) -set(SEAL_LANG_FLAG "-std=c++14") -if(SEAL_USE_CXX17) - set(SEAL_USE_STD_BYTE ON) - set(SEAL_USE_SHARED_MUTEX ON) - set(SEAL_USE_IF_CONSTEXPR ON) - set(SEAL_USE_MAYBE_UNUSED ON) - set(SEAL_USE_NODISCARD ON) - set(SEAL_LANG_FLAG "-std=c++17") -endif() - -# Should we build a shared library? -set(SEAL_DEFAULT_LIB_BUILD_TYPE "Static_PIC") -if(NOT SEAL_LIB_BUILD_TYPE) - set(SEAL_LIB_BUILD_TYPE ${SEAL_DEFAULT_LIB_BUILD_TYPE} CACHE - STRING "Library build type" FORCE) - set_property(CACHE SEAL_LIB_BUILD_TYPE PROPERTY STRINGS - "Static" "Static_PIC" "Shared") -endif() -message(STATUS "Library build type (SEAL_LIB_BUILD_TYPE): ${SEAL_LIB_BUILD_TYPE}") - -# Optionally enable CMAKE_POSITION_INDEPENDENT_CODE -if(SEAL_LIB_BUILD_TYPE STREQUAL "Static") - set(CMAKE_POSITION_INDEPENDENT_CODE OFF) -else() - set(CMAKE_POSITION_INDEPENDENT_CODE ON) -endif() - -# Throw on multiply_plain by a zero plaintext -set(SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT_STR "Throw an exception when a member of Evaluator outputs a transparent ciphertext") -option(SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT ${SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT_STR} ON) -mark_as_advanced(FORCE SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT) - -# Use intrinsics if available -set(SEAL_USE_INTRIN_OPTION_STR "Use intrinsics") -option(SEAL_USE_INTRIN ${SEAL_USE_INTRIN_OPTION_STR} ON) - -# Use Microsoft GSL if available -set(SEAL_USE_MSGSL_OPTION_STR "Use Microsoft GSL") -option(SEAL_USE_MSGSL ${SEAL_USE_MSGSL_OPTION_STR} ON) - -# Check for intrin.h or x64intrin.h -if(SEAL_USE_INTRIN) - if(DEFINED MSVC) - set(SEAL_INTRIN_HEADER "intrin.h") - else() - set(SEAL_INTRIN_HEADER "x86intrin.h") - endif() - - check_include_file_cxx(${SEAL_INTRIN_HEADER} HAVE_INTRIN_HEADER) - - if(NOT HAVE_INTRIN_HEADER) - set(SEAL_USE_INTRIN OFF CACHE BOOL ${SEAL_USE_INTRIN_OPTION_STR} FORCE) - endif() -endif() - -# Specific intrinsics depending on SEAL_USE_INTRIN -if(DEFINED MSVC) - set(SEAL_USE__UMUL128_OPTION_STR "Use _umul128") - cmake_dependent_option(SEAL_USE__UMUL128 SEAL_USE__UMUL128_OPTION_STR ON "SEAL_USE_INTRIN" OFF) - - set(SEAL_USE__BITSCANREVERSE64_OPTION_STR "Use _BitScanReverse64") - cmake_dependent_option(SEAL_USE__BITSCANREVERSE64 SEAL_USE__BITSCANREVERSE64_OPTION_STR ON "SEAL_USE_INTRIN" OFF) -else() - set(SEAL_USE___INT128_OPTION_STR "Use __int128") - cmake_dependent_option(SEAL_USE___INT128 SEAL_USE___INT128_OPTION_STR ON "SEAL_USE_INTRIN" OFF) - - set(SEAL_USE___BUILTIN_CLZLL_OPTION_STR "Use __builtin_clzll") - cmake_dependent_option(SEAL_USE___BUILTIN_CLZLL SEAL_USE___BUILTIN_CLZLL_OPTION_STR ON "SEAL_USE_INTRIN" OFF) -endif() - -set(SEAL_USE__ADDCARRY_U64_OPTION_STR "Use _addcarry_u64") -cmake_dependent_option(SEAL_USE__ADDCARRY_U64 SEAL_USE__ADDCARRY_U64_OPTION_STR ON "SEAL_USE_INTRIN" OFF) - -set(SEAL_USE__SUBBORROW_U64_OPTION_STR "Use _subborrow_u64") -cmake_dependent_option(SEAL_USE__SUBBORROW_U64 SEAL_USE__SUBBORROW_U64_OPTION_STR ON "SEAL_USE_INTRIN" OFF) - -set(SEAL_USE_AES_NI_PRNG_OPTION_STR "Use fast AES-NI PRNG") -cmake_dependent_option(SEAL_USE_AES_NI_PRNG SEAL_USE_AES_NI_PRNG_OPTION_STR ON "SEAL_USE_INTRIN" OFF) - -if(SEAL_USE_INTRIN) - cmake_push_check_state(RESET) - set(CMAKE_REQUIRED_QUIET TRUE) - if(NOT DEFINED MSVC) - set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} -O0 ${SEAL_LANG_FLAG}") - endif() - - if(DEFINED MSVC) - # Check for presence of _umul128 - if(SEAL_USE__UMUL128) - check_cxx_source_runs(" - #include <${SEAL_INTRIN_HEADER}> - int main() { - unsigned long long a = 0, b = 0; - unsigned long long c; - volatile unsigned long long d; - d = _umul128(a, b, &c); - return 0; - }" - USE_UMUL128 - ) - if(NOT USE_UMUL128 EQUAL 1) - set(SEAL_USE__UMUL128 OFF CACHE BOOL ${SEAL_USE__UMUL128_OPTION_STR} FORCE) - endif() - endif() - - # Check for _BitScanReverse64 - if(SEAL_USE__BITSCANREVERSE64) - check_cxx_source_runs(" - #include <${SEAL_INTRIN_HEADER}> - int main() { - unsigned long a = 0, b = 0; - volatile unsigned char res = _BitScanReverse64(&a, b); - return 0; - }" - USE_BITSCANREVERSE64 - ) - if(NOT USE_BITSCANREVERSE64 EQUAL 1) - set(SEAL_USE__BITSCANREVERSE64 OFF CACHE BOOL ${SEAL_USE__BITSCANREVERSE64_OPTION_STR} FORCE) - endif() - endif() - else() - # Check for presence of ___int128 - if(SEAL_USE___INT128) - check_type_size("__int128" INT128 LANGUAGE CXX) - if(NOT INT128 EQUAL 16) - set(SEAL_USE___INT128 OFF CACHE BOOL ${SEAL_USE___INT128_OPTION_STR} FORCE) - endif() - endif() - - # Check for __builtin_clzll - if(SEAL_USE___BUILTIN_CLZLL) - check_cxx_source_runs(" - int main() { - volatile auto res = __builtin_clzll(0); - return 0; - }" - USE_BUILTIN_CLZLL - ) - if(NOT USE_BUILTIN_CLZLL EQUAL 1) - set(SEAL_USE___BUILTIN_CLZLL OFF CACHE BOOL ${SEAL_USE___BUILTIN_CLZLL_OPTION_STR} FORCE) - endif() - endif() - endif() - - # Check for _addcarry_u64 - if(SEAL_USE__ADDCARRY_U64) - check_cxx_source_runs(" - #include <${SEAL_INTRIN_HEADER}> - int main() { - unsigned long long a; - volatile auto res = _addcarry_u64(0,0,0,&a); - return 0; - }" - USE_ADDCARRY_U64 - ) - if(NOT USE_ADDCARRY_U64 EQUAL 1) - set(SEAL_USE__ADDCARRY_U64 OFF CACHE BOOL ${SEAL_USE__ADDCARRY_U64_OPTION_STR} FORCE) - endif() - endif() - - # Check for _subborrow_u64 - if(SEAL_USE__SUBBORROW_U64) - check_cxx_source_runs(" - #include <${SEAL_INTRIN_HEADER}> - int main() { - unsigned long long a; - volatile auto res = _subborrow_u64(0,0,0,&a); - return 0; - }" - USE_SUBBORROW_U64 - ) - if(NOT USE_SUBBORROW_U64 EQUAL 1) - set(SEAL_USE__SUBBORROW_U64 OFF CACHE BOOL ${SEAL_USE__SUBBORROW_U64_OPTION_STR} FORCE) - endif() - endif() - - check_include_file_cxx("wmmintrin.h" HAVE_WMMINTRIN_HEADER) - if(NOT HAVE_WMMINTRIN_HEADER) - set(SEAL_USE_AES_NI_PRNG OFF CACHE BOOL ${SEAL_USE_AES_NI_PRNG_OPTION_STR} FORCE) - endif() - - # Check that AES-NI runs - if(SEAL_USE_AES_NI_PRNG) - if(NOT DEFINED MSVC) - set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} -maes") - endif() - check_cxx_source_runs(" - #include - int main() { - __m128i a{ 0 }; - volatile auto b = _mm_aeskeygenassist_si128(a, 0); - return 0; - }" - USE_AES_KEYGEN_ASSIST - ) - if(NOT USE_AES_KEYGEN_ASSIST EQUAL 1) - set(SEAL_USE_AES_NI_PRNG OFF CACHE BOOL ${SEAL_USE_AES_NI_PRNG_OPTION_STR} FORCE) - endif() - endif() - - cmake_pop_check_state() -endif() - -# Try to find MSGSL if requested -if(SEAL_USE_MSGSL) - find_package(msgsl MODULE) - if(NOT msgsl_FOUND) - set(SEAL_USE_MSGSL OFF CACHE BOOL ${SEAL_USE_MSGSL_OPTION_STR} FORCE) - endif() -endif() - -# Specific options depending on SEAL_USE_MSGSL -set(SEAL_USE_MSGSL_SPAN_OPTION_STR "Use gsl::span") -cmake_dependent_option(SEAL_USE_MSGSL_SPAN ${SEAL_USE_MSGSL_SPAN_OPTION_STR} ON "SEAL_USE_MSGSL" OFF) - -set(SEAL_USE_MSGSL_MULTISPAN_OPTION_STR "Use gsl::multi_span") -cmake_dependent_option(SEAL_USE_MSGSL_MULTISPAN ${SEAL_USE_MSGSL_MULTISPAN_OPTION_STR} ON "SEAL_USE_MSGSL" OFF) - -if(SEAL_USE_MSGSL) - # Now check for individual classes - cmake_push_check_state(RESET) - set(CMAKE_REQUIRED_INCLUDES ${MSGSL_INCLUDE_DIR}) - set(CMAKE_EXTRA_INCLUDE_FILES gsl/gsl) - set(CMAKE_REQUIRED_FLAGS "${CMAKE_REQUIRED_FLAGS} -O0 ${SEAL_LANG_FLAG}") - set(CMAKE_REQUIRED_QUIET TRUE) - - # Detect gsl::span - if(SEAL_USE_MSGSL_SPAN) - check_type_size("gsl::span" MSGSL_SPAN LANGUAGE CXX) - if(NOT MSGSL_SPAN GREATER 0) - set(SEAL_USE_MSGSL_SPAN OFF CACHE BOOL ${SEAL_USE_MSGSL_SPAN_OPTION_STR} FORCE) - endif() - endif() - - # Detect gsl::multi_span - if(SEAL_USE_MSGSL_MULTISPAN) - check_type_size("gsl::multi_span" MSGSL_MULTISPAN LANGUAGE CXX) - if(NOT MSGSL_MULTISPAN GREATER 0) - set(SEAL_USE_MSGSL_MULTISPAN OFF CACHE BOOL ${SEAL_USE_MSGSL_MULTISPAN_OPTION_STR} FORCE) - endif() - endif() - - cmake_pop_check_state() -endif() - -# Create library but add no source files yet -if(SEAL_LIB_BUILD_TYPE STREQUAL "Shared") - add_library(seal SHARED "") - - # Set SOVERSION for shared library - set_target_properties(seal PROPERTIES - SOVERSION ${SEAL_VERSION_MAJOR}.${SEAL_VERSION_MINOR}) -else() - add_library(seal STATIC "") -endif() - -# Set VERSION for all library build types -set_target_properties(seal PROPERTIES VERSION ${SEAL_VERSION}) - -# Add source files to library and header files to install -add_subdirectory(seal) - -# Add local include directories for build -target_include_directories(seal PUBLIC $) - -# We require at least C++14 -if(SEAL_USE_CXX17) - target_compile_features(seal PUBLIC cxx_std_17) -else() - target_compile_features(seal PUBLIC cxx_std_14) -endif() - -# Add -maes flag if needed -if(SEAL_USE_AES_NI_PRNG) - target_compile_options(seal PUBLIC "-maes") -endif() - -# Require thread library -set(CMAKE_THREAD_PREFER_PTHREAD TRUE) -set(THREADS_PREFER_PTHREAD_FLAG TRUE) -find_package(Threads REQUIRED) - -# Link Threads with seal -target_link_libraries(seal PUBLIC Threads::Threads) - -# Create msgsl interface target -if(SEAL_USE_MSGSL) - # Create interface target - add_library(msgsl INTERFACE) - set_target_properties(msgsl PROPERTIES - INTERFACE_INCLUDE_DIRECTORIES ${MSGSL_INCLUDE_DIR}) - - # Associate msgsl with export seal_export - install(TARGETS msgsl EXPORT seal_export) - - # Link with seal - target_link_libraries(seal PUBLIC msgsl) -endif() - -# Associate seal to export seal_export -install(TARGETS seal EXPORT seal_export - ARCHIVE DESTINATION lib - LIBRARY DESTINATION lib - RUNTIME DESTINATION bin - INCLUDES DESTINATION ${SEAL_INCLUDES_INSTALL_DIR}) - -# Export the targets -export(EXPORT seal_export - FILE ${SEAL_TARGETS_FILENAME} - NAMESPACE SEAL::) - -# Create the CMake config file -configure_file(${SEAL_CONFIG_IN_FILENAME} ${SEAL_CONFIG_FILENAME} @ONLY) - -# Install the export -install( - EXPORT seal_export - FILE SEALTargets.cmake - NAMESPACE SEAL:: - DESTINATION ${SEAL_CONFIG_INSTALL_DIR}) - -# Version file; we require exact version match for downstream -include(CMakePackageConfigHelpers) -write_basic_package_version_file( - ${SEAL_CONFIG_VERSION_FILENAME} - VERSION ${SEAL_VERSION} - COMPATIBILITY ExactVersion) - -# Install other files -install( - FILES - ${SEAL_CONFIG_FILENAME} - ${SEAL_CONFIG_VERSION_FILENAME} - DESTINATION ${SEAL_CONFIG_INSTALL_DIR}) diff --git a/SEAL/native/src/SEAL.vcxproj b/SEAL/native/src/SEAL.vcxproj deleted file mode 100644 index 7a3295f..0000000 --- a/SEAL/native/src/SEAL.vcxproj +++ /dev/null @@ -1,210 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - {7EA96C25-FC0D-485A-BB71-32B6DA55652A} - SEAL - 10.0.17763.0 - - - - StaticLibrary - true - v141 - Unicode - false - - - StaticLibrary - false - v141 - true - Unicode - false - - - - - - - - - - - - - - - $(ProjectDir)..\lib\$(Platform)\$(Configuration)\ - $(ProjectDir)obj\$(Platform)\$(Configuration)\ - .lib - seal - - - $(ProjectDir)..\lib\$(Platform)\$(Configuration)\ - $(ProjectDir)obj\$(Platform)\$(Configuration)\ - .lib - seal - - - - Level3 - Disabled - true - $(ProjectDir) - true - Neither - stdcpp17 - %(PreprocessorDefinitions); _ENABLE_EXTENDED_ALIGNED_STORAGE - /Zc:__cplusplus %(AdditionalOptions) - true - - - true - - - "$(ProjectDir)CMakeConfig.cmd" "$(VisualStudioVersion)" "$(Configuration)" "$(DevEnvDir)" "$(IncludePath)" - - - Configure Microsoft SEAL through CMake - - - - - Level3 - MaxSpeed - true - true - true - $(ProjectDir) - Speed - Default - stdcpp17 - %(PreprocessorDefinitions); _ENABLE_EXTENDED_ALIGNED_STORAGE - /Zc:__cplusplus %(AdditionalOptions) - true - - - true - true - true - - - "$(ProjectDir)CMakeConfig.cmd" "$(VisualStudioVersion)" "$(Configuration)" "$(DevEnvDir)" "$(IncludePath)" - - - Configure Microsoft SEAL through CMake - - - - - - \ No newline at end of file diff --git a/SEAL/native/src/SEAL.vcxproj.filters b/SEAL/native/src/SEAL.vcxproj.filters deleted file mode 100644 index 47e9427..0000000 --- a/SEAL/native/src/SEAL.vcxproj.filters +++ /dev/null @@ -1,319 +0,0 @@ - - - - - {4FC737F1-C7A5-4376-A066-2A32D752A2FF} - cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx - - - {93995380-89BD-4b04-88EB-625FBE52EBFB} - h;hh;hpp;hxx;hm;inl;inc;xsd - - - {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} - rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms - - - {a119ce23-aae9-4b06-be2c-1c8aada4ab20} - - - {8740bd83-253c-49f3-8f9a-3b9c526f67c2} - - - {8585bc5e-eaa9-481a-a6ee-c38be1319a32} - - - {aaf838b1-cab2-4ccc-a016-a81c7edf506e} - - - {31fb1149-1a6f-438b-a86a-744384986d21} - - - {497d5f96-98a3-44e9-8b38-a2ea4bbea366} - - - {87bf64a4-84f1-44c7-af13-5fce86d49abc} - - - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files\util - - - Header Files - - - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files - - - - - Other - - - Other\seal - - - Other\seal\util - - - - - Other\cmake - - - Other\seal\util - - - Other\cmake - - - Other - - - \ No newline at end of file diff --git a/SEAL/native/src/cmake/Findmsgsl.cmake b/SEAL/native/src/cmake/Findmsgsl.cmake deleted file mode 100644 index 65f4153..0000000 --- a/SEAL/native/src/cmake/Findmsgsl.cmake +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -# Simple attempt to locate Microsoft GSL -set(CURRENT_MSGSL_INCLUDE_DIR ${MSGSL_INCLUDE_DIR}) -unset(MSGSL_INCLUDE_DIR CACHE) -find_path(MSGSL_INCLUDE_DIR - NAMES gsl/gsl gsl/span gsl/multi_span - HINTS ${CMAKE_INCLUDE_PATH} ${CURRENT_MSGSL_INCLUDE_DIR}) - -# Determine whether found based on MSGSL_INCLUDE_DIR -find_package(PackageHandleStandardArgs) -find_package_handle_standard_args(msgsl - REQUIRED_VARS MSGSL_INCLUDE_DIR) diff --git a/SEAL/native/src/cmake/SEALConfig.cmake.in b/SEAL/native/src/cmake/SEALConfig.cmake.in deleted file mode 100644 index d9bb25f..0000000 --- a/SEAL/native/src/cmake/SEALConfig.cmake.in +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -# Exports target SEAL::seal -# -# Creates variables: -# SEAL_BUILD_TYPE : The build configuration used -# SEAL_DEBUG : Set to non-zero value if library is compiled with extra debugging code (very slow!) -# SEAL_LIB_BUILD_TYPE : Set to either "Static", "Static_PIC", or "Shared" depending on library build type -# SEAL_USE_CXX17 : Set to non-zero value if library is compiled as C++17 instead of C++14 -# SEAL_ENFORCE_HE_STD_SECURITY : Set to non-zero value if library is compiled to enforce at least -# a 128-bit security level based on HomomorphicEncryption.org security estimates -# SEAL_USE_MSGSL : Set to non-zero value if library is compiled with Microsoft GSL support -# MSGSL_INCLUDE_DIR : Holds the path to Microsoft GSL if library is compiled with Microsoft GSL support - -include(CMakeFindDependencyMacro) - -set(SEAL_BUILD_TYPE @CMAKE_BUILD_TYPE@) -set(SEAL_DEBUG @SEAL_DEBUG@) -set(SEAL_LIB_BUILD_TYPE @SEAL_LIB_BUILD_TYPE@) -set(SEAL_USE_CXX17 @SEAL_USE_CXX17@) -set(SEAL_ENFORCE_HE_STD_SECURITY @SEAL_ENFORCE_HE_STD_SECURITY@) -set(SEAL_USE_MSGSL @SEAL_USE_MSGSL@) -if(SEAL_USE_MSGSL) - set(MSGSL_INCLUDE_DIR @MSGSL_INCLUDE_DIR@) -endif() - -set(CMAKE_THREAD_PREFER_PTHREAD TRUE) -set(THREADS_PREFER_PTHREAD_FLAG TRUE) -find_dependency(Threads REQUIRED) - -include(${CMAKE_CURRENT_LIST_DIR}/SEALTargets.cmake) - -if(NOT SEAL_FIND_QUIETLY) - message(STATUS "Microsoft SEAL -> Version ${SEAL_VERSION} detected") - if(SEAL_DEBUG) - message(STATUS "Performance warning: Microsoft SEAL compiled in debug mode") - endif() -message(STATUS "Microsoft SEAL -> Library build type: ${SEAL_LIB_BUILD_TYPE}") -endif() diff --git a/SEAL/native/src/seal/CMakeLists.txt b/SEAL/native/src/seal/CMakeLists.txt deleted file mode 100644 index 0a7d14e..0000000 --- a/SEAL/native/src/seal/CMakeLists.txt +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -target_sources(seal - PRIVATE - ${CMAKE_CURRENT_LIST_DIR}/batchencoder.cpp - ${CMAKE_CURRENT_LIST_DIR}/biguint.cpp - ${CMAKE_CURRENT_LIST_DIR}/ciphertext.cpp - ${CMAKE_CURRENT_LIST_DIR}/ckks.cpp - ${CMAKE_CURRENT_LIST_DIR}/context.cpp - ${CMAKE_CURRENT_LIST_DIR}/decryptor.cpp - ${CMAKE_CURRENT_LIST_DIR}/intencoder.cpp - ${CMAKE_CURRENT_LIST_DIR}/encryptionparams.cpp - ${CMAKE_CURRENT_LIST_DIR}/encryptor.cpp - ${CMAKE_CURRENT_LIST_DIR}/evaluator.cpp - ${CMAKE_CURRENT_LIST_DIR}/keygenerator.cpp - ${CMAKE_CURRENT_LIST_DIR}/kswitchkeys.cpp - ${CMAKE_CURRENT_LIST_DIR}/memorymanager.cpp - ${CMAKE_CURRENT_LIST_DIR}/modulus.cpp - ${CMAKE_CURRENT_LIST_DIR}/plaintext.cpp - ${CMAKE_CURRENT_LIST_DIR}/randomgen.cpp - ${CMAKE_CURRENT_LIST_DIR}/smallmodulus.cpp - ${CMAKE_CURRENT_LIST_DIR}/valcheck.cpp -) - -install( - FILES - ${CMAKE_CURRENT_LIST_DIR}/batchencoder.h - ${CMAKE_CURRENT_LIST_DIR}/biguint.h - ${CMAKE_CURRENT_LIST_DIR}/ciphertext.h - ${CMAKE_CURRENT_LIST_DIR}/ckks.h - ${CMAKE_CURRENT_LIST_DIR}/modulus.h - ${CMAKE_CURRENT_LIST_DIR}/context.h - ${CMAKE_CURRENT_LIST_DIR}/decryptor.h - ${CMAKE_CURRENT_LIST_DIR}/intencoder.h - ${CMAKE_CURRENT_LIST_DIR}/encryptionparams.h - ${CMAKE_CURRENT_LIST_DIR}/encryptor.h - ${CMAKE_CURRENT_LIST_DIR}/evaluator.h - ${CMAKE_CURRENT_LIST_DIR}/galoiskeys.h - ${CMAKE_CURRENT_LIST_DIR}/intarray.h - ${CMAKE_CURRENT_LIST_DIR}/keygenerator.h - ${CMAKE_CURRENT_LIST_DIR}/kswitchkeys.h - ${CMAKE_CURRENT_LIST_DIR}/memorymanager.h - ${CMAKE_CURRENT_LIST_DIR}/modulus.h - ${CMAKE_CURRENT_LIST_DIR}/plaintext.h - ${CMAKE_CURRENT_LIST_DIR}/publickey.h - ${CMAKE_CURRENT_LIST_DIR}/randomgen.h - ${CMAKE_CURRENT_LIST_DIR}/randomtostd.h - ${CMAKE_CURRENT_LIST_DIR}/relinkeys.h - ${CMAKE_CURRENT_LIST_DIR}/seal.h - ${CMAKE_CURRENT_LIST_DIR}/secretkey.h - ${CMAKE_CURRENT_LIST_DIR}/smallmodulus.h - ${CMAKE_CURRENT_LIST_DIR}/valcheck.h - DESTINATION - ${SEAL_INCLUDES_INSTALL_DIR}/seal -) - -add_subdirectory(util) \ No newline at end of file diff --git a/SEAL/native/src/seal/batchencoder.cpp b/SEAL/native/src/seal/batchencoder.cpp deleted file mode 100644 index 1afb7b3..0000000 --- a/SEAL/native/src/seal/batchencoder.cpp +++ /dev/null @@ -1,564 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include -#include -#include -#include "seal/batchencoder.h" -#include "seal/util/polycore.h" -#include "seal/valcheck.h" - -using namespace std; -using namespace seal::util; - -namespace seal -{ - BatchEncoder::BatchEncoder(std::shared_ptr context) : - context_(std::move(context)) - { - // Verify parameters - if (!context_) - { - throw invalid_argument("invalid context"); - } - if (!context_->parameters_set()) - { - throw invalid_argument("encryption parameters are not set correctly"); - } - - auto &context_data = *context_->first_context_data(); - if (context_data.parms().scheme() != scheme_type::BFV) - { - throw invalid_argument("unsupported scheme"); - } - if (!context_data.qualifiers().using_batching) - { - throw invalid_argument("encryption parameters are not valid for batching"); - } - - // Set the slot count - slots_ = context_data.parms().poly_modulus_degree(); - - // Reserve space for all of the primitive roots - roots_of_unity_ = allocate_uint(slots_, pool_); - - // Fill the vector of roots of unity with all distinct odd powers of generator. - // These are all the primitive (2*slots_)-th roots of unity in integers modulo - // parms.plain_modulus(). - populate_roots_of_unity_vector(context_data); - - // Populate matrix representation index map - populate_matrix_reps_index_map(); - } - - void BatchEncoder::populate_roots_of_unity_vector( - const SEALContext::ContextData &context_data) - { - uint64_t root = context_data.plain_ntt_tables()->get_root(); - auto &modulus = context_data.parms().plain_modulus(); - - uint64_t generator_sq = multiply_uint_uint_mod(root, root, modulus); - roots_of_unity_[0] = root; - - for (size_t i = 1; i < slots_; i++) - { - roots_of_unity_[i] = multiply_uint_uint_mod(roots_of_unity_[i - 1], - generator_sq, modulus); - } - } - - void BatchEncoder::populate_matrix_reps_index_map() - { - int logn = get_power_of_two(slots_); - matrix_reps_index_map_ = allocate_uint(slots_, pool_); - - // Copy from the matrix to the value vectors - size_t row_size = slots_ >> 1; - size_t m = slots_ << 1; - uint64_t gen = 3; - uint64_t pos = 1; - for (size_t i = 0; i < row_size; i++) - { - // Position in normal bit order - uint64_t index1 = (pos - 1) >> 1; - uint64_t index2 = (m - pos - 1) >> 1; - - // Set the bit-reversed locations - matrix_reps_index_map_[i] = util::reverse_bits(index1, logn); - matrix_reps_index_map_[row_size | i] = util::reverse_bits(index2, logn); - - // Next primitive root - pos *= gen; - pos &= (m - 1); - } - } - - void BatchEncoder::reverse_bits(uint64_t *input) - { -#ifdef SEAL_DEBUG - if (input == nullptr) - { - throw invalid_argument("input cannot be null"); - } -#endif - size_t coeff_count = context_->first_context_data()->parms().poly_modulus_degree(); - int logn = get_power_of_two(coeff_count); - for (size_t i = 0; i < coeff_count; i++) - { - uint64_t reversed_i = util::reverse_bits(i, logn); - if (i < reversed_i) - { - swap(input[i], input[reversed_i]); - } - } - } - - void BatchEncoder::encode(const vector &values_matrix, - Plaintext &destination) - { - auto &context_data = *context_->first_context_data(); - - // Validate input parameters - size_t values_matrix_size = values_matrix.size(); - if (values_matrix_size > slots_) - { - throw logic_error("values_matrix size is too large"); - } -#ifdef SEAL_DEBUG - uint64_t modulus = context_data.parms().plain_modulus().value(); - for (auto v : values_matrix) - { - // Validate the i-th input - if (v >= modulus) - { - throw invalid_argument("input value is larger than plain_modulus"); - } - } -#endif - // Set destination to full size - destination.resize(slots_); - destination.parms_id() = parms_id_zero; - - // First write the values to destination coefficients. - // Read in top row, then bottom row. - for (size_t i = 0; i < values_matrix_size; i++) - { - *(destination.data() + matrix_reps_index_map_[i]) = values_matrix[i]; - } - for (size_t i = values_matrix_size; i < slots_; i++) - { - *(destination.data() + matrix_reps_index_map_[i]) = 0; - } - - // Transform destination using inverse of negacyclic NTT - // Note: We already performed bit-reversal when reading in the matrix - inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables()); - } - - void BatchEncoder::encode(const vector &values_matrix, - Plaintext &destination) - { - auto &context_data = *context_->first_context_data(); - uint64_t modulus = context_data.parms().plain_modulus().value(); - - // Validate input parameters - size_t values_matrix_size = values_matrix.size(); - if (values_matrix_size > slots_) - { - throw logic_error("values_matrix size is too large"); - } -#ifdef SEAL_DEBUG - uint64_t plain_modulus_div_two = modulus >> 1; - for (auto v : values_matrix) - { - // Validate the i-th input - if (unsigned_gt(llabs(v), plain_modulus_div_two)) - { - throw invalid_argument("input value is larger than plain_modulus"); - } - } -#endif - // Set destination to full size - destination.resize(slots_); - destination.parms_id() = parms_id_zero; - - // First write the values to destination coefficients. - // Read in top row, then bottom row. - for (size_t i = 0; i < values_matrix_size; i++) - { - *(destination.data() + matrix_reps_index_map_[i]) = - (values_matrix[i] < 0) ? (modulus + static_cast(values_matrix[i])) : - static_cast(values_matrix[i]); - } - for (size_t i = values_matrix_size; i < slots_; i++) - { - *(destination.data() + matrix_reps_index_map_[i]) = 0; - } - - // Transform destination using inverse of negacyclic NTT - // Note: We already performed bit-reversal when reading in the matrix - inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables()); - } -#ifdef SEAL_USE_MSGSL_SPAN - void BatchEncoder::encode(gsl::span values_matrix, - Plaintext &destination) - { - auto &context_data = *context_->first_context_data(); - - // Validate input parameters - size_t values_matrix_size = static_cast(values_matrix.size()); - if (values_matrix_size > slots_) - { - throw logic_error("values_matrix size is too large"); - } -#ifdef SEAL_DEBUG - uint64_t modulus = context_data.parms().plain_modulus().value(); - for (auto v : values_matrix) - { - // Validate the i-th input - if (v >= modulus) - { - throw invalid_argument("input value is larger than plain_modulus"); - } - } -#endif - // Set destination to full size - destination.resize(slots_); - destination.parms_id() = parms_id_zero; - - // First write the values to destination coefficients. Read - // in top row, then bottom row. - using index_type = decltype(values_matrix)::index_type; - for (size_t i = 0; i < values_matrix_size; i++) - { - *(destination.data() + matrix_reps_index_map_[i]) = - values_matrix[static_cast(i)]; - } - for (size_t i = values_matrix_size; i < slots_; i++) - { - *(destination.data() + matrix_reps_index_map_[i]) = 0; - } - - // Transform destination using inverse of negacyclic NTT - // Note: We already performed bit-reversal when reading in the matrix - inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables()); - } - - void BatchEncoder::encode(gsl::span values_matrix, - Plaintext &destination) - { - auto &context_data = *context_->first_context_data(); - uint64_t modulus = context_data.parms().plain_modulus().value(); - - // Validate input parameters - size_t values_matrix_size = static_cast(values_matrix.size()); - if (values_matrix_size > slots_) - { - throw logic_error("values_matrix size is too large"); - } -#ifdef SEAL_DEBUG - uint64_t plain_modulus_div_two = modulus >> 1; - for (auto v : values_matrix) - { - // Validate the i-th input - if (unsigned_gt(llabs(v), plain_modulus_div_two)) - { - throw invalid_argument("input value is larger than plain_modulus"); - } - } -#endif - // Set destination to full size - destination.resize(slots_); - destination.parms_id() = parms_id_zero; - - // First write the values to destination coefficients. Read - // in top row, then bottom row. - using index_type = decltype(values_matrix)::index_type; - for (size_t i = 0; i < values_matrix_size; i++) - { - *(destination.data() + matrix_reps_index_map_[i]) = - (values_matrix[static_cast(i)] < 0) ? - (modulus + static_cast(values_matrix[static_cast(i)])) : - static_cast(values_matrix[static_cast(i)]); - } - for (size_t i = values_matrix_size; i < slots_; i++) - { - *(destination.data() + matrix_reps_index_map_[i]) = 0; - } - - // Transform destination using inverse of negacyclic NTT - // Note: We already performed bit-reversal when reading in the matrix - inverse_ntt_negacyclic_harvey(destination.data(), *context_data.plain_ntt_tables()); - } -#endif - void BatchEncoder::encode(Plaintext &plain, MemoryPoolHandle pool) - { - if (plain.is_ntt_form()) - { - throw invalid_argument("plain cannot be in NTT form"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - auto &context_data = *context_->first_context_data(); - - // Validate input parameters - if (plain.coeff_count() > context_data.parms().poly_modulus_degree()) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } -#ifdef SEAL_DEBUG - if (!are_poly_coefficients_less_than(plain.data(), - plain.coeff_count(), context_data.parms().plain_modulus().value())) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } -#endif - // We need to permute the coefficients of plain. To do this, we allocate - // temporary space. - size_t input_plain_coeff_count = min(plain.coeff_count(), slots_); - auto temp(allocate_uint(input_plain_coeff_count, pool)); - set_uint_uint(plain.data(), input_plain_coeff_count, temp.get()); - - // Set plain to full slot count size. - plain.resize(slots_); - plain.parms_id() = parms_id_zero; - - // First write the values to destination coefficients. Read - // in top row, then bottom row. - for (size_t i = 0; i < input_plain_coeff_count; i++) - { - *(plain.data() + matrix_reps_index_map_[i]) = temp[i]; - } - for (size_t i = input_plain_coeff_count; i < slots_; i++) - { - *(plain.data() + matrix_reps_index_map_[i]) = 0; - } - - // Transform destination using inverse of negacyclic NTT - // Note: We already performed bit-reversal when reading in the matrix - inverse_ntt_negacyclic_harvey(plain.data(), *context_data.plain_ntt_tables()); - } - - void BatchEncoder::decode(const Plaintext &plain, vector &destination, - MemoryPoolHandle pool) - { - if (!is_valid_for(plain, context_)) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - if (plain.is_ntt_form()) - { - throw invalid_argument("plain cannot be in NTT form"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - auto &context_data = *context_->first_context_data(); - - // Set destination size - destination.resize(slots_); - - // Never include the leading zero coefficient (if present) - size_t plain_coeff_count = min(plain.coeff_count(), slots_); - - auto temp_dest(allocate_uint(slots_, pool)); - - // Make a copy of poly - set_uint_uint(plain.data(), plain_coeff_count, temp_dest.get()); - set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count); - - // Transform destination using negacyclic NTT. - ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables()); - - // Read top row - for (size_t i = 0; i < slots_; i++) - { - destination[i] = temp_dest[matrix_reps_index_map_[i]]; - } - } - - void BatchEncoder::decode(const Plaintext &plain, vector &destination, - MemoryPoolHandle pool) - { - if (!is_valid_for(plain, context_)) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - if (plain.is_ntt_form()) - { - throw invalid_argument("plain cannot be in NTT form"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - auto &context_data = *context_->first_context_data(); - uint64_t modulus = context_data.parms().plain_modulus().value(); - - // Set destination size - destination.resize(slots_); - - // Never include the leading zero coefficient (if present) - size_t plain_coeff_count = min(plain.coeff_count(), slots_); - - auto temp_dest(allocate_uint(slots_, pool)); - - // Make a copy of poly - set_uint_uint(plain.data(), plain_coeff_count, temp_dest.get()); - set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count); - - // Transform destination using negacyclic NTT. - ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables()); - - // Read top row, then bottom row - uint64_t plain_modulus_div_two = modulus >> 1; - for (size_t i = 0; i < slots_; i++) - { - uint64_t curr_value = temp_dest[matrix_reps_index_map_[i]]; - destination[i] = (curr_value > plain_modulus_div_two) ? - (static_cast(curr_value) - static_cast(modulus)) : - static_cast(curr_value); - } - } -#ifdef SEAL_USE_MSGSL_SPAN - void BatchEncoder::decode(const Plaintext &plain, gsl::span destination, - MemoryPoolHandle pool) - { - if (!is_valid_for(plain, context_)) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - if (plain.is_ntt_form()) - { - throw invalid_argument("plain cannot be in NTT form"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - auto &context_data = *context_->first_context_data(); - - using index_type = decltype(destination)::index_type; - if(unsigned_gt(destination.size(), numeric_limits::max()) || - unsigned_neq(destination.size(), slots_)) - { - throw invalid_argument("destination has incorrect size"); - } - - // Never include the leading zero coefficient (if present) - size_t plain_coeff_count = min(plain.coeff_count(), slots_); - - auto temp_dest(allocate_uint(slots_, pool)); - - // Make a copy of poly - set_uint_uint(plain.data(), plain_coeff_count, temp_dest.get()); - set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count); - - // Transform destination using negacyclic NTT. - ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables()); - - // Read top row - for (size_t i = 0; i < slots_; i++) - { - destination[static_cast(i)] = temp_dest[matrix_reps_index_map_[i]]; - } - } - - void BatchEncoder::decode(const Plaintext &plain, gsl::span destination, - MemoryPoolHandle pool) - { - if (!is_valid_for(plain, context_)) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - if (plain.is_ntt_form()) - { - throw invalid_argument("plain cannot be in NTT form"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - auto &context_data = *context_->first_context_data(); - uint64_t modulus = context_data.parms().plain_modulus().value(); - - using index_type = decltype(destination)::index_type; - if(unsigned_gt(destination.size(), numeric_limits::max()) || - unsigned_neq(destination.size(), slots_)) - { - throw invalid_argument("destination has incorrect size"); - } - - // Never include the leading zero coefficient (if present) - size_t plain_coeff_count = min(plain.coeff_count(), slots_); - - auto temp_dest(allocate_uint(slots_, pool)); - - // Make a copy of poly - set_uint_uint(plain.data(), plain_coeff_count, temp_dest.get()); - set_zero_uint(slots_ - plain_coeff_count, temp_dest.get() + plain_coeff_count); - - // Transform destination using negacyclic NTT. - ntt_negacyclic_harvey(temp_dest.get(), *context_data.plain_ntt_tables()); - - // Read top row, then bottom row - uint64_t plain_modulus_div_two = modulus >> 1; - for (size_t i = 0; i < slots_; i++) - { - uint64_t curr_value = temp_dest[matrix_reps_index_map_[i]]; - destination[static_cast(i)] = (curr_value > plain_modulus_div_two) ? - (static_cast(curr_value) - static_cast(modulus)) : - static_cast(curr_value); - } - } -#endif - void BatchEncoder::decode(Plaintext &plain, MemoryPoolHandle pool) - { - if (!is_valid_for(plain, context_)) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - if (plain.is_ntt_form()) - { - throw invalid_argument("plain cannot be in NTT form"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - auto &context_data = *context_->first_context_data(); - - // Never include the leading zero coefficient (if present) - size_t plain_coeff_count = min(plain.coeff_count(), slots_); - - // Allocate temporary space to store a wide copy of plain - auto temp(allocate_uint(slots_, pool)); - - // Make a copy of poly - set_uint_uint(plain.data(), plain_coeff_count, temp.get()); - set_zero_uint(slots_ - plain_coeff_count, temp.get() + plain_coeff_count); - - // Transform destination using negacyclic NTT. - ntt_negacyclic_harvey(temp.get(), *context_data.plain_ntt_tables()); - - // Set plain to full slot count size (note that all new coefficients are - // set to zero). - plain.resize(slots_); - - // Read top row, then bottom row - for (size_t i = 0; i < slots_; i++) - { - *(plain.data() + i) = temp[matrix_reps_index_map_[i]]; - } - } -} \ No newline at end of file diff --git a/SEAL/native/src/seal/batchencoder.h b/SEAL/native/src/seal/batchencoder.h deleted file mode 100644 index 00807e6..0000000 --- a/SEAL/native/src/seal/batchencoder.h +++ /dev/null @@ -1,387 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include "seal/util/defines.h" -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarithsmallmod.h" -#include "seal/plaintext.h" -#include "seal/context.h" - -namespace seal -{ - /** - Provides functionality for CRT batching. If the polynomial modulus degree is N, and - the plaintext modulus is a prime number T such that T is congruent to 1 modulo 2N, - then BatchEncoder allows the plaintext elements to be viewed as 2-by-(N/2) - matrices of integers modulo T. Homomorphic operations performed on such encrypted - matrices are applied coefficient (slot) wise, enabling powerful SIMD functionality - for computations that are vectorizable. This functionality is often called "batching" - in the homomorphic encryption literature. - - @par Mathematical Background - Mathematically speaking, if the polynomial modulus is X^N+1, N is a power of two, and - plain_modulus is a prime number T such that 2N divides T-1, then integers modulo T - contain a primitive 2N-th root of unity and the polynomial X^N+1 splits into n distinct - linear factors as X^N+1 = (X-a_1)*...*(X-a_N) mod T, where the constants a_1, ..., a_n - are all the distinct primitive 2N-th roots of unity in integers modulo T. The Chinese - Remainder Theorem (CRT) states that the plaintext space Z_T[X]/(X^N+1) in this case is - isomorphic (as an algebra) to the N-fold direct product of fields Z_T. The isomorphism - is easy to compute explicitly in both directions, which is what this class does. - Furthermore, the Galois group of the extension is (Z/2NZ)* ~= Z/2Z x Z/(N/2) whose - action on the primitive roots of unity is easy to describe. Since the batching slots - correspond 1-to-1 to the primitive roots of unity, applying Galois automorphisms on the - plaintext act by permuting the slots. By applying generators of the two cyclic - subgroups of the Galois group, we can effectively view the plaintext as a 2-by-(N/2) - matrix, and enable cyclic row rotations, and column rotations (row swaps). - - @par Valid Parameters - Whether batching can be used depends on whether the plaintext modulus has been chosen - appropriately. Thus, to construct a BatchEncoder the user must provide an instance - of SEALContext such that its associated EncryptionParameterQualifiers object has the - flags parameters_set and enable_batching set to true. - - @see EncryptionParameters for more information about encryption parameters. - @see EncryptionParameterQualifiers for more information about parameter qualifiers. - @see Evaluator for rotating rows and columns of encrypted matrices. - */ - class SEAL_NODISCARD BatchEncoder - { - public: - /** - Creates a BatchEncoder. It is necessary that the encryption parameters - given through the SEALContext object support batching. - - @param[in] context The SEALContext - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid for batching - @throws std::invalid_argument if scheme is not scheme_type::BFV - */ - BatchEncoder(std::shared_ptr context); - - /** - Creates a plaintext from a given matrix. This function "batches" a given matrix - of integers modulo the plaintext modulus into a plaintext element, and stores - the result in the destination parameter. The input vector must have size at most equal - to the degree of the polynomial modulus. The first half of the elements represent the - first row of the matrix, and the second half represent the second row. The numbers - in the matrix can be at most equal to the plaintext modulus for it to represent - a valid plaintext. - - If the destination plaintext overlaps the input values in memory, the behavior of - this function is undefined. - - @param[in] values The matrix of integers modulo plaintext modulus to batch - @param[out] destination The plaintext polynomial to overwrite with the result - @throws std::invalid_argument if values is too large - */ - void encode(const std::vector &values, Plaintext &destination); - - /** - Creates a plaintext from a given matrix. This function "batches" a given matrix - of integers modulo the plaintext modulus into a plaintext element, and stores - the result in the destination parameter. The input vector must have size at most equal - to the degree of the polynomial modulus. The first half of the elements represent the - first row of the matrix, and the second half represent the second row. The numbers - in the matrix can be at most equal to the plaintext modulus for it to represent - a valid plaintext. - - If the destination plaintext overlaps the input values in memory, the behavior of - this function is undefined. - - @param[in] values The matrix of integers modulo plaintext modulus to batch - @param[out] destination The plaintext polynomial to overwrite with the result - @throws std::invalid_argument if values is too large - */ - void encode(const std::vector &values, Plaintext &destination); -#ifdef SEAL_USE_MSGSL_SPAN - /** - Creates a plaintext from a given matrix. This function "batches" a given matrix - of integers modulo the plaintext modulus into a plaintext element, and stores - the result in the destination parameter. The input vector must have size at most equal - to the degree of the polynomial modulus. The first half of the elements represent the - first row of the matrix, and the second half represent the second row. The numbers - in the matrix can be at most equal to the plaintext modulus for it to represent - a valid plaintext. - - If the destination plaintext overlaps the input values in memory, the behavior of - this function is undefined. - - @param[in] values The matrix of integers modulo plaintext modulus to batch - @param[out] destination The plaintext polynomial to overwrite with the result - @throws std::invalid_argument if values is too large - */ - void encode(gsl::span values, Plaintext &destination); - - /** - Creates a plaintext from a given matrix. This function "batches" a given matrix - of integers modulo the plaintext modulus into a plaintext element, and stores - the result in the destination parameter. The input vector must have size at most equal - to the degree of the polynomial modulus. The first half of the elements represent the - first row of the matrix, and the second half represent the second row. The numbers - in the matrix can be at most equal to the plaintext modulus for it to represent - a valid plaintext. - - If the destination plaintext overlaps the input values in memory, the behavior of - this function is undefined. - - @param[in] values The matrix of integers modulo plaintext modulus to batch - @param[out] destination The plaintext polynomial to overwrite with the result - @throws std::invalid_argument if values is too large - */ - void encode(gsl::span values, Plaintext &destination); -#ifdef SEAL_USE_MSGSL_MULTISPAN - /** - Creates a plaintext from a given matrix. This function "batches" a given matrix - of integers modulo the plaintext modulus into a plaintext element, and stores - the result in the destination parameter. The input must have dimensions [2, N/2], - where N denotes the degree of the polynomial modulus, representing a 2 x (N/2) - matrix. The numbers in the matrix can be at most equal to the plaintext modulus for - it to represent a valid plaintext. - - If the destination plaintext overlaps the input values in memory, the behavior of - this function is undefined. - - @param[in] values The matrix of integers modulo plaintext modulus to batch - @param[out] destination The plaintext polynomial to overwrite with the result - @throws std::invalid_argument if values is too large or has incorrect size - */ - inline void encode(gsl::multi_span< - const std::uint64_t, - static_cast(2), - gsl::dynamic_range> values, Plaintext &destination) - { - encode(gsl::span(values.data(), values.size()), - destination); - } - - /** - Creates a plaintext from a given matrix. This function "batches" a given matrix - of integers modulo the plaintext modulus into a plaintext element, and stores - the result in the destination parameter. The input must have dimensions [2, N/2], - where N denotes the degree of the polynomial modulus, representing a 2 x (N/2) - matrix. The numbers in the matrix can be at most equal to the plaintext modulus for - it to represent a valid plaintext. - - If the destination plaintext overlaps the input values in memory, the behavior of - this function is undefined. - - @param[in] values The matrix of integers modulo plaintext modulus to batch - @param[out] destination The plaintext polynomial to overwrite with the result - @throws std::invalid_argument if values is too large or has incorrect size - */ - inline void encode(gsl::multi_span< - const std::int64_t, - static_cast(2), - gsl::dynamic_range> values, Plaintext &destination) - { - encode(gsl::span(values.data(), values.size()), - destination); - } -#endif -#endif - /** - Creates a plaintext from a given matrix. This function "batches" a given matrix - of integers modulo the plaintext modulus in-place into a plaintext ready to be - encrypted. The matrix is given as a plaintext element whose first N/2 coefficients - represent the first row of the matrix, and the second N/2 coefficients represent the - second row, where N denotes the degree of the polynomial modulus. The input plaintext - must have degress less than the polynomial modulus, and coefficients less than the - plaintext modulus, i.e. it must be a valid plaintext for the encryption parameters. - Dynamic memory allocations in the process are allocated from the memory pool pointed - to by the given MemoryPoolHandle. - - @param[in] plain The matrix of integers modulo plaintext modulus to batch - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if plain is in NTT form - @throws std::invalid_argument if pool is uninitialized - */ - void encode(Plaintext &plain, MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Inverse of encode. This function "unbatches" a given plaintext into a matrix - of integers modulo the plaintext modulus, and stores the result in the destination - parameter. The input plaintext must have degress less than the polynomial modulus, - and coefficients less than the plaintext modulus, i.e. it must be a valid plaintext - for the encryption parameters. Dynamic memory allocations in the process are - allocated from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] plain The plaintext polynomial to unbatch - @param[out] destination The matrix to be overwritten with the values in the slots - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if plain is in NTT form - @throws std::invalid_argument if pool is uninitialized - */ - void decode(const Plaintext &plain, std::vector &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Inverse of encode. This function "unbatches" a given plaintext into a matrix - of integers modulo the plaintext modulus, and stores the result in the destination - parameter. The input plaintext must have degress less than the polynomial modulus, - and coefficients less than the plaintext modulus, i.e. it must be a valid plaintext - for the encryption parameters. Dynamic memory allocations in the process are - allocated from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] plain The plaintext polynomial to unbatch - @param[out] destination The matrix to be overwritten with the values in the slots - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if plain is in NTT form - @throws std::invalid_argument if pool is uninitialized - */ - void decode(const Plaintext &plain, std::vector &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()); -#ifdef SEAL_USE_MSGSL_SPAN - /** - Inverse of encode. This function "unbatches" a given plaintext into a matrix - of integers modulo the plaintext modulus, and stores the result in the destination - parameter. The input plaintext must have degress less than the polynomial modulus, - and coefficients less than the plaintext modulus, i.e. it must be a valid plaintext - for the encryption parameters. Dynamic memory allocations in the process are - allocated from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] plain The plaintext polynomial to unbatch - @param[out] destination The matrix to be overwritten with the values in the slots - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if plain is in NTT form - @throws std::invalid_argument if destination has incorrect size - @throws std::invalid_argument if pool is uninitialized - */ - void decode(const Plaintext &plain, gsl::span destination, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Inverse of encode. This function "unbatches" a given plaintext into a matrix - of integers modulo the plaintext modulus, and stores the result in the destination - parameter. The input plaintext must have degress less than the polynomial modulus, - and coefficients less than the plaintext modulus, i.e. it must be a valid plaintext - for the encryption parameters. Dynamic memory allocations in the process are - allocated from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] plain The plaintext polynomial to unbatch - @param[out] destination The matrix to be overwritten with the values in the slots - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if plain is in NTT form - @throws std::invalid_argument if destination has incorrect size - @throws std::invalid_argument if pool is uninitialized - */ - void decode(const Plaintext &plain, gsl::span destination, - MemoryPoolHandle pool = MemoryManager::GetPool()); -#ifdef SEAL_USE_MSGSL_MULTISPAN - /** - Inverse of encode. This function "unbatches" a given plaintext into a matrix - of integers modulo the plaintext modulus, and stores the result in the destination - parameter. The destination must have dimensions [2, N/2], where N denotes the degree - of the polynomial modulus, representing a 2 x (N/2) matrix. The input plaintext must - have degress less than the polynomial modulus, and coefficients less than the - plaintext modulus, i.e. it must be a valid plaintext for the encryption parameters. - Dynamic memory allocations in the process are allocated from the memory pool pointed - to by the given MemoryPoolHandle. - - @param[in] plain The plaintext polynomial to unbatch - @param[out] destination The matrix to be overwritten with the values in the slots - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if plain is in NTT form - @throws std::invalid_argument if destination has incorrect size - @throws std::invalid_argument if pool is uninitialized - */ - inline void decode(const Plaintext &plain, - gsl::multi_span(2), - gsl::dynamic_range> destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - decode(plain, gsl::span(destination.data(), - destination.size()), std::move(pool)); - } - - /** - Inverse of encode. This function "unbatches" a given plaintext into a matrix - of integers modulo the plaintext modulus, and stores the result in the destination - parameter. The destination must have dimensions [2, N/2], where N denotes the degree - of the polynomial modulus, representing a 2 x (N/2) matrix. The input plaintext must - have degress less than the polynomial modulus, and coefficients less than the - plaintext modulus, i.e. it must be a valid plaintext for the encryption parameters. - Dynamic memory allocations in the process are allocated from the memory pool pointed - to by the given MemoryPoolHandle. - - @param[in] plain The plaintext polynomial to unbatch - @param[out] destination The matrix to be overwritten with the values in the slots - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if plain is in NTT form - @throws std::invalid_argument if destination has incorrect size - @throws std::invalid_argument if pool is uninitialized - */ - inline void decode(const Plaintext &plain, - gsl::multi_span(2), - gsl::dynamic_range> destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - decode(plain, gsl::span(destination.data(), - destination.size()), std::move(pool)); - } -#endif -#endif - /** - Inverse of encode. This function "unbatches" a given plaintext in-place into - a matrix of integers modulo the plaintext modulus. The input plaintext must have - degress less than the polynomial modulus, and coefficients less than the plaintext - modulus, i.e. it must be a valid plaintext for the encryption parameters. Dynamic - memory allocations in the process are allocated from the memory pool pointed to by - the given MemoryPoolHandle. - - @param[in] plain The plaintext polynomial to unbatch - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if plain is in NTT form - @throws std::invalid_argument if pool is uninitialized - */ - void decode(Plaintext &plain, MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Returns the number of slots. - */ - SEAL_NODISCARD inline auto slot_count() const noexcept - { - return slots_; - } - - private: - BatchEncoder(const BatchEncoder ©) = delete; - - BatchEncoder(BatchEncoder &&source) = delete; - - BatchEncoder &operator =(const BatchEncoder &assign) = delete; - - BatchEncoder &operator =(BatchEncoder &&assign) = delete; - - void populate_roots_of_unity_vector( - const SEALContext::ContextData &context_data); - - void populate_matrix_reps_index_map(); - - void reverse_bits(std::uint64_t *input); - - MemoryPoolHandle pool_ = MemoryManager::GetPool(); - - std::shared_ptr context_{ nullptr }; - - std::size_t slots_; - - util::Pointer roots_of_unity_; - - util::Pointer matrix_reps_index_map_; - }; -} diff --git a/SEAL/native/src/seal/biguint.cpp b/SEAL/native/src/seal/biguint.cpp deleted file mode 100644 index 45b1b37..0000000 --- a/SEAL/native/src/seal/biguint.cpp +++ /dev/null @@ -1,312 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/biguint.h" -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithmod.h" -#include - -using namespace std; -using namespace seal::util; - -namespace seal -{ - BigUInt::BigUInt(int bit_count) - { - resize(bit_count); - } - - BigUInt::BigUInt(const string &hex_value) - { - operator =(hex_value); - } - - BigUInt::BigUInt(int bit_count, const string &hex_value) - { - resize(bit_count); - operator =(hex_value); - if (bit_count != bit_count_) - { - resize(bit_count); - } - } - - BigUInt::BigUInt(int bit_count, uint64_t *value) : - value_(decltype(value_)::Aliasing(value)), bit_count_(bit_count) - { - if (bit_count < 0) - { - throw invalid_argument("bit_count must be non-negative"); - } - if (value == nullptr && bit_count > 0) - { - throw invalid_argument("value must be non-null for non-zero bit count"); - } - } -#ifdef SEAL_USE_MSGSL_SPAN - BigUInt::BigUInt(gsl::span value) - { - if(unsigned_gt(value.size(), numeric_limits::max() / bits_per_uint64)) - { - throw std::invalid_argument("value has too large size"); - } - value_ = decltype(value_)::Aliasing(value.data()); - bit_count_ = static_cast(value.size()) * bits_per_uint64; - } -#endif - BigUInt::BigUInt(int bit_count, uint64_t value) - { - resize(bit_count); - operator =(value); - if (bit_count != bit_count_) - { - resize(bit_count); - } - } - - BigUInt::BigUInt(const BigUInt ©) - { - resize(copy.bit_count()); - operator =(copy); - } - - BigUInt::BigUInt(BigUInt &&source) noexcept : - pool_(move(source.pool_)), - value_(move(source.value_)), - bit_count_(source.bit_count_) - { - // Pointer in source has been taken over so set it to nullptr - source.bit_count_ = 0; - } - - BigUInt::~BigUInt() noexcept - { - reset(); - } - - string BigUInt::to_string() const - { - return uint_to_hex_string(value_.get(), uint64_count()); - } - - string BigUInt::to_dec_string() const - { - return uint_to_dec_string(value_.get(), uint64_count(), pool_); - } - - void BigUInt::resize(int bit_count) - { - if (bit_count < 0) - { - throw invalid_argument("bit_count must be non-negative"); - } - if (value_.is_alias()) - { - throw logic_error("Cannot resize an aliased BigUInt"); - } - if (bit_count == bit_count_) - { - return; - } - - // Lazy initialization of MemoryPoolHandle - if (!pool_) - { - pool_ = MemoryManager::GetPool(); - } - - // Fast path if allocation size doesn't change. - size_t old_uint64_count = uint64_count(); - size_t new_uint64_count = safe_cast( - divide_round_up(bit_count, bits_per_uint64)); - if (old_uint64_count == new_uint64_count) - { - bit_count_ = bit_count; - return; - } - - // Allocate new space. - decltype(value_) new_value; - if (new_uint64_count > 0) - { - new_value = allocate_uint(new_uint64_count, pool_); - } - - // Copy over old value. - if (new_uint64_count > 0) - { - set_uint_uint(value_.get(), old_uint64_count, new_uint64_count, new_value.get()); - filter_highbits_uint(new_value.get(), new_uint64_count, bit_count); - } - - // Deallocate any owned pointers. - reset(); - - // Update class. - swap(value_, new_value); - bit_count_ = bit_count; - } - - BigUInt &BigUInt::operator =(const BigUInt& assign) - { - // Do nothing if same thing. - if (&assign == this) - { - return *this; - } - - // Verify assigned value will fit within bit count. - int assign_sig_bit_count = assign.significant_bit_count(); - if (assign_sig_bit_count > bit_count_) - { - // Size is too large to currently fit, so resize. - resize(assign_sig_bit_count); - } - - // Copy over value. - size_t assign_uint64_count = safe_cast( - divide_round_up(assign_sig_bit_count, bits_per_uint64)); - if (uint64_count() > 0) - { - set_uint_uint(assign.value_.get(), assign_uint64_count, - uint64_count(), value_.get()); - } - return *this; - } - - BigUInt &BigUInt::operator =(const string &hex_value) - { - int hex_value_length = safe_cast(hex_value.size()); - - int assign_bit_count = get_hex_string_bit_count(hex_value.data(), hex_value_length); - if (assign_bit_count > bit_count_) - { - // Size is too large to currently fit, so resize. - resize(assign_bit_count); - } - if (bit_count_ > 0) - { - // Copy over value. - hex_string_to_uint(hex_value.data(), hex_value_length, uint64_count(), value_.get()); - } - return *this; - } - - BigUInt BigUInt::operator /(const BigUInt& operand2) const - { - int result_bits = significant_bit_count(); - int operand2_bits = operand2.significant_bit_count(); - if (operand2_bits == 0) - { - throw invalid_argument("operand2 must be positive"); - } - if (operand2_bits > result_bits) - { - BigUInt zero(result_bits); - return zero; - } - BigUInt result(result_bits); - BigUInt remainder(result_bits); - size_t result_uint64_count = result.uint64_count(); - if (result_uint64_count > operand2.uint64_count()) - { - BigUInt operand2resized(result_bits); - operand2resized = operand2; - divide_uint_uint(value_.get(), operand2resized.data(), result_uint64_count, - result.data(), remainder.data(), pool_); - } - else - { - divide_uint_uint(value_.get(), operand2.data(), result_uint64_count, - result.data(), remainder.data(), pool_); - } - return result; - } - - BigUInt BigUInt::divrem(const BigUInt& operand2, BigUInt &remainder) const - { - int result_bits = significant_bit_count(); - remainder = *this; - int operand2_bits = operand2.significant_bit_count(); - if (operand2_bits > result_bits) - { - BigUInt zero; - return zero; - } - BigUInt quotient(result_bits); - size_t uint64_count = remainder.uint64_count(); - if (uint64_count > operand2.uint64_count()) - { - BigUInt operand2resized(result_bits); - operand2resized = operand2; - divide_uint_uint_inplace(remainder.data(), operand2resized.data(), - uint64_count, quotient.data(), pool_); - } - else - { - divide_uint_uint_inplace(remainder.data(), operand2.data(), - uint64_count, quotient.data(), pool_); - } - return quotient; - } - - void BigUInt::save(ostream &stream) const - { - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - stream.exceptions(ios_base::badbit | ios_base::failbit); - - int32_t bit_count32 = safe_cast(bit_count_); - streamsize data_size = safe_cast(mul_safe(uint64_count(), sizeof(uint64_t))); - stream.write(reinterpret_cast(&bit_count32), sizeof(int32_t)); - stream.write(reinterpret_cast(value_.get()), data_size); - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - - stream.exceptions(old_except_mask); - } - - void BigUInt::load(istream &stream) - { - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - stream.exceptions(ios_base::badbit | ios_base::failbit); - - int32_t read_bit_count = 0; - stream.read(reinterpret_cast(&read_bit_count), sizeof(int32_t)); - if (read_bit_count > bit_count_) - { - // Size is too large to currently fit, so resize. - resize(read_bit_count); - } - size_t read_uint64_count = safe_cast( - divide_round_up(read_bit_count, bits_per_uint64)); - streamsize data_size = safe_cast(mul_safe(read_uint64_count, sizeof(uint64_t))); - stream.read(reinterpret_cast(value_.get()), data_size); - - // Zero any extra space. - if (uint64_count() > read_uint64_count) - { - set_zero_uint(uint64_count() - read_uint64_count, - value_.get() + read_uint64_count); - } - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - - stream.exceptions(old_except_mask); - } -} diff --git a/SEAL/native/src/seal/biguint.h b/SEAL/native/src/seal/biguint.h deleted file mode 100644 index db1f56c..0000000 --- a/SEAL/native/src/seal/biguint.h +++ /dev/null @@ -1,1673 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include "seal/memorymanager.h" -#include "seal/util/defines.h" -#include "seal/util/pointer.h" -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithmod.h" - -namespace seal -{ - /** - Represents an unsigned integer with a specified bit width. Non-const - BigUInts are mutable and able to be resized. The bit count for a BigUInt - (which can be read with bit_count()) is set initially by the constructor - and can be resized either explicitly with the resize() function or - implicitly with an assignment operation (e.g., operator=(), operator+=(), - etc.). A rich set of unsigned integer operations are provided by the - BigUInt class, including comparison, traditional arithmetic (addition, - subtraction, multiplication, division), and modular arithmetic functions. - - @par Backing Array - The backing array for a BigUInt stores its unsigned integer value as - a contiguous std::uint64_t array. Each std::uint64_t in the array - sequentially represents 64-bits of the integer value, with the least - significant quad-word storing the lower 64-bits and the order of the bits - for each quad word dependent on the architecture's std::uint64_t - representation. The size of the array equals the bit count of the BigUInt - (which can be read with bit_count()) rounded up to the next std::uint64_t - boundary (i.e., rounded up to the next 64-bit boundary). The uint64_count() - function returns the number of std::uint64_t in the backing array. The - data() function returns a pointer to the first std::uint64_t in the array. - Additionally, the operator [] function allows accessing the individual - bytes of the integer value in a platform-independent way - for example, - reading the third byte will always return bits 16-24 of the BigUInt value - regardless of the platform being little-endian or big-endian. - - @par Implicit Resizing - Both the copy constructor and operator=() allocate more memory for the - backing array when needed, i.e. when the source BigUInt has a larger - backing array than the destination. Conversely, when the destination - backing array is already large enough, the data is only copied and the - unnecessary higher order bits are set to zero. When new memory has to be - allocated, only the significant bits of the source BigUInt are taken - into account. This is is important, because it avoids unnecessary zero - bits to be included in the destination, which in some cases could - accumulate and result in very large unnecessary allocations. However, - sometimes it is necessary to preserve the original size, even if some - of the leading bits are zero. For this purpose BigUInt contains functions - duplicate_from and duplicate_to, which create an exact copy of the source - BigUInt. - - @par Alias BigUInts - An aliased BigUInt (which can be determined with is_alias()) is a special - type of BigUInt that does not manage its underlying std::uint64_t pointer - used to store the value. An aliased BigUInt supports most of the same - operations as a non-aliased BigUInt, including reading and writing the - value, however an aliased BigUInt does not internally allocate or - deallocate its backing array and, therefore, does not support resizing. - Any attempt, either explicitly or implicitly, to resize the BigUInt will - result in an exception being thrown. An aliased BigUInt can be created - with the BigUInt(int, std::uint64_t*) constructor or the alias() function. - Note that the pointer specified to be aliased must be deallocated - externally after the BigUInt is no longer in use. Aliasing is useful in - cases where it is desirable to not have each BigUInt manage its own memory - allocation and/or to prevent unnecessary copying. - - @par Thread Safety - In general, reading a BigUInt is thread-safe while mutating is not. - Specifically, the backing array may be freed whenever a resize occurs, - the BigUInt is destroyed, or alias() is called, which would invalidate - the address returned by data() and the byte references returned by - operator []. When it is known that a resize will not occur, concurrent - reading and mutating will not inherently fail but it is possible for - a read to see a partially updated value from a concurrent write. - A non-aliased BigUInt allocates its backing array from the global - (thread-safe) memory pool. Consequently, creating or resizing a large - number of BigUInt can result in a performance loss due to thread - contention. - */ - class BigUInt - { - public: - /** - Creates an empty BigUInt with zero bit width. No memory is allocated - by this constructor. - */ - BigUInt() = default; - - /** - Creates a zero-initialized BigUInt of the specified bit width. - - @param[in] bit_count The bit width - @throws std::invalid_argument if bit_count is negative - */ - BigUInt(int bit_count); - - /** - Creates a BigUInt initialized and minimally sized to fit the unsigned - hexadecimal integer specified by the string. The string matches the format - returned by to_string() and must consist of only the characters 0-9, A-F, - or a-f, most-significant nibble first. - - @param[in] hex_value The hexadecimal integer string specifying the initial - value - @throws std::invalid_argument if hex_value does not adhere to the expected - format - */ - BigUInt(const std::string &hex_value); - - /** - Creates a BigUInt of the specified bit width and initializes it with the - unsigned hexadecimal integer specified by the string. The string must match - the format returned by to_string() and must consist of only the characters - 0-9, A-F, or a-f, most-significant nibble first. - - @param[in] bit_count The bit width - @param[in] hex_value The hexadecimal integer string specifying the initial - value - @throws std::invalid_argument if bit_count is negative - @throws std::invalid_argument if hex_value does not adhere to the expected - format - */ - BigUInt(int bit_count, const std::string &hex_value); - - /** - Creates an aliased BigUInt with the specified bit width and backing array. - An aliased BigUInt does not internally allocate or deallocate the backing - array, and instead uses the specified backing array for all read/write - operations. Note that resizing is not supported by an aliased BigUInt and - any required deallocation of the specified backing array must occur - externally after the aliased BigUInt is no longer in use. - - @param[in] bit_count The bit width - @param[in] value The backing array to use - @throws std::invalid_argument if bit_count is negative or value is null - and bit_count is positive - */ - BigUInt(int bit_count, std::uint64_t *value); -#ifdef SEAL_USE_MSGSL_SPAN - /** - Creates an aliased BigUInt with given backing array and bit width set to - the size of the backing array. An aliased BigUInt does not internally - allocate or deallocate the backing array, and instead uses the specified - backing array for all read/write operations. Note that resizing is not - supported by an aliased BigUInt and any required deallocation of the - specified backing array must occur externally after the aliased BigUInt - is no longer in use. - - @param[in] value The backing array to use - @throws std::invalid_argument if value has too large size - */ - BigUInt(gsl::span value); -#endif - /** - Creates a BigUInt of the specified bit width and initializes it to the - specified unsigned integer value. - - @param[in] bit_count The bit width - @param[in] value The initial value to set the BigUInt - @throws std::invalid_argument if bit_count is negative - */ - BigUInt(int bit_count, std::uint64_t value); - - /** - Creates a deep copy of a BigUInt. The created BigUInt will have the same - bit count and value as the original. - - @param[in] copy The BigUInt to copy from - */ - BigUInt(const BigUInt ©); - - /** - Creates a new BigUInt by moving an old one. - - @param[in] source The BigUInt to move from - */ - BigUInt(BigUInt &&source) noexcept; - - /** - Destroys the BigUInt and deallocates the backing array if it is not - an aliased BigUInt. - */ - ~BigUInt() noexcept; - - /** - Returns whether or not the BigUInt is an alias. - - @see BigUInt for a detailed description of aliased BigUInt. - */ - SEAL_NODISCARD inline bool is_alias() const noexcept - { - return value_.is_alias(); - } - - /** - Returns the bit count for the BigUInt. - - @see significant_bit_count() to instead ignore leading zero bits. - */ - SEAL_NODISCARD inline int bit_count() const noexcept - { - return bit_count_; - } - - /** - Returns a pointer to the backing array storing the BigUInt value. - The pointer points to the beginning of the backing array at the - least-significant quad word. - - @warning The pointer is valid only until the backing array is freed, - which occurs when the BigUInt is resized, destroyed, or the alias() - function is called. - @see uint64_count() to determine the number of std::uint64_t values - in the backing array. - */ - SEAL_NODISCARD inline std::uint64_t *data() - { - return value_.get(); - } - - /** - Returns a const pointer to the backing array storing the BigUInt value. - The pointer points to the beginning of the backing array at the - least-significant quad word. - - @warning The pointer is valid only until the backing array is freed, which - occurs when the BigUInt is resized, destroyed, or the alias() function is - called. - @see uint64_count() to determine the number of std::uint64_t values in the - backing array. - */ - SEAL_NODISCARD inline const std::uint64_t *data() const noexcept - { - return value_.get(); - } -#ifdef SEAL_USE_MSGSL_SPAN - /** - Returns the backing array storing the BigUInt value. - - @warning The span is valid only until the backing array is freed, which - occurs when the BigUInt is resized, destroyed, or the alias() function is - called. - */ - SEAL_NODISCARD inline gsl::span data_span() - { - return gsl::span(value_.get(), - static_cast(uint64_count())); - } - - /** - Returns the backing array storing the BigUInt value. - - @warning The span is valid only until the backing array is freed, which - occurs when the BigUInt is resized, destroyed, or the alias() function is - called. - */ - SEAL_NODISCARD inline gsl::span data_span() const - { - return gsl::span(value_.get(), - static_cast(uint64_count())); - } -#endif - /** - Returns the number of bytes in the backing array used to store the BigUInt - value. - - @see BigUInt for a detailed description of the format of the backing array. - */ - SEAL_NODISCARD inline std::size_t byte_count() const - { - return static_cast( - util::divide_round_up(bit_count_, util::bits_per_byte)); - } - - /** - Returns the number of std::uint64_t in the backing array used to store - the BigUInt value. - - @see BigUInt for a detailed description of the format of the backing array. - */ - SEAL_NODISCARD inline std::size_t uint64_count() const - { - return static_cast( - util::divide_round_up(bit_count_, util::bits_per_uint64)); - } - - /** - Returns the number of significant bits for the BigUInt. - - @see bit_count() to instead return the bit count regardless of leading zero - bits. - */ - SEAL_NODISCARD inline int significant_bit_count() const - { - if (bit_count_ == 0) - { - return 0; - } - return util::get_significant_bit_count_uint(value_.get(), uint64_count()); - } - - /** - Returns the BigUInt value as a double. Note that precision may be lost during - the conversion. - */ - SEAL_NODISCARD double to_double() const noexcept - { - const double TwoToThe64 = 18446744073709551616.0; - double result = 0; - for (std::size_t i = uint64_count(); i--; ) - { - result *= TwoToThe64; - result += static_cast(value_[i]); - } - return result; - } - - /** - Returns the BigUInt value as a hexadecimal string. - */ - SEAL_NODISCARD std::string to_string() const; - - /** - Returns the BigUInt value as a decimal string. - */ - SEAL_NODISCARD std::string to_dec_string() const; - - /** - Returns whether or not the BigUInt has the value zero. - */ - SEAL_NODISCARD inline bool is_zero() const - { - if (bit_count_ == 0) - { - return true; - } - return util::is_zero_uint(value_.get(), uint64_count()); - } - - /** - Returns the byte at the corresponding byte index of the BigUInt's integer - value. The bytes of the BigUInt are indexed least-significant byte first. - - @param[in] index The index of the byte to read - @throws std::out_of_range if index is not within [0, byte_count()) - @see BigUInt for a detailed description of the format of the backing array. - */ - SEAL_NODISCARD inline const SEAL_BYTE &operator []( - std::size_t index) const - { - if (index >= byte_count()) - { - throw std::out_of_range("index must be within [0, byte count)"); - } - return *util::get_uint64_byte(value_.get(), index); - } - - /** - Returns an byte reference that can read/write the byte at the corresponding - byte index of the BigUInt's integer value. The bytes of the BigUInt are - indexed least-significant byte first. - - @warning The returned byte is an reference backed by the BigUInt's backing - array. As such, it is only valid until the BigUInt is resized, destroyed, - or alias() is called. - - @param[in] index The index of the byte to read - @throws std::out_of_range if index is not within [0, byte_count()) - @see BigUInt for a detailed description of the format of the backing array. - */ - SEAL_NODISCARD inline SEAL_BYTE &operator [](std::size_t index) - { - if (index >= byte_count()) - { - throw std::out_of_range("index must be within [0, byte count)"); - } - return *util::get_uint64_byte(value_.get(), index); - } - - /** - Sets the BigUInt value to zero. This does not resize the BigUInt. - */ - inline void set_zero() - { - if (bit_count_) - { - return util::set_zero_uint(uint64_count(), value_.get()); - } - } - - /** - Resizes the BigUInt to the specified bit width, copying over the old value - as much as will fit. - - @param[in] bit_count The bit width - @throws std::invalid_argument if bit_count is negative - @throws std::logic_error if the BigUInt is an alias - */ - void resize(int bit_count); - - /** - Makes the BigUInt an aliased BigUInt with the specified bit width and - backing array. An aliased BigUInt does not internally allocate or - deallocate the backing array, and instead uses the specified backing array - for all read/write operations. Note that resizing is not supported by - an aliased BigUInt and any required deallocation of the specified backing - array must occur externally after the aliased BigUInt is no longer in use. - - @param[in] bit_count The bit width - @param[in] value The backing array to use - @throws std::invalid_argument if bit_count is negative or value is null - */ - inline void alias(int bit_count, std::uint64_t *value) - { - if (bit_count < 0) - { - throw std::invalid_argument("bit_count must be non-negative"); - } - if (value == nullptr && bit_count > 0) - { - throw std::invalid_argument("value must be non-null for non-zero bit count"); - } - - // Deallocate any owned pointers. - reset(); - - // Update class. - value_ = util::Pointer::Aliasing(value); - bit_count_ = bit_count; - } -#ifdef SEAL_USE_MSGSL_SPAN - /** - Makes the BigUInt an aliased BigUInt with the given backing array - and bit width set equal to the size of the backing array. An aliased - BigUInt does not internally allocate or deallocate the backing array, - and instead uses the specified backing array for all read/write - operations. Note that resizing is not supported by an aliased BigUInt - and any required deallocation of the specified backing array must - occur externally after the aliased BigUInt is no longer in use. - - @param[in] value The backing array to use - @throws std::invalid_argument if value has too large size - */ - inline void alias(gsl::span value) - { - if(util::unsigned_gt(value.size(), std::numeric_limits::max())) - { - throw std::invalid_argument("value has too large size"); - } - - // Deallocate any owned pointers. - reset(); - - // Update class. - value_ = util::Pointer::Aliasing(value.data()); - bit_count_ = static_cast(value.size());; - } -#endif - /** - Resets an aliased BigUInt into an empty non-alias BigUInt with bit count - of zero. - - @throws std::logic_error if BigUInt is not an alias - */ - inline void unalias() - { - if (!value_.is_alias()) - { - throw std::logic_error("BigUInt is not an alias"); - } - - // Reset class. - reset(); - } - - /** - Overwrites the BigUInt with the value of the specified BigUInt, enlarging - if needed to fit the assigned value. Only significant bits are used to - size the BigUInt. - - @param[in] assign The BigUInt whose value should be assigned to the - current BigUInt - @throws std::logic_error if BigUInt is an alias and the assigned BigUInt is - too large to fit the current bit width - */ - BigUInt &operator =(const BigUInt &assign); - - /** - Overwrites the BigUInt with the unsigned hexadecimal value specified by - the string, enlarging if needed to fit the assigned value. The string must - match the format returned by to_string() and must consist of only the - characters 0-9, A-F, or a-f, most-significant nibble first. - - @param[in] hex_value The hexadecimal integer string specifying the value - to assign - @throws std::invalid_argument if hex_value does not adhere to the - expected format - @throws std::logic_error if BigUInt is an alias and the assigned value - is too large to fit the current bit width - */ - BigUInt &operator =(const std::string &hex_value); - - /** - Overwrites the BigUInt with the specified integer value, enlarging if - needed to fit the value. - - @param[in] value The value to assign - @throws std::logic_error if BigUInt is an alias and the significant bit - count of value is too large to fit the - current bit width - */ - inline BigUInt &operator =(std::uint64_t value) - { - int assign_bit_count = util::get_significant_bit_count(value); - if (assign_bit_count > bit_count_) - { - // Size is too large to currently fit, so resize. - resize(assign_bit_count); - } - if (bit_count_ > 0) - { - util::set_uint(value, uint64_count(), value_.get()); - } - return *this; - } - - /** - Returns a copy of the BigUInt value resized to the significant bit count. - */ - SEAL_NODISCARD inline BigUInt operator +() const - { - BigUInt result; - result = *this; - return result; - } - - /** - Returns a negated copy of the BigUInt value. The bit count does not change. - */ - SEAL_NODISCARD inline BigUInt operator -() const - { - BigUInt result(bit_count_); - util::negate_uint(value_.get(), result.uint64_count(), result.data()); - util::filter_highbits_uint(result.data(), result.uint64_count(), result.bit_count()); - return result; - } - - /** - Returns an inverted copy of the BigUInt value. The bit count does not change. - */ - SEAL_NODISCARD inline BigUInt operator ~() const - { - BigUInt result(bit_count_); - util::not_uint(value_.get(), result.uint64_count(), result.data()); - util::filter_highbits_uint(result.data(), result.uint64_count(), result.bit_count()); - return result; - } - - /** - Increments the BigUInt and returns the incremented value. The BigUInt will - increment the bit count if needed to fit the carry. - - @throws std::logic_error if BigUInt is an alias and a carry occurs requiring - the BigUInt to be resized - */ - inline BigUInt &operator ++() - { - if (util::increment_uint(value_.get(), uint64_count(), value_.get())) - { - resize(util::add_safe(bit_count_, 1)); - util::set_bit_uint(value_.get(), uint64_count(), bit_count_); - } - bit_count_ = std::max(bit_count_, significant_bit_count()); - return *this; - } - - /** - Decrements the BigUInt and returns the decremented value. The bit count - does not change. - */ - inline BigUInt &operator --() - { - util::decrement_uint(value_.get(), uint64_count(), value_.get()); - util::filter_highbits_uint(value_.get(), uint64_count(), bit_count_); - return *this; - } -#ifndef SEAL_USE_MAYBE_UNUSED -#if (SEAL_COMPILER == SEAL_COMPILER_GCC) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#elif (SEAL_COMPILER == SEAL_COMPILER_CLANG) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunused-parameter" -#endif -#endif - /** - Increments the BigUInt but returns its old value. The BigUInt will increment - the bit count if needed to fit the carry. - */ - inline BigUInt operator ++(int postfix SEAL_MAYBE_UNUSED) - { - BigUInt result; - result = *this; - if (util::increment_uint(value_.get(), uint64_count(), value_.get())) - { - resize(util::add_safe(bit_count_, 1)); - util::set_bit_uint(value_.get(), uint64_count(), bit_count_); - } - bit_count_ = std::max(bit_count_, significant_bit_count()); - return result; - } - - /** - Decrements the BigUInt but returns its old value. The bit count does not change. - */ - inline BigUInt operator --(int postfix SEAL_MAYBE_UNUSED) - { - BigUInt result; - result = *this; - util::decrement_uint(value_.get(), uint64_count(), value_.get()); - util::filter_highbits_uint(value_.get(), uint64_count(), bit_count_); - return result; - } -#ifndef SEAL_USE_MAYBE_UNUSED -#if (SEAL_COMPILER == SEAL_COMPILER_GCC) -#pragma GCC diagnostic pop -#elif (SEAL_COMPILER == SEAL_COMPILER_CLANG) -#pragma clang diagnostic pop -#endif -#endif - /** - Adds two BigUInts and returns the sum. The input operands are not modified. - The bit count of the sum is set to be one greater than the significant bit - count of the larger of the two input operands. - - @param[in] operand2 The second operand to add - */ - SEAL_NODISCARD inline BigUInt operator +(const BigUInt &operand2) const - { - int result_bits = util::add_safe(std::max(significant_bit_count(), - operand2.significant_bit_count()), 1); - BigUInt result(result_bits); - util::add_uint_uint(value_.get(), uint64_count(), operand2.data(), - operand2.uint64_count(), false, result.uint64_count(), result.data()); - return result; - } - - /** - Adds a BigUInt and an unsigned integer and returns the sum. The input - operands are not modified. The bit count of the sum is set to be one greater - than the significant bit count of the larger of the two operands. - - @param[in] operand2 The second operand to add - */ - SEAL_NODISCARD inline BigUInt operator +(std::uint64_t operand2) const - { - BigUInt operand2uint; - operand2uint = operand2; - return *this + operand2uint; - } - - /** - Subtracts two BigUInts and returns the difference. The input operands are - not modified. The bit count of the difference is set to be the significant - bit count of the larger of the two input operands. - - @param[in] operand2 The second operand to subtract - */ - SEAL_NODISCARD inline BigUInt operator -(const BigUInt &operand2) const - { - int result_bits = std::max(bit_count_, operand2.bit_count()); - BigUInt result(result_bits); - util::sub_uint_uint(value_.get(), uint64_count(), operand2.data(), - operand2.uint64_count(), false, result.uint64_count(), result.data()); - util::filter_highbits_uint(result.data(), result.uint64_count(), result_bits); - return result; - } - - /** - Subtracts a BigUInt and an unsigned integer and returns the difference. - The input operands are not modified. The bit count of the difference is set - to be the significant bit count of the larger of the two operands. - - @param[in] operand2 The second operand to subtract - */ - SEAL_NODISCARD inline BigUInt operator -(std::uint64_t operand2) const - { - BigUInt operand2uint; - operand2uint = operand2; - return *this - operand2uint; - } - - /** - Multiplies two BigUInts and returns the product. The input operands are - not modified. The bit count of the product is set to be the sum of the - significant bit counts of the two input operands. - - @param[in] operand2 The second operand to multiply - */ - SEAL_NODISCARD inline BigUInt operator *(const BigUInt &operand2) const - { - int result_bits = util::add_safe(significant_bit_count(), - operand2.significant_bit_count()); - BigUInt result(result_bits); - util::multiply_uint_uint(value_.get(), uint64_count(), operand2.data(), - operand2.uint64_count(), result.uint64_count(), result.data()); - return result; - } - - /** - Multiplies a BigUInt and an unsigned integer and returns the product. - The input operands are not modified. The bit count of the product is set - to be the sum of the significant bit counts of the two input operands. - - @param[in] operand2 The second operand to multiply - */ - SEAL_NODISCARD inline BigUInt operator *(std::uint64_t operand2) const - { - BigUInt operand2uint; - operand2uint = operand2; - return *this * operand2uint; - } - - /** - Divides two BigUInts and returns the quotient. The input operands are - not modified. The bit count of the quotient is set to be the significant - bit count of the first input operand. - - @param[in] operand2 The second operand to divide - @throws std::invalid_argument if operand2 is zero - */ - SEAL_NODISCARD BigUInt operator /(const BigUInt &operand2) const; - - /** - Divides a BigUInt and an unsigned integer and returns the quotient. The - input operands are not modified. The bit count of the quotient is set - to be the significant bit count of the first input operand. - - @param[in] operand2 The second operand to divide - @throws std::invalid_argument if operand2 is zero - */ - SEAL_NODISCARD inline BigUInt operator /(std::uint64_t operand2) const - { - BigUInt operand2uint; - operand2uint = operand2; - return *this / operand2uint; - } - - /** - Performs a bit-wise XOR operation between two BigUInts and returns the - result. The input operands are not modified. The bit count of the result - is set to the maximum of the two input operand bit counts. - - @param[in] operand2 The second operand to XOR - */ - SEAL_NODISCARD inline BigUInt operator ^(const BigUInt &operand2) const - { - int result_bits = std::max(bit_count_, operand2.bit_count()); - BigUInt result(result_bits); - std::size_t uint64_count = result.uint64_count(); - if (uint64_count != this->uint64_count()) - { - result = *this; - util::xor_uint_uint( - result.data(), operand2.data(), uint64_count, result.data()); - } - else if (uint64_count != operand2.uint64_count()) - { - result = operand2; - util::xor_uint_uint( - result.data(), value_.get(), uint64_count, result.data()); - } - else - { - util::xor_uint_uint( - value_.get(), operand2.data(), uint64_count, result.data()); - } - return result; - } - - /** - Performs a bit-wise XOR operation between a BigUInt and an unsigned - integer and returns the result. The input operands are not modified. - The bit count of the result is set to the maximum of the two input - operand bit counts. - - @param[in] operand2 The second operand to XOR - */ - SEAL_NODISCARD inline BigUInt operator ^(std::uint64_t operand2) const - { - BigUInt operand2uint; - operand2uint = operand2; - return *this ^ operand2uint; - } - - /** - Performs a bit-wise AND operation between two BigUInts and returns the - result. The input operands are not modified. The bit count of the result - is set to the maximum of the two input operand bit counts. - - @param[in] operand2 The second operand to AND - */ - SEAL_NODISCARD inline BigUInt operator &(const BigUInt &operand2) const - { - int result_bits = std::max(bit_count_, operand2.bit_count()); - BigUInt result(result_bits); - std::size_t uint64_count = result.uint64_count(); - if (uint64_count != this->uint64_count()) - { - result = *this; - util::and_uint_uint( - result.data(), operand2.data(), uint64_count, result.data()); - } - else if (uint64_count != operand2.uint64_count()) - { - result = operand2; - util::and_uint_uint( - result.data(), value_.get(), uint64_count, result.data()); - } - else - { - util::and_uint_uint( - value_.get(), operand2.data(), uint64_count, result.data()); - } - return result; - } - - /** - Performs a bit-wise AND operation between a BigUInt and an unsigned - integer and returns the result. The input operands are not modified. - The bit count of the result is set to the maximum of the two input - operand bit counts. - - @param[in] operand2 The second operand to AND - */ - SEAL_NODISCARD inline BigUInt operator &(std::uint64_t operand2) const - { - BigUInt operand2uint; - operand2uint = operand2; - return *this & operand2uint; - } - - /** - Performs a bit-wise OR operation between two BigUInts and returns the - result. The input operands are not modified. The bit count of the result - is set to the maximum of the two input operand bit counts. - - @param[in] operand2 The second operand to OR - */ - SEAL_NODISCARD inline BigUInt operator |(const BigUInt &operand2) const - { - int result_bits = std::max(bit_count_, operand2.bit_count()); - BigUInt result(result_bits); - std::size_t uint64_count = result.uint64_count(); - if (uint64_count != this->uint64_count()) - { - result = *this; - util::or_uint_uint( - result.data(), operand2.data(), uint64_count, result.data()); - } - else if (uint64_count != operand2.uint64_count()) - { - result = operand2; - util::or_uint_uint( - result.data(), value_.get(), uint64_count, result.data()); - } - else - { - util::or_uint_uint( - value_.get(), operand2.data(), uint64_count, result.data()); - } - return result; - } - - /** - Performs a bit-wise OR operation between a BigUInt and an unsigned - integer and returns the result. The input operands are not modified. - The bit count of the result is set to the maximum of the two input - operand bit counts. - - @param[in] operand2 The second operand to OR - */ - SEAL_NODISCARD inline BigUInt operator |(std::uint64_t operand2) const - { - BigUInt operand2uint; - operand2uint = operand2; - return *this | operand2uint; - } - - /** - Compares two BigUInts and returns -1, 0, or 1 if the BigUInt is - less-than, equal-to, or greater-than the second operand respectively. - The input operands are not modified. - - @param[in] compare The value to compare against - */ - SEAL_NODISCARD inline int compareto(const BigUInt &compare) const - { - return util::compare_uint_uint(value_.get(), uint64_count(), - compare.value_.get(), compare.uint64_count()); - } - - /** - Compares a BigUInt and an unsigned integer and returns -1, 0, or 1 if - the BigUInt is less-than, equal-to, or greater-than the second operand - respectively. The input operands are not modified. - - @param[in] compare The value to compare against - */ - SEAL_NODISCARD inline int compareto(std::uint64_t compare) const - { - BigUInt compareuint; - compareuint = compare; - return compareto(compareuint); - } - - /** - Returns whether or not a BigUInt is less-than a second BigUInt. The - input operands are not modified. - - @param[in] operand2 The value to compare against - */ - SEAL_NODISCARD inline bool operator <(const BigUInt &compare) const - { - return util::compare_uint_uint(value_.get(), uint64_count(), - compare.value_.get(), compare.uint64_count()) < 0; - } - - /** - Returns whether or not a BigUInt is less-than an unsigned integer. - The input operands are not modified. - - @param[in] operand2 The value to compare against - */ - SEAL_NODISCARD inline bool operator <(std::uint64_t compare) const - { - BigUInt compareuint; - compareuint = compare; - return *this < compareuint; - } - - /** - Returns whether or not a BigUInt is greater-than a second BigUInt. - The input operands are not modified. - - @param[in] operand2 The value to compare against - */ - SEAL_NODISCARD inline bool operator >(const BigUInt &compare) const - { - return util::compare_uint_uint(value_.get(), uint64_count(), - compare.value_.get(), compare.uint64_count()) > 0; - } - - /** - Returns whether or not a BigUInt is greater-than an unsigned integer. - The input operands are not modified. - - @param[in] operand2 The value to compare against - */ - SEAL_NODISCARD inline bool operator >(std::uint64_t compare) const - { - BigUInt compareuint; - compareuint = compare; - return *this > compareuint; - } - - /** - Returns whether or not a BigUInt is less-than or equal to a second - BigUInt. The input operands are not modified. - - @param[in] operand2 The value to compare against - */ - SEAL_NODISCARD inline bool operator <=(const BigUInt &compare) const - { - return util::compare_uint_uint(value_.get(), uint64_count(), - compare.value_.get(), compare.uint64_count()) <= 0; - } - - /** - Returns whether or not a BigUInt is less-than or equal to an unsigned - integer. The input operands are not modified. - - @param[in] operand2 The value to compare against - */ - SEAL_NODISCARD inline bool operator <=(std::uint64_t compare) const - { - BigUInt compareuint; - compareuint = compare; - return *this <= compareuint; - } - - /** - Returns whether or not a BigUInt is greater-than or equal to a second - BigUInt. The input operands are not modified. - - @param[in] operand2 The value to compare against - */ - SEAL_NODISCARD inline bool operator >=(const BigUInt &compare) const - { - return util::compare_uint_uint(value_.get(), uint64_count(), - compare.value_.get(), compare.uint64_count()) >= 0; - } - - /** - Returns whether or not a BigUInt is greater-than or equal to an unsigned - integer. The input operands are not modified. - - @param[in] operand2 The value to compare against - */ - SEAL_NODISCARD inline bool operator >=(std::uint64_t compare) const - { - BigUInt compareuint; - compareuint = compare; - return *this >= compareuint; - } - - /** - Returns whether or not a BigUInt is equal to a second BigUInt. - The input operands are not modified. - - @param[in] compare The value to compare against - */ - SEAL_NODISCARD inline bool operator ==(const BigUInt &compare) const - { - return util::compare_uint_uint(value_.get(), uint64_count(), - compare.value_.get(), compare.uint64_count()) == 0; - } - - /** - Returns whether or not a BigUInt is equal to an unsigned integer. - The input operands are not modified. - - @param[in] compare The value to compare against - */ - SEAL_NODISCARD inline bool operator ==(std::uint64_t compare) const - { - BigUInt compareuint; - compareuint = compare; - return *this == compareuint; - } - - /** - Returns whether or not a BigUInt is not equal to a second BigUInt. - The input operands are not modified. - - @param[in] compare The value to compare against - */ - SEAL_NODISCARD inline bool operator !=(const BigUInt &compare) const - { - return !(operator ==(compare)); - } - - /** - Returns whether or not a BigUInt is not equal to an unsigned integer. - The input operands are not modified. - - @param[in] operand2 The value to compare against - */ - SEAL_NODISCARD inline bool operator !=(std::uint64_t compare) const - { - BigUInt compareuint; - compareuint = compare; - return *this != compareuint; - } - - /** - Returns a left-shifted copy of the BigUInt. The bit count of the - returned value is the sum of the original significant bit count and - the shift amount. - - @param[in] shift The number of bits to shift by - @throws std::invalid_argument if shift is negative - */ - SEAL_NODISCARD inline BigUInt operator <<(int shift) const - { - if (shift < 0) - { - throw std::invalid_argument("shift must be non-negative"); - } - int result_bits = util::add_safe(significant_bit_count(), shift); - BigUInt result(result_bits); - result = *this; - util::left_shift_uint( - result.data(), shift, result.uint64_count(), result.data()); - return result; - } - - /** - Returns a right-shifted copy of the BigUInt. The bit count of the - returned value is the original significant bit count subtracted by - the shift amount (clipped to zero if negative). - - @param[in] shift The number of bits to shift by - @throws std::invalid_argument if shift is negative - */ - SEAL_NODISCARD inline BigUInt operator >>(int shift) const - { - if (shift < 0) - { - throw std::invalid_argument("shift must be non-negative"); - } - int result_bits = util::sub_safe(significant_bit_count(), shift); - if (result_bits <= 0) - { - BigUInt zero; - return zero; - } - BigUInt result(result_bits); - result = *this; - util::right_shift_uint( - result.data(), shift, result.uint64_count(), result.data()); - return result; - } - - /** - Adds two BigUInts saving the sum to the first operand, returning - a reference of the first operand. The second input operand is not - modified. The first operand is resized if and only if its bit count - is smaller than one greater than the significant bit count of the - larger of the two input operands. - - @param[in] operand2 The second operand to add - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator +=(const BigUInt &operand2) - { - int result_bits = util::add_safe(std::max( - significant_bit_count(), operand2.significant_bit_count()), 1); - if (bit_count_ < result_bits) - { - resize(result_bits); - } - util::add_uint_uint(value_.get(), uint64_count(), operand2.data(), - operand2.uint64_count(), false, uint64_count(), value_.get()); - return *this; - } - - /** - Adds a BigUInt and an unsigned integer saving the sum to the first operand, - returning a reference of the first operand. The second input operand is not - modified. The first operand is resized if and only if its bit count is - smaller than one greater than the significant bit count of the larger of - the two input operands. - - @param[in] operand2 The second operand to add - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator +=(std::uint64_t operand2) - { - BigUInt operand2uint; - operand2uint = operand2; - return operator +=(operand2uint); - } - - /** - Subtracts two BigUInts saving the difference to the first operand, - returning a reference of the first operand. The second input operand is - not modified. The first operand is resized if and only if its bit count - is smaller than the significant bit count of the second operand. - - @param[in] operand2 The second operand to subtract - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator -=(const BigUInt &operand2) - { - int result_bits = std::max(bit_count_, operand2.bit_count()); - if (bit_count_ < result_bits) - { - resize(result_bits); - } - util::sub_uint_uint(value_.get(), uint64_count(), operand2.data(), - operand2.uint64_count(), false, uint64_count(), value_.get()); - util::filter_highbits_uint(value_.get(), uint64_count(), result_bits); - return *this; - } - - /** - Subtracts a BigUInt and an unsigned integer saving the difference to - the first operand, returning a reference of the first operand. The second - input operand is not modified. The first operand is resized if and only - if its bit count is smaller than the significant bit count of the second - operand. - - @param[in] operand2 The second operand to subtract - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator -=(std::uint64_t operand2) - { - BigUInt operand2uint; - operand2uint = operand2; - return operator -=(operand2uint); - } - - /** - Multiplies two BigUInts saving the product to the first operand, - returning a reference of the first operand. The second input operand - is not modified. The first operand is resized if and only if its bit - count is smaller than the sum of the significant bit counts of the two - input operands. - - @param[in] operand2 The second operand to multiply - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator *=(const BigUInt &operand2) - { - *this = *this * operand2; - return *this; - } - - /** - Multiplies a BigUInt and an unsigned integer saving the product to - the first operand, returning a reference of the first operand. The - second input operand is not modified. The first operand is resized if - and only if its bit count is smaller than the sum of the significant - bit counts of the two input operands. - - @param[in] operand2 The second operand to multiply - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator *=(std::uint64_t operand2) - { - BigUInt operand2uint; - operand2uint = operand2; - return operator *=(operand2uint); - } - - /** - Divides two BigUInts saving the quotient to the first operand, - returning a reference of the first operand. The second input operand - is not modified. The first operand is never resized. - - @param[in] operand2 The second operand to divide - @throws std::invalid_argument if operand2 is zero - */ - inline BigUInt &operator /=(const BigUInt &operand2) - { - *this = *this / operand2; - return *this; - } - - /** - Divides a BigUInt and an unsigned integer saving the quotient to - the first operand, returning a reference of the first operand. The - second input operand is not modified. The first operand is never resized. - - @param[in] operand2 The second operand to divide - @throws std::invalid_argument if operand2 is zero - */ - inline BigUInt &operator /=(std::uint64_t operand2) - { - BigUInt operand2uint; - operand2uint = operand2; - return operator /=(operand2uint); - } - - /** - Performs a bit-wise XOR operation between two BigUInts saving the result - to the first operand, returning a reference of the first operand. The - second input operand is not modified. The first operand is resized if - and only if its bit count is smaller than the significant bit count of - the second operand. - - @param[in] operand2 The second operand to XOR - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator ^=(const BigUInt &operand2) - { - int result_bits = std::max(bit_count_, operand2.bit_count()); - if (bit_count_ != result_bits) - { - resize(result_bits); - } - util::xor_uint_uint( - value_.get(), operand2.data(), operand2.uint64_count(), value_.get()); - return *this; - } - - /** - Performs a bit-wise XOR operation between a BigUInt and an unsigned integer - saving the result to the first operand, returning a reference of the first - operand. The second input operand is not modified. The first operand is - resized if and only if its bit count is smaller than the significant bit - count of the second operand. - - @param[in] operand2 The second operand to XOR - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator ^=(std::uint64_t operand2) - { - BigUInt operand2uint; - operand2uint = operand2; - return operator ^=(operand2uint); - } - - /** - Performs a bit-wise AND operation between two BigUInts saving the result - to the first operand, returning a reference of the first operand. The - second input operand is not modified. The first operand is resized if - and only if its bit count is smaller than the significant bit count of - the second operand. - - @param[in] operand2 The second operand to AND - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator &=(const BigUInt &operand2) - { - int result_bits = std::max(bit_count_, operand2.bit_count()); - if (bit_count_ != result_bits) - { - resize(result_bits); - } - util::and_uint_uint( - value_.get(), operand2.data(), operand2.uint64_count(), value_.get()); - return *this; - } - - /** - Performs a bit-wise AND operation between a BigUInt and an unsigned integer - saving the result to the first operand, returning a reference of the first - operand. The second input operand is not modified. The first operand is - resized if and only if its bit count is smaller than the significant bit - count of the second operand. - - @param[in] operand2 The second operand to AND - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator &=(std::uint64_t operand2) - { - BigUInt operand2uint; - operand2uint = operand2; - return operator &=(operand2uint); - } - - /** - Performs a bit-wise OR operation between two BigUInts saving the result to - the first operand, returning a reference of the first operand. The second - input operand is not modified. The first operand is resized if and only if - its bit count is smaller than the significant bit count of the second - operand. - - @param[in] operand2 The second operand to OR - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator |=(const BigUInt &operand2) - { - int result_bits = std::max(bit_count_, operand2.bit_count()); - if (bit_count_ != result_bits) - { - resize(result_bits); - } - util::or_uint_uint(value_.get(), operand2.data(), - operand2.uint64_count(), value_.get()); - return *this; - } - - /** - Performs a bit-wise OR operation between a BigUInt and an unsigned integer - saving the result to the first operand, returning a reference of the first - operand. The second input operand is not modified. The first operand is - resized if and only if its bit count is smaller than the significant bit - count of the second operand. - - @param[in] operand2 The second operand to OR - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator |=(std::uint64_t operand2) - { - BigUInt operand2uint; - operand2uint = operand2; - return operator |=(operand2uint); - } - - /** - Left-shifts a BigUInt by the specified amount. The BigUInt is resized if - and only if its bit count is smaller than the sum of its significant bit - count and the shift amount. - - @param[in] shift The number of bits to shift by - @throws std::Invalid_argument if shift is negative - @throws std::logic_error if the BigUInt is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - inline BigUInt &operator <<=(int shift) - { - if (shift < 0) - { - throw std::invalid_argument("shift must be non-negative"); - } - int result_bits = util::add_safe(significant_bit_count(), shift); - if (bit_count_ < result_bits) - { - resize(result_bits); - } - util::left_shift_uint(value_.get(), shift, uint64_count(), value_.get()); - return *this; - } - - /** - Right-shifts a BigUInt by the specified amount. The BigUInt is never - resized. - - @param[in] shift The number of bits to shift by - @throws std::Invalid_argument if shift is negative - */ - inline BigUInt &operator >>=(int shift) - { - if (shift < 0) - { - throw std::invalid_argument("shift must be non-negative"); - } - if (shift > bit_count_) - { - set_zero(); - return *this; - } - util::right_shift_uint(value_.get(), shift, uint64_count(), value_.get()); - return *this; - } - - /** - Divides two BigUInts and returns the quotient and sets the remainder - parameter to the remainder. The bit count of the quotient is set to be - the significant bit count of the BigUInt. The remainder is resized if - and only if it is smaller than the bit count of the BigUInt. - - @param[in] operand2 The second operand to divide - @param[out] remainder The BigUInt to store the remainder - @throws std::Invalid_argument if operand2 is zero - @throws std::logic_error if the remainder is an alias and the operator - attempts to enlarge the BigUInt to fit the result - */ - BigUInt divrem(const BigUInt &operand2, BigUInt &remainder) const; - - /** - Divides a BigUInt and an unsigned integer and returns the quotient and - sets the remainder parameter to the remainder. The bit count of the - quotient is set to be the significant bit count of the BigUInt. The - remainder is resized if and only if it is smaller than the bit count - of the BigUInt. - - @param[in] operand2 The second operand to divide - @param[out] remainder The BigUInt to store the remainder - @throws std::Invalid_argument if operand2 is zero - @throws std::logic_error if the remainder is an alias which the - function attempts to enlarge to fit the result - */ - inline BigUInt divrem(std::uint64_t operand2, BigUInt &remainder) const - { - BigUInt operand2uint; - operand2uint = operand2; - return divrem(operand2uint, remainder); - } - - /** - Returns the inverse of a BigUInt with respect to the specified modulus. - The original BigUInt is not modified. The bit count of the inverse is - set to be the significant bit count of the modulus. - - @param[in] modulus The modulus to calculate the inverse with respect to - @throws std::Invalid_argument if modulus is zero - @throws std::Invalid_argument if modulus is not greater than the BigUInt value - @throws std::Invalid_argument if the BigUInt value and modulus are not co-prime - */ - SEAL_NODISCARD inline BigUInt modinv(const BigUInt &modulus) const - { - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus must be positive"); - } - int result_bits = modulus.significant_bit_count(); - if (*this >= modulus) - { - throw std::invalid_argument("modulus must be greater than BigUInt"); - } - BigUInt result(result_bits); - result = *this; - if (!util::try_invert_uint_mod(result.data(), modulus.data(), - result.uint64_count(), result.data(), pool_)) - { - throw std::invalid_argument("BigUInt and modulus are not co-prime"); - } - return result; - } - - /** - Returns the inverse of a BigUInt with respect to the specified modulus. - The original BigUInt is not modified. The bit count of the inverse is set - to be the significant bit count of the modulus. - - @param[in] modulus The modulus to calculate the inverse with respect to - @throws std::Invalid_argument if modulus is zero - @throws std::Invalid_argument if modulus is not greater than the BigUInt value - @throws std::Invalid_argument if the BigUInt value and modulus are not co-prime - */ - SEAL_NODISCARD inline BigUInt modinv(std::uint64_t modulus) const - { - BigUInt modulusuint; - modulusuint = modulus; - return modinv(modulusuint); - } - - /** - Attempts to calculate the inverse of a BigUInt with respect to the - specified modulus, returning whether or not the inverse was successful - and setting the inverse parameter to the inverse. The original BigUInt - is not modified. The inverse parameter is resized if and only if its bit - count is smaller than the significant bit count of the modulus. - - @param[in] modulus The modulus to calculate the inverse with respect to - @param[out] inverse Stores the inverse if the inverse operation was - successful - @throws std::Invalid_argument if modulus is zero - @throws std::Invalid_argument if modulus is not greater than the BigUInt - value - @throws std::logic_error if the inverse is an alias which the function - attempts to enlarge to fit the result - */ - inline bool trymodinv(const BigUInt &modulus, BigUInt &inverse) const - { - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus must be positive"); - } - int result_bits = modulus.significant_bit_count(); - if (*this >= modulus) - { - throw std::invalid_argument("modulus must be greater than BigUInt"); - } - if (inverse.bit_count() < result_bits) - { - inverse.resize(result_bits); - } - inverse = *this; - return util::try_invert_uint_mod(inverse.data(), modulus.data(), - inverse.uint64_count(), inverse.data(), pool_); - } - - /** - Attempts to calculate the inverse of a BigUInt with respect to the - specified modulus, returning whether or not the inverse was successful - and setting the inverse parameter to the inverse. The original BigUInt - is not modified. The inverse parameter is resized if and only if its - bit count is smaller than the significant bit count of the modulus. - - @param[in] modulus The modulus to calculate the inverse with respect to - @param[out] inverse Stores the inverse if the inverse operation was - successful - @throws std::Invalid_argument if modulus is zero - @throws std::Invalid_argument if modulus is not greater than the BigUInt - value - @throws std::logic_error if the inverse is an alias which the function - attempts to enlarge to fit the result - */ - inline bool trymodinv(std::uint64_t modulus, BigUInt &inverse) const - { - BigUInt modulusuint; - modulusuint = modulus; - return trymodinv(modulusuint, inverse); - } - - /** - Saves the BigUInt to an output stream. The full state of the BigUInt is - serialized, including insignificant bits. The output is in binary format - and not human-readable. The output stream must have the "binary" flag set. - - @param[in] stream The stream to save the BigUInt to - @throws std::exception if the BigUInt could not be written to stream - */ - void save(std::ostream &stream) const; - - /** - Loads a BigUInt from an input stream overwriting the current BigUInt - and enlarging if needed to fit the loaded BigUInt. - - @param[in] stream The stream to load the BigUInt from - @throws std::logic_error if BigUInt is an alias and the loaded BigUInt - is too large to fit with the current bit - @throws std::exception if a valid BigUInt could not be read from stream - */ - void load(std::istream &stream); - - /** - Creates a minimally sized BigUInt initialized to the specified unsigned - integer value. - - @param[in] value The value to initialized the BigUInt to - */ - SEAL_NODISCARD inline static BigUInt of(std::uint64_t value) - { - BigUInt result; - result = value; - return result; - } - - /** - Duplicates the current BigUInt. The bit count and the value of the - given BigUInt are set to be exactly the same as in the current one. - - @param[out] destination The BigUInt to overwrite with the duplicate - @throws std::logic_error if the destination BigUInt is an alias - */ - inline void duplicate_to(BigUInt &destination) const - { - destination.resize(this->bit_count_); - destination = *this; - } - - /** - Duplicates a given BigUInt. The bit count and the value of the current - BigUInt are set to be exactly the same as in the given one. - - @param[in] value The BigUInt to duplicate - @throws std::logic_error if the current BigUInt is an alias - */ - inline void duplicate_from(const BigUInt &value) - { - this->resize(value.bit_count_); - *this = value; - } - - private: - MemoryPoolHandle pool_; - - /** - Resets the entire state of the BigUInt to an empty, zero-sized state, - freeing any memory it internally allocated. If the BigUInt was an alias, - the backing array is not freed but the alias is no longer referenced. - */ - inline void reset() noexcept - { - value_.release(); - bit_count_ = 0; - } - - /** - Points to the backing array for the BigUInt. This pointer will be set - to nullptr if and only if the bit count is zero. This pointer is - automatically allocated and freed by the BigUInt if and only if - the BigUInt is not an alias. If the BigUInt is an alias, then the - pointer was passed-in to a constructor or alias() call, and will not be - deallocated by the BigUInt. - - @see BigUInt for more information about aliased BigUInts or the format - of the backing array. - */ - util::Pointer value_; - - /** - The bit count for the BigUInt. - */ - int bit_count_ = 0; - }; -} diff --git a/SEAL/native/src/seal/ciphertext.cpp b/SEAL/native/src/seal/ciphertext.cpp deleted file mode 100644 index 73c514d..0000000 --- a/SEAL/native/src/seal/ciphertext.cpp +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/ciphertext.h" -#include "seal/util/polycore.h" - -using namespace std; -using namespace seal::util; - -namespace seal -{ - Ciphertext &Ciphertext::operator =(const Ciphertext &assign) - { - // Check for self-assignment - if (this == &assign) - { - return *this; - } - - // Copy over fields - parms_id_ = assign.parms_id_; - is_ntt_form_ = assign.is_ntt_form_; - scale_ = assign.scale_; - - // Then resize - resize_internal(assign.size_, assign.poly_modulus_degree_, - assign.coeff_mod_count_); - - // Size is guaranteed to be OK now so copy over - copy(assign.data_.cbegin(), assign.data_.cend(), data_.begin()); - - return *this; - } - - void Ciphertext::reserve(shared_ptr context, - parms_id_type parms_id, size_type size_capacity) - { - // Verify parameters - if (!context) - { - throw invalid_argument("invalid context"); - } - if (!context->parameters_set()) - { - throw invalid_argument("encryption parameters are not set correctly"); - } - - auto context_data_ptr = context->get_context_data(parms_id); - if (!context_data_ptr) - { - throw invalid_argument("parms_id is not valid for encryption parameters"); - } - - // Need to set parms_id first - auto &parms = context_data_ptr->parms(); - parms_id_ = context_data_ptr->parms_id(); - - reserve_internal(size_capacity, parms.poly_modulus_degree(), - safe_cast(parms.coeff_modulus().size())); - } - - void Ciphertext::reserve_internal(size_type size_capacity, - size_type poly_modulus_degree, size_type coeff_mod_count) - { - if (size_capacity < SEAL_CIPHERTEXT_SIZE_MIN || - size_capacity > SEAL_CIPHERTEXT_SIZE_MAX) - { - throw invalid_argument("invalid size_capacity"); - } - - size_type new_data_capacity = - mul_safe(size_capacity, poly_modulus_degree, coeff_mod_count); - size_type new_data_size = min(new_data_capacity, data_.size()); - - // First reserve, then resize - data_.reserve(new_data_capacity); - data_.resize(new_data_size); - - // Set the size - size_ = min(size_capacity, size_); - poly_modulus_degree_ = poly_modulus_degree; - coeff_mod_count_ = coeff_mod_count; - } - - void Ciphertext::resize(shared_ptr context, - parms_id_type parms_id, size_type size) - { - // Verify parameters - if (!context) - { - throw invalid_argument("invalid context"); - } - if (!context->parameters_set()) - { - throw invalid_argument("encryption parameters are not set correctly"); - } - - auto context_data_ptr = context->get_context_data(parms_id); - if (!context_data_ptr) - { - throw invalid_argument("parms_id is not valid for encryption parameters"); - } - - // Need to set parms_id first - auto &parms = context_data_ptr->parms(); - parms_id_ = context_data_ptr->parms_id(); - - resize_internal(size, parms.poly_modulus_degree(), - safe_cast(parms.coeff_modulus().size())); - } - - void Ciphertext::resize_internal(size_type size, - size_type poly_modulus_degree, size_type coeff_mod_count) - { - if ((size < SEAL_CIPHERTEXT_SIZE_MIN && size != 0) || - size > SEAL_CIPHERTEXT_SIZE_MAX) - { - throw invalid_argument("invalid size"); - } - - // Resize the data - size_type new_data_size = - mul_safe(size, poly_modulus_degree, coeff_mod_count); - data_.resize(new_data_size); - - // Set the size parameters - size_ = size; - poly_modulus_degree_ = poly_modulus_degree; - coeff_mod_count_ = coeff_mod_count; - } - - void Ciphertext::save(ostream &stream) const - { - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - stream.exceptions(ios_base::badbit | ios_base::failbit); - - stream.write(reinterpret_cast(&parms_id_), sizeof(parms_id_type)); - SEAL_BYTE is_ntt_form_byte = static_cast(is_ntt_form_); - stream.write(reinterpret_cast(&is_ntt_form_byte), sizeof(SEAL_BYTE)); - uint64_t size64 = safe_cast(size_); - stream.write(reinterpret_cast(&size64), sizeof(uint64_t)); - uint64_t poly_modulus_degree64 = safe_cast(poly_modulus_degree_); - stream.write(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); - uint64_t coeff_mod_count64 = safe_cast(coeff_mod_count_); - stream.write(reinterpret_cast(&coeff_mod_count64), sizeof(uint64_t)); - stream.write(reinterpret_cast(&scale_), sizeof(double)); - - // Save the data - data_.save(stream); - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - - stream.exceptions(old_except_mask); - } - - void Ciphertext::unsafe_load(istream &stream) - { - Ciphertext new_data(data_.pool()); - - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - stream.exceptions(ios_base::badbit | ios_base::failbit); - - parms_id_type parms_id{}; - stream.read(reinterpret_cast(&parms_id), sizeof(parms_id_type)); - SEAL_BYTE is_ntt_form_byte; - stream.read(reinterpret_cast(&is_ntt_form_byte), sizeof(SEAL_BYTE)); - uint64_t size64 = 0; - stream.read(reinterpret_cast(&size64), sizeof(uint64_t)); - uint64_t poly_modulus_degree64 = 0; - stream.read(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); - uint64_t coeff_mod_count64 = 0; - stream.read(reinterpret_cast(&coeff_mod_count64), sizeof(uint64_t)); - double scale = 0; - stream.read(reinterpret_cast(&scale), sizeof(double)); - - // Load the data - new_data.data_.load(stream); - if (unsigned_neq(new_data.data_.size(), - mul_safe(size64, poly_modulus_degree64, coeff_mod_count64))) - { - throw invalid_argument("ciphertext data is invalid"); - } - - // Set values - new_data.parms_id_ = parms_id; - new_data.is_ntt_form_ = (is_ntt_form_byte == SEAL_BYTE(0)) ? false : true; - new_data.size_ = safe_cast(size64); - new_data.poly_modulus_degree_ = safe_cast(poly_modulus_degree64); - new_data.coeff_mod_count_ = safe_cast(coeff_mod_count64); - new_data.scale_ = scale; - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - stream.exceptions(old_except_mask); - - swap(*this, new_data); - } -} \ No newline at end of file diff --git a/SEAL/native/src/seal/ciphertext.h b/SEAL/native/src/seal/ciphertext.h deleted file mode 100644 index 1542063..0000000 --- a/SEAL/native/src/seal/ciphertext.h +++ /dev/null @@ -1,710 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "seal/util/defines.h" -#include "seal/context.h" -#include "seal/memorymanager.h" -#include "seal/intarray.h" -#include "seal/valcheck.h" - -namespace seal -{ - /** - Class to store a ciphertext element. The data for a ciphertext consists - of two or more polynomials, which are in Microsoft SEAL stored in a CRT - form with respect to the factors of the coefficient modulus. This data - itself is not meant to be modified directly by the user, but is instead - operated on by functions in the Evaluator class. The size of the backing - array of a ciphertext depends on the encryption parameters and the size - of theciphertext (at least 2). If the degree of the poly_modulus encryption - parameter is N, and the number of primes in the coeff_modulus encryption - parameter is K, then the ciphertext backing array requires precisely - 8*N*K*size bytes of memory. A ciphertext also carries with it the - parms_id of its associated encryption parameters, which is used to check - the validity of the ciphertext for homomorphic operations and decryption. - - @par Memory Management - The size of a ciphertext refers to the number of polynomials it contains, - whereas its capacity refers to the number of polynomials that fit in the - current memory allocation. In high-performance applications unnecessary - re-allocations should be avoided by reserving enough memory for the - ciphertext to begin with either by providing the desired capacity to the - constructor as an extra argument, or by calling the reserve function at - any time. - - @par Thread Safety - In general, reading from ciphertext is thread-safe as long as no other - thread is concurrently mutating it. This is due to the underlying data - structure storing the ciphertext not being thread-safe. - - @see Plaintext for the class that stores plaintexts. - */ - class Ciphertext - { - public: - using ct_coeff_type = std::uint64_t; - - using size_type = IntArray::size_type; - - /** - Constructs an empty ciphertext allocating no memory. - - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if pool is uninitialized - */ - Ciphertext(MemoryPoolHandle pool = MemoryManager::GetPool()) : - data_(std::move(pool)) - { - } - - /** - Constructs an empty ciphertext with capacity 2. In addition to the - capacity, the allocation size is determined by the highest-level - parameters associated to the given SEALContext. - - @param[in] context The SEALContext - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::invalid_argument if pool is uninitialized - */ - explicit Ciphertext(std::shared_ptr context, - MemoryPoolHandle pool = MemoryManager::GetPool()) : - data_(std::move(pool)) - { - // Allocate memory but don't resize - reserve(std::move(context), 2); - } - - /** - Constructs an empty ciphertext with capacity 2. In addition to the - capacity, the allocation size is determined by the encryption parameters - with given parms_id. - - @param[in] context The SEALContext - @param[in] parms_id The parms_id corresponding to the encryption - parameters to be used - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::invalid_argument if parms_id is not valid for the encryption - parameters - @throws std::invalid_argument if pool is uninitialized - */ - explicit Ciphertext(std::shared_ptr context, - parms_id_type parms_id, - MemoryPoolHandle pool = MemoryManager::GetPool()) : - data_(std::move(pool)) - { - // Allocate memory but don't resize - reserve(std::move(context), parms_id, 2); - } - - /** - Constructs an empty ciphertext with given capacity. In addition to - the capacity, the allocation size is determined by the given - encryption parameters. - - @param[in] context The SEALContext - @param[in] parms_id The parms_id corresponding to the encryption - parameters to be used - @param[in] size_capacity The capacity - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::invalid_argument if parms_id is not valid for the encryption - parameters - @throws std::invalid_argument if size_capacity is less than 2 or too large - @throws std::invalid_argument if pool is uninitialized - */ - explicit Ciphertext(std::shared_ptr context, - parms_id_type parms_id, size_type size_capacity, - MemoryPoolHandle pool = MemoryManager::GetPool()) : - data_(std::move(pool)) - { - // Allocate memory but don't resize - reserve(std::move(context), parms_id, size_capacity); - } - - /** - Constructs a new ciphertext by copying a given one. - - @param[in] copy The ciphertext to copy from - */ - Ciphertext(const Ciphertext ©) = default; - - /** - Creates a new ciphertext by moving a given one. - - @param[in] source The ciphertext to move from - */ - Ciphertext(Ciphertext &&source) = default; - - /** - Constructs a new ciphertext by copying a given one. - - @param[in] copy The ciphertext to copy from - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if pool is uninitialized - */ - Ciphertext(const Ciphertext ©, MemoryPoolHandle pool) : - Ciphertext(std::move(pool)) - { - *this = copy; - } - - /** - Allocates enough memory to accommodate the backing array of a ciphertext - with given capacity. In addition to the capacity, the allocation size is - determined by the encryption parameters corresponing to the given - parms_id. - - @param[in] context The SEALContext - @param[in] parms_id The parms_id corresponding to the encryption - parameters to be used - @param[in] size_capacity The capacity - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::invalid_argument if parms_id is not valid for the encryption - parameters - @throws std::invalid_argument if size_capacity is less than 2 or too large - */ - void reserve(std::shared_ptr context, - parms_id_type parms_id, size_type size_capacity); - - /** - Allocates enough memory to accommodate the backing array of a ciphertext - with given capacity. In addition to the capacity, the allocation size is - determined by the highest-level parameters associated to the given - SEALContext. - - @param[in] context The SEALContext - @param[in] size_capacity The capacity - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::invalid_argument if size_capacity is less than 2 or too large - */ - inline void reserve(std::shared_ptr context, - size_type size_capacity) - { - // Verify parameters - if (!context) - { - throw std::invalid_argument("invalid context"); - } - auto parms_id = context->first_parms_id(); - reserve(std::move(context), parms_id, size_capacity); - } - - /** - Allocates enough memory to accommodate the backing array of a ciphertext - with given capacity. In addition to the capacity, the allocation size is - determined by the current encryption parameters. - - @param[in] size_capacity The capacity - @throws std::invalid_argument if size_capacity is less than 2 or too large - @throws std::logic_error if the encryption parameters are not - */ - inline void reserve(size_type size_capacity) - { - // Note: poly_modulus_degree_ and coeff_mod_count_ are either valid - // or coeff_mod_count_ is zero (in which case no memory is allocated). - reserve_internal(size_capacity, poly_modulus_degree_, - coeff_mod_count_); - } - - /** - Resizes the ciphertext to given size, reallocating if the capacity - of the ciphertext is too small. The ciphertext parameters are - determined by the given SEALContext and parms_id. - - This function is mainly intended for internal use and is called - automatically by functions such as Evaluator::multiply and - Evaluator::relinearize. A normal user should never have a reason - to manually resize a ciphertext. - - @param[in] context The SEALContext - @param[in] parms_id The parms_id corresponding to the encryption - parameters to be used - @param[in] size The new size - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::invalid_argument if parms_id is not valid for the encryption - parameters - @throws std::invalid_argument if size is less than 2 or too large - */ - void resize(std::shared_ptr context, - parms_id_type parms_id, size_type size); - - /** - Resizes the ciphertext to given size, reallocating if the capacity - of the ciphertext is too small. The ciphertext parameters are - determined by the highest-level parameters associated to the given - SEALContext. - - This function is mainly intended for internal use and is called - automatically by functions such as Evaluator::multiply and - Evaluator::relinearize. A normal user should never have a reason - to manually resize a ciphertext. - - @param[in] context The SEALContext - @param[in] size The new size - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::invalid_argument if size is less than 2 or too large - */ - inline void resize(std::shared_ptr context, - size_type size) - { - // Verify parameters - if (!context) - { - throw std::invalid_argument("invalid context"); - } - auto parms_id = context->first_parms_id(); - resize(std::move(context), parms_id, size); - } - - /** - Resizes the ciphertext to given size, reallocating if the capacity - of the ciphertext is too small. - - This function is mainly intended for internal use and is called - automatically by functions such as Evaluator::multiply and - Evaluator::relinearize. A normal user should never have a reason - to manually resize a ciphertext. - - @param[in] size The new size - @throws std::invalid_argument if size is less than 2 or too large - */ - inline void resize(size_type size) - { - // Note: poly_modulus_degree_ and coeff_mod_count_ are either valid - // or coeff_mod_count_ is zero (in which case no memory is allocated). - resize_internal(size, poly_modulus_degree_, coeff_mod_count_); - } - - /** - Resets the ciphertext. This function releases any memory allocated - by the ciphertext, returning it to the memory pool. It also sets all - encryption parameter specific size information to zero. - */ - inline void release() noexcept - { - parms_id_ = parms_id_zero; - is_ntt_form_ = false; - size_ = 0; - poly_modulus_degree_ = 0; - coeff_mod_count_ = 0; - scale_ = 1.0; - data_.release(); - } - - /** - Copies a given ciphertext to the current one. - - @param[in] assign The ciphertext to copy from - */ - Ciphertext &operator =(const Ciphertext &assign); - - /** - Moves a given ciphertext to the current one. - - @param[in] assign The ciphertext to move from - */ - Ciphertext &operator =(Ciphertext &&assign) = default; - - /** - Returns a pointer to the beginning of the ciphertext data. - */ - SEAL_NODISCARD inline ct_coeff_type *data() noexcept - { - return data_.begin(); - } - - /** - Returns a const pointer to the beginning of the ciphertext data. - */ - SEAL_NODISCARD inline const ct_coeff_type *data() const noexcept - { - return data_.cbegin(); - } -#ifdef SEAL_USE_MSGSL_MULTISPAN - /** - Returns the ciphertext data. - */ - SEAL_NODISCARD inline auto data_span() - -> gsl::multi_span< - ct_coeff_type, - gsl::dynamic_range, - gsl::dynamic_range, - gsl::dynamic_range> - { - return gsl::as_multi_span< - ct_coeff_type, - gsl::dynamic_range, - gsl::dynamic_range, - gsl::dynamic_range>( - data_.begin(), - util::safe_cast(size_), - util::safe_cast(coeff_mod_count_), - util::safe_cast(poly_modulus_degree_)); - } - - /** - Returns the backing array storing all of the coefficient values. - */ - SEAL_NODISCARD inline auto data_span() const - -> gsl::multi_span< - const ct_coeff_type, - gsl::dynamic_range, - gsl::dynamic_range, - gsl::dynamic_range> - { - return gsl::as_multi_span< - const ct_coeff_type, - gsl::dynamic_range, - gsl::dynamic_range, - gsl::dynamic_range>( - data_.cbegin(), - util::safe_cast(size_), - util::safe_cast(coeff_mod_count_), - util::safe_cast(poly_modulus_degree_)); - } -#endif - /** - Returns a pointer to a particular polynomial in the ciphertext - data. Note that Microsoft SEAL stores each polynomial in the ciphertext - modulo all of the K primes in the coefficient modulus. The pointer - returned by this function is to the beginning (constant coefficient) - of the first one of these K polynomials. - - @param[in] poly_index The index of the polynomial in the ciphertext - @throws std::out_of_range if poly_index is less than 0 or bigger - than the size of the ciphertext - */ - SEAL_NODISCARD inline ct_coeff_type *data( - size_type poly_index) - { - auto poly_uint64_count = util::mul_safe( - poly_modulus_degree_, coeff_mod_count_); - if (poly_uint64_count == 0) - { - return nullptr; - } - if (poly_index >= size_) - { - throw std::out_of_range("poly_index must be within [0, size)"); - } - return data_.begin() + util::safe_cast( - util::mul_safe(poly_index, poly_uint64_count)); - } - - /** - Returns a const pointer to a particular polynomial in the - ciphertext data. Note that Microsoft SEAL stores each polynomial in the - ciphertext modulo all of the K primes in the coefficient modulus. - The pointer returned by this function is to the beginning - (constant coefficient) of the first one of these K polynomials. - - @param[in] poly_index The index of the polynomial in the ciphertext - @throws std::out_of_range if poly_index is out of range - */ - SEAL_NODISCARD inline const ct_coeff_type *data( - size_type poly_index) const - { - auto poly_uint64_count = util::mul_safe( - poly_modulus_degree_, coeff_mod_count_); - if (poly_uint64_count == 0) - { - return nullptr; - } - if (poly_index >= size_) - { - throw std::out_of_range("poly_index must be within [0, size)"); - } - return data_.cbegin() + util::safe_cast( - util::mul_safe(poly_index, poly_uint64_count)); - } - - /** - Returns a reference to a polynomial coefficient at a particular - index in the ciphertext data. If the polynomial modulus has degree N, - and the number of primes in the coefficient modulus is K, then the - ciphertext contains size*N*K coefficients. Thus, the coeff_index has - a range of [0, size*N*K). - - @param[in] coeff_index The index of the coefficient - @throws std::out_of_range if coeff_index is out of range - */ - SEAL_NODISCARD inline ct_coeff_type &operator []( - size_type coeff_index) - { - return data_.at(coeff_index); - } - - /** - Returns a const reference to a polynomial coefficient at a particular - index in the ciphertext data. If the polynomial modulus has degree N, - and the number of primes in the coefficient modulus is K, then the - ciphertext contains size*N*K coefficients. Thus, the coeff_index has - a range of [0, size*N*K). - - @param[in] coeff_index The index of the coefficient - @throws std::out_of_range if coeff_index is out of range - */ - SEAL_NODISCARD inline const ct_coeff_type &operator []( - size_type coeff_index) const - { - return data_.at(coeff_index); - } - - /** - Returns the number of primes in the coefficient modulus of the - associated encryption parameters. This directly affects the - allocation size of the ciphertext. - */ - SEAL_NODISCARD inline size_type coeff_mod_count() const noexcept - { - return coeff_mod_count_; - } - - /** - Returns the degree of the polynomial modulus of the associated - encryption parameters. This directly affects the allocation size - of the ciphertext. - */ - SEAL_NODISCARD inline size_type poly_modulus_degree() const noexcept - { - return poly_modulus_degree_; - } - - /** - Returns the size of the ciphertext. - */ - SEAL_NODISCARD inline size_type size() const noexcept - { - return size_; - } - - /** - Returns the total size of the current allocation in 64-bit words. - */ - SEAL_NODISCARD inline size_type uint64_count_capacity() const noexcept - { - return data_.capacity(); - } - - /** - Returns the capacity of the allocation. This means the largest size - of the ciphertext that can be stored in the current allocation with - the current encryption parameters. - */ - SEAL_NODISCARD inline size_type size_capacity() const noexcept - { - size_type poly_uint64_count = poly_modulus_degree_ * coeff_mod_count_; - return poly_uint64_count ? - uint64_count_capacity() / poly_uint64_count : size_type(0); - } - - /** - Returns the total size of the current ciphertext in 64-bit words. - */ - SEAL_NODISCARD inline size_type uint64_count() const noexcept - { - return data_.size(); - } - - /** - Check whether the current ciphertext is transparent, i.e. does not require - a secret key to decrypt. In typical security models such transparent - ciphertexts would not be considered to be valid. Starting from the second - polynomial in the current ciphertext, this function returns true if all - following coefficients are identically zero. Otherwise, returns false. - */ - SEAL_NODISCARD inline bool is_transparent() const - { - return (!uint64_count() || - (size_ < SEAL_CIPHERTEXT_SIZE_MIN) || - std::all_of(data(1), data_.cend(), util::is_zero)); - } - - /** - Saves the ciphertext to an output stream. The output is in binary format - and not human-readable. The output stream must have the "binary" flag set. - - @param[in] stream The stream to save the ciphertext to - @throws std::exception if the ciphertext could not be written to stream - */ - void save(std::ostream &stream) const; - - void python_save(std::string &path) const - { - try - { - std::ofstream out(path, std::ofstream::binary); - this->save(out); - out.close(); - } - catch (const std::exception &) - { - throw "Ciphertext write exception"; - } - } - - /** - Loads a ciphertext from an input stream overwriting the current ciphertext. - No checking of the validity of the ciphertext data against encryption - parameters is performed. This function should not be used unless the - ciphertext comes from a fully trusted source. - - @param[in] stream The stream to load the ciphertext from - @throws std::exception if a valid ciphertext could not be read from stream - */ - void unsafe_load(std::istream &stream); - - void python_load(std::shared_ptr context, - std::string &path) - { - try - { - std::ifstream in(path, std::ifstream::binary); - this->load(context, in); - in.close(); - } - catch (const std::exception &) - { - throw "Ciphertext read exception"; - } - } - - /** - Loads a ciphertext from an input stream overwriting the current ciphertext. - The loaded ciphertext is verified to be valid for the given SEALContext. - - @param[in] context The SEALContext - @param[in] stream The stream to load the ciphertext from - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::exception if a valid ciphertext could not be read from stream - @throws std::invalid_argument if the loaded ciphertext is invalid for the - context - */ - inline void load(std::shared_ptr context, - std::istream &stream) - { - Ciphertext new_data(pool()); - new_data.unsafe_load(stream); - if (!is_valid_for(new_data, std::move(context))) - { - throw std::invalid_argument("ciphertext data is invalid"); - } - std::swap(*this, new_data); - } - - /** - Returns whether the ciphertext is in NTT form. - */ - SEAL_NODISCARD inline bool is_ntt_form() const noexcept - { - return is_ntt_form_; - } - - /** - Returns whether the ciphertext is in NTT form. - */ - SEAL_NODISCARD inline bool &is_ntt_form() noexcept - { - return is_ntt_form_; - } - - /** - Returns a reference to parms_id. - - @see EncryptionParameters for more information about parms_id. - */ - SEAL_NODISCARD inline auto &parms_id() noexcept - { - return parms_id_; - } - - /** - Returns a const reference to parms_id. - - @see EncryptionParameters for more information about parms_id. - */ - SEAL_NODISCARD inline auto &parms_id() const noexcept - { - return parms_id_; - } - - /** - Returns a reference to the scale. This is only needed when using the - CKKS encryption scheme. The user should have little or no reason to ever - change the scale by hand. - */ - SEAL_NODISCARD inline auto &scale() noexcept - { - return scale_; - } - - /** - Returns a constant reference to the scale. This is only needed when - using the CKKS encryption scheme. - */ - SEAL_NODISCARD inline auto &scale() const noexcept - { - return scale_; - } - - /** - Set the scale. - */ - inline void set_scale(double scale) - { - scale_ = scale; - } - - /** - Returns the currently used MemoryPoolHandle. - */ - SEAL_NODISCARD inline MemoryPoolHandle pool() const noexcept - { - return data_.pool(); - } - - /** - Enables access to private members of seal::Ciphertext for .NET wrapper. - */ - struct CiphertextPrivateHelper; - - private: - void reserve_internal(size_type size_capacity, - size_type poly_modulus_degree, size_type coeff_mod_count); - - void resize_internal(size_type size, size_type poly_modulus_degree, - size_type coeff_mod_count); - - parms_id_type parms_id_ = parms_id_zero; - - bool is_ntt_form_ = false; - - size_type size_ = 0; - - size_type poly_modulus_degree_ = 0; - - size_type coeff_mod_count_ = 0; - - double scale_ = 1.0; - - IntArray data_; - }; -} diff --git a/SEAL/native/src/seal/ckks.cpp b/SEAL/native/src/seal/ckks.cpp deleted file mode 100644 index d3f387b..0000000 --- a/SEAL/native/src/seal/ckks.cpp +++ /dev/null @@ -1,273 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include -#include -#include -#include "seal/ckks.h" - -using namespace std; -using namespace seal::util; - -namespace seal -{ - // For C++14 compatibility need to define static constexpr - // member variables with no initialization here. - constexpr double CKKSEncoder::PI_; - - CKKSEncoder::CKKSEncoder(shared_ptr context) : - context_(context) - { - // Verify parameters - if (!context_) - { - throw invalid_argument("invalid context"); - } - if (!context_->parameters_set()) - { - throw invalid_argument("encryption parameters are not set correctly"); - } - - auto &context_data = *context_->first_context_data(); - if (context_data.parms().scheme() != scheme_type::CKKS) - { - throw invalid_argument("unsupported scheme"); - } - - size_t coeff_count = context_data.parms().poly_modulus_degree(); - slots_ = coeff_count >> 1; - int logn = get_power_of_two(coeff_count); - - matrix_reps_index_map_ = allocate_uint(coeff_count, pool_); - - // Copy from the matrix to the value vectors - uint64_t gen = 3; - uint64_t pos = 1; - uint64_t m = coeff_count << 1; - for (size_t i = 0; i < slots_; i++) - { - // Position in normal bit order - uint64_t index1 = (pos - 1) >> 1; - uint64_t index2 = (m - pos - 1) >> 1; - - // Set the bit-reversed locations - matrix_reps_index_map_[i] = reverse_bits(index1, logn); - matrix_reps_index_map_[slots_ | i] = reverse_bits(index2, logn); - - // Next primitive root - pos *= gen; - pos &= (m - 1); - } - - roots_ = allocate>(coeff_count, pool_); - inv_roots_ = allocate>(coeff_count, pool_); - double psi_arg = 2 * PI_ / static_cast(m); - for (size_t i = 0; i < coeff_count; i++) - { - roots_[i] = polar(1.0, psi_arg * reverse_bits(i, logn)); - inv_roots_[i] = 1.0 / roots_[i]; - } - } - - void CKKSEncoder::encode_internal(double value, parms_id_type parms_id, - double scale, Plaintext &destination, MemoryPoolHandle pool) - { - // Verify parameters. - auto context_data_ptr = context_->get_context_data(parms_id); - if (!context_data_ptr) - { - throw invalid_argument("parms_id is not valid for encryption parameters"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - auto &context_data = *context_data_ptr; - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t coeff_count = parms.poly_modulus_degree(); - - // Quick sanity check - if (!product_fits_in(coeff_mod_count, coeff_count)) - { - throw logic_error("invalid parameters"); - } - - // Check that scale is positive and not too large - if (scale <= 0 || (static_cast(log2(scale)) >= - context_data.total_coeff_modulus_bit_count())) - { - throw invalid_argument("scale out of bounds"); - } - - // Compute the scaled value - value *= scale; - - int coeff_bit_count = static_cast(log2(fabs(value))) + 2; - if (coeff_bit_count >= context_data.total_coeff_modulus_bit_count()) - { - throw invalid_argument("encoded value is too large"); - } - - double two_pow_64 = pow(2.0, 64); - - // Resize destination to appropriate size - // Need to first set parms_id to zero, otherwise resize - // will throw an exception. - destination.parms_id() = parms_id_zero; - destination.resize(coeff_count * coeff_mod_count); - - double coeffd = round(value); - bool is_negative = signbit(coeffd); - coeffd = fabs(coeffd); - - // Use faster decomposition methods when possible - if (coeff_bit_count <= 64) - { - uint64_t coeffu = static_cast(fabs(coeffd)); - - if (is_negative) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - fill_n(destination.data() + (j * coeff_count), coeff_count, - negate_uint_mod(coeffu % coeff_modulus[j].value(), - coeff_modulus[j])); - } - } - else - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - fill_n(destination.data() + (j * coeff_count), coeff_count, - coeffu % coeff_modulus[j].value()); - } - } - } - else if (coeff_bit_count <= 128) - { - uint64_t coeffu[2]{ - static_cast(fmod(coeffd, two_pow_64)), - static_cast(coeffd / two_pow_64) }; - - if (is_negative) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - fill_n(destination.data() + (j * coeff_count), coeff_count, - negate_uint_mod(barrett_reduce_128( - coeffu, coeff_modulus[j]), coeff_modulus[j])); - } - } - else - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - fill_n(destination.data() + (j * coeff_count), coeff_count, - barrett_reduce_128(coeffu, coeff_modulus[j])); - } - } - } - else - { - // Slow case - auto coeffu(allocate_uint(coeff_mod_count, pool)); - auto decomp_coeffu(allocate_uint(coeff_mod_count, pool)); - - // We are at this point guaranteed to fit in the allocated space - set_zero_uint(coeff_mod_count, coeffu.get()); - auto coeffu_ptr = coeffu.get(); - while (coeffd >= 1) - { - *coeffu_ptr++ = static_cast(fmod(coeffd, two_pow_64)); - coeffd /= two_pow_64; - } - - // Next decompose this coefficient - decompose_single_coeff(context_data, coeffu.get(), decomp_coeffu.get(), pool); - - // Finally replace the sign if necessary - if (is_negative) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - fill_n(destination.data() + (j * coeff_count), coeff_count, - negate_uint_mod(decomp_coeffu[j], coeff_modulus[j])); - } - } - else - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - fill_n(destination.data() + (j * coeff_count), coeff_count, - decomp_coeffu[j]); - } - } - } - - destination.parms_id() = parms_id; - destination.scale() = scale; - } - - void CKKSEncoder::encode_internal(int64_t value, parms_id_type parms_id, - Plaintext &destination) - { - // Verify parameters. - auto context_data_ptr = context_->get_context_data(parms_id); - if (!context_data_ptr) - { - throw invalid_argument("parms_id is not valid for encryption parameters"); - } - - auto &context_data = *context_data_ptr; - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t coeff_count = parms.poly_modulus_degree(); - - // Quick sanity check - if (!product_fits_in(coeff_mod_count, coeff_count)) - { - throw logic_error("invalid parameters"); - } - - int coeff_bit_count = get_significant_bit_count( - static_cast(llabs(value))) + 2; - if (coeff_bit_count >= context_data.total_coeff_modulus_bit_count()) - { - throw invalid_argument("encoded value is too large"); - } - - // Resize destination to appropriate size - // Need to first set parms_id to zero, otherwise resize - // will throw an exception. - destination.parms_id() = parms_id_zero; - destination.resize(coeff_count * coeff_mod_count); - - if (value < 0) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - uint64_t tmp = static_cast(value); - tmp += coeff_modulus[j].value(); - tmp %= coeff_modulus[j].value(); - fill_n(destination.data() + (j * coeff_count), coeff_count, tmp); - } - } - else - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - uint64_t tmp = static_cast(value); - tmp %= coeff_modulus[j].value(); - fill_n(destination.data() + (j * coeff_count), coeff_count, tmp); - } - } - - destination.parms_id() = parms_id; - destination.scale() = 1.0; - } -} diff --git a/SEAL/native/src/seal/ckks.h b/SEAL/native/src/seal/ckks.h deleted file mode 100644 index 168e076..0000000 --- a/SEAL/native/src/seal/ckks.h +++ /dev/null @@ -1,764 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "seal/plaintext.h" -#include "seal/context.h" -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarithsmallmod.h" - -namespace seal -{ - template::value || - std::is_same>::value>> - SEAL_NODISCARD inline T_out from_complex(std::complex in); - - template<> - SEAL_NODISCARD inline double from_complex(std::complex in) - { - return in.real(); - } - - template<> - SEAL_NODISCARD inline std::complex from_complex(std::complex in) - { - return in; - } - - /** - Provides functionality for encoding vectors of complex or real numbers into - plaintext polynomials to be encrypted and computed on using the CKKS scheme. - If the polynomial modulus degree is N, then CKKSEncoder converts vectors of - N/2 complex numbers into plaintext elements. Homomorphic operations performed - on such encrypted vectors are applied coefficient (slot-)wise, enabling - powerful SIMD functionality for computations that are vectorizable. This - functionality is often called "batching" in the homomorphic encryption - literature. - - @par Mathematical Background - Mathematically speaking, if the polynomial modulus is X^N+1, N is a power of - two, the CKKSEncoder implements an approximation of the canonical embedding - of the ring of integers Z[X]/(X^N+1) into C^(N/2), where C denotes the complex - numbers. The Galois group of the extension is (Z/2NZ)* ~= Z/2Z x Z/(N/2) - whose action on the primitive roots of unity modulo coeff_modulus is easy to - describe. Since the batching slots correspond 1-to-1 to the primitive roots - of unity, applying Galois automorphisms on the plaintext acts by permuting - the slots. By applying generators of the two cyclic subgroups of the Galois - group, we can effectively enable cyclic rotations and complex conjugations - of the encrypted complex vectors. - */ - class SEAL_NODISCARD CKKSEncoder - { - public: - /** - Creates a CKKSEncoder instance initialized with the specified SEALContext. - - @param[in] context The SEALContext - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::invalid_argument if scheme is not scheme_type::CKKS - */ - CKKSEncoder(std::shared_ptr context); - - /** - Encodes double-precision floating-point real or complex numbers into - a plaintext polynomial. Dynamic memory allocations in the process are - allocated from the memory pool pointed to by the given MemoryPoolHandle. - - @tparam T Vector value type (double or std::complex) - @param[in] values The vector of double-precision floating-point numbers - (of type T) to encode - @param[in] parms_id parms_id determining the encryption parameters to - be used by the result plaintext - @param[in] scale Scaling parameter defining encoding precision - @param[out] destination The plaintext polynomial to overwrite with the - result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if values has invalid size - @throws std::invalid_argument if parms_id is not valid for the encryption - parameters - @throws std::invalid_argument if scale is not strictly positive - @throws std::invalid_argument if encoding is too large for the encryption - parameters - @throws std::invalid_argument if pool is uninitialized - */ - template::value || - std::is_same>::value>> - inline void encode(const std::vector &values, - parms_id_type parms_id, double scale, Plaintext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - encode_internal(values, parms_id, scale, destination, std::move(pool)); - } - - /** - Encodes double-precision floating-point real or complex numbers into - a plaintext polynomial. The encryption parameters used are the top - level parameters for the given context. Dynamic memory allocations in - the process are allocated from the memory pool pointed to by the given - MemoryPoolHandle. - - @tparam T Vector value type (double or std::complex) - @param[in] values The vector of double-precision floating-point numbers - (of type T) to encode - @param[in] scale Scaling parameter defining encoding precision - @param[out] destination The plaintext polynomial to overwrite with the - result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if values has invalid size - @throws std::invalid_argument if scale is not strictly positive - @throws std::invalid_argument if encoding is too large for the encryption - parameters - @throws std::invalid_argument if pool is uninitialized - */ - template::value || - std::is_same>::value>> - inline void encode(const std::vector &values, - double scale, Plaintext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - encode(values, context_->first_parms_id(), scale, - destination, std::move(pool)); - } - - /** - Encodes a double-precision floating-point number into a plaintext - polynomial. Dynamic memory allocations in the process are allocated from - the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] value The double-precision floating-point number to encode - @param[in] parms_id parms_id determining the encryption parameters to be - used by the result plaintext - @param[in] scale Scaling parameter defining encoding precision - @param[out] destination The plaintext polynomial to overwrite with the - result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if parms_id is not valid for the encryption - parameters - @throws std::invalid_argument if scale is not strictly positive - @throws std::invalid_argument if encoding is too large for the encryption - parameters - @throws std::invalid_argument if pool is uninitialized - */ - inline void encode(double value, parms_id_type parms_id, - double scale, Plaintext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - encode_internal(value, parms_id, scale, destination, std::move(pool)); - } - - /** - Encodes a double-precision floating-point number into a plaintext - polynomial. The encryption parameters used are the top level parameters - for the given context. Dynamic memory allocations in the process are - allocated from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] value The double-precision floating-point number to encode - @param[in] scale Scaling parameter defining encoding precision - @param[out] destination The plaintext polynomial to overwrite with the - result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if scale is not strictly positive - @throws std::invalid_argument if encoding is too large for the encryption - parameters - @throws std::invalid_argument if pool is uninitialized - */ - inline void encode(double value, - double scale, Plaintext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - encode(value, context_->first_parms_id(), scale, - destination, std::move(pool)); - } - - /** - Encodes a double-precision complex number into a plaintext polynomial. - Dynamic memory allocations in the process are allocated from the memory - pool pointed to by the given MemoryPoolHandle. - - @param[in] value The double-precision complex number to encode - @param[in] parms_id parms_id determining the encryption parameters to be - used by the result plaintext - @param[in] scale Scaling parameter defining encoding precision - @param[out] destination The plaintext polynomial to overwrite with the - result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if parms_id is not valid for the encryption - parameters - @throws std::invalid_argument if scale is not strictly positive - @throws std::invalid_argument if encoding is too large for the encryption - parameters - @throws std::invalid_argument if pool is uninitialized - */ - inline void encode(std::complex value, - parms_id_type parms_id, double scale, Plaintext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - encode_internal(value, parms_id, scale, destination, std::move(pool)); - } - - /** - Encodes a double-precision complex number into a plaintext polynomial. - The encryption parameters used are the top level parameters for the - given context. Dynamic memory allocations in the process are allocated - from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] value The double-precision complex number to encode - @param[in] scale Scaling parameter defining encoding precision - @param[out] destination The plaintext polynomial to overwrite with the - result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if scale is not strictly positive - @throws std::invalid_argument if encoding is too large for the encryption - parameters - @throws std::invalid_argument if pool is uninitialized - */ - inline void encode(std::complex value, - double scale, Plaintext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - encode(value, context_->first_parms_id(), scale, - destination, std::move(pool)); - } - - /** - Encodes an integer number into a plaintext polynomial without any scaling. - - @param[in] value The integer number to encode - @param[in] parms_id parms_id determining the encryption parameters to be - used by the result plaintext - @param[out] destination The plaintext polynomial to overwrite with the - result - @throws std::invalid_argument if parms_id is not valid for the encryption - parameters - */ - inline void encode(std::int64_t value, - parms_id_type parms_id, Plaintext &destination) - { - encode_internal(value, parms_id, destination); - } - - /** - Encodes an integer number into a plaintext polynomial without any scaling. - The encryption parameters used are the top level parameters for the given - context. - - @param[in] value The integer number to encode - @param[out] destination The plaintext polynomial to overwrite with the - result - */ - inline void encode(std::int64_t value, Plaintext &destination) - { - encode(value, context_->first_parms_id(), destination); - } - - /** - Decodes a plaintext polynomial into double-precision floating-point - real or complex numbers. Dynamic memory allocations in the process are - allocated from the memory pool pointed to by the given MemoryPoolHandle. - - @tparam T Vector value type (double or std::complex) - @param[in] plain The plaintext to decode - @param[out] destination The vector to be overwritten with the values in - the slots - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is not in NTT form or is invalid - for the encryption parameters - @throws std::invalid_argument if pool is uninitialized - */ - template::value || - std::is_same>::value>> - inline void decode(const Plaintext &plain, std::vector &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - decode_internal(plain, destination, std::move(pool)); - } - - /** - Returns the number of complex numbers encoded. - */ - SEAL_NODISCARD inline std::size_t slot_count() const noexcept - { - return slots_; - } - - private: - // This is the same function as in evaluator.h - inline void decompose_single_coeff( - const SEALContext::ContextData &context_data, - const std::uint64_t *value, std::uint64_t *destination, - util::MemoryPool &pool) - { - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - std::size_t coeff_mod_count = coeff_modulus.size(); -#ifdef SEAL_DEBUG - if (value == nullptr) - { - throw std::invalid_argument("value cannot be null"); - } - if (destination == nullptr) - { - throw std::invalid_argument("destination cannot be null"); - } - if (destination == value) - { - throw std::invalid_argument("value cannot be the same as destination"); - } -#endif - if (coeff_mod_count == 1) - { - util::set_uint_uint(value, coeff_mod_count, destination); - return; - } - - auto value_copy(util::allocate_uint(coeff_mod_count, pool)); - for (std::size_t j = 0; j < coeff_mod_count; j++) - { - //destination[j] = util::modulo_uint( - // value, coeff_mod_count, coeff_modulus_[j], pool); - - // Manually inlined for efficiency - // Make a fresh copy of value - util::set_uint_uint(value, coeff_mod_count, value_copy.get()); - - // Starting from the top, reduce always 128-bit blocks - for (std::size_t k = coeff_mod_count - 1; k--; ) - { - value_copy[k] = util::barrett_reduce_128( - value_copy.get() + k, coeff_modulus[j]); - } - destination[j] = value_copy[0]; - } - } - - template::value || - std::is_same>::value>> - void encode_internal(const std::vector &values, - parms_id_type parms_id, double scale, Plaintext &destination, - MemoryPoolHandle pool) - { - // Verify parameters. - auto context_data_ptr = context_->get_context_data(parms_id); - if (!context_data_ptr) - { - throw std::invalid_argument("parms_id is not valid for encryption parameters"); - } - if (values.size() > slots_) - { - throw std::invalid_argument("values has invalid size"); - } - if (!pool) - { - throw std::invalid_argument("pool is uninitialized"); - } - - auto &context_data = *context_data_ptr; - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - std::size_t coeff_mod_count = coeff_modulus.size(); - std::size_t coeff_count = parms.poly_modulus_degree(); - - // Quick sanity check - if (!util::product_fits_in(coeff_mod_count, coeff_count)) - { - throw std::logic_error("invalid parameters"); - } - - // Check that scale is positive and not too large - if (scale <= 0 || (static_cast(log2(scale)) + 1 >= - context_data.total_coeff_modulus_bit_count())) - { - throw std::invalid_argument("scale out of bounds"); - } - - auto &small_ntt_tables = context_data.small_ntt_tables(); - - // input_size is guaranteed to be no bigger than slots_ - std::size_t input_size = values.size(); - std::size_t n = util::mul_safe(slots_, std::size_t(2)); - - auto conj_values = util::allocate>(n, pool, 0); - for (std::size_t i = 0; i < input_size; i++) - { - conj_values[matrix_reps_index_map_[i]] = values[i]; - conj_values[matrix_reps_index_map_[i + slots_]] = std::conj(values[i]); - } - - int logn = util::get_power_of_two(n); - std::size_t tt = 1; - for (int i = 0; i < logn; i++) - { - std::size_t mm = std::size_t(1) << (logn - i); - std::size_t k_start = 0; - std::size_t h = mm / 2; - - for (std::size_t j = 0; j < h; j++) - { - std::size_t k_end = k_start + tt; - auto s = inv_roots_[h + j]; - - for (std::size_t k = k_start; k < k_end; k++) - { - auto u = conj_values[k]; - auto v = conj_values[k + tt]; - conj_values[k] = u + v; - conj_values[k + tt] = (u - v) * s; - } - - k_start += 2 * tt; - } - tt *= 2; - } - - double n_inv = double(1.0) / static_cast(n); - - // Put the scale in at this point - n_inv *= scale; - - int max_coeff_bit_count = 1; - for (std::size_t i = 0; i < n; i++) - { - // Multiply by scale and n_inv (see above) - conj_values[i] *= n_inv; - - // Verify that the values are not too large to fit in coeff_modulus - // Note that we have an extra + 1 for the sign bit - max_coeff_bit_count = std::max(max_coeff_bit_count, - static_cast(std::log2(std::fabs(conj_values[i].real()))) + 2); - } - if (max_coeff_bit_count >= context_data.total_coeff_modulus_bit_count()) - { - throw std::invalid_argument("encoded values are too large"); - } - - double two_pow_64 = std::pow(2.0, 64); - - // Resize destination to appropriate size - // Need to first set parms_id to zero, otherwise resize - // will throw an exception. - destination.parms_id() = parms_id_zero; - destination.resize(util::mul_safe(coeff_count, coeff_mod_count)); - - // Use faster decomposition methods when possible - if (max_coeff_bit_count <= 64) - { - for (std::size_t i = 0; i < n; i++) - { - double coeffd = std::round(conj_values[i].real()); - bool is_negative = std::signbit(coeffd); - - std::uint64_t coeffu = - static_cast(std::fabs(coeffd)); - - if (is_negative) - { - for (std::size_t j = 0; j < coeff_mod_count; j++) - { - destination[i + (j * coeff_count)] = util::negate_uint_mod( - coeffu % coeff_modulus[j].value(), coeff_modulus[j]); - } - } - else - { - for (std::size_t j = 0; j < coeff_mod_count; j++) - { - destination[i + (j * coeff_count)] = - coeffu % coeff_modulus[j].value(); - } - } - } - } - else if (max_coeff_bit_count <= 128) - { - for (std::size_t i = 0; i < n; i++) - { - double coeffd = std::round(conj_values[i].real()); - bool is_negative = std::signbit(coeffd); - coeffd = std::fabs(coeffd); - - std::uint64_t coeffu[2]{ - static_cast(std::fmod(coeffd, two_pow_64)), - static_cast(coeffd / two_pow_64) }; - - if (is_negative) - { - for (std::size_t j = 0; j < coeff_mod_count; j++) - { - destination[i + (j * coeff_count)] = - util::negate_uint_mod(util::barrett_reduce_128( - coeffu, coeff_modulus[j]), coeff_modulus[j]); - } - } - else - { - for (std::size_t j = 0; j < coeff_mod_count; j++) - { - destination[i + (j * coeff_count)] = - util::barrett_reduce_128(coeffu, coeff_modulus[j]); - } - } - } - } - else - { - // Slow case - auto coeffu(util::allocate_uint(coeff_mod_count, pool)); - auto decomp_coeffu(util::allocate_uint(coeff_mod_count, pool)); - for (std::size_t i = 0; i < n; i++) - { - double coeffd = std::round(conj_values[i].real()); - bool is_negative = std::signbit(coeffd); - coeffd = std::fabs(coeffd); - - // We are at this point guaranteed to fit in the allocated space - util::set_zero_uint(coeff_mod_count, coeffu.get()); - auto coeffu_ptr = coeffu.get(); - while (coeffd >= 1) - { - *coeffu_ptr++ = static_cast( - std::fmod(coeffd, two_pow_64)); - coeffd /= two_pow_64; - } - - // Next decompose this coefficient - decompose_single_coeff(context_data, coeffu.get(), - decomp_coeffu.get(), pool); - - // Finally replace the sign if necessary - if (is_negative) - { - for (std::size_t j = 0; j < coeff_mod_count; j++) - { - destination[i + (j * coeff_count)] = - util::negate_uint_mod(decomp_coeffu[j], coeff_modulus[j]); - } - } - else - { - for (std::size_t j = 0; j < coeff_mod_count; j++) - { - destination[i + (j * coeff_count)] = decomp_coeffu[j]; - } - } - } - } - - // Transform to NTT domain - for (std::size_t i = 0; i < coeff_mod_count; i++) - { - util::ntt_negacyclic_harvey( - destination.data(i * coeff_count), small_ntt_tables[i]); - } - - destination.parms_id() = parms_id; - destination.scale() = scale; - } - - template::value || - std::is_same>::value>> - void decode_internal(const Plaintext &plain, std::vector &destination, - MemoryPoolHandle pool) - { - // Verify parameters. - if (!is_valid_for(plain, context_)) - { - throw std::invalid_argument("plain is not valid for encryption parameters"); - } - if (!plain.is_ntt_form()) - { - throw std::invalid_argument("plain is not in NTT form"); - } - if (!pool) - { - throw std::invalid_argument("pool is uninitialized"); - } - - auto context_data_ptr = context_->get_context_data(plain.parms_id()); - auto &parms = context_data_ptr->parms(); - auto &coeff_modulus = parms.coeff_modulus(); - std::size_t coeff_mod_count = coeff_modulus.size(); - std::size_t coeff_count = parms.poly_modulus_degree(); - std::size_t rns_poly_uint64_count = - util::mul_safe(coeff_count, coeff_mod_count); - - auto &small_ntt_tables = context_data_ptr->small_ntt_tables(); - - // Check that scale is positive and not too large - if (plain.scale() <= 0 || (static_cast(log2(plain.scale())) >= - context_data_ptr->total_coeff_modulus_bit_count())) - { - throw std::invalid_argument("scale out of bounds"); - } - - auto decryption_modulus = context_data_ptr->total_coeff_modulus(); - auto upper_half_threshold = context_data_ptr->upper_half_threshold(); - - auto &inv_coeff_products_mod_coeff_array = - context_data_ptr->base_converter()->get_inv_coeff_mod_coeff_array(); - auto coeff_products_array = - context_data_ptr->base_converter()->get_coeff_products_array(); - - int logn = util::get_power_of_two(coeff_count); - - // Quick sanity check - if ((logn < 0) || (coeff_count < SEAL_POLY_MOD_DEGREE_MIN) || - (coeff_count > SEAL_POLY_MOD_DEGREE_MAX)) - { - throw std::logic_error("invalid parameters"); - } - - double inv_scale = double(1.0) / plain.scale(); - - // Create mutable copy of input - auto plain_copy = util::allocate_uint(rns_poly_uint64_count, pool); - util::set_uint_uint(plain.data(), rns_poly_uint64_count, plain_copy.get()); - - // Array to keep number bigger than std::uint64_t - auto temp(util::allocate_uint(coeff_mod_count, pool)); - - // destination mod q - auto wide_tmp_dest(util::allocate_zero_uint(rns_poly_uint64_count, pool)); - - // Transform each polynomial from NTT domain - for (std::size_t i = 0; i < coeff_mod_count; i++) - { - util::inverse_ntt_negacyclic_harvey( - plain_copy.get() + (i * coeff_count), small_ntt_tables[i]); - } - - auto res = util::allocate>(coeff_count, pool); - - double two_pow_64 = std::pow(2.0, 64); - for (std::size_t i = 0; i < coeff_count; i++) - { - for (std::size_t j = 0; j < coeff_mod_count; j++) - { - std::uint64_t tmp = util::multiply_uint_uint_mod( - plain_copy[(j * coeff_count) + i], - inv_coeff_products_mod_coeff_array[j], // (qi/q * plain[i]) mod qi - coeff_modulus[j]); - util::multiply_uint_uint64( - coeff_products_array + (j * coeff_mod_count), - coeff_mod_count, tmp, coeff_mod_count, temp.get()); - util::add_uint_uint_mod(temp.get(), - wide_tmp_dest.get() + (i * coeff_mod_count), - decryption_modulus, coeff_mod_count, - wide_tmp_dest.get() + (i * coeff_mod_count)); - } - - res[i] = 0.0; - if (util::is_greater_than_or_equal_uint_uint( - wide_tmp_dest.get() + (i * coeff_mod_count), - upper_half_threshold, coeff_mod_count)) - { - double scaled_two_pow_64 = inv_scale; - for (std::size_t j = 0; j < coeff_mod_count; - j++, scaled_two_pow_64 *= two_pow_64) - { - if (wide_tmp_dest[i * coeff_mod_count + j] > decryption_modulus[j]) - { - auto diff = wide_tmp_dest[i * coeff_mod_count + j] - decryption_modulus[j]; - res[i] += diff ? - static_cast(diff) * scaled_two_pow_64 : 0.0; - } - else - { - auto diff = decryption_modulus[j] - wide_tmp_dest[i * coeff_mod_count + j]; - res[i] -= diff ? - static_cast(diff) * scaled_two_pow_64 : 0.0; - } - } - } - else - { - double scaled_two_pow_64 = inv_scale; - for (std::size_t j = 0; j < coeff_mod_count; - j++, scaled_two_pow_64 *= two_pow_64) - { - auto curr_coeff = wide_tmp_dest[i * coeff_mod_count + j]; - res[i] += curr_coeff ? - static_cast(curr_coeff) * scaled_two_pow_64 : 0.0; - } - } - - // Scaling instead incorporated above; this can help in cases - // where otherwise pow(two_pow_64, j) would overflow due to very - // large coeff_mod_count and very large scale - // res[i] = res_accum * inv_scale; - } - - std::size_t tt = coeff_count; - for (int i = 0; i < logn; i++) - { - std::size_t mm = std::size_t(1) << i; - tt >>= 1; - - for (std::size_t j = 0; j < mm; j++) - { - std::size_t j1 = 2 * j * tt; - std::size_t j2 = j1 + tt - 1; - auto s = roots_[mm + j]; - - for (std::size_t k = j1; k < j2 + 1; k++) - { - auto u = res[k]; - auto v = res[k + tt] * s; - res[k] = u + v; - res[k + tt] = u - v; - } - } - } - - destination.clear(); - destination.reserve(slots_); - for (std::size_t i = 0; i < slots_; i++) - { - destination.emplace_back( - from_complex(res[matrix_reps_index_map_[i]])); - } - } - - void encode_internal(double value, parms_id_type parms_id, - double scale, Plaintext &destination, MemoryPoolHandle pool); - - inline void encode_internal(std::complex value, - parms_id_type parms_id, double scale, Plaintext &destination, - MemoryPoolHandle pool) - { - encode_internal(std::vector>(1, value), - parms_id, scale, destination, std::move(pool)); - } - - void encode_internal(std::int64_t value, - parms_id_type parms_id, Plaintext &destination); - - MemoryPoolHandle pool_ = MemoryManager::GetPool(); - - static constexpr double PI_ = 3.1415926535897932384626433832795028842; - - std::shared_ptr context_{ nullptr }; - - std::size_t slots_; - - util::Pointer> roots_; - - util::Pointer> inv_roots_; - - util::Pointer matrix_reps_index_map_; - }; -} diff --git a/SEAL/native/src/seal/context.cpp b/SEAL/native/src/seal/context.cpp deleted file mode 100644 index b46d5b1..0000000 --- a/SEAL/native/src/seal/context.cpp +++ /dev/null @@ -1,404 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/context.h" -#include "seal/util/pointer.h" -#include "seal/util/polycore.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithsmallmod.h" -#include "seal/util/numth.h" -#include -#include - -using namespace std; -using namespace seal::util; - -namespace seal -{ - SEALContext::ContextData SEALContext::validate(EncryptionParameters parms) - { - ContextData context_data(parms, pool_); - context_data.qualifiers_.parameters_set = true; - - auto &coeff_modulus = parms.coeff_modulus(); - auto &plain_modulus = parms.plain_modulus(); - - // The number of coeff moduli is restricted to 62 for lazy reductions - // in baseconverter.cpp to work - if (coeff_modulus.size() > SEAL_COEFF_MOD_COUNT_MAX || - coeff_modulus.size() < SEAL_COEFF_MOD_COUNT_MIN) - { - context_data.qualifiers_.parameters_set = false; - return context_data; - } - - size_t coeff_mod_count = coeff_modulus.size(); - for (size_t i = 0; i < coeff_mod_count; i++) - { - // Check coefficient moduli bounds - if (coeff_modulus[i].value() >> SEAL_USER_MOD_BIT_COUNT_MAX || - !(coeff_modulus[i].value() >> (SEAL_USER_MOD_BIT_COUNT_MIN - 1))) - { - context_data.qualifiers_.parameters_set = false; - return context_data; - } - - // Check that all coeff moduli are pairwise relatively prime - for (size_t j = 0; j < i; j++) - { - if (gcd(coeff_modulus[i].value(), coeff_modulus[j].value()) > 1) - { - context_data.qualifiers_.parameters_set = false; - return context_data; - } - } - } - - // Compute the product of all coeff moduli - context_data.total_coeff_modulus_ = allocate_uint(coeff_mod_count, pool_); - auto temp(allocate_uint(coeff_mod_count, pool_)); - set_uint(1, coeff_mod_count, context_data.total_coeff_modulus_.get()); - for (size_t i = 0; i < coeff_mod_count; i++) - { - multiply_uint_uint64(context_data.total_coeff_modulus_.get(), - coeff_mod_count, coeff_modulus[i].value(), coeff_mod_count, - temp.get()); - set_uint_uint(temp.get(), coeff_mod_count, - context_data.total_coeff_modulus_.get()); - } - context_data.total_coeff_modulus_bit_count_ = get_significant_bit_count_uint( - context_data.total_coeff_modulus_.get(), coeff_mod_count); - - // Check polynomial modulus degree and create poly_modulus - size_t poly_modulus_degree = parms.poly_modulus_degree(); - int coeff_count_power = get_power_of_two(poly_modulus_degree); - if (poly_modulus_degree < SEAL_POLY_MOD_DEGREE_MIN || - poly_modulus_degree > SEAL_POLY_MOD_DEGREE_MAX || - coeff_count_power < 0) - { - // Parameters are not valid - context_data.qualifiers_.parameters_set = false; - return context_data; - } - - // Quick sanity check - if (!product_fits_in(coeff_mod_count, poly_modulus_degree)) - { - throw logic_error("invalid parameters"); - } - - // Polynomial modulus X^(2^k) + 1 is guaranteed at this point - context_data.qualifiers_.using_fft = true; - - // Assume parameters satisfy desired security level - context_data.qualifiers_.sec_level = sec_level_; - - // Check if the parameters are secure according to HomomorphicEncryption.org - // security standard - if (context_data.total_coeff_modulus_bit_count_ > - CoeffModulus::MaxBitCount(poly_modulus_degree, sec_level_)) - { - // Not secure according to HomomorphicEncryption.org security standard - context_data.qualifiers_.sec_level = sec_level_type::none; - if (sec_level_ != sec_level_type::none) - { - // Parameters are not valid - context_data.qualifiers_.parameters_set = false; - return context_data; - } - } - - // Can we use NTT with coeff_modulus? - context_data.qualifiers_.using_ntt = true; - context_data.small_ntt_tables_ = - allocate(coeff_mod_count, pool_, pool_); - for (size_t i = 0; i < coeff_mod_count; i++) - { - if (!context_data.small_ntt_tables_[i].generate(coeff_count_power, - coeff_modulus[i])) - { - // Parameters are not valid - context_data.qualifiers_.using_ntt = false; - context_data.qualifiers_.parameters_set = false; - return context_data; - } - } - - if (parms.scheme() == scheme_type::BFV) - { - // Plain modulus must be at least 2 and at most 60 bits - if (plain_modulus.value() >> SEAL_PLAIN_MOD_MAX || - !(plain_modulus.value() >> (SEAL_PLAIN_MOD_MIN - 1))) - { - context_data.qualifiers_.parameters_set = false; - return context_data; - } - - // Check that all coeff moduli are relatively prime to plain_modulus - for (size_t i = 0; i < coeff_mod_count; i++) - { - if (gcd(coeff_modulus[i].value(), plain_modulus.value()) > 1) - { - context_data.qualifiers_.parameters_set = false; - return context_data; - } - } - - // Check that plain_modulus is smaller than total coeff modulus - if (!is_less_than_uint_uint(plain_modulus.data(), plain_modulus.uint64_count(), - context_data.total_coeff_modulus_.get(), coeff_mod_count)) - { - // Parameters are not valid - context_data.qualifiers_.parameters_set = false; - return context_data; - } - - // Can we use batching? (NTT with plain_modulus) - context_data.qualifiers_.using_batching = false; - context_data.plain_ntt_tables_ = allocate(pool_); - if (context_data.plain_ntt_tables_->generate(coeff_count_power, plain_modulus)) - { - context_data.qualifiers_.using_batching = true; - } - - // Check for plain_lift - // If all the small coefficient moduli are larger than plain modulus, - // we can quickly lift plain coefficients to RNS form - context_data.qualifiers_.using_fast_plain_lift = true; - for (size_t i = 0; i < coeff_mod_count; i++) - { - context_data.qualifiers_.using_fast_plain_lift &= - (coeff_modulus[i].value() > plain_modulus.value()); - } - - // Calculate coeff_div_plain_modulus (BFV-"Delta") and the remainder - // upper_half_increment - context_data.coeff_div_plain_modulus_ = allocate_uint(coeff_mod_count, pool_); - context_data.upper_half_increment_ = allocate_uint(coeff_mod_count, pool_); - auto wide_plain_modulus(duplicate_uint_if_needed(plain_modulus.data(), - plain_modulus.uint64_count(), coeff_mod_count, false, pool_)); - divide_uint_uint(context_data.total_coeff_modulus_.get(), - wide_plain_modulus.get(), coeff_mod_count, - context_data.coeff_div_plain_modulus_.get(), - context_data.upper_half_increment_.get(), pool_); - - // Decompose coeff_div_plain_modulus into RNS factors - for (size_t i = 0; i < coeff_mod_count; i++) - { - temp[i] = modulo_uint(context_data.coeff_div_plain_modulus_.get(), - coeff_mod_count, coeff_modulus[i], pool_); - } - set_uint_uint(temp.get(), coeff_mod_count, - context_data.coeff_div_plain_modulus_.get()); - - // Decompose upper_half_increment into RNS factors - for (size_t i = 0; i < coeff_mod_count; i++) - { - temp[i] = modulo_uint(context_data.upper_half_increment_.get(), - coeff_mod_count, coeff_modulus[i], pool_); - } - set_uint_uint(temp.get(), coeff_mod_count, - context_data.upper_half_increment_.get()); - - // Calculate (plain_modulus + 1) / 2. - context_data.plain_upper_half_threshold_ = (plain_modulus.value() + 1) >> 1; - - // Calculate coeff_modulus - plain_modulus. - context_data.plain_upper_half_increment_ = - allocate_uint(coeff_mod_count, pool_); - if (context_data.qualifiers_.using_fast_plain_lift) - { - // Calculate coeff_modulus[i] - plain_modulus if using_fast_plain_lift - for (size_t i = 0; i < coeff_mod_count; i++) - { - context_data.plain_upper_half_increment_[i] = - coeff_modulus[i].value() - plain_modulus.value(); - } - } - else - { - sub_uint_uint(context_data.total_coeff_modulus(), - wide_plain_modulus.get(), coeff_mod_count, - context_data.plain_upper_half_increment_.get()); - } - } - else if (parms.scheme() == scheme_type::CKKS) - { - // Check that plain_modulus is set to zero - if (!plain_modulus.is_zero()) - { - // Parameters are not valid - context_data.qualifiers_.parameters_set = false; - return context_data; - } - - // When using CKKS batching (BatchEncoder) is always enabled - context_data.qualifiers_.using_batching = true; - - // Cannot use fast_plain_lift for CKKS since the plaintext coefficients - // can easily be larger than coefficient moduli - context_data.qualifiers_.using_fast_plain_lift = false; - - // Calculate 2^64 / 2 (most negative plaintext coefficient value) - context_data.plain_upper_half_threshold_ = uint64_t(1) << 63; - - // Calculate plain_upper_half_increment = 2^64 mod coeff_modulus for CKKS plaintexts - context_data.plain_upper_half_increment_ = allocate_uint(coeff_mod_count, pool_); - for (size_t i = 0; i < coeff_mod_count; i++) - { - uint64_t tmp = (uint64_t(1) << 63) % coeff_modulus[i].value(); - context_data.plain_upper_half_increment_[i] = multiply_uint_uint_mod( - tmp, - sub_safe(coeff_modulus[i].value(), uint64_t(2)), - coeff_modulus[i]); - } - - // Compute the upper_half_threshold for this modulus. - context_data.upper_half_threshold_ = allocate_uint( - coeff_mod_count, pool_); - increment_uint(context_data.total_coeff_modulus(), - coeff_mod_count, context_data.upper_half_threshold_.get()); - right_shift_uint(context_data.upper_half_threshold_.get(), 1, - coeff_mod_count, context_data.upper_half_threshold_.get()); - } - else - { - throw invalid_argument("unsupported scheme"); - } - - // Create BaseConverter - context_data.base_converter_ = allocate(pool_, pool_); - context_data.base_converter_->generate(coeff_modulus, poly_modulus_degree, - plain_modulus); - if (!context_data.base_converter_->is_generated()) - { - // Parameters are not valid - context_data.qualifiers_.parameters_set = false; - return context_data; - } - - // Check whether the coefficient modulus consists of a set of primes that - // are in decreasing order - context_data.qualifiers_.using_descending_modulus_chain = true; - for (size_t i = 0; i < coeff_mod_count - 1; i++) - { - context_data.qualifiers_.using_descending_modulus_chain - &= (coeff_modulus[i].value() > coeff_modulus[i + 1].value()); - } - - // Done with validation and pre-computations - return context_data; - } - - parms_id_type SEALContext::create_next_context_data( - const parms_id_type &prev_parms_id) - { - // Create the next set of parameters by removing last modulus - auto next_parms = context_data_map_.at(prev_parms_id)->parms_; - auto next_coeff_modulus = next_parms.coeff_modulus(); - next_coeff_modulus.pop_back(); - next_parms.set_coeff_modulus(next_coeff_modulus); - auto next_parms_id = next_parms.parms_id(); - - // Validate next parameters and create next context_data - auto next_context_data = validate(next_parms); - - // If not valid then return zero parms_id - if (!next_context_data.qualifiers_.parameters_set) - { - return parms_id_zero; - } - - // Add them to the context_data_map_ - context_data_map_.emplace(make_pair(next_parms_id, - make_shared(move(next_context_data)))); - - // Add pointer to next context_data to the previous one (linked list) - // Add pointer to previous context_data to the next one (doubly linked list) - // We need to remove constness first to modify this - const_pointer_cast( - context_data_map_.at(prev_parms_id))->next_context_data_ = - context_data_map_.at(next_parms_id); - const_pointer_cast( - context_data_map_.at(next_parms_id))->prev_context_data_ = - context_data_map_.at(prev_parms_id); - - return next_parms_id; - } - - SEALContext::SEALContext(EncryptionParameters parms, bool expand_mod_chain, - sec_level_type sec_level, MemoryPoolHandle pool) - : pool_(move(pool)), sec_level_(sec_level) - { - if (!pool_) - { - throw invalid_argument("pool is uninitialized"); - } - - // Set random generator - if (!parms.random_generator()) - { - parms.set_random_generator( - UniformRandomGeneratorFactory::default_factory()); - } - - // Validate parameters and add new ContextData to the map - // Note that this happens even if parameters are not valid - - // First create key_parms_id_. - context_data_map_.emplace(make_pair(parms.parms_id(), - make_shared(validate(parms)))); - key_parms_id_ = parms.parms_id(); - - // Then create first_parms_id_ if the parameters are valid and there is - // more than one modulus in coeff_modulus. This is equivalent to expanding - // the chain by one step. Otherwise, we set first_parms_id_ to equal - // key_parms_id_. - if (!context_data_map_.at(key_parms_id_)->qualifiers_.parameters_set || - parms.coeff_modulus().size() == 1) - { - first_parms_id_ = key_parms_id_; - } - else - { - auto next_parms_id = create_next_context_data(key_parms_id_); - first_parms_id_ = (next_parms_id == parms_id_zero) ? - key_parms_id_ : next_parms_id; - } - - // Set last_parms_id_ to point to first_parms_id_ - last_parms_id_ = first_parms_id_; - - // Check if keyswitching is available - using_keyswitching_ = (first_parms_id_ != key_parms_id_); - - // If modulus switching chain is to be created, compute the remaining - // parameter sets as long as they are valid to use (parameters_set == true) - if (expand_mod_chain && - context_data_map_.at(first_parms_id_)->qualifiers_.parameters_set) - { - auto prev_parms_id = first_parms_id_; - while (context_data_map_.at(prev_parms_id)->parms().coeff_modulus().size() > 1) - { - auto next_parms_id = create_next_context_data(prev_parms_id); - if (next_parms_id == parms_id_zero) - { - break; - } - prev_parms_id = next_parms_id; - last_parms_id_ = next_parms_id; - } - } - - // Set the chain_index for each context_data - size_t parms_count = context_data_map_.size(); - auto context_data_ptr = context_data_map_.at(key_parms_id_); - while (context_data_ptr) - { - // We need to remove constness first to modify this - const_pointer_cast( - context_data_ptr)->chain_index_ = --parms_count; - context_data_ptr = context_data_ptr->next_context_data_; - } - } -} diff --git a/SEAL/native/src/seal/context.h b/SEAL/native/src/seal/context.h deleted file mode 100644 index f511fa8..0000000 --- a/SEAL/native/src/seal/context.h +++ /dev/null @@ -1,547 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include "seal/encryptionparams.h" -#include "seal/memorymanager.h" -#include "seal/modulus.h" -#include "seal/util/smallntt.h" -#include "seal/util/baseconverter.h" -#include "seal/util/pointer.h" - -namespace seal -{ - /** - Stores a set of attributes (qualifiers) of a set of encryption parameters. - These parameters are mainly used internally in various parts of the library, - e.g., to determine which algorithmic optimizations the current support. The - qualifiers are automatically created by the SEALContext class, silently passed - on to classes such as Encryptor, Evaluator, and Decryptor, and the only way to - change them is by changing the encryption parameters themselves. In other - words, a user will never have to create their own instance of this class, and - in most cases never have to worry about it at all. - */ - class EncryptionParameterQualifiers - { - public: - /** - If the encryption parameters are set in a way that is considered valid by - Microsoft SEAL, the variable parameters_set is set to true. - */ - bool parameters_set; - - /** - Tells whether FFT can be used for polynomial multiplication. If the - polynomial modulus is of the form X^N+1, where N is a power of two, then - FFT can be used for fast multiplication of polynomials modulo the polynomial - modulus. In this case the variable using_fft will be set to true. However, - currently Microsoft SEAL requires this to be the case for the parameters - to be valid. Therefore, parameters_set can only be true if using_fft is - true. - */ - bool using_fft; - - /** - Tells whether NTT can be used for polynomial multiplication. If the primes - in the coefficient modulus are congruent to 1 modulo 2N, where X^N+1 is the - polynomial modulus and N is a power of two, then the number-theoretic - transform (NTT) can be used for fast multiplications of polynomials modulo - the polynomial modulus and coefficient modulus. In this case the variable - using_ntt will be set to true. However, currently Microsoft SEAL requires - this to be the case for the parameters to be valid. Therefore, parameters_set - can only be true if using_ntt is true. - */ - bool using_ntt; - - /** - Tells whether batching is supported by the encryption parameters. If the - plaintext modulus is congruent to 1 modulo 2N, where X^N+1 is the polynomial - modulus and N is a power of two, then it is possible to use the BatchEncoder - class to view plaintext elements as 2-by-(N/2) matrices of integers modulo - the plaintext modulus. This is called batching, and allows the user to - operate on the matrix elements (slots) in a SIMD fashion, and rotate the - matrix rows and columns. When the computation is easily vectorizable, using - batching can yield a huge performance boost. If the encryption parameters - support batching, the variable using_batching is set to true. - */ - bool using_batching; - - /** - Tells whether fast plain lift is supported by the encryption parameters. - A certain performance optimization in multiplication of a ciphertext by - a plaintext (Evaluator::multiply_plain) and in transforming a plaintext - element to NTT domain (Evaluator::transform_to_ntt) can be used when the - plaintext modulus is smaller than each prime in the coefficient modulus. - In this case the variable using_fast_plain_lift is set to true. - */ - bool using_fast_plain_lift; - - /** - Tells whether the coefficient modulus consists of a set of primes that - are in decreasing order. If this is true, certain modular reductions in - base conversion can be omitted, improving performance. - */ - bool using_descending_modulus_chain; - - /** - Tells whether the encryption parameters are secure based on the standard - parameters from HomomorphicEncryption.org security standard. - */ - sec_level_type sec_level; - - private: - EncryptionParameterQualifiers() : - parameters_set(false), - using_fft(false), - using_ntt(false), - using_batching(false), - using_fast_plain_lift(false), - using_descending_modulus_chain(false), - sec_level(sec_level_type::none) - { - } - - friend class SEALContext; - }; - - /** - Performs sanity checks (validation) and pre-computations for a given set of encryption - parameters. While the EncryptionParameters class is intended to be a light-weight class - to store the encryption parameters, the SEALContext class is a heavy-weight class that - is constructed from a given set of encryption parameters. It validates the parameters - for correctness, evaluates their properties, and performs and stores the results of - several costly pre-computations. - - After the user has set at least the poly_modulus, coeff_modulus, and plain_modulus - parameters in a given EncryptionParameters instance, the parameters can be validated - for correctness and functionality by constructing an instance of SEALContext. The - constructor of SEALContext does all of its work automatically, and concludes by - constructing and storing an instance of the EncryptionParameterQualifiers class, with - its flags set according to the properties of the given parameters. If the created - instance of EncryptionParameterQualifiers has the parameters_set flag set to true, the - given parameter set has been deemed valid and is ready to be used. If the parameters - were for some reason not appropriately set, the parameters_set flag will be false, - and a new SEALContext will have to be created after the parameters are corrected. - - By default, SEALContext creates a chain of SEALContext::ContextData instances. The - first one in the chain corresponds to special encryption parameters that are reserved - to be used by the various key classes (SecretKey, PublicKey, etc.). These are the exact - same encryption parameters that are created by the user and passed to th constructor of - SEALContext. The functions key_context_data() and key_parms_id() return the ContextData - and the parms_id corresponding to these special parameters. The rest of the ContextData - instances in the chain correspond to encryption parameters that are derived from the - first encryption parameters by always removing the last one of the moduli in the - coeff_modulus, until the resulting parameters are no longer valid, e.g., there are no - more primes left. These derived encryption parameters are used by ciphertexts and - plaintexts and their respective ContextData can be accessed through the - get_context_data(parms_id_type) function. The functions first_context_data() and - last_context_data() return the ContextData corresponding to the first and the last - set of parameters in the "data" part of the chain, i.e., the second and the last element - in the full chain. The chain itself is a doubly linked list, and is referred to as the - modulus switching chain. - - @see EncryptionParameters for more details on the parameters. - @see EncryptionParameterQualifiers for more details on the qualifiers. - */ - class SEALContext - { - public: - /** - Class to hold pre-computation data for a given set of encryption parameters. - */ - class ContextData - { - friend class SEALContext; - - public: - ContextData() = delete; - - ContextData(const ContextData ©) = delete; - - ContextData(ContextData &&move) = default; - - ContextData &operator =(ContextData &&move) = default; - - /** - Returns a const reference to the underlying encryption parameters. - */ - SEAL_NODISCARD inline auto &parms() const noexcept - { - return parms_; - } - - /** - Returns the parms_id of the current parameters. - */ - SEAL_NODISCARD inline auto &parms_id() const noexcept - { - return parms_.parms_id(); - } - - /** - Returns a copy of EncryptionParameterQualifiers corresponding to the - current encryption parameters. Note that to change the qualifiers it is - necessary to create a new instance of SEALContext once appropriate changes - to the encryption parameters have been made. - */ - SEAL_NODISCARD inline auto qualifiers() const noexcept - { - return qualifiers_; - } - - /** - Returns a pointer to a pre-computed product of all primes in the coefficient - modulus. The security of the encryption parameters largely depends on the - bit-length of this product, and on the degree of the polynomial modulus. - */ - SEAL_NODISCARD inline auto total_coeff_modulus() const noexcept - -> const std::uint64_t* - { - return total_coeff_modulus_.get(); - } - - /** - Returns the significant bit count of the total coefficient modulus. - */ - SEAL_NODISCARD inline int total_coeff_modulus_bit_count() const noexcept - { - return total_coeff_modulus_bit_count_; - } - - /** - Returns a const reference to the base converter. - */ - SEAL_NODISCARD inline auto &base_converter() const noexcept - { - return base_converter_; - } - - /** - Returns a const reference to the NTT tables. - */ - SEAL_NODISCARD inline auto &small_ntt_tables() const noexcept - { - return small_ntt_tables_; - } - - /** - Returns a const reference to the NTT tables. - */ - SEAL_NODISCARD inline auto &plain_ntt_tables() const noexcept - { - return plain_ntt_tables_; - } - - /** - Return a pointer to BFV "Delta", i.e. coefficient modulus divided by - plaintext modulus. - */ - SEAL_NODISCARD inline auto coeff_div_plain_modulus() const noexcept - -> const std::uint64_t* - { - return coeff_div_plain_modulus_.get(); - } - - /** - Return the threshold for the upper half of integers modulo plain_modulus. - This is simply (plain_modulus + 1) / 2. - */ - SEAL_NODISCARD inline auto plain_upper_half_threshold() const noexcept - -> std::uint64_t - { - return plain_upper_half_threshold_; - } - - /** - Return a pointer to the plaintext upper half increment, i.e. coeff_modulus - minus plain_modulus. The upper half increment is represented as an integer - for the full product coeff_modulus if using_fast_plain_lift is false and is - otherwise represented modulo each of the coeff_modulus primes in order. - */ - SEAL_NODISCARD inline auto plain_upper_half_increment() const noexcept - -> const std::uint64_t* - { - return plain_upper_half_increment_.get(); - } - - /** - Return a pointer to the upper half threshold with respect to the total - coefficient modulus. This is needed in CKKS decryption. - */ - SEAL_NODISCARD inline auto upper_half_threshold() const noexcept - -> const std::uint64_t* - { - return upper_half_threshold_.get(); - } - - /** - Return a pointer to the upper half increment used for computing Delta*m - and converting the coefficients to modulo coeff_modulus. For example, - t-1 in plaintext should change into - q - Delta = Delta*t + r_t(q) - Delta - = Delta*(t-1) + r_t(q) - so multiplying the message by Delta is not enough and requires also an - addition of r_t(q). This is precisely the upper_half_increment. Note that - this operation is only done for negative message coefficients, i.e. those - that exceed plain_upper_half_threshold. - */ - SEAL_NODISCARD inline auto upper_half_increment() const noexcept - -> const std::uint64_t* - { - return upper_half_increment_.get(); - } - - /** - Returns a shared_ptr to the context data corresponding to the previous parameters - in the modulus switching chain. If the current data is the first one in the - chain, then the result is nullptr. - */ - SEAL_NODISCARD inline auto prev_context_data() const noexcept - { - return prev_context_data_.lock(); - } - - /** - Returns a shared_ptr to the context data corresponding to the next parameters - in the modulus switching chain. If the current data is the last one in the - chain, then the result is nullptr. - */ - SEAL_NODISCARD inline auto next_context_data() const noexcept - { - return next_context_data_; - } - - /** - Returns the index of the parameter set in a chain. The initial parameters - have index 0 and the index increases sequentially in the parameter chain. - */ - SEAL_NODISCARD inline std::size_t chain_index() const noexcept - { - return chain_index_; - } - - private: - ContextData(EncryptionParameters parms, MemoryPoolHandle pool) : - pool_(std::move(pool)), parms_(parms) - { - if (!pool_) - { - throw std::invalid_argument("pool is uninitialized"); - } - } - - MemoryPoolHandle pool_; - - EncryptionParameters parms_; - - EncryptionParameterQualifiers qualifiers_; - - util::Pointer base_converter_; - - util::Pointer small_ntt_tables_; - - util::Pointer plain_ntt_tables_; - - util::Pointer total_coeff_modulus_; - - int total_coeff_modulus_bit_count_ = 0; - - util::Pointer coeff_div_plain_modulus_; - - std::uint64_t plain_upper_half_threshold_ = 0; - - util::Pointer plain_upper_half_increment_; - - util::Pointer upper_half_threshold_; - - util::Pointer upper_half_increment_; - - std::weak_ptr prev_context_data_; - - std::shared_ptr next_context_data_{ nullptr }; - - std::size_t chain_index_ = 0; - }; - - SEALContext() = delete; - - /** - Creates an instance of SEALContext and performs several pre-computations - on the given EncryptionParameters. - - @param[in] parms The encryption parameters - @param[in] expand_mod_chain Determines whether the modulus switching chain - should be created - @param[in] sec_level Determines whether a specific security level should be - enforced according to HomomorphicEncryption.org security standard - */ - SEAL_NODISCARD static auto Create( - const EncryptionParameters &parms, - bool expand_mod_chain = true, - sec_level_type sec_level = sec_level_type::tc128) - { - return std::shared_ptr( - new SEALContext( - parms, - expand_mod_chain, - sec_level, - MemoryManager::GetPool()) - ); - } - - /** - Returns the ContextData corresponding to encryption parameters with a given - parms_id. If parameters with the given parms_id are not found then the - function returns nullptr. - - @param[in] parms_id The parms_id of the encryption parameters - */ - SEAL_NODISCARD inline auto get_context_data( - parms_id_type parms_id) const - { - auto data = context_data_map_.find(parms_id); - return (data != context_data_map_.end()) ? - data->second : std::shared_ptr{ nullptr }; - } - - /** - Returns the ContextData corresponding to encryption parameters that are - used for keys. - */ - SEAL_NODISCARD inline auto key_context_data() const - { - auto data = context_data_map_.find(key_parms_id_); - return (data != context_data_map_.end()) ? - data->second : std::shared_ptr{ nullptr }; - } - - /** - Returns the ContextData corresponding to the first encryption parameters - that are used for data. - */ - SEAL_NODISCARD inline auto first_context_data() const - { - auto data = context_data_map_.find(first_parms_id_); - return (data != context_data_map_.end()) ? - data->second : std::shared_ptr{ nullptr }; - } - - /** - Returns the ContextData corresponding to the last encryption parameters - that are used for data. - */ - SEAL_NODISCARD inline auto last_context_data() const - { - auto data = context_data_map_.find(last_parms_id_); - return (data != context_data_map_.end()) ? - data->second : std::shared_ptr{ nullptr }; - } - - /** - Returns whether the encryption parameters are valid. - */ - SEAL_NODISCARD inline auto parameters_set() const - { - return first_context_data() ? - first_context_data()->qualifiers_.parameters_set : false; - } - - /** - Returns a parms_id_type corresponding to the set of encryption parameters - that are used for keys. - */ - SEAL_NODISCARD inline auto &key_parms_id() const noexcept - { - return key_parms_id_; - } - - /** - Returns a parms_id_type corresponding to the first encryption parameters - that are used for data. - */ - SEAL_NODISCARD inline auto &first_parms_id() const noexcept - { - return first_parms_id_; - } - - /** - Returns a parms_id_type corresponding to the last encryption parameters - that are used for data. - */ - SEAL_NODISCARD inline auto &last_parms_id() const noexcept - { - return last_parms_id_; - } - - /** - Returns whether the coefficient modulus supports keyswitching. In practice, - support for keyswitching is required by Evaluator::relinearize, - Evaluator::apply_galois, and all rotation and conjugation operations. For - keyswitching to be available, the coefficient modulus parameter must consist - of at least two prime number factors. - */ - SEAL_NODISCARD inline bool using_keyswitching() const noexcept - { - return using_keyswitching_; - } - - private: - SEALContext(const SEALContext ©) = delete; - - SEALContext(SEALContext &&source) = delete; - - SEALContext &operator =(const SEALContext &assign) = delete; - - SEALContext &operator =(SEALContext &&assign) = delete; - - /** - Creates an instance of SEALContext, and performs several pre-computations - on the given EncryptionParameters. - - @param[in] parms The encryption parameters - @param[in] expand_mod_chain Determines whether the modulus switching chain - should be created - @param[in] sec_level Determines whether a specific security level should be - enforced according to HomomorphicEncryption.org security standard - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if pool is uninitialized - */ - SEALContext(EncryptionParameters parms, bool expand_mod_chain, - sec_level_type sec_level, MemoryPoolHandle pool); - - ContextData validate(EncryptionParameters parms); - - /** - Create the next context_data by dropping the last element from coeff_modulus. - If the new encryption parameters are not valid, returns parms_id_zero. - Otherwise, returns the parms_id of the next parameter and appends the next - context_data to the chain. - */ - parms_id_type create_next_context_data(const parms_id_type &prev_parms); - - MemoryPoolHandle pool_; - - parms_id_type key_parms_id_; - - parms_id_type first_parms_id_; - - parms_id_type last_parms_id_; - - std::unordered_map< - parms_id_type, std::shared_ptr> context_data_map_{}; - - /** - Is HomomorphicEncryption.org security standard enforced? - */ - sec_level_type sec_level_; - - /** - Is keyswitching supported by the encryption parameters? - */ - bool using_keyswitching_; - }; -} diff --git a/SEAL/native/src/seal/decryptor.cpp b/SEAL/native/src/seal/decryptor.cpp deleted file mode 100644 index c6b9963..0000000 --- a/SEAL/native/src/seal/decryptor.cpp +++ /dev/null @@ -1,557 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include -#include "seal/decryptor.h" -#include "seal/valcheck.h" -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithmod.h" -#include "seal/util/uintarithsmallmod.h" -#include "seal/util/polycore.h" -#include "seal/util/polyarithmod.h" -#include "seal/util/polyarithsmallmod.h" - -using namespace std; -using namespace seal::util; - -namespace seal -{ - Decryptor::Decryptor(std::shared_ptr context, - const SecretKey &secret_key) : context_(move(context)) - { - // Verify parameters - if (!context_) - { - throw invalid_argument("invalid context"); - } - if (!context_->parameters_set()) - { - throw invalid_argument("encryption parameters are not set correctly"); - } - if (secret_key.parms_id() != context_->key_parms_id()) - { - throw invalid_argument("secret key is not valid for encryption parameters"); - } - - auto &parms = context_->key_context_data()->parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - - // Set the secret_key_array to have size 1 (first power of secret) - // and copy over data - secret_key_array_ = allocate_poly(coeff_count, coeff_mod_count, pool_); - set_poly_poly(secret_key.data().data(), coeff_count, coeff_mod_count, - secret_key_array_.get()); - secret_key_array_size_ = 1; - } - - void Decryptor::decrypt(const Ciphertext &encrypted, Plaintext &destination) - { - // Verify that encrypted is valid. - if (!is_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - - auto &context_data = *context_->first_context_data(); - auto &parms = context_data.parms(); - - switch (parms.scheme()) - { - case scheme_type::BFV: - bfv_decrypt(encrypted, destination, pool_); - return; - - case scheme_type::CKKS: - ckks_decrypt(encrypted, destination, pool_); - return; - - default: - throw invalid_argument("unsupported scheme"); - } - } - - void Decryptor::bfv_decrypt(const Ciphertext &encrypted, - Plaintext &destination, MemoryPoolHandle pool) - { - if (encrypted.is_ntt_form()) - { - throw invalid_argument("encrypted cannot be in NTT form"); - } - - auto &context_data = *context_->get_context_data(encrypted.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_mod_count); - size_t key_rns_poly_uint64_count = mul_safe(coeff_count, - context_->key_context_data()->parms().coeff_modulus().size()); - size_t encrypted_size = encrypted.size(); - - auto &small_ntt_tables = context_data.small_ntt_tables(); - auto &base_converter = context_data.base_converter(); - auto &plain_gamma_product = base_converter->get_plain_gamma_product(); - auto &plain_gamma_array = base_converter->get_plain_gamma_array(); - auto &neg_inv_coeff = base_converter->get_neg_inv_coeff(); - auto inv_gamma = base_converter->get_inv_gamma(); - - // The number of uint64 count for plain_modulus and gamma together - size_t plain_gamma_uint64_count = 2; - - // Allocate a full size destination to write to - auto wide_destination(allocate_uint(coeff_count, pool)); - - // Make sure we have enough secret key powers computed - compute_secret_key_array(encrypted_size - 1); - - /* - Firstly find c_0 + c_1 *s + ... + c_{count-1} * s^{count-1} mod q - This is equal to Delta m + v where ||v|| < Delta/2. - So, add Delta / 2 and now we have something which is Delta * (m + epsilon) where epsilon < 1 - Therefore, we can (integer) divide by Delta and the answer will round down to m. - */ - - // Make a temp destination for all the arithmetic mod qi before calling FastBConverse - auto tmp_dest_modq(allocate_zero_poly(coeff_count, coeff_mod_count, pool)); - - // put < (c_1 , c_2, ... , c_{count-1}) , (s,s^2,...,s^{count-1}) > mod q in destination - - // Now do the dot product of encrypted_copy and the secret key array using NTT. - // The secret key powers are already NTT transformed. - auto copy_operand1(allocate_uint(coeff_count, pool)); - for (size_t i = 0; i < coeff_mod_count; i++) - { - // Initialize pointers for multiplication - const uint64_t *current_array1 = encrypted.data(1) + (i * coeff_count); - const uint64_t *current_array2 = secret_key_array_.get() + (i * coeff_count); - - for (size_t j = 0; j < encrypted_size - 1; j++) - { - // Perform the dyadic product. - set_uint_uint(current_array1, coeff_count, copy_operand1.get()); - - // Lazy reduction - ntt_negacyclic_harvey_lazy(copy_operand1.get(), small_ntt_tables[i]); - - dyadic_product_coeffmod(copy_operand1.get(), current_array2, coeff_count, - coeff_modulus[i], copy_operand1.get()); - add_poly_poly_coeffmod(tmp_dest_modq.get() + (i * coeff_count), - copy_operand1.get(), coeff_count, coeff_modulus[i], - tmp_dest_modq.get() + (i * coeff_count)); - - current_array1 += rns_poly_uint64_count; - current_array2 += key_rns_poly_uint64_count; - } - - // Perform inverse NTT - inverse_ntt_negacyclic_harvey(tmp_dest_modq.get() + (i * coeff_count), - small_ntt_tables[i]); - } - - // add c_0 into destination - for (size_t i = 0; i < coeff_mod_count; i++) - { - //add_poly_poly_coeffmod(tmp_dest_modq.get() + (i * coeff_count), - // encrypted.data() + (i * coeff_count), coeff_count, coeff_modulus_[i], - // tmp_dest_modq.get() + (i * coeff_count)); - - // Lazy reduction - for (size_t j = 0; j < coeff_count; j++) - { - tmp_dest_modq[j + (i * coeff_count)] += encrypted[j + (i * coeff_count)]; - } - - // Compute |gamma * plain|qi * ct(s) - multiply_poly_scalar_coeffmod(tmp_dest_modq.get() + (i * coeff_count), coeff_count, - plain_gamma_product[i], coeff_modulus[i], tmp_dest_modq.get() + (i * coeff_count)); - } - - // Make another temp destination to get the poly in mod {gamma U plain_modulus} - auto tmp_dest_plain_gamma(allocate_poly(coeff_count, plain_gamma_uint64_count, pool)); - - // Compute FastBConvert from q to {gamma, plain_modulus} - base_converter->fastbconv_plain_gamma(tmp_dest_modq.get(), tmp_dest_plain_gamma.get(), pool); - - // Compute result multiply by coeff_modulus inverse in mod {gamma U plain_modulus} - for (size_t i = 0; i < plain_gamma_uint64_count; i++) - { - multiply_poly_scalar_coeffmod(tmp_dest_plain_gamma.get() + (i * coeff_count), - coeff_count, neg_inv_coeff[i], plain_gamma_array[i], - tmp_dest_plain_gamma.get() + (i * coeff_count)); - } - - // First correct the values which are larger than floor(gamma/2) - uint64_t gamma_div_2 = plain_gamma_array[1].value() >> 1; - - // Now compute the subtraction to remove error and perform final multiplication by - // gamma inverse mod plain_modulus - for (size_t i = 0; i < coeff_count; i++) - { - // Need correction beacuse of center mod - if (tmp_dest_plain_gamma[i + coeff_count] > gamma_div_2) - { - // Compute -(gamma - a) instead of (a - gamma) - tmp_dest_plain_gamma[i + coeff_count] = plain_gamma_array[1].value() - - tmp_dest_plain_gamma[i + coeff_count]; - tmp_dest_plain_gamma[i + coeff_count] %= plain_gamma_array[0].value(); - wide_destination[i] = add_uint_uint_mod(tmp_dest_plain_gamma[i], - tmp_dest_plain_gamma[i + coeff_count], plain_gamma_array[0]); - } - // No correction needed - else - { - tmp_dest_plain_gamma[i + coeff_count] %= plain_gamma_array[0].value(); - wide_destination[i] = sub_uint_uint_mod(tmp_dest_plain_gamma[i], - tmp_dest_plain_gamma[i + coeff_count], plain_gamma_array[0]); - } - } - - // How many non-zero coefficients do we really have in the result? - size_t plain_coeff_count = get_significant_uint64_count_uint( - wide_destination.get(), coeff_count); - - // Resize destination to appropriate size - destination.resize(max(plain_coeff_count, size_t(1))); - destination.parms_id() = parms_id_zero; - - // Perform final multiplication by gamma inverse mod plain_modulus - multiply_poly_scalar_coeffmod(wide_destination.get(), - max(plain_coeff_count, size_t(1)), - inv_gamma, plain_gamma_array[0], destination.data()); - } - - void Decryptor::ckks_decrypt(const Ciphertext &encrypted, - Plaintext &destination, MemoryPoolHandle pool) - { - if (!encrypted.is_ntt_form()) - { - throw invalid_argument("encrypted must be in NTT form"); - } - - // We already know that the parameters are valid - auto &context_data = *context_->get_context_data(encrypted.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_mod_count); - size_t key_rns_poly_uint64_count = mul_safe(coeff_count, - context_->key_context_data()->parms().coeff_modulus().size()); - size_t encrypted_size = encrypted.size(); - - // Make sure we have enough secret key powers computed - compute_secret_key_array(encrypted_size - 1); - - /* - Decryption consists in finding c_0 + c_1 *s + ... + c_{count-1} * s^{count-1} mod q_1 * q_2 * q_3 - as long as ||m + v|| < q_1 * q_2 * q_3 - This is equal to m + v where ||v|| is small enough. - */ - - // Since we overwrite destination, we zeroize destination parameters - // This is necessary, otherwise resize will throw an exception. - destination.parms_id() = parms_id_zero; - - // Resize destination to appropriate size - destination.resize(rns_poly_uint64_count); - - // Make a temp destination for all the arithmetic mod q1, q2, q3 - //auto tmp_dest_modq(allocate_zero_poly(coeff_count, decryption_coeff_mod_count, pool)); - - // put < (c_1 , c_2, ... , c_{count-1}) , (s,s^2,...,s^{count-1}) > mod q in destination - - // Now do the dot product of encrypted_copy and the secret key array using NTT. - // The secret key powers are already NTT transformed. - - auto copy_operand1(allocate_uint(coeff_count, pool)); - for (size_t i = 0; i < coeff_mod_count; i++) - { - // Initialize pointers for multiplication - // c_1 mod qi - const uint64_t *current_array1 = encrypted.data(1) + (i * coeff_count); - // s mod qi - const uint64_t *current_array2 = secret_key_array_.get() + (i * coeff_count); - // set destination coefficients to zero modulo q_i - set_zero_uint(coeff_count, destination.data() + (i * coeff_count)); - - for (size_t j = 0; j < encrypted_size - 1; j++) - { - // Perform the dyadic product. - set_uint_uint(current_array1, coeff_count, copy_operand1.get()); - - // Lazy reduction - //ntt_negacyclic_harvey_lazy(copy_operand1.get(), small_ntt_tables[i]); - dyadic_product_coeffmod(copy_operand1.get(), current_array2, coeff_count, - coeff_modulus[i], copy_operand1.get()); - add_poly_poly_coeffmod(destination.data() + (i * coeff_count), - copy_operand1.get(), coeff_count, coeff_modulus[i], - destination.data() + (i * coeff_count)); - - // go to c_{1+j+1} and s^{1+j+1} mod qi - current_array1 += rns_poly_uint64_count; - current_array2 += key_rns_poly_uint64_count; - } - - // add c_0 into destination - add_poly_poly_coeffmod(destination.data() + (i * coeff_count), - encrypted.data() + (i * coeff_count), coeff_count, - coeff_modulus[i], destination.data() + (i * coeff_count)); - } - - // Set destination parameters as in encrypted - destination.parms_id() = encrypted.parms_id(); - destination.scale() = encrypted.scale(); - } - - void Decryptor::compute_secret_key_array(size_t max_power) - { -#ifdef SEAL_DEBUG - if (max_power < 1) - { - throw invalid_argument("max_power must be at least 1"); - } - if (!secret_key_array_size_ || !secret_key_array_) - { - throw logic_error("secret_key_array_ is uninitialized"); - } -#endif - // WARNING: This function must be called with the original context_data - auto &context_data = *context_->key_context_data(); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t key_rns_poly_uint64_count = mul_safe(coeff_count, coeff_mod_count); - - ReaderLock reader_lock(secret_key_array_locker_.acquire_read()); - - size_t old_size = secret_key_array_size_; - size_t new_size = max(max_power, old_size); - - if (old_size == new_size) - { - return; - } - - reader_lock.unlock(); - - // Need to extend the array - // Compute powers of secret key until max_power - auto new_secret_key_array(allocate_poly( - mul_safe(new_size, coeff_count), coeff_mod_count, pool_)); - set_poly_poly(secret_key_array_.get(), old_size * coeff_count, - coeff_mod_count, new_secret_key_array.get()); - - set_poly_poly(secret_key_array_.get(), mul_safe(old_size, coeff_count), - coeff_mod_count, new_secret_key_array.get()); - - uint64_t *prev_poly_ptr = new_secret_key_array.get() + - mul_safe(old_size - 1, key_rns_poly_uint64_count); - uint64_t *next_poly_ptr = prev_poly_ptr + key_rns_poly_uint64_count; - - // Since all of the key powers in secret_key_array_ are already NTT transformed, - // to get the next one we simply need to compute a dyadic product of the last - // one with the first one [which is equal to NTT(secret_key_)]. - for (size_t i = old_size; i < new_size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - dyadic_product_coeffmod(prev_poly_ptr + (j * coeff_count), - new_secret_key_array.get() + (j * coeff_count), - coeff_count, coeff_modulus[j], - next_poly_ptr + (j * coeff_count)); - } - prev_poly_ptr = next_poly_ptr; - next_poly_ptr += key_rns_poly_uint64_count; - } - - - // Take writer lock to update array - WriterLock writer_lock(secret_key_array_locker_.acquire_write()); - - // Do we still need to update size? - old_size = secret_key_array_size_; - new_size = max(max_power, secret_key_array_size_); - - if (old_size == new_size) - { - return; - } - - // Acquire new array - secret_key_array_size_ = new_size; - secret_key_array_.acquire(new_secret_key_array); - } - - void Decryptor::compose( - const SEALContext::ContextData &context_data, uint64_t *value) - { -#ifdef SEAL_DEBUG - if (value == nullptr) - { - throw invalid_argument("input cannot be null"); - } -#endif - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_mod_count); - - auto &base_converter = context_data.base_converter(); - auto coeff_products_array = base_converter->get_coeff_products_array(); - auto &inv_coeff_mod_coeff_array = base_converter->get_inv_coeff_mod_coeff_array(); - - // Set temporary coefficients_ptr pointer to point to either an existing - // allocation given as parameter, or else to a new allocation from the memory pool. - auto coefficients(allocate_uint(rns_poly_uint64_count, pool_)); - uint64_t *coefficients_ptr = coefficients.get(); - - // Re-merge the coefficients first - for (size_t i = 0; i < coeff_count; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - coefficients_ptr[j + (i * coeff_mod_count)] = value[(j * coeff_count) + i]; - } - } - - auto temp(allocate_uint(coeff_mod_count, pool_)); - set_zero_uint(rns_poly_uint64_count, value); - - for (size_t i = 0; i < coeff_count; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - uint64_t tmp = multiply_uint_uint_mod(coefficients_ptr[j], - inv_coeff_mod_coeff_array[j], coeff_modulus[j]); - multiply_uint_uint64(coeff_products_array + (j * coeff_mod_count), - coeff_mod_count, tmp, coeff_mod_count, temp.get()); - add_uint_uint_mod(temp.get(), value + (i * coeff_mod_count), - context_data.total_coeff_modulus(), - coeff_mod_count, value + (i * coeff_mod_count)); - } - set_zero_uint(coeff_mod_count, temp.get()); - coefficients_ptr += coeff_mod_count; - } - } - - int Decryptor::invariant_noise_budget(const Ciphertext &encrypted) - { - // Verify that encrypted is valid. - if (!is_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - - if (context_->key_context_data()->parms().scheme() != scheme_type::BFV) - { - throw logic_error("unsupported scheme"); - } - if (encrypted.is_ntt_form()) - { - throw invalid_argument("encrypted cannot be in NTT form"); - } - - auto &context_data = *context_->get_context_data(encrypted.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t rns_poly_uint64_count = mul_safe(coeff_count, coeff_mod_count); - size_t key_rns_poly_uint64_count = mul_safe(coeff_count, - context_->key_context_data()->parms().coeff_modulus().size()); - size_t encrypted_size = encrypted.size(); - uint64_t plain_modulus = parms.plain_modulus().value(); - - auto &small_ntt_tables = context_data.small_ntt_tables(); - - // Storage for noise uint - auto destination(allocate_uint(coeff_mod_count, pool_)); - - // Storage for noise poly - auto noise_poly(allocate_zero_poly(coeff_count, coeff_mod_count, pool_)); - - // Now need to compute c(s) - Delta*m (mod q) - - // Make sure we have enough secret keys computed - compute_secret_key_array(encrypted_size - 1); - - /* - Firstly find c_0 + c_1 *s + ... + c_{count-1} * s^{count-1} mod q - This is equal to Delta m + v where ||v|| < Delta/2. - */ - // put < (c_1 , c_2, ... , c_{count-1}) , (s,s^2,...,s^{count-1}) > mod q - // in destination_poly. - // Make a copy of the encryption for NTT (except the first polynomial is - // not needed). - auto encrypted_copy(allocate_poly( - mul_safe(encrypted_size - 1, coeff_count), coeff_mod_count, pool_)); - set_poly_poly(encrypted.data(1), mul_safe(encrypted_size - 1, coeff_count), - coeff_mod_count, encrypted_copy.get()); - - // Now do the dot product of encrypted_copy and the secret key array using NTT. - // The secret key powers are already NTT transformed. - auto copy_operand1(allocate_uint(coeff_count, pool_)); - for (size_t i = 0; i < coeff_mod_count; i++) - { - // Initialize pointers for multiplication - const uint64_t *current_array1 = encrypted.data(1) + (i * coeff_count); - const uint64_t *current_array2 = secret_key_array_.get() + (i * coeff_count); - - for (size_t j = 0; j < encrypted_size - 1; j++) - { - // Perform the dyadic product. - set_uint_uint(current_array1, coeff_count, copy_operand1.get()); - - // Lazy reduction - ntt_negacyclic_harvey_lazy(copy_operand1.get(), small_ntt_tables[i]); - - dyadic_product_coeffmod(copy_operand1.get(), current_array2, coeff_count, - coeff_modulus[i], copy_operand1.get()); - add_poly_poly_coeffmod(noise_poly.get() + (i * coeff_count), - copy_operand1.get(), - coeff_count, coeff_modulus[i], - noise_poly.get() + (i * coeff_count)); - - current_array1 += rns_poly_uint64_count; - current_array2 += key_rns_poly_uint64_count; - } - - // Perform inverse NTT - inverse_ntt_negacyclic_harvey(noise_poly.get() + (i * coeff_count), - small_ntt_tables[i]); - } - - for (size_t i = 0; i < coeff_mod_count; i++) - { - // add c_0 into noise_poly - add_poly_poly_coeffmod(noise_poly.get() + (i * coeff_count), - encrypted.data() + (i * coeff_count), coeff_count, coeff_modulus[i], - noise_poly.get() + (i * coeff_count)); - - // Multiply by parms.plain_modulus() and reduce mod parms.coeff_modulus() to get - // parms.coeff_modulus()*noise - multiply_poly_scalar_coeffmod(noise_poly.get() + (i * coeff_count), - coeff_count, plain_modulus, coeff_modulus[i], - noise_poly.get() + (i * coeff_count)); - } - - // Compose the noise - compose(context_data, noise_poly.get()); - - // Next we compute the infinity norm mod parms.coeff_modulus() - poly_infty_norm_coeffmod(noise_poly.get(), coeff_count, coeff_mod_count, - context_data.total_coeff_modulus(), destination.get(), pool_); - - // The -1 accounts for scaling the invariant noise by 2 - int bit_count_diff = context_data.total_coeff_modulus_bit_count() - - get_significant_bit_count_uint(destination.get(), coeff_mod_count) - 1; - return max(0, bit_count_diff); - } -} diff --git a/SEAL/native/src/seal/decryptor.h b/SEAL/native/src/seal/decryptor.h deleted file mode 100644 index 8ec7796..0000000 --- a/SEAL/native/src/seal/decryptor.h +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include "seal/util/defines.h" -#include "seal/randomgen.h" -#include "seal/encryptionparams.h" -#include "seal/context.h" -#include "seal/memorymanager.h" -#include "seal/ciphertext.h" -#include "seal/plaintext.h" -#include "seal/secretkey.h" -#include "seal/smallmodulus.h" -#include "seal/util/smallntt.h" -#include "seal/util/baseconverter.h" -#include "seal/util/locks.h" - -namespace seal -{ - /** - Decrypts Ciphertext objects into Plaintext objects. Constructing a Decryptor - requires a SEALContext with valid encryption parameters, and the secret key. - The Decryptor is also used to compute the invariant noise budget in a given - ciphertext. - - @par Overloads - For the decrypt function we provide two overloads concerning the memory pool - used in allocations needed during the operation. In one overload the global - memory pool is used for this purpose, and in another overload the user can - supply a MemoryPoolHandle to be used instead. This is to allow one single - Decryptor to be used concurrently by several threads without running into - thread contention in allocations taking place during operations. For example, - one can share one single Decryptor across any number of threads, but in each - thread call the decrypt function by giving it a thread-local MemoryPoolHandle - to use. It is important for a developer to understand how this works to avoid - unnecessary performance bottlenecks. - - - @par NTT form - When using the BFV scheme (scheme_type::BFV), all plaintext and ciphertexts - should remain by default in the usual coefficient representation, i.e. not in - NTT form. When using the CKKS scheme (scheme_type::CKKS), all plaintexts and - ciphertexts should remain by default in NTT form. We call these scheme-specific - NTT states the "default NTT form". Decryption requires the input ciphertexts - to be in the default NTT form, and will throw an exception if this is not the - case. - */ - class SEAL_NODISCARD Decryptor - { - public: - /** - Creates a Decryptor instance initialized with the specified SEALContext - and secret key. - - @param[in] context The SEALContext - @param[in] secret_key The secret key - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::invalid_argument if secret_key is not valid - */ - Decryptor(std::shared_ptr context, const SecretKey &secret_key); - - /* - Decrypts a Ciphertext and stores the result in the destination parameter. - - @param[in] encrypted The ciphertext to decrypt - @param[out] destination The plaintext to overwrite with the decrypted - ciphertext - @throws std::invalid_argument if encrypted is not valid for the encryption - parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - */ - void decrypt(const Ciphertext &encrypted, Plaintext &destination); - - /* - Computes the invariant noise budget (in bits) of a ciphertext. The - invariant noise budget measures the amount of room there is for the noise - to grow while ensuring correct decryptions. This function works only with - the BFV scheme. - - @par Invariant Noise Budget - The invariant noise polynomial of a ciphertext is a rational coefficient - polynomial, such that a ciphertext decrypts correctly as long as the - coefficients of the invariantnoise polynomial are of absolute value less - than 1/2. Thus, we call the infinity-norm of the invariant noise polynomial - the invariant noise, and for correct decryption requireit to be less than - 1/2. If v denotes the invariant noise, we define the invariant noise budget - as -log2(2v). Thus, the invariant noise budget starts from some initial - value, which depends on the encryption parameters, and decreases when - computations are performed. When the budget reaches zero, the ciphertext - becomes too noisy to decrypt correctly. - - @param[in] encrypted The ciphertext - @throws std::invalid_argument if the scheme is not BFV - @throws std::invalid_argument if encrypted is not valid for the encryption - parameters - @throws std::invalid_argument if encrypted is in NTT form - */ - SEAL_NODISCARD int invariant_noise_budget(const Ciphertext &encrypted); - - private: - void bfv_decrypt(const Ciphertext &encrypted, Plaintext &destination, - MemoryPoolHandle pool); - - void ckks_decrypt(const Ciphertext &encrypted, Plaintext &destination, - MemoryPoolHandle pool); - - Decryptor(const Decryptor ©) = delete; - - Decryptor(Decryptor &&source) = delete; - - Decryptor &operator =(const Decryptor &assign) = delete; - - Decryptor &operator =(Decryptor &&assign) = delete; - - void compute_secret_key_array(std::size_t max_power); - - void compose(const SEALContext::ContextData &context_data, - std::uint64_t *value); - - // We use a fresh memory pool with `clear_on_destruction' enabled. - MemoryPoolHandle pool_ = MemoryManager::GetPool(mm_prof_opt::FORCE_NEW, true); - - std::shared_ptr context_{ nullptr }; - - std::size_t secret_key_array_size_ = 0; - - util::Pointer secret_key_array_; - - mutable util::ReaderWriterLocker secret_key_array_locker_; - }; -} diff --git a/SEAL/native/src/seal/encryptionparams.cpp b/SEAL/native/src/seal/encryptionparams.cpp deleted file mode 100644 index 705076c..0000000 --- a/SEAL/native/src/seal/encryptionparams.cpp +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/encryptionparams.h" -#include "seal/util/uintcore.h" -#include - -using namespace std; -using namespace seal::util; - -namespace seal -{ - void EncryptionParameters::Save(const EncryptionParameters &parms, ostream &stream) - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - auto old_except_mask = stream.exceptions(); - try - { - stream.exceptions(ios_base::badbit | ios_base::failbit); - - uint64_t poly_modulus_degree64 = static_cast(parms.poly_modulus_degree()); - uint64_t coeff_mod_count64 = static_cast(parms.coeff_modulus().size()); - uint8_t scheme = static_cast(parms.scheme()); - - stream.write(reinterpret_cast(&scheme), sizeof(uint8_t)); - stream.write(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); - stream.write(reinterpret_cast(&coeff_mod_count64), sizeof(uint64_t)); - for (const auto &mod : parms.coeff_modulus()) - { - mod.save(stream); - } - // CKKS does not use plain_modulus - if (parms.scheme() == scheme_type::BFV) - { - parms.plain_modulus().save(stream); - } - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - } - - EncryptionParameters EncryptionParameters::Load(istream &stream) - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - auto old_except_mask = stream.exceptions(); - try - { - stream.exceptions(ios_base::badbit | ios_base::failbit); - - // Read the scheme identifier - uint8_t scheme; - stream.read(reinterpret_cast(&scheme), sizeof(uint8_t)); - - // This constructor will throw if scheme is invalid - EncryptionParameters parms(scheme); - - // Read the poly_modulus_degree - uint64_t poly_modulus_degree64 = 0; - stream.read(reinterpret_cast(&poly_modulus_degree64), sizeof(uint64_t)); - if (poly_modulus_degree64 < SEAL_POLY_MOD_DEGREE_MIN || - poly_modulus_degree64 > SEAL_POLY_MOD_DEGREE_MAX) - { - throw invalid_argument("poly_modulus_degree is invalid"); - } - - // Read the coeff_modulus size - uint64_t coeff_mod_count64 = 0; - stream.read(reinterpret_cast(&coeff_mod_count64), sizeof(uint64_t)); - if (coeff_mod_count64 > SEAL_COEFF_MOD_COUNT_MAX || - coeff_mod_count64 < SEAL_COEFF_MOD_COUNT_MIN) - { - throw invalid_argument("coeff_modulus is invalid"); - } - - // Read the coeff_modulus - vector coeff_modulus(coeff_mod_count64); - for (auto &mod : coeff_modulus) - { - mod.load(stream); - } - - // Read the plain_modulus - SmallModulus plain_modulus; - // CKKS does not use plain_modulus - if (parms.scheme() == scheme_type::BFV) - { - plain_modulus.load(stream); - } - - // Supposedly everything worked so set the values of member variables - parms.set_poly_modulus_degree(safe_cast(poly_modulus_degree64)); - parms.set_coeff_modulus(coeff_modulus); - // CKKS does not use plain_modulus - if (parms.scheme() == scheme_type::BFV) - { - parms.set_plain_modulus(plain_modulus); - } - - stream.exceptions(old_except_mask); - return parms; - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - catch (...) - { - stream.exceptions(old_except_mask); - throw; - } - } - - void EncryptionParameters::compute_parms_id() - { - size_t coeff_mod_count = coeff_modulus_.size(); - - size_t total_uint64_count = add_safe( - size_t(1), // scheme - size_t(1), // poly_modulus_degree - coeff_mod_count, - plain_modulus_.uint64_count() - ); - - auto param_data(allocate_uint(total_uint64_count, pool_)); - uint64_t *param_data_ptr = param_data.get(); - - // Write the scheme identifier - *param_data_ptr++ = static_cast(scheme_); - - // Write the poly_modulus_degree. Note that it will always be positive. - *param_data_ptr++ = static_cast(poly_modulus_degree_); - - for(const auto &mod : coeff_modulus_) - { - *param_data_ptr++ = mod.value(); - } - - set_uint_uint(plain_modulus_.data(), plain_modulus_.uint64_count(), param_data_ptr); - param_data_ptr += plain_modulus_.uint64_count(); - - HashFunction::sha3_hash(param_data.get(), total_uint64_count, parms_id_); - - // Did we somehow manage to get a zero block as result? This is reserved for - // plaintexts to indicate non-NTT-transformed form. - if (parms_id_ == parms_id_zero) - { - throw logic_error("parms_id cannot be zero"); - } - } -} diff --git a/SEAL/native/src/seal/encryptionparams.h b/SEAL/native/src/seal/encryptionparams.h deleted file mode 100644 index 11eecbb..0000000 --- a/SEAL/native/src/seal/encryptionparams.h +++ /dev/null @@ -1,398 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include "seal/util/defines.h" -#include "seal/util/globals.h" -#include "seal/randomgen.h" -#include "seal/smallmodulus.h" -#include "seal/util/hash.h" -#include "seal/memorymanager.h" - -namespace seal -{ - /** - Describes the type of encryption scheme to be used. - */ - enum class scheme_type : std::uint8_t - { - // Brakerski/Fan-Vercauteren scheme - BFV = 0x1, - - // Cheon-Kim-Kim-Song scheme - CKKS = 0x2 - }; - - /** - The data type to store unique identifiers of encryption parameters. - */ - using parms_id_type = util::HashFunction::sha3_block_type; - - /** - A parms_id_type value consisting of zeros. - */ - static constexpr parms_id_type parms_id_zero = - util::HashFunction::sha3_zero_block; - - /** - Represents user-customizable encryption scheme settings. The parameters (most - importantly poly_modulus, coeff_modulus, plain_modulus) significantly affect - the performance, capabilities, and security of the encryption scheme. Once - an instance of EncryptionParameters is populated with appropriate parameters, - it can be used to create an instance of the SEALContext class, which verifies - the validity of the parameters, and performs necessary pre-computations. - - Picking appropriate encryption parameters is essential to enable a particular - application while balancing performance and security. Some encryption settings - will not allow some inputs (e.g. attempting to encrypt a polynomial with more - coefficients than poly_modulus or larger coefficients than plain_modulus) or, - support the desired computations (with noise growing too fast due to too large - plain_modulus and too small coeff_modulus). - - @par parms_id - The EncryptionParameters class maintains at all times a 256-bit SHA-3 hash of - the currently set encryption parameters called the parms_id. This hash acts as - a unique identifier of the encryption parameters and is used by all further - objects created for these encryption parameters. The parms_id is not intended - to be directly modified by the user but is used internally for pre-computation - data lookup and input validity checks. In modulus switching the user can use - the parms_id to keep track of the chain of encryption parameters. The parms_id - is not exposed in the public API of EncryptionParameters, but can be accessed - through the SEALContext::ContextData class once the SEALContext has been created. - - @par Thread Safety - In general, reading from EncryptionParameters is thread-safe, while mutating - is not. - - @warning Choosing inappropriate encryption parameters may lead to an encryption - scheme that is not secure, does not perform well, and/or does not support the - input and computation of the desired application. We highly recommend consulting - an expert in RLWE-based encryption when selecting parameters, as this is where - inexperienced users seem to most often make critical mistakes. - */ - class EncryptionParameters - { - friend class SEALContext; - - public: - /** - Creates an empty set of encryption parameters. At a minimum, the user needs - to specify the parameters poly_modulus, coeff_modulus, and plain_modulus - for the parameters to be usable. - */ - EncryptionParameters(scheme_type scheme) : scheme_(scheme) - { - compute_parms_id(); - } - - /** - Creates an empty set of encryption parameters. At a minimum, the user needs - to specify the parameters poly_modulus, coeff_modulus, and plain_modulus - for the parameters to be usable. - - @throws std::invalid_argument if scheme is not supported - @see scheme_type for the supported schemes - */ - EncryptionParameters(std::uint8_t scheme) - { - // Check that a valid scheme is given - if (!is_valid_scheme(scheme)) - { - throw std::invalid_argument("unsupported scheme"); - } - - scheme_ = static_cast(scheme); - compute_parms_id(); - } - - /** - Creates a copy of a given instance of EncryptionParameters. - - @param[in] copy The EncryptionParameters to copy from - */ - EncryptionParameters(const EncryptionParameters ©) = default; - - /** - Overwrites the EncryptionParameters instance with a copy of a given instance. - - @param[in] assign The EncryptionParameters to copy from - */ - EncryptionParameters &operator =(const EncryptionParameters &assign) = default; - - /** - Creates a new EncryptionParameters instance by moving a given instance. - - @param[in] source The EncryptionParameters to move from - */ - EncryptionParameters(EncryptionParameters &&source) = default; - - /** - Overwrites the EncryptionParameters instance by moving a given instance. - - @param[in] assign The EncryptionParameters to move from - */ - EncryptionParameters &operator =(EncryptionParameters &&assign) = default; - - /** - Sets the degree of the polynomial modulus parameter to the specified value. - The polynomial modulus directly affects the number of coefficients in - plaintext polynomials, the size of ciphertext elements, the computational - performance of the scheme (bigger is worse), and the security level (bigger - is better). In Microsoft SEAL the degree of the polynomial modulus must be a power - of 2 (e.g. 1024, 2048, 4096, 8192, 16384, or 32768). - - @param[in] poly_modulus_degree The new polynomial modulus degree - */ - inline void set_poly_modulus_degree(std::size_t poly_modulus_degree) - { - // Set the degree - poly_modulus_degree_ = poly_modulus_degree; - - // Re-compute the parms_id - compute_parms_id(); - } - - /** - Sets the coefficient modulus parameter. The coefficient modulus consists - of a list of distinct prime numbers, and is represented by a vector of - SmallModulus objects. The coefficient modulus directly affects the size - of ciphertext elements, the amount of computation that the scheme can perform - (bigger is better), and the security level (bigger is worse). In Microsoft SEAL each - of the prime numbers in the coefficient modulus must be at most 60 bits, - and must be congruent to 1 modulo 2*poly_modulus_degree. - - @param[in] coeff_modulus The new coefficient modulus - @throws std::invalid_argument if size of coeff_modulus is invalid - */ - inline void set_coeff_modulus(const std::vector &coeff_modulus) - { - // Set the coeff_modulus_ - if (coeff_modulus.size() > SEAL_COEFF_MOD_COUNT_MAX || - coeff_modulus.size() < SEAL_COEFF_MOD_COUNT_MIN) - { - throw std::invalid_argument("coeff_modulus is invalid"); - } - - coeff_modulus_ = coeff_modulus; - - // Re-compute the parms_id - compute_parms_id(); - } - - /** - Sets the plaintext modulus parameter. The plaintext modulus is an integer - modulus represented by the SmallModulus class. The plaintext modulus - determines the largest coefficient that plaintext polynomials can represent. - It also affects the amount of computation that the scheme can perform - (bigger is worse). In Microsoft SEAL the plaintext modulus can be at most 60 bits - long, but can otherwise be any integer. Note, however, that some features - (e.g. batching) require the plaintext modulus to be of a particular form. - - @param[in] plain_modulus The new plaintext modulus - @throws std::logic_error if scheme is not scheme_type::BFV - */ - inline void set_plain_modulus(const SmallModulus &plain_modulus) - { - // CKKS does not use plain_modulus - if (scheme_ != scheme_type::BFV) - { - throw std::logic_error("unsupported scheme"); - } - - plain_modulus_ = plain_modulus; - - // Re-compute the parms_id - compute_parms_id(); - } - - /** - Sets the plaintext modulus parameter. The plaintext modulus is an integer - modulus represented by the SmallModulus class. This constructor instead - takes a std::uint64_t and automatically creates the SmallModulus object. - The plaintext modulus determines the largest coefficient that plaintext - polynomials can represent. It also affects the amount of computation that - the scheme can perform (bigger is worse). In Microsoft SEAL the plaintext modulus - can be at most 60 bits long, but can otherwise be any integer. Note, - however, that some features (e.g. batching) require the plaintext modulus - to be of a particular form. - - @param[in] plain_modulus The new plaintext modulus - @throws std::invalid_argument if plain_modulus is invalid - */ - inline void set_plain_modulus(std::uint64_t plain_modulus) - { - set_plain_modulus(SmallModulus(plain_modulus)); - } - - /** - Sets the random number generator factory to use for encryption. By default, - the random generator is set to UniformRandomGeneratorFactory::default_factory(). - Setting this value allows a user to specify a custom random number generator - source. - - @param[in] random_generator Pointer to the random generator factory - */ - inline void set_random_generator( - std::shared_ptr random_generator) noexcept - { - random_generator_ = std::move(random_generator); - } - - /** - Returns the encryption scheme type. - */ - SEAL_NODISCARD inline scheme_type scheme() const noexcept - { - return scheme_; - } - - /** - Returns the degree of the polynomial modulus parameter. - */ - SEAL_NODISCARD inline std::size_t poly_modulus_degree() const noexcept - { - return poly_modulus_degree_; - } - - /** - Returns a const reference to the currently set coefficient modulus parameter. - */ - SEAL_NODISCARD inline auto coeff_modulus() const noexcept - -> const std::vector& - { - return coeff_modulus_; - } - - /** - Returns a const reference to the currently set plaintext modulus parameter. - */ - SEAL_NODISCARD inline const SmallModulus &plain_modulus() const noexcept - { - return plain_modulus_; - } - - /** - Returns a pointer to the random number generator factory to use for encryption. - */ - SEAL_NODISCARD inline auto random_generator() const noexcept - -> std::shared_ptr - { - return random_generator_; - } - - /** - Compares a given set of encryption parameters to the current set of - encryption parameters. The comparison is performed by comparing the - parms_ids of the parameter sets rather than comparing the parameters - individually. - - @parms[in] other The EncryptionParameters to compare against - */ - SEAL_NODISCARD inline bool operator ==( - const EncryptionParameters &other) const noexcept - { - return (parms_id_ == other.parms_id_); - } - - /** - Compares a given set of encryption parameters to the current set of - encryption parameters. The comparison is performed by comparing - parms_ids of the parameter sets rather than comparing the parameters - individually. - - @parms[in] other The EncryptionParameters to compare against - */ - SEAL_NODISCARD inline bool operator !=( - const EncryptionParameters &other) const noexcept - { - return (parms_id_ != other.parms_id_); - } - - /** - Saves EncryptionParameters to an output stream. The output is in binary - format and is not human-readable. The output stream must have the "binary" - flag set. - - @param[in] stream The stream to save the EncryptionParameters to - @throws std::exception if the EncryptionParameters could not be written - to stream - */ - static void Save(const EncryptionParameters &parms, std::ostream &stream); - - /** - Loads EncryptionParameters from an input stream. - - @param[in] stream The stream to load the EncryptionParameters from - @throws std::exception if valid EncryptionParameters could not be read - from stream - */ - SEAL_NODISCARD static EncryptionParameters Load(std::istream &stream); - - /** - Enables access to private members of seal::EncryptionParameters for .NET - wrapper. - */ - struct EncryptionParametersPrivateHelper; - - private: - /** - Helper function to determine whether given std::uint8_t represents a valid - value for scheme_type. - */ - SEAL_NODISCARD bool is_valid_scheme(std::uint8_t scheme) const noexcept - { - return (scheme == static_cast(scheme_type::BFV) || - (scheme == static_cast(scheme_type::CKKS))); - } - - /** - Returns the parms_id of the current parameters. This function is intended - for internal use. - */ - SEAL_NODISCARD inline auto &parms_id() const noexcept - { - return parms_id_; - } - - void compute_parms_id(); - - MemoryPoolHandle pool_ = MemoryManager::GetPool(); - - scheme_type scheme_; - - std::size_t poly_modulus_degree_ = 0; - - std::vector coeff_modulus_{}; - - std::shared_ptr random_generator_{ nullptr }; - - SmallModulus plain_modulus_{}; - - parms_id_type parms_id_ = parms_id_zero; - }; -} - -/** -Specializes the std::hash template for parms_id_type. -*/ -namespace std -{ - template<> - struct hash - { - std::size_t operator()( - const seal::parms_id_type &parms_id) const - { - std::uint64_t result = 17; - result = 31 * result + parms_id[0]; - result = 31 * result + parms_id[1]; - result = 31 * result + parms_id[2]; - result = 31 * result + parms_id[3]; - return static_cast(result); - } - }; -} diff --git a/SEAL/native/src/seal/encryptor.cpp b/SEAL/native/src/seal/encryptor.cpp deleted file mode 100644 index 63c6858..0000000 --- a/SEAL/native/src/seal/encryptor.cpp +++ /dev/null @@ -1,245 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include -#include "seal/encryptor.h" -#include "seal/randomgen.h" -#include "seal/randomtostd.h" -#include "seal/smallmodulus.h" -#include "seal/util/common.h" -#include "seal/util/uintarith.h" -#include "seal/util/polyarithsmallmod.h" -#include "seal/util/clipnormal.h" -#include "seal/util/smallntt.h" -#include "seal/util/rlwe.h" - -using namespace std; -using namespace seal::util; - -namespace seal -{ - Encryptor::Encryptor(shared_ptr context, - const PublicKey &public_key) : context_(move(context)), - public_key_(public_key) - { - // Verify parameters - if (!context_) - { - throw invalid_argument("invalid context"); - } - if (!context_->parameters_set()) - { - throw invalid_argument("encryption parameters are not set correctly"); - } - if (public_key.parms_id() != context_->key_parms_id()) - { - throw invalid_argument("public key is not valid for encryption parameters"); - } - - auto &parms = context_->key_context_data()->parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - - // Quick sanity check - if (!product_fits_in(coeff_count, coeff_mod_count, size_t(2))) - { - throw logic_error("invalid parameters"); - } - } - - void Encryptor::encrypt_zero(parms_id_type parms_id, - Ciphertext &destination, - MemoryPoolHandle pool) - { - // Verify parameters. - auto context_data_ptr = context_->get_context_data(parms_id); - if (!context_data_ptr) - { - throw invalid_argument("parms_id is not valid for encryption parameters"); - } - - auto &context_data = *context_->get_context_data(parms_id); - auto &parms = context_data.parms(); - size_t coeff_mod_count = parms.coeff_modulus().size(); - size_t coeff_count = parms.poly_modulus_degree(); - - bool is_ntt_form = false; - if (parms.scheme() == scheme_type::CKKS) - { - is_ntt_form = true; - } - else if (parms.scheme() != scheme_type::BFV) - { - throw invalid_argument("unsupported scheme"); - } - - shared_ptr random( - parms.random_generator()->create()); - - // Resize destination and save results - destination.resize(context_, parms_id, 2); - - auto prev_context_data_ptr = context_data.prev_context_data(); - if (prev_context_data_ptr) - { - auto &prev_context_data = *prev_context_data_ptr; - auto &prev_parms_id = prev_context_data.parms_id(); - auto &base_converter = prev_context_data.base_converter(); - - // Zero encryption without modulus switching - Ciphertext temp(pool); - encrypt_zero_asymmetric(public_key_, context_, prev_parms_id, - random, is_ntt_form, temp, pool); - if (temp.is_ntt_form() != is_ntt_form) - { - throw invalid_argument("NTT form mismatch"); - } - - // Modulus switching - for (size_t j = 0; j < 2; j++) - { - if (is_ntt_form) - { - base_converter->round_last_coeff_modulus_ntt_inplace( - temp.data(j), - prev_context_data.small_ntt_tables(), - pool); - } - else - { - base_converter->round_last_coeff_modulus_inplace( - temp.data(j), - pool); - } - set_poly_poly( - temp.data(j), - coeff_count, - coeff_mod_count, - destination.data(j)); - } - - destination.is_ntt_form() = is_ntt_form; - - // Need to set the scale here since encrypt_zero_asymmetric only sets - // it for temp - destination.scale() = temp.scale(); - } - else - { - encrypt_zero_asymmetric(public_key_, context_, - parms_id, random, is_ntt_form, destination, pool); - } - } - - void Encryptor::encrypt(const Plaintext &plain, - Ciphertext &destination, - MemoryPoolHandle pool) - { - // Verify parameters. - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - // Verify that plain is valid. - if (!is_valid_for(plain, context_)) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - - auto scheme = context_->key_context_data()->parms().scheme(); - if (scheme == scheme_type::BFV) - { - if (plain.is_ntt_form()) - { - throw invalid_argument("plain cannot be in NTT form"); - } - - encrypt_zero(context_->first_parms_id(), destination); - - // Multiply plain by scalar coeff_div_plaintext and reposition if in upper-half. - // Result gets added into the c_0 term of ciphertext (c_0,c_1). - preencrypt(plain.data(), - plain.coeff_count(), - *context_->first_context_data(), - destination.data()); - } - else if (scheme == scheme_type::CKKS) - { - if (!plain.is_ntt_form()) - { - throw invalid_argument("plain must be in NTT form"); - } - auto context_data_ptr = context_->get_context_data(plain.parms_id()); - if (!context_data_ptr) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - auto &context_data = *context_->get_context_data(plain.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t coeff_count = parms.poly_modulus_degree(); - - encrypt_zero(context_data.parms_id(), destination); - - // The plaintext gets added into the c_0 term of ciphertext (c_0,c_1). - for (size_t i = 0; i < coeff_mod_count; i++) - { - add_poly_poly_coeffmod( - destination.data() + (i * coeff_count), - plain.data() + (i * coeff_count), - coeff_count, - coeff_modulus[i], - destination.data() + (i * coeff_count)); - } - destination.scale() = plain.scale(); - } - else - { - throw invalid_argument("unsupported scheme"); - } - } - - void Encryptor::preencrypt(const uint64_t *plain, size_t plain_coeff_count, - const SEALContext::ContextData &context_data, uint64_t *destination) - { - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - - auto coeff_div_plain_modulus = context_data.coeff_div_plain_modulus(); - auto plain_upper_half_threshold = context_data.plain_upper_half_threshold(); - auto upper_half_increment = context_data.upper_half_increment(); - - // Multiply plain by scalar coeff_div_plain_modulus_ and reposition if in upper-half. - for (size_t i = 0; i < plain_coeff_count; i++, destination++) - { - if (plain[i] >= plain_upper_half_threshold) - { - // Loop over primes - for (size_t j = 0; j < coeff_mod_count; j++) - { - unsigned long long temp[2]{ 0, 0 }; - multiply_uint64(coeff_div_plain_modulus[j], plain[i], temp); - temp[1] += add_uint64(temp[0], upper_half_increment[j], 0, temp); - uint64_t scaled_plain_coeff = barrett_reduce_128(temp, coeff_modulus[j]); - destination[j * coeff_count] = add_uint_uint_mod( - destination[j * coeff_count], scaled_plain_coeff, coeff_modulus[j]); - } - } - else - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - uint64_t scaled_plain_coeff = multiply_uint_uint_mod( - coeff_div_plain_modulus[j], plain[i], coeff_modulus[j]); - destination[j * coeff_count] = add_uint_uint_mod( - destination[j * coeff_count], scaled_plain_coeff, coeff_modulus[j]); - } - } - } - } -} diff --git a/SEAL/native/src/seal/encryptor.h b/SEAL/native/src/seal/encryptor.h deleted file mode 100644 index b51b0e8..0000000 --- a/SEAL/native/src/seal/encryptor.h +++ /dev/null @@ -1,125 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include "seal/util/defines.h" -#include "seal/encryptionparams.h" -#include "seal/plaintext.h" -#include "seal/ciphertext.h" -#include "seal/memorymanager.h" -#include "seal/context.h" -#include "seal/publickey.h" -#include "seal/util/smallntt.h" - -namespace seal -{ - /** - Encrypts Plaintext objects into Ciphertext objects. Constructing an Encryptor - requires a SEALContext with valid encryption parameters, and the public key. - - @par Overloads - For the encrypt function we provide two overloads concerning the memory pool - used in allocations needed during the operation. In one overload the global - memory pool is used for this purpose, and in another overload the user can - supply a MemoryPoolHandle to to be used instead. This is to allow one single - Encryptor to be used concurrently by several threads without running into thread - contention in allocations taking place during operations. For example, one can - share one single Encryptor across any number of threads, but in each thread - call the encrypt function by giving it a thread-local MemoryPoolHandle to use. - It is important for a developer to understand how this works to avoid unnecessary - performance bottlenecks. - - @par NTT form - When using the BFV scheme (scheme_type::BFV), all plaintext and ciphertexts should - remain by default in the usual coefficient representation, i.e. not in NTT form. - When using the CKKS scheme (scheme_type::CKKS), all plaintexts and ciphertexts - should remain by default in NTT form. We call these scheme-specific NTT states - the "default NTT form". Decryption requires the input ciphertexts to be in - the default NTT form, and will throw an exception if this is not the case. - */ - class SEAL_NODISCARD Encryptor - { - public: - /** - Creates an Encryptor instance initialized with the specified SEALContext - and public key. - - @param[in] context The SEALContext - @param[in] public_key The public key - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::invalid_argument if public_key is not valid - */ - Encryptor(std::shared_ptr context, const PublicKey &public_key); - - /** - Encrypts a plaintext and stores the result in the destination parameter. - Dynamic memory allocations in the process are allocated from the memory - pool pointed to by the given MemoryPoolHandle. - - @param[in] plain The plaintext to encrypt - @param[out] destination The ciphertext to overwrite with the encrypted plaintext - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if plain is not in default NTT form - @throws std::invalid_argument if pool is uninitialized - */ - void encrypt(const Plaintext &plain, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Encrypts a zero plaintext and stores the result in the destination parameter. - The encryption parameters for the resulting ciphertext correspond to the given - parms_id. Dynamic memory allocations in the process are allocated from the memory - pool pointed to by the given MemoryPoolHandle. - - @param[in] parms_id The parms_id for the resulting ciphertext - @param[out] destination The ciphertext to overwrite with the encrypted plaintext - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if parms_id is not valid for the encryption parameters - @throws std::invalid_argument if pool is uninitialized - */ - void encrypt_zero( - parms_id_type parms_id, - Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Encrypts a zero plaintext and stores the result in the destination parameter. - The encryption parameters for the resulting ciphertext correspond to the - highest (data) level in the modulus switching chain. Dynamic memory allocations - in the process are allocated from the memory pool pointed to by the given - MemoryPoolHandle. - - @param[out] destination The ciphertext to overwrite with the encrypted plaintext - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if pool is uninitialized - */ - inline void encrypt_zero(Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - encrypt_zero(context_->first_parms_id(), destination, pool); - } - - private: - Encryptor(const Encryptor ©) = delete; - - Encryptor(Encryptor &&source) = delete; - - Encryptor &operator =(const Encryptor &assign) = delete; - - Encryptor &operator =(Encryptor &&assign) = delete; - - void preencrypt(const std::uint64_t *plain, std::size_t plain_coeff_count, - const SEALContext::ContextData &context_data, std::uint64_t *destination); - - MemoryPoolHandle pool_ = MemoryManager::GetPool(); - - std::shared_ptr context_{ nullptr }; - - PublicKey public_key_; - }; -} diff --git a/SEAL/native/src/seal/evaluator.cpp b/SEAL/native/src/seal/evaluator.cpp deleted file mode 100644 index c07177a..0000000 --- a/SEAL/native/src/seal/evaluator.cpp +++ /dev/null @@ -1,2860 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include -#include -#include -#include -#include "seal/evaluator.h" -#include "seal/util/common.h" -#include "seal/util/uintarith.h" -#include "seal/util/polycore.h" -#include "seal/util/polyarithsmallmod.h" - -using namespace std; -using namespace seal::util; - -namespace seal -{ - namespace - { - template - inline bool are_same_scale(const T &value1, const S &value2) noexcept - { - return util::are_close(value1.scale(), value2.scale()); - } - } - - Evaluator::Evaluator(shared_ptr context) : context_(move(context)) - { - // Verify parameters - if (!context_) - { - throw invalid_argument("invalid context"); - } - if (!context_->parameters_set()) - { - throw invalid_argument("encryption parameters are not set correctly"); - } - - // Calculate map from Zmstar to generator representation - populate_Zmstar_to_generator(); - } - - void Evaluator::populate_Zmstar_to_generator() - { - uint64_t n = static_cast( - context_->first_context_data()->parms().poly_modulus_degree()); - uint64_t m = n << 1; - - for (uint64_t i = 0; i < n / 2; i++) - { - uint64_t galois_elt = exponentiate_uint64(3, i) & (m - 1); - pair temp_pair1{ i, 0 }; - Zmstar_to_generator_.emplace(galois_elt, temp_pair1); - galois_elt = (exponentiate_uint64(3, i) * (m - 1)) & (m - 1); - pair temp_pair2{ i, 1 }; - Zmstar_to_generator_.emplace(galois_elt, temp_pair2); - } - } - - void Evaluator::negate_inplace(Ciphertext &encrypted) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - - // Extract encryption parameters. - auto &context_data = *context_->get_context_data(encrypted.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t encrypted_size = encrypted.size(); - - // Negate each poly in the array - for (size_t j = 0; j < encrypted_size; j++) - { - for (size_t i = 0; i < coeff_mod_count; i++) - { - negate_poly_coeffmod(encrypted.data(j) + (i * coeff_count), - coeff_count, coeff_modulus[i], encrypted.data(j) + (i * coeff_count)); - } - } -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::add_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted1, context_)) - { - throw invalid_argument("encrypted1 is not valid for encryption parameters"); - } - if (!is_metadata_valid_for(encrypted2, context_)) - { - throw invalid_argument("encrypted2 is not valid for encryption parameters"); - } - if (encrypted1.parms_id() != encrypted2.parms_id()) - { - throw invalid_argument("encrypted1 and encrypted2 parameter mismatch"); - } - if (encrypted1.is_ntt_form() != encrypted2.is_ntt_form()) - { - throw invalid_argument("NTT form mismatch"); - } - if (!are_same_scale(encrypted1, encrypted2)) - { - throw invalid_argument("scale mismatch"); - } - - // Extract encryption parameters. - auto &context_data = *context_->get_context_data(encrypted1.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t encrypted1_size = encrypted1.size(); - size_t encrypted2_size = encrypted2.size(); - size_t max_count = max(encrypted1_size, encrypted2_size); - size_t min_count = min(encrypted1_size, encrypted2_size); - - // Size check - if (!product_fits_in(max_count, coeff_count)) - { - throw logic_error("invalid parameters"); - } - - // Prepare destination - encrypted1.resize(context_, context_data.parms_id(), max_count); - - // Add ciphertexts - for (size_t j = 0; j < min_count; j++) - { - uint64_t *encrypted1_ptr = encrypted1.data(j); - const uint64_t *encrypted2_ptr = encrypted2.data(j); - for (size_t i = 0; i < coeff_mod_count; i++) - { - add_poly_poly_coeffmod(encrypted1_ptr + (i * coeff_count), - encrypted2_ptr + (i * coeff_count), coeff_count, coeff_modulus[i], - encrypted1_ptr + (i * coeff_count)); - } - } - - // Copy the remainding polys of the array with larger count into encrypted1 - if (encrypted1_size < encrypted2_size) - { - set_poly_poly(encrypted2.data(min_count), - coeff_count * (encrypted2_size - encrypted1_size), - coeff_mod_count, encrypted1.data(encrypted1_size)); - } -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted1.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::add_many(const vector &encrypteds, Ciphertext &destination) - { - if (encrypteds.empty()) - { - throw invalid_argument("encrypteds cannot be empty"); - } - for (size_t i = 0; i < encrypteds.size(); i++) - { - if (&encrypteds[i] == &destination) - { - throw invalid_argument("encrypteds must be different from destination"); - } - } - destination = encrypteds[0]; - for (size_t i = 1; i < encrypteds.size(); i++) - { - add_inplace(destination, encrypteds[i]); - } - } - - void Evaluator::sub_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted1, context_)) - { - throw invalid_argument("encrypted1 is not valid for encryption parameters"); - } - if (!is_metadata_valid_for(encrypted2, context_)) - { - throw invalid_argument("encrypted2 is not valid for encryption parameters"); - } - if (encrypted1.parms_id() != encrypted2.parms_id()) - { - throw invalid_argument("encrypted1 and encrypted2 parameter mismatch"); - } - if (encrypted1.is_ntt_form() != encrypted2.is_ntt_form()) - { - throw invalid_argument("NTT form mismatch"); - } - if (!are_same_scale(encrypted1, encrypted2)) - { - throw invalid_argument("scale mismatch"); - } - - // Extract encryption parameters. - auto &context_data = *context_->get_context_data(encrypted1.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t encrypted1_size = encrypted1.size(); - size_t encrypted2_size = encrypted2.size(); - size_t max_count = max(encrypted1_size, encrypted2_size); - size_t min_count = min(encrypted1_size, encrypted2_size); - - // Size check - if (!product_fits_in(max_count, coeff_count)) - { - throw logic_error("invalid parameters"); - } - - // Prepare destination - encrypted1.resize(context_, context_data.parms_id(), max_count); - - // Subtract polynomials. - for (size_t j = 0; j < min_count; j++) - { - uint64_t *encrypted1_ptr = encrypted1.data(j); - const uint64_t *encrypted2_ptr = encrypted2.data(j); - for (size_t i = 0; i < coeff_mod_count; i++) - { - sub_poly_poly_coeffmod(encrypted1_ptr + (i * coeff_count), - encrypted2_ptr + (i * coeff_count), coeff_count, coeff_modulus[i], - encrypted1_ptr + (i * coeff_count)); - } - } - - // If encrypted2 has larger count, negate remaining entries - if (encrypted1_size < encrypted2_size) - { - for (size_t i = 0; i < coeff_mod_count; i++) - { - negate_poly_coeffmod(encrypted2.data(encrypted1_size) + (i * coeff_count), - coeff_count * (encrypted2_size - encrypted1_size), coeff_modulus[i], - encrypted1.data(encrypted1_size) + (i * coeff_count)); - } - } -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted1.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::multiply_inplace(Ciphertext &encrypted1, - const Ciphertext &encrypted2, MemoryPoolHandle pool) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted1, context_)) - { - throw invalid_argument("encrypted1 is not valid for encryption parameters"); - } - if (!is_metadata_valid_for(encrypted2, context_)) - { - throw invalid_argument("encrypted2 is not valid for encryption parameters"); - } - if (encrypted1.parms_id() != encrypted2.parms_id()) - { - throw invalid_argument("encrypted1 and encrypted2 parameter mismatch"); - } - - auto context_data_ptr = context_->first_context_data(); - switch (context_data_ptr->parms().scheme()) - { - case scheme_type::BFV: - bfv_multiply(encrypted1, encrypted2, pool); - break; - - case scheme_type::CKKS: - ckks_multiply(encrypted1, encrypted2, pool); - break; - - default: - throw invalid_argument("unsupported scheme"); - } -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted1.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::bfv_multiply(Ciphertext &encrypted1, - const Ciphertext &encrypted2, MemoryPoolHandle pool) - { - if (encrypted1.is_ntt_form() || encrypted2.is_ntt_form()) - { - throw invalid_argument("encrypted1 or encrypted2 cannot be in NTT form"); - } - - // Extract encryption parameters. - auto &context_data = *context_->get_context_data(encrypted1.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t encrypted1_size = encrypted1.size(); - size_t encrypted2_size = encrypted2.size(); - - uint64_t plain_modulus = parms.plain_modulus().value(); - auto &base_converter = context_data.base_converter(); - auto &bsk_modulus = base_converter->get_bsk_mod_array(); - size_t bsk_base_mod_count = base_converter->bsk_base_mod_count(); - size_t bsk_mtilde_count = add_safe(bsk_base_mod_count, size_t(1)); - auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); - auto &bsk_small_ntt_tables = base_converter->get_bsk_small_ntt_tables(); - - // Determine destination.size() - // Default is 3 (c_0, c_1, c_2) - size_t dest_count = sub_safe(add_safe(encrypted1_size, encrypted2_size), size_t(1)); - - // Size check - if (!product_fits_in(dest_count, coeff_count, bsk_mtilde_count)) - { - throw logic_error("invalid parameters"); - } - - // Prepare destination - encrypted1.resize(context_, context_data.parms_id(), dest_count); - - size_t encrypted_ptr_increment = coeff_count * coeff_mod_count; - size_t encrypted_bsk_mtilde_ptr_increment = coeff_count * bsk_mtilde_count; - size_t encrypted_bsk_ptr_increment = coeff_count * bsk_base_mod_count; - - // Make temp polys for FastBConverter result from q ---> Bsk U {m_tilde} - auto tmp_encrypted1_bsk_mtilde(allocate_poly( - coeff_count * encrypted1_size, bsk_mtilde_count, pool)); - auto tmp_encrypted2_bsk_mtilde(allocate_poly( - coeff_count * encrypted2_size, bsk_mtilde_count, pool)); - - // Make temp polys for FastBConverter result from Bsk U {m_tilde} -----> Bsk - auto tmp_encrypted1_bsk(allocate_poly( - coeff_count * encrypted1_size, bsk_base_mod_count, pool)); - auto tmp_encrypted2_bsk(allocate_poly( - coeff_count * encrypted2_size, bsk_base_mod_count, pool)); - - // Step 0: fast base convert from q to Bsk U {m_tilde} - // Step 1: reduce q-overflows in Bsk - // Iterate over all the ciphertexts inside encrypted1 - for (size_t i = 0; i < encrypted1_size; i++) - { - base_converter->fastbconv_mtilde( - encrypted1.data(i), - tmp_encrypted1_bsk_mtilde.get() + (i * encrypted_bsk_mtilde_ptr_increment), - pool); - base_converter->mont_rq( - tmp_encrypted1_bsk_mtilde.get() + (i * encrypted_bsk_mtilde_ptr_increment), - tmp_encrypted1_bsk.get() + (i * encrypted_bsk_ptr_increment)); - } - - // Iterate over all the ciphertexts inside encrypted2 - for (size_t i = 0; i < encrypted2_size; i++) - { - base_converter->fastbconv_mtilde( - encrypted2.data(i), - tmp_encrypted2_bsk_mtilde.get() + (i * encrypted_bsk_mtilde_ptr_increment), pool); - base_converter->mont_rq( - tmp_encrypted2_bsk_mtilde.get() + (i * encrypted_bsk_mtilde_ptr_increment), - tmp_encrypted2_bsk.get() + (i * encrypted_bsk_ptr_increment)); - } - - // Step 2: compute product and multiply plain modulus to the result - // We need to multiply both in q and Bsk. Values in encrypted_safe are in - // base q and values in tmp_encrypted_bsk are in base Bsk. We iterate over - // destination poly array and generate each poly based on the indices of - // inputs (arbitrary sizes for ciphertexts). First allocate two temp polys: - // one for results in base q and the other for the result in base Bsk. These - // need to be zero for the arbitrary size multiplication; not for 2x2 though - auto tmp_des_coeff_base(allocate_zero_poly( - coeff_count * dest_count, coeff_mod_count, pool)); - auto tmp_des_bsk_base(allocate_zero_poly( - coeff_count * dest_count, bsk_base_mod_count, pool)); - - // Allocate two tmp polys: one for NTT multiplication results in base q and - // one for result in base Bsk - auto tmp1_poly_coeff_base(allocate_poly(coeff_count, coeff_mod_count, pool)); - auto tmp1_poly_bsk_base(allocate_poly(coeff_count, bsk_base_mod_count, pool)); - auto tmp2_poly_coeff_base(allocate_poly(coeff_count, coeff_mod_count, pool)); - auto tmp2_poly_bsk_base(allocate_poly(coeff_count, bsk_base_mod_count, pool)); - - size_t current_encrypted1_limit = 0; - - // First convert all the inputs into NTT form - auto copy_encrypted1_ntt_coeff_mod(allocate_poly( - coeff_count * encrypted1_size, coeff_mod_count, pool)); - set_poly_poly(encrypted1.data(), coeff_count * encrypted1_size, - coeff_mod_count, copy_encrypted1_ntt_coeff_mod.get()); - - auto copy_encrypted1_ntt_bsk_base_mod(allocate_poly( - coeff_count * encrypted1_size, bsk_base_mod_count, pool)); - set_poly_poly(tmp_encrypted1_bsk.get(), coeff_count * encrypted1_size, - bsk_base_mod_count, copy_encrypted1_ntt_bsk_base_mod.get()); - - auto copy_encrypted2_ntt_coeff_mod(allocate_poly( - coeff_count * encrypted2_size, coeff_mod_count, pool)); - set_poly_poly(encrypted2.data(), coeff_count * encrypted2_size, - coeff_mod_count, copy_encrypted2_ntt_coeff_mod.get()); - - auto copy_encrypted2_ntt_bsk_base_mod(allocate_poly( - coeff_count * encrypted2_size, bsk_base_mod_count, pool)); - set_poly_poly(tmp_encrypted2_bsk.get(), coeff_count * encrypted2_size, - bsk_base_mod_count, copy_encrypted2_ntt_bsk_base_mod.get()); - - for (size_t i = 0; i < encrypted1_size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - // Lazy reduction - ntt_negacyclic_harvey_lazy(copy_encrypted1_ntt_coeff_mod.get() + - (j * coeff_count) + (i * encrypted_ptr_increment), coeff_small_ntt_tables[j]); - } - for (size_t j = 0; j < bsk_base_mod_count; j++) - { - // Lazy reduction - ntt_negacyclic_harvey_lazy(copy_encrypted1_ntt_bsk_base_mod.get() + - (j * coeff_count) + (i * encrypted_bsk_ptr_increment), bsk_small_ntt_tables[j]); - } - } - - for (size_t i = 0; i < encrypted2_size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - // Lazy reduction - ntt_negacyclic_harvey_lazy(copy_encrypted2_ntt_coeff_mod.get() + - (j * coeff_count) + (i * encrypted_ptr_increment), coeff_small_ntt_tables[j]); - } - for (size_t j = 0; j < bsk_base_mod_count; j++) - { - // Lazy reduction - ntt_negacyclic_harvey_lazy(copy_encrypted2_ntt_bsk_base_mod.get() + - (j * coeff_count) + (i * encrypted_bsk_ptr_increment), bsk_small_ntt_tables[j]); - } - } - - // Perform multiplication on arbitrary size ciphertexts - for (size_t secret_power_index = 0; - secret_power_index < dest_count; secret_power_index++) - { - // Loop over encrypted1 components [i], seeing if a match exists with an encrypted2 - // component [j] such that [i+j]=[secret_power_index] - // Only need to check encrypted1 components up to and including [secret_power_index], - // and strictly less than [encrypted_array.size()] - current_encrypted1_limit = min(encrypted1_size, secret_power_index + 1); - - for (size_t encrypted1_index = 0; - encrypted1_index < current_encrypted1_limit; encrypted1_index++) - { - // check if a corresponding component in encrypted2 exists - if (encrypted2_size > secret_power_index - encrypted1_index) - { - size_t encrypted2_index = secret_power_index - encrypted1_index; - - // NTT Multiplication and addition for results in q - for (size_t i = 0; i < coeff_mod_count; i++) - { - dyadic_product_coeffmod( - copy_encrypted1_ntt_coeff_mod.get() + (i * coeff_count) + - (encrypted_ptr_increment * encrypted1_index), - copy_encrypted2_ntt_coeff_mod.get() + (i * coeff_count) + - (encrypted_ptr_increment * encrypted2_index), - coeff_count, coeff_modulus[i], - tmp1_poly_coeff_base.get() + (i * coeff_count)); - add_poly_poly_coeffmod( - tmp1_poly_coeff_base.get() + (i * coeff_count), - tmp_des_coeff_base.get() + (i * coeff_count) + - (secret_power_index * coeff_count * coeff_mod_count), - coeff_count, coeff_modulus[i], - tmp_des_coeff_base.get() + (i * coeff_count) + - (secret_power_index * coeff_count * coeff_mod_count)); - } - - // NTT Multiplication and addition for results in Bsk - for (size_t i = 0; i < bsk_base_mod_count; i++) - { - dyadic_product_coeffmod( - copy_encrypted1_ntt_bsk_base_mod.get() + (i * coeff_count) + - (encrypted_bsk_ptr_increment * encrypted1_index), - copy_encrypted2_ntt_bsk_base_mod.get() + (i * coeff_count) + - (encrypted_bsk_ptr_increment * encrypted2_index), - coeff_count, bsk_modulus[i], - tmp1_poly_bsk_base.get() + (i * coeff_count)); - add_poly_poly_coeffmod( - tmp1_poly_bsk_base.get() + (i * coeff_count), - tmp_des_bsk_base.get() + (i * coeff_count) + - (secret_power_index * coeff_count * bsk_base_mod_count), - coeff_count, bsk_modulus[i], - tmp_des_bsk_base.get() + (i * coeff_count) + - (secret_power_index * coeff_count * bsk_base_mod_count)); - } - } - } - } - - // Convert back outputs from NTT form - for (size_t i = 0; i < dest_count; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - inverse_ntt_negacyclic_harvey( - tmp_des_coeff_base.get() + (i * (encrypted_ptr_increment)) + - (j * coeff_count), coeff_small_ntt_tables[j]); - } - for (size_t j = 0; j < bsk_base_mod_count; j++) - { - inverse_ntt_negacyclic_harvey( - tmp_des_bsk_base.get() + (i * (encrypted_bsk_ptr_increment)) + - (j * coeff_count), bsk_small_ntt_tables[j]); - } - } - - // Now we multiply plain modulus to both results in base q and Bsk and - // allocate them together in one container as - // (te0)q(te'0)Bsk | ... |te count)q (te' count)Bsk to make it ready for - // fast_floor - auto tmp_coeff_bsk_together(allocate_poly( - coeff_count, dest_count * (coeff_mod_count + bsk_base_mod_count), pool)); - uint64_t *tmp_coeff_bsk_together_ptr = tmp_coeff_bsk_together.get(); - - // Base q - for (size_t i = 0; i < dest_count; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - multiply_poly_scalar_coeffmod( - tmp_des_coeff_base.get() + (j * coeff_count) + (i * encrypted_ptr_increment), - coeff_count, plain_modulus, coeff_modulus[j], - tmp_coeff_bsk_together_ptr + (j * coeff_count)); - } - tmp_coeff_bsk_together_ptr += encrypted_ptr_increment; - - for (size_t k = 0; k < bsk_base_mod_count; k++) - { - multiply_poly_scalar_coeffmod( - tmp_des_bsk_base.get() + (k * coeff_count) + (i * encrypted_bsk_ptr_increment), - coeff_count, plain_modulus, bsk_modulus[k], - tmp_coeff_bsk_together_ptr + (k * coeff_count)); - } - tmp_coeff_bsk_together_ptr += encrypted_bsk_ptr_increment; - } - - // Allocate a new poly for fast floor result in Bsk - auto tmp_result_bsk(allocate_poly( - coeff_count, dest_count * bsk_base_mod_count, pool)); - for (size_t i = 0; i < dest_count; i++) - { - // Step 3: fast floor from q U {Bsk} to Bsk - base_converter->fast_floor( - tmp_coeff_bsk_together.get() + - (i * (encrypted_ptr_increment + encrypted_bsk_ptr_increment)), - tmp_result_bsk.get() + (i * encrypted_bsk_ptr_increment), pool); - - // Step 4: fast base convert from Bsk to q - base_converter->fastbconv_sk( - tmp_result_bsk.get() + (i * encrypted_bsk_ptr_increment), - encrypted1.data(i), pool); - } - } - - void Evaluator::ckks_multiply(Ciphertext &encrypted1, - const Ciphertext &encrypted2, MemoryPoolHandle pool) - { - if (!(encrypted1.is_ntt_form() && encrypted2.is_ntt_form())) - { - throw invalid_argument("encrypted1 or encrypted2 must be in NTT form"); - } - - // Extract encryption parameters. - auto &context_data = *context_->get_context_data(encrypted1.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t encrypted1_size = encrypted1.size(); - size_t encrypted2_size = encrypted2.size(); - - double new_scale = encrypted1.scale() * encrypted2.scale(); - - // Check that scale is positive and not too large - if (new_scale <= 0 || (static_cast(log2(new_scale)) >= - context_data.total_coeff_modulus_bit_count())) - { - throw invalid_argument("scale out of bounds"); - } - - // Determine destination.size() - // Default is 3 (c_0, c_1, c_2) - size_t dest_count = sub_safe(add_safe(encrypted1_size, encrypted2_size), size_t(1)); - - // Size check - if (!product_fits_in(dest_count, coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - // Prepare destination - encrypted1.resize(context_, context_data.parms_id(), dest_count); - - //pointer increment to switch to a next polynomial - size_t encrypted_ptr_increment = coeff_count * coeff_mod_count; - - //Step 1: naive multiplication modulo the coefficient modulus - //First allocate two temp polys : - //one for results in base q. This need to be zero - //for the arbitrary size multiplication; not for 2x2 though - auto tmp_des(allocate_zero_poly( - coeff_count * dest_count, coeff_mod_count, pool)); - - //Allocate tmp polys for NTT multiplication results in base q - auto tmp1_poly(allocate_poly(coeff_count, coeff_mod_count, pool)); - auto tmp2_poly(allocate_poly(coeff_count, coeff_mod_count, pool)); - - // First convert all the inputs into NTT form - auto copy_encrypted1_ntt(allocate_poly( - coeff_count * encrypted1_size, coeff_mod_count, pool)); - set_poly_poly(encrypted1.data(), coeff_count * encrypted1_size, - coeff_mod_count, copy_encrypted1_ntt.get()); - - auto copy_encrypted2_ntt(allocate_poly( - coeff_count * encrypted2_size, coeff_mod_count, pool)); - set_poly_poly(encrypted2.data(), coeff_count * encrypted2_size, - coeff_mod_count, copy_encrypted2_ntt.get()); - - // Perform multiplication on arbitrary size ciphertexts - - // Loop over encrypted1 components [i], seeing if a match exists with an encrypted2 - // component [j] such that [i+j]=[secret_power_index] - // Only need to check encrypted1 components up to and including [secret_power_index], - // and strictly less than [encrypted_array.size()] - - // Number of encrypted1 components to check - size_t current_encrypted1_limit = 0; - - for (size_t secret_power_index = 0; - secret_power_index < dest_count; secret_power_index++) - { - current_encrypted1_limit = min(encrypted1_size, secret_power_index + 1); - - for (size_t encrypted1_index = 0; - encrypted1_index < current_encrypted1_limit; encrypted1_index++) - { - // check if a corresponding component in encrypted2 exists - if (encrypted2_size > secret_power_index - encrypted1_index) - { - size_t encrypted2_index = secret_power_index - encrypted1_index; - - // NTT Multiplication and addition for results in q - for (size_t i = 0; i < coeff_mod_count; i++) - { - // ci * dj - dyadic_product_coeffmod( - copy_encrypted1_ntt.get() + (i * coeff_count) + - (encrypted_ptr_increment * encrypted1_index), - copy_encrypted2_ntt.get() + (i * coeff_count) + - (encrypted_ptr_increment * encrypted2_index), - coeff_count, coeff_modulus[i], - tmp1_poly.get() + (i * coeff_count)); - // Dest[i+j] - add_poly_poly_coeffmod( - tmp1_poly.get() + (i * coeff_count), - tmp_des.get() + (i * coeff_count) + - (secret_power_index * coeff_count * coeff_mod_count), - coeff_count, coeff_modulus[i], - tmp_des.get() + (i * coeff_count) + - (secret_power_index * coeff_count * coeff_mod_count)); - } - } - } - } - - // Set the final result - set_poly_poly(tmp_des.get(), coeff_count * dest_count, - coeff_mod_count, encrypted1.data()); - - // Set the scale - encrypted1.scale() = new_scale; - } - - void Evaluator::square_inplace(Ciphertext &encrypted, MemoryPoolHandle pool) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - - auto context_data_ptr = context_->first_context_data(); - switch (context_data_ptr->parms().scheme()) - { - case scheme_type::BFV: - bfv_square(encrypted, move(pool)); - break; - - case scheme_type::CKKS: - ckks_square(encrypted, move(pool)); - break; - - default: - throw invalid_argument("unsupported scheme"); - } -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::bfv_square(Ciphertext &encrypted, MemoryPoolHandle pool) - { - if (encrypted.is_ntt_form()) - { - throw invalid_argument("encrypted cannot be in NTT form"); - } - - // Extract encryption parameters. - auto &context_data = *context_->get_context_data(encrypted.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t encrypted_size = encrypted.size(); - - uint64_t plain_modulus = parms.plain_modulus().value(); - auto &base_converter = context_data.base_converter(); - auto &bsk_modulus = base_converter->get_bsk_mod_array(); - size_t bsk_base_mod_count = base_converter->bsk_base_mod_count(); - size_t bsk_mtilde_count = add_safe(bsk_base_mod_count, size_t(1)); - auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); - auto &bsk_small_ntt_tables = base_converter->get_bsk_small_ntt_tables(); - - // Optimization implemented currently only for size 2 ciphertexts - if (encrypted_size != 2) - { - bfv_multiply(encrypted, encrypted, move(pool)); - return; - } - - // Determine destination_array.size() - size_t dest_count = sub_safe(add_safe(encrypted_size, encrypted_size), size_t(1)); - - // Size check - if (!product_fits_in(dest_count, coeff_count, bsk_mtilde_count)) - { - throw logic_error("invalid parameters"); - } - - size_t encrypted_ptr_increment = coeff_count * coeff_mod_count; - size_t encrypted_bsk_mtilde_ptr_increment = coeff_count * bsk_mtilde_count; - size_t encrypted_bsk_ptr_increment = coeff_count * bsk_base_mod_count; - - // Prepare destination - encrypted.resize(context_, context_data.parms_id(), dest_count); - - // Make temp poly for FastBConverter result from q ---> Bsk U {m_tilde} - auto tmp_encrypted_bsk_mtilde(allocate_poly( - coeff_count * encrypted_size, bsk_mtilde_count, pool)); - - // Make temp poly for FastBConverter result from Bsk U {m_tilde} -----> Bsk - auto tmp_encrypted_bsk(allocate_poly( - coeff_count * encrypted_size, bsk_base_mod_count, pool)); - - // Step 0: fast base convert from q to Bsk U {m_tilde} - // Step 1: reduce q-overflows in Bsk - // Iterate over all the ciphertexts inside encrypted1 - for (size_t i = 0; i < encrypted_size; i++) - { - base_converter->fastbconv_mtilde( - encrypted.data(i), - tmp_encrypted_bsk_mtilde.get() + - (i * encrypted_bsk_mtilde_ptr_increment), pool); - base_converter->mont_rq( - tmp_encrypted_bsk_mtilde.get() + - (i * encrypted_bsk_mtilde_ptr_increment), - tmp_encrypted_bsk.get() + (i * encrypted_bsk_ptr_increment)); - } - - // Step 2: compute product and multiply plain modulus to the result. - // We need to multiply both in q and Bsk. Values in encrypted_safe are - // in base q and values in tmp_encrypted_bsk are in base Bsk. We iterate - // over destination poly array and generate each poly based on the indices - // of inputs (arbitrary sizes for ciphertexts). First allocate two temp polys: - // one for results in base q and the other for the result in base Bsk. - auto tmp_des_coeff_base(allocate_poly( - coeff_count * dest_count, coeff_mod_count, pool)); - auto tmp_des_bsk_base(allocate_poly( - coeff_count * dest_count, bsk_base_mod_count, pool)); - - // First convert all the inputs into NTT form - auto copy_encrypted_ntt_coeff_mod(allocate_poly( - coeff_count * encrypted_size, coeff_mod_count, pool)); - set_poly_poly(encrypted.data(), coeff_count * encrypted_size, - coeff_mod_count, copy_encrypted_ntt_coeff_mod.get()); - - auto copy_encrypted_ntt_bsk_base_mod(allocate_poly( - coeff_count * encrypted_size, bsk_base_mod_count, pool)); - set_poly_poly(tmp_encrypted_bsk.get(), coeff_count * encrypted_size, - bsk_base_mod_count, copy_encrypted_ntt_bsk_base_mod.get()); - - for (size_t i = 0; i < encrypted_size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - ntt_negacyclic_harvey_lazy( - copy_encrypted_ntt_coeff_mod.get() + (j * coeff_count) + - (i * encrypted_ptr_increment), coeff_small_ntt_tables[j]); - } - for (size_t j = 0; j < bsk_base_mod_count; j++) - { - ntt_negacyclic_harvey_lazy( - copy_encrypted_ntt_bsk_base_mod.get() + (j * coeff_count) + - (i * encrypted_bsk_ptr_increment), bsk_small_ntt_tables[j]); - } - } - - // Perform fast squaring - // Compute c0^2 in base q - for (size_t i = 0; i < coeff_mod_count; i++) - { - // Des[0] in q - dyadic_product_coeffmod( - copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count), - copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count), - coeff_count, coeff_modulus[i], - tmp_des_coeff_base.get() + (i * coeff_count)); - - // Des[2] in q - dyadic_product_coeffmod( - copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count) + encrypted_ptr_increment, - copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count) + encrypted_ptr_increment, - coeff_count, coeff_modulus[i], - tmp_des_coeff_base.get() + (i * coeff_count) + (2 * encrypted_ptr_increment)); - } - - // Compute c0^2 in base bsk - for (size_t i = 0; i < bsk_base_mod_count; i++) - { - // Des[0] in bsk - dyadic_product_coeffmod( - copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count), - copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count), - coeff_count, bsk_modulus[i], - tmp_des_bsk_base.get() + (i * coeff_count)); - - // Des[2] in bsk - dyadic_product_coeffmod( - copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count) + encrypted_bsk_ptr_increment, - copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count) + encrypted_bsk_ptr_increment, - coeff_count, bsk_modulus[i], - tmp_des_bsk_base.get() + (i * coeff_count) + (2 * encrypted_bsk_ptr_increment)); - } - - auto tmp_second_mul_coeff_base(allocate_poly(coeff_count, coeff_mod_count, pool)); - - // Compute 2*c0*c1 in base q - for (size_t i = 0; i < coeff_mod_count; i++) - { - dyadic_product_coeffmod( - copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count), - copy_encrypted_ntt_coeff_mod.get() + (i * coeff_count) + encrypted_ptr_increment, - coeff_count, coeff_modulus[i], - tmp_second_mul_coeff_base.get() + (i * coeff_count)); - add_poly_poly_coeffmod( - tmp_second_mul_coeff_base.get() + (i * coeff_count), - tmp_second_mul_coeff_base.get() + (i * coeff_count), - coeff_count, coeff_modulus[i], - tmp_des_coeff_base.get() + (i * coeff_count) + encrypted_ptr_increment); - } - - auto tmp_second_mul_bsk_base(allocate_poly(coeff_count, bsk_base_mod_count, pool)); - - // Compute 2*c0*c1 in base bsk - for (size_t i = 0; i < bsk_base_mod_count; i++) - { - dyadic_product_coeffmod( - copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count), - copy_encrypted_ntt_bsk_base_mod.get() + (i * coeff_count) + encrypted_bsk_ptr_increment, - coeff_count, bsk_modulus[i], - tmp_second_mul_bsk_base.get() + (i * coeff_count)); - add_poly_poly_coeffmod( - tmp_second_mul_bsk_base.get() + (i * coeff_count), - tmp_second_mul_bsk_base.get() + (i * coeff_count), - coeff_count, bsk_modulus[i], - tmp_des_bsk_base.get() + (i * coeff_count) + encrypted_bsk_ptr_increment); - } - - // Convert back outputs from NTT form - for (size_t i = 0; i < dest_count; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - inverse_ntt_negacyclic_harvey_lazy( - tmp_des_coeff_base.get() + (i * (encrypted_ptr_increment)) + (j * coeff_count), - coeff_small_ntt_tables[j]); - } - for (size_t j = 0; j < bsk_base_mod_count; j++) - { - inverse_ntt_negacyclic_harvey_lazy( - tmp_des_bsk_base.get() + (i * (encrypted_bsk_ptr_increment)) + - (j * coeff_count), bsk_small_ntt_tables[j]); - } - } - - // Now we multiply plain modulus to both results in base q and Bsk and - // allocate them together in one container as (te0)q(te'0)Bsk | ... |te count)q (te' count)Bsk - // to make it ready for fast_floor - auto tmp_coeff_bsk_together(allocate_poly( - coeff_count, dest_count * (coeff_mod_count + bsk_base_mod_count), pool)); - uint64_t *tmp_coeff_bsk_together_ptr = tmp_coeff_bsk_together.get(); - - // Base q - for (size_t i = 0; i < dest_count; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - multiply_poly_scalar_coeffmod( - tmp_des_coeff_base.get() + (j * coeff_count) + (i * encrypted_ptr_increment), - coeff_count, plain_modulus, coeff_modulus[j], - tmp_coeff_bsk_together_ptr + (j * coeff_count)); - } - tmp_coeff_bsk_together_ptr += encrypted_ptr_increment; - - for (size_t k = 0; k < bsk_base_mod_count; k++) - { - multiply_poly_scalar_coeffmod( - tmp_des_bsk_base.get() + (k * coeff_count) + (i * encrypted_bsk_ptr_increment), - coeff_count, plain_modulus, bsk_modulus[k], - tmp_coeff_bsk_together_ptr + (k * coeff_count)); - } - tmp_coeff_bsk_together_ptr += encrypted_bsk_ptr_increment; - } - - // Allocate a new poly for fast floor result in Bsk - auto tmp_result_bsk(allocate_poly(coeff_count, dest_count * bsk_base_mod_count, pool)); - for (size_t i = 0; i < dest_count; i++) - { - // Step 3: fast floor from q U {Bsk} to Bsk - base_converter->fast_floor( - tmp_coeff_bsk_together.get() + (i * (encrypted_ptr_increment + encrypted_bsk_ptr_increment)), - tmp_result_bsk.get() + (i * encrypted_bsk_ptr_increment), pool); - - // Step 4: fast base convert from Bsk to q - base_converter->fastbconv_sk( - tmp_result_bsk.get() + (i * encrypted_bsk_ptr_increment), encrypted.data(i), pool); - } - } - - void Evaluator::ckks_square(Ciphertext &encrypted, MemoryPoolHandle pool) - { - if (!encrypted.is_ntt_form()) - { - throw invalid_argument("encrypted must be in NTT form"); - } - - // Extract encryption parameters. - auto &context_data = *context_->get_context_data(encrypted.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t encrypted_size = encrypted.size(); - - double new_scale = encrypted.scale() * encrypted.scale(); - - // Check that scale is positive and not too large - if (new_scale <= 0 || (static_cast(log2(new_scale)) >= - context_data.total_coeff_modulus_bit_count())) - { - throw invalid_argument("scale out of bounds"); - } - - // Determine destination.size() - // Default is 3 (c_0, c_1, c_2) - size_t dest_count = sub_safe(add_safe(encrypted_size, encrypted_size), size_t(1)); - - // Size check - if (!product_fits_in(dest_count, coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - // Prepare destination - encrypted.resize(context_, context_data.parms_id(), dest_count); - - //pointer increment to switch to a next polynomial - size_t encrypted_ptr_increment = coeff_count * coeff_mod_count; - - //Step 1: naive multiplication modulo the coefficient modulus - //First allocate two temp polys : - //one for results in base q. This need to be zero - //for the arbitrary size multiplication; not for 2x2 though - auto tmp_des(allocate_zero_poly( - coeff_count * dest_count, coeff_mod_count, pool)); - - //Allocate tmp polys for NTT multiplication results in base q - auto tmp1_poly(allocate_poly(coeff_count, coeff_mod_count, pool)); - auto tmp2_poly(allocate_poly(coeff_count, coeff_mod_count, pool)); - - // First convert all the inputs into NTT form - auto copy_encrypted_ntt(allocate_poly( - coeff_count * encrypted_size, coeff_mod_count, pool)); - set_poly_poly(encrypted.data(), coeff_count * encrypted_size, - coeff_mod_count, copy_encrypted_ntt.get()); - - // The simplest case when the ciphertext dimension is 2 - if (encrypted_size == 2) - { - //Compute c0^2, 2*c0 + c1 and c1^2 modulo q - //tmp poly to keep 2 * c0 * c1 - auto tmp_second_mul(allocate_poly(coeff_count, coeff_mod_count, pool)); - - for (size_t i = 0; i < coeff_mod_count; i++) - { - //Des[0] = c0^2 in NTT - dyadic_product_coeffmod( - copy_encrypted_ntt.get() + (i * coeff_count), - copy_encrypted_ntt.get() + (i * coeff_count), - coeff_count, coeff_modulus[i], - tmp_des.get() + (i * coeff_count)); - - //Des[1] = 2 * c0 * c1 - dyadic_product_coeffmod( - copy_encrypted_ntt.get() + (i * coeff_count), - copy_encrypted_ntt.get() + (i * coeff_count) + encrypted_ptr_increment, - coeff_count, coeff_modulus[i], - tmp_second_mul.get() + (i * coeff_count)); - add_poly_poly_coeffmod( - tmp_second_mul.get() + (i * coeff_count), - tmp_second_mul.get() + (i * coeff_count), - coeff_count, coeff_modulus[i], - tmp_des.get() + (i * coeff_count) + encrypted_ptr_increment); - - //Des[2] = c1^2 in NTT - dyadic_product_coeffmod( - copy_encrypted_ntt.get() + (i * coeff_count) + encrypted_ptr_increment, - copy_encrypted_ntt.get() + (i * coeff_count) + encrypted_ptr_increment, - coeff_count, coeff_modulus[i], - tmp_des.get() + (i * coeff_count) + (2 * encrypted_ptr_increment)); - } - } - else - { - // Perform multiplication on arbitrary size ciphertexts - - // Loop over encrypted1 components [i], seeing if a match exists with an encrypted2 - // component [j] such that [i+j]=[secret_power_index] - // Only need to check encrypted1 components up to and including [secret_power_index], - // and strictly less than [encrypted_array.size()] - - // Number of encrypted1 components to check - size_t current_encrypted_limit = 0; - - for (size_t secret_power_index = 0; secret_power_index < dest_count; secret_power_index++) - { - current_encrypted_limit = min(encrypted_size, secret_power_index + 1); - - for (size_t encrypted1_index = 0; encrypted1_index < current_encrypted_limit; - encrypted1_index++) - { - // check if a corresponding component in encrypted2 exists - if (encrypted_size > secret_power_index - encrypted1_index) - { - size_t encrypted2_index = secret_power_index - encrypted1_index; - - // NTT Multiplication and addition for results in q - for (size_t i = 0; i < coeff_mod_count; i++) - { - // ci * dj - dyadic_product_coeffmod( - copy_encrypted_ntt.get() + (i * coeff_count) + - (encrypted_ptr_increment * encrypted1_index), - copy_encrypted_ntt.get() + (i * coeff_count) + - (encrypted_ptr_increment * encrypted2_index), - coeff_count, coeff_modulus[i], - tmp1_poly.get() + (i * coeff_count)); - - // Dest[i+j] - add_poly_poly_coeffmod( - tmp1_poly.get() + (i * coeff_count), - tmp_des.get() + (i * coeff_count) + - (secret_power_index * coeff_count * coeff_mod_count), - coeff_count, coeff_modulus[i], - tmp_des.get() + (i * coeff_count) + - (secret_power_index * coeff_count * coeff_mod_count)); - } - } - } - } - } - - // Set the final result - set_poly_poly(tmp_des.get(), coeff_count * dest_count, coeff_mod_count, encrypted.data()); - - // Set the scale - encrypted.scale() = new_scale; - } - - void Evaluator::relinearize_internal(Ciphertext &encrypted, - const RelinKeys &relin_keys, size_t destination_size, - MemoryPoolHandle pool) - { - // Verify parameters. - auto context_data_ptr = context_->get_context_data(encrypted.parms_id()); - if (!context_data_ptr) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (relin_keys.parms_id() != context_->key_parms_id()) - { - throw invalid_argument("relin_keys is not valid for encryption parameters"); - } - - size_t encrypted_size = encrypted.size(); - - // Verify parameters. - if (destination_size < 2 || destination_size > encrypted_size) - { - throw invalid_argument("destination_size must be at least 2 and less than or equal to current count"); - } - if (relin_keys.size() < sub_safe(encrypted_size, size_t(2))) - { - throw invalid_argument("not enough relinearization keys"); - } - - // If encrypted is already at the desired level, return - if (destination_size == encrypted_size) - { - return; - } - - // Calculate number of relinearize_one_step calls needed - size_t relins_needed = encrypted_size - destination_size; - for (size_t i = 0; i < relins_needed; i++) - { - switch_key_inplace( - encrypted, - encrypted.data(encrypted_size - 1), - static_cast(relin_keys), - RelinKeys::get_index(encrypted_size - 1), - pool); - encrypted_size--; - } - - // Put the output of final relinearization into destination. - // Prepare destination only at this point because we are resizing down - encrypted.resize(context_, context_data_ptr->parms_id(), destination_size); -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::mod_switch_scale_to_next(const Ciphertext &encrypted, - Ciphertext &destination, MemoryPoolHandle pool) - { - auto context_data_ptr = context_->get_context_data(encrypted.parms_id()); - if (context_data_ptr->parms().scheme() == scheme_type::BFV && - encrypted.is_ntt_form()) - { - throw invalid_argument("BFV encrypted cannot be in NTT form"); - } - if (context_data_ptr->parms().scheme() == scheme_type::CKKS && - !encrypted.is_ntt_form()) - { - throw invalid_argument("CKKS encrypted must be in NTT form"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - // Extract encryption parameters. - auto &context_data = *context_data_ptr; - auto &next_context_data = *context_data.next_context_data(); - auto &next_parms = next_context_data.parms(); - - // q_1,...,q_{k-1} - auto &next_coeff_modulus = next_parms.coeff_modulus(); - size_t next_coeff_mod_count = next_coeff_modulus.size(); - size_t coeff_count = next_parms.poly_modulus_degree(); - size_t encrypted_size = encrypted.size(); - auto &inv_last_coeff_mod_array = - context_data.base_converter()->get_inv_last_coeff_mod_array(); - - // Size test - if (!product_fits_in(coeff_count, encrypted_size, next_coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - // In CKKS need to transform away from NTT form - Ciphertext encrypted_copy(pool); - encrypted_copy = encrypted; - if (next_parms.scheme() == scheme_type::CKKS) - { - transform_from_ntt_inplace(encrypted_copy); - } - - auto temp1(allocate_uint(coeff_count, pool)); - - // Allocate enough room for the result - auto temp2(allocate_poly(coeff_count * encrypted_size, next_coeff_mod_count, pool)); - auto temp2_ptr = temp2.get(); - - for (size_t poly_index = 0; poly_index < encrypted_size; poly_index++) - { - // Set temp1 to ct mod qk - set_uint_uint( - encrypted_copy.data(poly_index) + next_coeff_mod_count * coeff_count, - coeff_count, temp1.get()); - // Add (p-1)/2 to change from flooring to rounding. - auto last_modulus = context_data.parms().coeff_modulus().back(); - uint64_t half = last_modulus.value() >> 1; - for (size_t j = 0; j < coeff_count; j++) - { - temp1.get()[j] = barrett_reduce_63(temp1.get()[j] + half, last_modulus); - } - for (size_t mod_index = 0; mod_index < next_coeff_mod_count; mod_index++, - temp2_ptr += coeff_count) - { - // (ct mod qk) mod qi - modulo_poly_coeffs_63(temp1.get(), coeff_count, - next_coeff_modulus[mod_index], temp2_ptr); - uint64_t half_mod = barrett_reduce_63(half, next_coeff_modulus[mod_index]); - for (size_t j = 0; j < coeff_count; j++) - { - temp2_ptr[j] = sub_uint_uint_mod(temp2_ptr[j], half_mod, next_coeff_modulus[mod_index]); - } - // ((ct mod qi) - (ct mod qk)) mod qi - sub_poly_poly_coeffmod( - encrypted_copy.data(poly_index) + mod_index * coeff_count, temp2_ptr, - coeff_count, next_coeff_modulus[mod_index], temp2_ptr); - // qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi - multiply_poly_scalar_coeffmod(temp2_ptr, coeff_count, - inv_last_coeff_mod_array[mod_index], - next_coeff_modulus[mod_index], temp2_ptr); - } - } - - // Resize destination - destination.resize(context_, next_context_data.parms_id(), encrypted_size); - destination.is_ntt_form() = false; - - set_poly_poly(temp2.get(), coeff_count * encrypted_size, next_coeff_mod_count, - destination.data()); - - // In CKKS need to transform back to NTT form - if (next_parms.scheme() == scheme_type::CKKS) - { - transform_to_ntt_inplace(destination); - - // Also change the scale - destination.scale() = encrypted.scale() / - static_cast(context_data.parms().coeff_modulus().back().value()); - } - } - - void Evaluator::mod_switch_drop_to_next(const Ciphertext &encrypted, - Ciphertext &destination, MemoryPoolHandle pool) - { - // Assuming at this point encrypted is already validated. - auto context_data_ptr = context_->get_context_data(encrypted.parms_id()); - if (context_data_ptr->parms().scheme() == scheme_type::CKKS && - !encrypted.is_ntt_form()) - { - throw invalid_argument("CKKS encrypted must be in NTT form"); - } - - // Extract encryption parameters. - auto &next_context_data = *context_data_ptr->next_context_data(); - auto &next_parms = next_context_data.parms(); - - // Check that scale is positive and not too large - if (encrypted.scale() <= 0 || (static_cast(log2(encrypted.scale())) >= - next_context_data.total_coeff_modulus_bit_count())) - { - throw invalid_argument("scale out of bounds"); - } - - // q_1,...,q_{k-1} - size_t next_coeff_mod_count = next_parms.coeff_modulus().size(); - size_t coeff_count = next_parms.poly_modulus_degree(); - size_t encrypted_size = encrypted.size(); - - // Size check - if (!product_fits_in(encrypted_size, coeff_count, next_coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - size_t rns_poly_total_count = next_coeff_mod_count * coeff_count; - - if (&encrypted == &destination) - { - // Switching in-place so need temporary space - auto temp(allocate_uint(rns_poly_total_count * encrypted_size, pool)); - - // Copy data over to temp - for (size_t i = 0; i < encrypted_size; i++) - { - const uint64_t *encrypted_ptr = encrypted.data(i); - for (size_t j = 0; j < next_coeff_mod_count; j++) - { - set_uint_uint(encrypted_ptr + (j * coeff_count), coeff_count, - temp.get() + (i * rns_poly_total_count) + (j * coeff_count)); - } - } - - // Resize destination before writing - destination.resize(context_, next_context_data.parms_id(), encrypted_size); - destination.is_ntt_form() = true; - destination.scale() = encrypted.scale(); - - // Copy data to destination - set_uint_uint(temp.get(), rns_poly_total_count * encrypted_size, - destination.data()); - } - else - { - // Resize destination before writing - destination.resize(context_, next_context_data.parms_id(), encrypted_size); - destination.is_ntt_form() = true; - destination.scale() = encrypted.scale(); - - // Copy data directly to new destination - for (size_t i = 0; i < encrypted_size; i++) - { - for (size_t j = 0; j < next_coeff_mod_count; j++) - { - const uint64_t *encrypted_ptr = encrypted.data(i); - set_uint_uint(encrypted_ptr + (j * coeff_count), coeff_count, - destination.data() + (i * rns_poly_total_count) + (j * coeff_count)); - } - } - } - } - - void Evaluator::mod_switch_drop_to_next(Plaintext &plain) - { - // Assuming at this point plain is already validated. - auto context_data_ptr = context_->get_context_data(plain.parms_id()); - if (!plain.is_ntt_form()) - { - throw invalid_argument("plain is not in NTT form"); - } - if (!context_data_ptr->next_context_data()) - { - throw invalid_argument("end of modulus switching chain reached"); - } - - // Extract encryption parameters. - auto &next_context_data = *context_data_ptr->next_context_data(); - auto &next_parms = context_data_ptr->next_context_data()->parms(); - - // Check that scale is positive and not too large - if (plain.scale() <= 0 || (static_cast(log2(plain.scale())) >= - next_context_data.total_coeff_modulus_bit_count())) - { - throw invalid_argument("scale out of bounds"); - } - - // q_1,...,q_{k-1} - auto &next_coeff_modulus = next_parms.coeff_modulus(); - size_t next_coeff_mod_count = next_coeff_modulus.size(); - size_t coeff_count = next_parms.poly_modulus_degree(); - - // Compute destination size first for exception safety - auto dest_size = mul_safe(next_coeff_mod_count, coeff_count); - - plain.parms_id() = parms_id_zero; - plain.resize(dest_size); - plain.parms_id() = next_context_data.parms_id(); - } - - void Evaluator::mod_switch_to_next(const Ciphertext &encrypted, - Ciphertext &destination, MemoryPoolHandle pool) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - - auto context_data_ptr = context_->get_context_data(encrypted.parms_id()); - if (context_->last_parms_id() == encrypted.parms_id()) - { - throw invalid_argument("end of modulus switching chain reached"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - switch (context_->first_context_data()->parms().scheme()) - { - case scheme_type::BFV: - // Modulus switching with scaling - mod_switch_scale_to_next(encrypted, destination, move(pool)); - break; - - case scheme_type::CKKS: - // Modulus switching without scaling - mod_switch_drop_to_next(encrypted, destination, move(pool)); - break; - - default: - throw invalid_argument("unsupported scheme"); - } -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (destination.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::mod_switch_to_inplace(Ciphertext &encrypted, - parms_id_type parms_id, MemoryPoolHandle pool) - { - // Verify parameters. - auto context_data_ptr = context_->get_context_data(encrypted.parms_id()); - auto target_context_data_ptr = context_->get_context_data(parms_id); - if (!context_data_ptr) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (!target_context_data_ptr) - { - throw invalid_argument("parms_id is not valid for encryption parameters"); - } - if (context_data_ptr->chain_index() < target_context_data_ptr->chain_index()) - { - throw invalid_argument("cannot switch to higher level modulus"); - } - - while (encrypted.parms_id() != parms_id) - { - mod_switch_to_next_inplace(encrypted, pool); - } - } - - void Evaluator::mod_switch_to_inplace(Plaintext &plain, parms_id_type parms_id) - { - // Verify parameters. - auto context_data_ptr = context_->get_context_data(plain.parms_id()); - auto target_context_data_ptr = context_->get_context_data(parms_id); - if (!context_data_ptr) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - if (!context_->get_context_data(parms_id)) - { - throw invalid_argument("parms_id is not valid for encryption parameters"); - } - if (!plain.is_ntt_form()) - { - throw invalid_argument("plain is not in NTT form"); - } - if (context_data_ptr->chain_index() < target_context_data_ptr->chain_index()) - { - throw invalid_argument("cannot switch to higher level modulus"); - } - - while (plain.parms_id() != parms_id) - { - mod_switch_to_next_inplace(plain); - } - } - - void Evaluator::rescale_to_next(const Ciphertext &encrypted, Ciphertext &destination, - MemoryPoolHandle pool) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (context_->last_parms_id() == encrypted.parms_id()) - { - throw invalid_argument("end of modulus switching chain reached"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - switch (context_->first_context_data()->parms().scheme()) - { - case scheme_type::BFV: - throw invalid_argument("unsupported operation for scheme type"); - - case scheme_type::CKKS: - // Modulus switching with scaling - mod_switch_scale_to_next(encrypted, destination, move(pool)); - break; - - default: - throw invalid_argument("unsupported scheme"); - } -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (destination.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::rescale_to_inplace(Ciphertext &encrypted, parms_id_type parms_id, - MemoryPoolHandle pool) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - - auto context_data_ptr = context_->get_context_data(encrypted.parms_id()); - auto target_context_data_ptr = context_->get_context_data(parms_id); - if (!context_data_ptr) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (!target_context_data_ptr) - { - throw invalid_argument("parms_id is not valid for encryption parameters"); - } - if (context_data_ptr->chain_index() < target_context_data_ptr->chain_index()) - { - throw invalid_argument("cannot switch to higher level modulus"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - switch (context_data_ptr->parms().scheme()) - { - case scheme_type::BFV: - throw invalid_argument("unsupported operation for scheme type"); - - case scheme_type::CKKS: - while (encrypted.parms_id() != parms_id) - { - // Modulus switching with scaling - mod_switch_scale_to_next(encrypted, encrypted, pool); - } - break; - - default: - throw invalid_argument("unsupported scheme"); - } -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::multiply_many(const vector &encrypteds, - const RelinKeys &relin_keys, Ciphertext &destination, - MemoryPoolHandle pool) - { - // Verify parameters. - if (encrypteds.size() == 0) - { - throw invalid_argument("encrypteds vector must not be empty"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - for (size_t i = 0; i < encrypteds.size(); i++) - { - if (&encrypteds[i] == &destination) - { - throw invalid_argument("encrypteds must be different from destination"); - } - } - - // There is at least one ciphertext - auto context_data_ptr = context_->get_context_data(encrypteds[0].parms_id()); - if (!context_data_ptr) - { - throw invalid_argument("encrypteds is not valid for encryption parameters"); - } - - // Extract encryption parameters. - auto &context_data = *context_data_ptr; - auto &parms = context_data.parms(); - - if (parms.scheme() != scheme_type::BFV) - { - throw logic_error("unsupported scheme"); - } - - // If there is only one ciphertext, return it. - if (encrypteds.size() == 1) - { - destination = encrypteds[0]; - return; - } - - // Do first level of multiplications - vector product_vec; - for (size_t i = 0; i < encrypteds.size() - 1; i += 2) - { - Ciphertext temp(context_, context_data.parms_id(), pool); - if (encrypteds[i].data() == encrypteds[i + 1].data()) - { - square(encrypteds[i], temp); - } - else - { - multiply(encrypteds[i], encrypteds[i + 1], temp); - } - relinearize_inplace(temp, relin_keys, pool); - product_vec.emplace_back(move(temp)); - } - if (encrypteds.size() & 1) - { - product_vec.emplace_back(encrypteds.back()); - } - - // Repeatedly multiply and add to the back of the vector until the end is reached - for (size_t i = 0; i < product_vec.size() - 1; i += 2) - { - Ciphertext temp(context_, context_data.parms_id(), pool); - multiply(product_vec[i], product_vec[i + 1], temp); - relinearize_inplace(temp, relin_keys, pool); - product_vec.emplace_back(move(temp)); - } - - destination = product_vec.back(); - } - - void Evaluator::exponentiate_inplace(Ciphertext &encrypted, uint64_t exponent, - const RelinKeys &relin_keys, MemoryPoolHandle pool) - { - // Verify parameters. - auto context_data_ptr = context_->get_context_data(encrypted.parms_id()); - if (!context_data_ptr) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (!context_->get_context_data(relin_keys.parms_id())) - { - throw invalid_argument("relin_keys is not valid for encryption parameters"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - if (exponent == 0) - { - throw invalid_argument("exponent cannot be 0"); - } - - // Fast case - if (exponent == 1) - { - return; - } - - // Create a vector of copies of encrypted - vector exp_vector(exponent, encrypted); - multiply_many(exp_vector, relin_keys, encrypted, move(pool)); - } - - void Evaluator::add_plain_inplace(Ciphertext &encrypted, const Plaintext &plain) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (!is_valid_for(plain, context_)) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - - auto &context_data = *context_->get_context_data(encrypted.parms_id()); - auto &parms = context_data.parms(); - if (parms.scheme() == scheme_type::BFV && encrypted.is_ntt_form()) - { - throw invalid_argument("BFV encrypted cannot be in NTT form"); - } - if (parms.scheme() == scheme_type::CKKS && !encrypted.is_ntt_form()) - { - throw invalid_argument("CKKS encrypted must be in NTT form"); - } - if (plain.is_ntt_form() != encrypted.is_ntt_form()) - { - throw invalid_argument("NTT form mismatch"); - } - if (encrypted.is_ntt_form() && - (encrypted.parms_id() != plain.parms_id())) - { - throw invalid_argument("encrypted and plain parameter mismatch"); - } - if (!are_same_scale(encrypted, plain)) - { - throw invalid_argument("scale mismatch"); - } - - // Extract encryption parameters. - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - - // Size check - if (!product_fits_in(coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - switch (parms.scheme()) - { - case scheme_type::BFV: - { - auto coeff_div_plain_modulus = context_data.coeff_div_plain_modulus(); - auto plain_upper_half_threshold = context_data.plain_upper_half_threshold(); - auto upper_half_increment = context_data.upper_half_increment(); - - for (size_t i = 0; i < plain.coeff_count(); i++) - { - // This is Encryptor::preencrypt - // Multiply plain by scalar coeff_div_plain_modulus and reposition - // if in upper-half. - if (plain[i] >= plain_upper_half_threshold) - { - // Loop over primes - for (size_t j = 0; j < coeff_mod_count; j++) - { - unsigned long long temp[2]{ 0, 0 }; - multiply_uint64(coeff_div_plain_modulus[j], plain[i], temp); - temp[1] += add_uint64(temp[0], upper_half_increment[j], temp); - uint64_t scaled_plain_coeff = barrett_reduce_128(temp, coeff_modulus[j]); - *(encrypted.data() + i + (j * coeff_count)) = add_uint_uint_mod( - *(encrypted.data() + i + (j * coeff_count)), - scaled_plain_coeff, coeff_modulus[j]); - } - } - else - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - uint64_t scaled_plain_coeff = multiply_uint_uint_mod( - coeff_div_plain_modulus[j], plain[i], coeff_modulus[j]); - *(encrypted.data() + i + (j * coeff_count)) = add_uint_uint_mod( - *(encrypted.data() + i + (j * coeff_count)), - scaled_plain_coeff, coeff_modulus[j]); - } - } - } - break; - } - - case scheme_type::CKKS: - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - add_poly_poly_coeffmod(encrypted.data() + (j * coeff_count), - plain.data() + (j*coeff_count), coeff_count, - coeff_modulus[j], encrypted.data() + (j * coeff_count)); - } - break; - } - - default: - throw invalid_argument("unsupported scheme"); - } -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::sub_plain_inplace(Ciphertext &encrypted, const Plaintext &plain) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (!is_valid_for(plain, context_)) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - - auto &context_data = *context_->get_context_data(encrypted.parms_id()); - auto &parms = context_data.parms(); - if (parms.scheme() == scheme_type::BFV && encrypted.is_ntt_form()) - { - throw invalid_argument("BFV encrypted cannot be in NTT form"); - } - if (parms.scheme() == scheme_type::CKKS && !encrypted.is_ntt_form()) - { - throw invalid_argument("CKKS encrypted must be in NTT form"); - } - if (plain.is_ntt_form() != encrypted.is_ntt_form()) - { - throw invalid_argument("NTT form mismatch"); - } - if (encrypted.is_ntt_form() && - (encrypted.parms_id() != plain.parms_id())) - { - throw invalid_argument("encrypted and plain parameter mismatch"); - } - if (!are_same_scale(encrypted, plain)) - { - throw invalid_argument("scale mismatch"); - } - - // Extract encryption parameters. - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - - // Size check - if (!product_fits_in(coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - switch (parms.scheme()) - { - case scheme_type::BFV: - { - auto coeff_div_plain_modulus = context_data.coeff_div_plain_modulus(); - auto plain_upper_half_threshold = context_data.plain_upper_half_threshold(); - auto upper_half_increment = context_data.upper_half_increment(); - - for (size_t i = 0; i < plain.coeff_count(); i++) - { - // This is Encryptor::preencrypt changed to subtract instead - // Multiply plain by scalar coeff_div_plain_modulus and reposition - // if in upper-half. - if (plain[i] >= plain_upper_half_threshold) - { - // Loop over primes - for (size_t j = 0; j < coeff_mod_count; j++) - { - unsigned long long temp[2]{ 0, 0 }; - multiply_uint64(coeff_div_plain_modulus[j], plain[i], temp); - temp[1] += add_uint64(temp[0], upper_half_increment[j], temp); - uint64_t scaled_plain_coeff = barrett_reduce_128(temp, coeff_modulus[j]); - *(encrypted.data() + i + (j * coeff_count)) = sub_uint_uint_mod( - *(encrypted.data() + i + (j * coeff_count)), - scaled_plain_coeff, coeff_modulus[j]); - } - } - else - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - uint64_t scaled_plain_coeff = multiply_uint_uint_mod( - coeff_div_plain_modulus[j], plain[i], coeff_modulus[j]); - *(encrypted.data() + i + (j * coeff_count)) = sub_uint_uint_mod( - *(encrypted.data() + i + (j * coeff_count)), - scaled_plain_coeff, coeff_modulus[j]); - } - } - } - break; - } - - case scheme_type::CKKS: - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - sub_poly_poly_coeffmod(encrypted.data() + (j * coeff_count), - plain.data() + (j * coeff_count), coeff_count, - coeff_modulus[j], encrypted.data() + (j * coeff_count)); - } - break; - } - - default: - throw invalid_argument("unsupported scheme"); - } -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::multiply_plain_inplace(Ciphertext &encrypted, - const Plaintext &plain, MemoryPoolHandle pool) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (!is_valid_for(plain, context_)) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - if (!context_->get_context_data(encrypted.parms_id())) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (encrypted.is_ntt_form() != plain.is_ntt_form()) - { - throw invalid_argument("NTT form mismatch"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - if (encrypted.is_ntt_form()) - { - multiply_plain_ntt(encrypted, plain); - } - else - { - multiply_plain_normal(encrypted, plain, move(pool)); - } -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::multiply_plain_normal(Ciphertext &encrypted, - const Plaintext &plain, MemoryPool &pool) - { - // Extract encryption parameters. - auto &context_data = *context_->get_context_data(encrypted.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - - auto plain_upper_half_threshold = context_data.plain_upper_half_threshold(); - auto plain_upper_half_increment = context_data.plain_upper_half_increment(); - auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); - - size_t encrypted_size = encrypted.size(); - size_t plain_coeff_count = plain.coeff_count(); - size_t plain_nonzero_coeff_count = plain.nonzero_coeff_count(); - - // Size check - if (!product_fits_in(encrypted_size, coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - double new_scale = encrypted.scale() * plain.scale(); - - // Check that scale is positive and not too large - if (new_scale <= 0 || (static_cast(log2(new_scale)) >= - context_data.total_coeff_modulus_bit_count())) - { - throw invalid_argument("scale out of bounds"); - } - - // Set the scale - encrypted.scale() = new_scale; - - /* - Optimizations for constant / monomial multiplication can lead to the - presence of a timing side-channel in use-cases where the plaintext - data should also be kept private. - */ - if (plain_nonzero_coeff_count == 1) - { - // Multiplying by a monomial? - size_t mono_exponent = plain.significant_coeff_count() - 1; - - if (plain[mono_exponent] >= plain_upper_half_threshold) - { - if (!context_data.qualifiers().using_fast_plain_lift) - { - auto adjusted_coeff(allocate_uint(coeff_mod_count, pool)); - auto decomposed_coeff(allocate_uint(coeff_mod_count, pool)); - add_uint_uint64(plain_upper_half_increment, plain[mono_exponent], - coeff_mod_count, adjusted_coeff.get()); - decompose_single_coeff(context_data, adjusted_coeff.get(), - decomposed_coeff.get(), pool); - - for (size_t i = 0; i < encrypted_size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - negacyclic_multiply_poly_mono_coeffmod( - encrypted.data(i) + (j * coeff_count), coeff_count, - decomposed_coeff[j], mono_exponent, coeff_modulus[j], - encrypted.data(i) + (j * coeff_count), pool); - } - } - } - else - { - for (size_t i = 0; i < encrypted_size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - negacyclic_multiply_poly_mono_coeffmod( - encrypted.data(i) + (j * coeff_count), coeff_count, - plain[mono_exponent] + plain_upper_half_increment[j], - mono_exponent, coeff_modulus[j], - encrypted.data(i) + (j * coeff_count), pool); - } - } - } - } - else - { - for (size_t i = 0; i < encrypted_size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - negacyclic_multiply_poly_mono_coeffmod( - encrypted.data(i) + (j * coeff_count), coeff_count, - plain[mono_exponent], mono_exponent, coeff_modulus[j], - encrypted.data(i) + (j * coeff_count), pool); - } - } - } - - return; - } - - // Generic plain case - auto adjusted_poly(allocate_zero_uint(coeff_count * coeff_mod_count, pool)); - auto decomposed_poly(allocate_uint(coeff_count * coeff_mod_count, pool)); - uint64_t *poly_to_transform = nullptr; - if (!context_data.qualifiers().using_fast_plain_lift) - { - // Reposition coefficients. - const uint64_t *plain_ptr = plain.data(); - uint64_t *adjusted_poly_ptr = adjusted_poly.get(); - for (size_t i = 0; i < plain_coeff_count; i++, plain_ptr++, - adjusted_poly_ptr += coeff_mod_count) - { - if (*plain_ptr >= plain_upper_half_threshold) - { - add_uint_uint64(plain_upper_half_increment, - *plain_ptr, coeff_mod_count, adjusted_poly_ptr); - } - else - { - *adjusted_poly_ptr = *plain_ptr; - } - } - decompose(context_data, adjusted_poly.get(), decomposed_poly.get(), pool); - poly_to_transform = decomposed_poly.get(); - } - else - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - const uint64_t *plain_ptr = plain.data(); - uint64_t *adjusted_poly_ptr = adjusted_poly.get() + (j * coeff_count); - uint64_t current_plain_upper_half_increment = plain_upper_half_increment[j]; - for (size_t i = 0; i < plain_coeff_count; i++, plain_ptr++, adjusted_poly_ptr++) - { - // Need to lift the coefficient in each qi - if (*plain_ptr >= plain_upper_half_threshold) - { - *adjusted_poly_ptr = *plain_ptr + current_plain_upper_half_increment; - } - // No need for lifting - else - { - *adjusted_poly_ptr = *plain_ptr; - } - } - } - poly_to_transform = adjusted_poly.get(); - } - - // Need to multiply each component in encrypted with decomposed_poly (plain poly) - // Transform plain poly only once - for (size_t i = 0; i < coeff_mod_count; i++) - { - ntt_negacyclic_harvey( - poly_to_transform + (i * coeff_count), coeff_small_ntt_tables[i]); - } - - for (size_t i = 0; i < encrypted_size; i++) - { - uint64_t *encrypted_ptr = encrypted.data(i); - for (size_t j = 0; j < coeff_mod_count; j++, encrypted_ptr += coeff_count) - { - // Explicit inline to avoid unnecessary copy - //ntt_multiply_poly_nttpoly(encrypted.data(i) + (j * coeff_count), - //poly_to_transform + (j * coeff_count), - // coeff_small_ntt_tables_[j], encrypted.data(i) + (j * coeff_count), pool); - - // Lazy reduction - ntt_negacyclic_harvey_lazy(encrypted_ptr, coeff_small_ntt_tables[j]); - dyadic_product_coeffmod(encrypted_ptr, poly_to_transform + (j * coeff_count), - coeff_count, coeff_modulus[j], encrypted_ptr); - inverse_ntt_negacyclic_harvey(encrypted_ptr, coeff_small_ntt_tables[j]); - } - } - } - - void Evaluator::multiply_plain_ntt(Ciphertext &encrypted_ntt, - const Plaintext &plain_ntt) - { - // Verify parameters. - if (!plain_ntt.is_ntt_form()) - { - throw invalid_argument("plain_ntt is not in NTT form"); - } - if (encrypted_ntt.parms_id() != plain_ntt.parms_id()) - { - throw invalid_argument("encrypted_ntt and plain_ntt parameter mismatch"); - } - - // Extract encryption parameters. - auto &context_data = *context_->get_context_data(encrypted_ntt.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t encrypted_ntt_size = encrypted_ntt.size(); - - // Size check - if (!product_fits_in(encrypted_ntt_size, coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - double new_scale = encrypted_ntt.scale() * plain_ntt.scale(); - - // Check that scale is positive and not too large - if (new_scale <= 0 || (static_cast(log2(new_scale)) >= - context_data.total_coeff_modulus_bit_count())) - { - throw invalid_argument("scale out of bounds"); - } - - for (size_t i = 0; i < encrypted_ntt_size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - dyadic_product_coeffmod( - encrypted_ntt.data(i) + (j * coeff_count), - plain_ntt.data() + (j * coeff_count), - coeff_count, coeff_modulus[j], - encrypted_ntt.data(i) + (j * coeff_count)); - } - } - - // Set the scale - encrypted_ntt.scale() = new_scale; - } - - void Evaluator::transform_to_ntt_inplace(Plaintext &plain, - parms_id_type parms_id, MemoryPoolHandle pool) - { - // Verify parameters. - if (!is_valid_for(plain, context_)) - { - throw invalid_argument("plain is not valid for encryption parameters"); - } - - auto context_data_ptr = context_->get_context_data(parms_id); - if (!context_data_ptr) - { - throw invalid_argument("parms_id is not valid for the current context"); - } - if (plain.is_ntt_form()) - { - throw invalid_argument("plain is already in NTT form"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - - // Extract encryption parameters. - auto &context_data = *context_data_ptr; - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t plain_coeff_count = plain.coeff_count(); - - auto plain_upper_half_threshold = context_data.plain_upper_half_threshold(); - auto plain_upper_half_increment = context_data.plain_upper_half_increment(); - - auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); - - // Size check - if (!product_fits_in(coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - // Resize to fit the entire NTT transformed (ciphertext size) polynomial - // Note that the new coefficients are automatically set to 0 - plain.resize(coeff_count * coeff_mod_count); - - // Verify if plain lift is needed - if (!context_data.qualifiers().using_fast_plain_lift) - { - auto adjusted_poly(allocate_zero_uint(coeff_count * coeff_mod_count, pool)); - for (size_t i = 0; i < plain_coeff_count; i++) - { - if (plain[i] >= plain_upper_half_threshold) - { - add_uint_uint64(plain_upper_half_increment, plain[i], - coeff_mod_count, adjusted_poly.get() + (i * coeff_mod_count)); - } - else - { - adjusted_poly[i * coeff_mod_count] = plain[i]; - } - } - decompose(context_data, adjusted_poly.get(), plain.data(), pool); - } - // No need for composed plain lift and decomposition - else - { - for (size_t j = coeff_mod_count; j--; ) - { - const uint64_t *plain_ptr = plain.data(); - uint64_t *adjusted_poly_ptr = plain.data() + (j * coeff_count); - uint64_t current_plain_upper_half_increment = plain_upper_half_increment[j]; - for (size_t i = 0; i < plain_coeff_count; i++, plain_ptr++, adjusted_poly_ptr++) - { - // Need to lift the coefficient in each qi - if (*plain_ptr >= plain_upper_half_threshold) - { - *adjusted_poly_ptr = *plain_ptr + current_plain_upper_half_increment; - } - // No need for lifting - else - { - *adjusted_poly_ptr = *plain_ptr; - } - } - } - } - - // Transform to NTT domain - for (size_t i = 0; i < coeff_mod_count; i++) - { - ntt_negacyclic_harvey( - plain.data() + (i * coeff_count), coeff_small_ntt_tables[i]); - } - - plain.parms_id() = parms_id; - } - - void Evaluator::transform_to_ntt_inplace(Ciphertext &encrypted) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - - auto context_data_ptr = context_->get_context_data(encrypted.parms_id()); - if (!context_data_ptr) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (encrypted.is_ntt_form()) - { - throw invalid_argument("encrypted is already in NTT form"); - } - - // Extract encryption parameters. - auto &context_data = *context_data_ptr; - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t encrypted_size = encrypted.size(); - - auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); - - // Size check - if (!product_fits_in(coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - // Transform each polynomial to NTT domain - for (size_t i = 0; i < encrypted_size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - ntt_negacyclic_harvey( - encrypted.data(i) + (j * coeff_count), coeff_small_ntt_tables[j]); - } - } - - // Finally change the is_ntt_transformed flag - encrypted.is_ntt_form() = true; -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::transform_from_ntt_inplace(Ciphertext &encrypted_ntt) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted_ntt, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - - auto context_data_ptr = context_->get_context_data(encrypted_ntt.parms_id()); - if (!context_data_ptr) - { - throw invalid_argument("encrypted_ntt is not valid for encryption parameters"); - } - if (!encrypted_ntt.is_ntt_form()) - { - throw invalid_argument("encrypted_ntt is not in NTT form"); - } - - // Extract encryption parameters. - auto &context_data = *context_data_ptr; - auto &parms = context_data.parms(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = parms.coeff_modulus().size(); - size_t encrypted_ntt_size = encrypted_ntt.size(); - - auto &coeff_small_ntt_tables = context_data.small_ntt_tables(); - - // Size check - if (!product_fits_in(coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - // Transform each polynomial from NTT domain - for (size_t i = 0; i < encrypted_ntt_size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - inverse_ntt_negacyclic_harvey( - encrypted_ntt.data(i) + (j * coeff_count), coeff_small_ntt_tables[j]); - } - } - - // Finally change the is_ntt_transformed flag - encrypted_ntt.is_ntt_form() = false; -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted_ntt.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::apply_galois_inplace(Ciphertext &encrypted, uint64_t galois_elt, - const GaloisKeys &galois_keys, MemoryPoolHandle pool) - { - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - - // Don't validate all of galois_keys but just check the parms_id. - if (galois_keys.parms_id() != context_->key_parms_id()) - { - throw invalid_argument("galois_keys is not valid for encryption parameters"); - } - - auto &context_data = *context_->get_context_data(encrypted.parms_id()); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t encrypted_size = encrypted.size(); - - // Size check - if (!product_fits_in(coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - uint64_t m = mul_safe(static_cast(coeff_count), uint64_t(2)); - uint64_t subgroup_size = static_cast(coeff_count >> 1); - int n_power_of_two = get_power_of_two(static_cast(coeff_count)); - - // Verify parameters - if (!(galois_elt & 1) || unsigned_geq(galois_elt, m)) - { - throw invalid_argument("Galois element is not valid"); - } - if (encrypted_size > 2) - { - throw invalid_argument("encrypted size must be 2"); - } - - // Check if Galois key is generated or not. - // If not, attempt a bit decomposition; maybe we have log(n) many keys - if (!galois_keys.has_key(galois_elt)) - { - // galois_elt = 3^order1 * (-1)^order2 - uint64_t order1 = Zmstar_to_generator_.at(galois_elt).first; - uint64_t order2 = Zmstar_to_generator_.at(galois_elt).second; - - // We use either 3 or -3 as our generator, depending on which gives smaller HW - uint64_t two_power_of_gen = 3; - - // Does order1 or n/2-order1 have smaller Hamming weight? - if (hamming_weight(subgroup_size - order1) < hamming_weight(order1)) - { - order1 = subgroup_size - order1; - try_mod_inverse(3, m, two_power_of_gen); - } - - while(order1) - { - if (order1 & 1) - { - if (!galois_keys.has_key(two_power_of_gen)) - { - throw invalid_argument("Galois key not present"); - } - apply_galois_inplace(encrypted, two_power_of_gen, galois_keys, pool); - } - two_power_of_gen = mul_safe(two_power_of_gen, two_power_of_gen); - two_power_of_gen &= (m - 1); - order1 >>= 1; - } - if (order2) - { - if (!galois_keys.has_key(m - 1)) - { - throw invalid_argument("Galois key not present"); - } - apply_galois_inplace(encrypted, m - 1, galois_keys, pool); - } - return; - } - - auto temp(allocate_poly(coeff_count, coeff_mod_count, pool)); - - // DO NOT CHANGE EXECUTION ORDER OF FOLLOWING SECTION - // BEGIN: Apply Galois for each ciphertext - // Execution order is sensitive, since apply_galois is not inplace! - if (parms.scheme() == scheme_type::BFV) - { - // !!! DO NOT CHANGE EXECUTION ORDER!!! - for (size_t i = 0; i < coeff_mod_count; i++) - { - util::apply_galois( - encrypted.data(0) + i * coeff_count, - n_power_of_two, - galois_elt, - coeff_modulus[i], - temp.get() + i * coeff_count); - } - // copy result to encrypted.data(0) - set_poly_poly(temp.get(), coeff_count, coeff_mod_count, - encrypted.data(0)); - for (size_t i = 0; i < coeff_mod_count; i++) - { - util::apply_galois( - encrypted.data(1) + i * coeff_count, - n_power_of_two, - galois_elt, - coeff_modulus[i], - temp.get() + i * coeff_count); - } - } - else if (parms.scheme() == scheme_type::CKKS) - { - // !!! DO NOT CHANGE EXECUTION ORDER!!! - for (size_t i = 0; i < coeff_mod_count; i++) - { - util::apply_galois_ntt( - encrypted.data(0) + i * coeff_count, - n_power_of_two, - galois_elt, - temp.get() + i * coeff_count); - } - // copy result to encrypted.data(0) - set_poly_poly(temp.get(), coeff_count, coeff_mod_count, - encrypted.data(0)); - for (size_t i = 0; i < coeff_mod_count; i++) - { - util::apply_galois_ntt( - encrypted.data(1) + i * coeff_count, - n_power_of_two, - galois_elt, - temp.get() + i * coeff_count); - } - } - else - { - throw logic_error("scheme not implemented"); - } - // wipe encrypted.data(1) - set_zero_poly(coeff_count, coeff_mod_count, encrypted.data(1)); - // END: Apply Galois for each ciphertext - // REORDERING IS SAFE NOW - - // Calculate (temp * galois_key[0], temp * galois_key[1]) + (ct[0], 0) - switch_key_inplace( - encrypted, - temp.get(), - static_cast(galois_keys), - GaloisKeys::get_index(galois_elt), - pool); -#ifdef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - // Transparent ciphertext output is not allowed. - if (encrypted.is_transparent()) - { - throw logic_error("result ciphertext is transparent"); - } -#endif - } - - void Evaluator::rotate_internal(Ciphertext &encrypted, int steps, - const GaloisKeys &galois_keys, MemoryPoolHandle pool) - { - auto context_data_ptr = context_->get_context_data(encrypted.parms_id()); - if (!context_data_ptr) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (!context_data_ptr->qualifiers().using_batching) - { - throw logic_error("encryption parameters do not support batching"); - } - if (galois_keys.parms_id() != context_->key_parms_id()) - { - throw invalid_argument("galois_keys is not valid for encryption parameters"); - } - - // Is there anything to do? - if (steps == 0) - { - return; - } - - size_t coeff_count = context_data_ptr->parms().poly_modulus_degree(); - - // Perform rotation and key switching - apply_galois_inplace(encrypted, - steps_to_galois_elt(steps, coeff_count), - galois_keys, move(pool)); - } - - void Evaluator::switch_key_inplace( - Ciphertext &encrypted, - const uint64_t *target, - const KSwitchKeys &kswitch_keys, - size_t kswitch_keys_index, - MemoryPoolHandle pool) - { - auto parms_id = encrypted.parms_id(); - auto &context_data = *context_->get_context_data(parms_id); - auto &parms = context_data.parms(); - auto &key_context_data = *context_->key_context_data(); - auto &key_parms = key_context_data.parms(); - auto scheme = parms.scheme(); - - // Verify parameters. - if (!is_metadata_valid_for(encrypted, context_)) - { - throw invalid_argument("encrypted is not valid for encryption parameters"); - } - if (!target) - { - throw invalid_argument("target"); - } - if (!context_->using_keyswitching()) - { - throw logic_error("keyswitching is not supported by the context"); - } - - // Don't validate all of kswitch_keys but just check the parms_id. - if (kswitch_keys.parms_id() != context_->key_parms_id()) - { - throw invalid_argument("parameter mismatch"); - } - - if (kswitch_keys_index >= kswitch_keys.data().size()) - { - throw out_of_range("kswitch_keys_index"); - } - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - if (scheme == scheme_type::BFV && encrypted.is_ntt_form()) - { - throw invalid_argument("BFV encrypted cannot be in NTT form"); - } - if (scheme == scheme_type::CKKS && !encrypted.is_ntt_form()) - { - throw invalid_argument("CKKS encrypted must be in NTT form"); - } - - // Extract encryption parameters. - size_t coeff_count = parms.poly_modulus_degree(); - size_t decomp_mod_count = parms.coeff_modulus().size(); - auto &key_modulus = key_parms.coeff_modulus(); - size_t key_mod_count = key_modulus.size(); - size_t rns_mod_count = decomp_mod_count + 1; - auto &small_ntt_tables = key_context_data.small_ntt_tables(); - auto &modswitch_factors = key_context_data.base_converter()-> - get_inv_last_coeff_mod_array(); - - // Size check - if (!product_fits_in(coeff_count, rns_mod_count, size_t(2))) - { - throw logic_error("invalid parameters"); - } - - // Prepare input - auto &key_vector = kswitch_keys.data()[kswitch_keys_index]; - - // Check only the used component in KSwitchKeys. - for (auto &each_key : key_vector) - { - if (!is_metadata_valid_for(each_key, context_)) - { - throw invalid_argument( - "kswitch_keys is not valid for encryption parameters"); - } - } - - // Temporary results - Pointer temp_poly[2] { - allocate_zero_poly(2 * coeff_count, rns_mod_count, pool), - allocate_zero_poly(2 * coeff_count, rns_mod_count, pool) - }; - - // RNS decomposition index = key index - for (size_t i = 0; i < decomp_mod_count; i++) - { - // For each RNS decomposition, multiply with key data and sum up. - auto local_small_poly_0(allocate_uint(coeff_count, pool)); - auto local_small_poly_1(allocate_uint(coeff_count, pool)); - auto local_small_poly_2(allocate_uint(coeff_count, pool)); - - const uint64_t *local_encrypted_ptr = nullptr; - set_uint_uint( - target + i * coeff_count, - coeff_count, - local_small_poly_0.get()); - if (scheme == scheme_type::CKKS) - { - inverse_ntt_negacyclic_harvey( - local_small_poly_0.get(), - small_ntt_tables[i]); - } - // Key RNS representation - for (size_t j = 0; j < rns_mod_count; j++) - { - size_t index = (j == decomp_mod_count ? key_mod_count - 1 : j); - if (scheme == scheme_type::CKKS && i == j) - { - local_encrypted_ptr = target + j * coeff_count; - } - else - { - // Reduce modulus only if needed - if (key_modulus[i].value() <= key_modulus[index].value()) - { - set_uint_uint( - local_small_poly_0.get(), - coeff_count, - local_small_poly_1.get()); - } - else - { - modulo_poly_coeffs_63( - local_small_poly_0.get(), - coeff_count, - key_modulus[index], - local_small_poly_1.get()); - } - - // Lazy reduction, output in [0, 4q). - ntt_negacyclic_harvey_lazy( - local_small_poly_1.get(), - small_ntt_tables[index]); - local_encrypted_ptr = local_small_poly_1.get(); - } - // Two components in key - for (size_t k = 0; k < 2; k++) - { - // dyadic_product_coeffmod( - // local_encrypted_ptr, - // key_vector[i].data(k) + index * coeff_count, - // coeff_count, - // key_modulus[index], - // local_small_poly_2.get()); - // add_poly_poly_coeffmod( - // local_small_poly_2.get(), - // temp_poly[k].get() + j * coeff_count, - // coeff_count, - // key_modulus[index], - // temp_poly[k].get() + j * coeff_count); - const uint64_t *key_ptr = key_vector[i].data().data(k); - for (size_t l = 0; l < coeff_count; l++) - { - unsigned long long local_wide_product[2]; - unsigned long long local_low_word; - unsigned char local_carry; - - multiply_uint64( - local_encrypted_ptr[l], - key_ptr[(index * coeff_count) + l], - local_wide_product); - local_carry = add_uint64( - temp_poly[k].get()[(j * coeff_count + l) * 2], - local_wide_product[0], - &local_low_word); - temp_poly[k].get()[(j * coeff_count + l) * 2] = - local_low_word; - temp_poly[k].get()[(j * coeff_count + l) * 2 + 1] += - local_wide_product[1] + local_carry; - } - } - } - } - - // Results are now stored in temp_poly[k] - // Modulus switching should be performed - auto local_small_poly(allocate_uint(coeff_count, pool)); - for (size_t k = 0; k < 2; k++) - { - // Reduce (ct mod 4qk) mod qk - uint64_t *temp_poly_ptr = temp_poly[k].get() + - decomp_mod_count * coeff_count * 2; - for (size_t l = 0; l < coeff_count; l++) - { - temp_poly_ptr[l] = barrett_reduce_128( - temp_poly_ptr + l * 2, - key_modulus[key_mod_count - 1]); - } - // Lazy reduction, they are then reduced mod qi - uint64_t *temp_last_poly_ptr = temp_poly[k].get() + decomp_mod_count * coeff_count * 2; - inverse_ntt_negacyclic_harvey_lazy( - temp_last_poly_ptr, - small_ntt_tables[key_mod_count - 1]); - - // Add (p-1)/2 to change from flooring to rounding. - uint64_t half = key_modulus[key_mod_count - 1].value() >> 1; - for (size_t l = 0; l < coeff_count; l++) - { - temp_last_poly_ptr[l] = barrett_reduce_63(temp_last_poly_ptr[l] + half, - key_modulus[key_mod_count - 1]); - } - - uint64_t *encrypted_ptr = encrypted.data(k); - for (size_t j = 0; j < decomp_mod_count; j++) - { - temp_poly_ptr = temp_poly[k].get() + j * coeff_count * 2; - // (ct mod 4qi) mod qi - for (size_t l = 0; l < coeff_count; l++) - { - temp_poly_ptr[l] = barrett_reduce_128( - temp_poly_ptr + l * 2, - key_modulus[j]); - } - // (ct mod 4qk) mod qi - modulo_poly_coeffs_63( - temp_last_poly_ptr, - coeff_count, - key_modulus[j], - local_small_poly.get()); - - uint64_t half_mod = barrett_reduce_63(half, key_modulus[j]); - for (size_t l = 0; l < coeff_count; l++) - { - local_small_poly.get()[l] = sub_uint_uint_mod(local_small_poly.get()[l], - half_mod, - key_modulus[j]); - } - - if (scheme == scheme_type::CKKS) - { - ntt_negacyclic_harvey( - local_small_poly.get(), - small_ntt_tables[j]); - } - else if (scheme == scheme_type::BFV) - { - inverse_ntt_negacyclic_harvey( - temp_poly_ptr, - small_ntt_tables[j]); - } - // ((ct mod qi) - (ct mod qk)) mod qi - sub_poly_poly_coeffmod( - temp_poly_ptr, - local_small_poly.get(), - coeff_count, - key_modulus[j], - temp_poly_ptr); - // qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi - multiply_poly_scalar_coeffmod( - temp_poly_ptr, - coeff_count, - modswitch_factors[j], - key_modulus[j], - temp_poly_ptr); - add_poly_poly_coeffmod( - temp_poly_ptr, - encrypted_ptr + j * coeff_count, - coeff_count, - key_modulus[j], - encrypted_ptr + j * coeff_count); - } - } - } -} diff --git a/SEAL/native/src/seal/evaluator.h b/SEAL/native/src/seal/evaluator.h deleted file mode 100644 index 8daca7d..0000000 --- a/SEAL/native/src/seal/evaluator.h +++ /dev/null @@ -1,1502 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include "seal/context.h" -#include "seal/relinkeys.h" -#include "seal/smallmodulus.h" -#include "seal/memorymanager.h" -#include "seal/ciphertext.h" -#include "seal/plaintext.h" -#include "seal/galoiskeys.h" -#include "seal/util/pointer.h" -#include "seal/secretkey.h" -#include "seal/util/uintarithsmallmod.h" -#include "seal/util/common.h" -#include "seal/kswitchkeys.h" -#include "seal/valcheck.h" - -namespace seal -{ - /** - Provides operations on ciphertexts. Due to the properties of the encryption - scheme, the arithmetic operations pass through the encryption layer to the - underlying plaintext, changing it according to the type of the operation. Since - the plaintext elements are fundamentally polynomials in the polynomial quotient - ring Z_T[x]/(X^N+1), where T is the plaintext modulus and X^N+1 is the polynomial - modulus, this is the ring where the arithmetic operations will take place. - BatchEncoder (batching) provider an alternative possibly more convenient view - of the plaintext elements as 2-by-(N2/2) matrices of integers modulo the plaintext - modulus. In the batching view the arithmetic operations act on the matrices - element-wise. Some of the operations only apply in the batching view, such as - matrix row and column rotations. Other operations such as relinearization have - no semantic meaning but are necessary for performance reasons. - - @par Arithmetic Operations - The core operations are arithmetic operations, in particular multiplication - and addition of ciphertexts. In addition to these, we also provide negation, - subtraction, squaring, exponentiation, and multiplication and addition of - several ciphertexts for convenience. in many cases some of the inputs to a - computation are plaintext elements rather than ciphertexts. For this we - provide fast "plain" operations: plain addition, plain subtraction, and plain - multiplication. - - @par Relinearization - One of the most important non-arithmetic operations is relinearization, which - takes as input a ciphertext of size K+1 and relinearization keys (at least K-1 - keys are needed), and changes the size of the ciphertext down to 2 (minimum size). - For most use-cases only one relinearization key suffices, in which case - relinearization should be performed after every multiplication. Homomorphic - multiplication of ciphertexts of size K+1 and L+1 outputs a ciphertext of size - K+L+1, and the computational cost of multiplication is proportional to K*L. - Plain multiplication and addition operations of any type do not change the - size. Relinearization requires relinearization keys to have been generated. - - @par Rotations - When batching is enabled, we provide operations for rotating the plaintext matrix - rows cyclically left or right, and for rotating the columns (swapping the rows). - Rotations require Galois keys to have been generated. - - @par Other Operations - We also provide operations for transforming ciphertexts to NTT form and back, - and for transforming plaintext polynomials to NTT form. These can be used in - a very fast plain multiplication variant, that assumes the inputs to be in NTT - form. Since the NTT has to be done in any case in plain multiplication, this - function can be used when e.g. one plaintext input is used in several plain - multiplication, and transforming it several times would not make sense. - - @par NTT form - When using the BFV scheme (scheme_type::BFV), all plaintexts and ciphertexts - should remain by default in the usual coefficient representation, i.e., not - in NTT form. When using the CKKS scheme (scheme_type::CKKS), all plaintexts - and ciphertexts should remain by default in NTT form. We call these scheme- - specific NTT states the "default NTT form". Some functions, such as add, work - even if the inputs are not in the default state, but others, such as multiply, - will throw an exception. The output of all evaluation functions will be in - the same state as the input(s), with the exception of the transform_to_ntt - and transform_from_ntt functions, which change the state. Ideally, unless these - two functions are called, all other functions should "just work". - - @see EncryptionParameters for more details on encryption parameters. - @see BatchEncoder for more details on batching - @see RelinKeys for more details on relinearization keys. - @see GaloisKeys for more details on Galois keys. - */ - class Evaluator - { - public: - /** - Creates an Evaluator instance initialized with the specified SEALContext. - - @param[in] context The SEALContext - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - */ - Evaluator(std::shared_ptr context); - - /** - Negates a ciphertext. - - @param[in] encrypted The ciphertext to negate - @throws std::invalid_argument if encrypted is not valid for the encryption - parameters - */ - void negate_inplace(Ciphertext &encrypted); - - /** - Negates a ciphertext and stores the result in the destination parameter. - - @param[in] encrypted The ciphertext to negate - @param[out] destination The ciphertext to overwrite with the negated result - @throws std::invalid_argument if encrypted is not valid for the encryption - parameters - @throws std::logic_error if result ciphertext is transparent - */ - inline void negate(const Ciphertext &encrypted, Ciphertext &destination) - { - destination = encrypted; - negate_inplace(destination); - } - - /** - Adds two ciphertexts. This function adds together encrypted1 and encrypted2 - and stores the result in encrypted1. - - @param[in] encrypted1 The first ciphertext to add - @param[in] encrypted2 The second ciphertext to add - @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for - the encryption parameters - @throws std::invalid_argument if encrypted1 and encrypted2 are in different - NTT forms - @throws std::invalid_argument if encrypted1 and encrypted2 have different scale - @throws std::logic_error if result ciphertext is transparent - */ - void add_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2); - - /** - Adds two ciphertexts. This function adds together encrypted1 and encrypted2 - and stores the result in the destination parameter. - - @param[in] encrypted1 The first ciphertext to add - @param[in] encrypted2 The second ciphertext to add - @param[out] destination The ciphertext to overwrite with the addition result - @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for - the encryption parameters - @throws std::invalid_argument if encrypted1 and encrypted2 are in different - NTT forms - @throws std::invalid_argument if encrypted1 and encrypted2 have different scale - @throws std::logic_error if result ciphertext is transparent - */ - inline void add(const Ciphertext &encrypted1, const Ciphertext &encrypted2, - Ciphertext &destination) - { - if (&encrypted2 == &destination) - { - add_inplace(destination, encrypted1); - } - else - { - destination = encrypted1; - add_inplace(destination, encrypted2); - } - } - - /** - Adds together a vector of ciphertexts and stores the result in the destination - parameter. - - @param[in] encrypteds The ciphertexts to add - @param[out] destination The ciphertext to overwrite with the addition result - @throws std::invalid_argument if encrypteds is empty - @throws std::invalid_argument if the encrypteds are not valid for the encryption - parameters - @throws std::invalid_argument if encrypteds are in different NTT forms - @throws std::invalid_argument if encrypteds have different scale - @throws std::invalid_argument if destination is one of encrypteds - @throws std::logic_error if result ciphertext is transparent - */ - void add_many(const std::vector &encrypteds, Ciphertext &destination); - - /** - Subtracts two ciphertexts. This function computes the difference of encrypted1 - and encrypted2, and stores the result in encrypted1. - - @param[in] encrypted1 The ciphertext to subtract from - @param[in] encrypted2 The ciphertext to subtract - @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted1 and encrypted2 are in different - NTT forms - @throws std::invalid_argument if encrypted1 and encrypted2 have different scale - @throws std::logic_error if result ciphertext is transparent - */ - void sub_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2); - - /** - Subtracts two ciphertexts. This function computes the difference of encrypted1 - and encrypted2 and stores the result in the destination parameter. - - @param[in] encrypted1 The ciphertext to subtract from - @param[in] encrypted2 The ciphertext to subtract - @param[out] destination The ciphertext to overwrite with the subtraction result - @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted1 and encrypted2 are in different - NTT forms - @throws std::invalid_argument if encrypted1 and encrypted2 have different scale - @throws std::logic_error if result ciphertext is transparent - */ - inline void sub(const Ciphertext &encrypted1, const Ciphertext &encrypted2, - Ciphertext &destination) - { - if (&encrypted2 == &destination) - { - sub_inplace(destination, encrypted1); - negate_inplace(destination); - } - else - { - destination = encrypted1; - sub_inplace(destination, encrypted2); - } - } - - /** - Multiplies two ciphertexts. This functions computes the product of encrypted1 - and encrypted2 and stores the result in encrypted1. Dynamic memory allocations - in the process are allocated from the memory pool pointed to by the given - MemoryPoolHandle. - - @param[in] encrypted1 The first ciphertext to multiply - @param[in] encrypted2 The second ciphertext to multiply - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted1 or encrypted2 is not in the default - NTT form - @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale - is too large for the encryption parameters - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - void multiply_inplace(Ciphertext &encrypted1, const Ciphertext &encrypted2, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Multiplies two ciphertexts. This functions computes the product of encrypted1 - and encrypted2 and stores the result in the destination parameter. Dynamic - memory allocations in the process are allocated from the memory pool pointed - to by the given MemoryPoolHandle. - - @param[in] encrypted1 The first ciphertext to multiply - @param[in] encrypted2 The second ciphertext to multiply - @param[out] destination The ciphertext to overwrite with the multiplication result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if encrypted1 or encrypted2 is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted1 or encrypted2 is not in the default - NTT form - @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale - is too large for the encryption parameters - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - inline void multiply(const Ciphertext &encrypted1, - const Ciphertext &encrypted2, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - if (&encrypted2 == &destination) - { - multiply_inplace(destination, encrypted1, std::move(pool)); - } - else - { - destination = encrypted1; - multiply_inplace(destination, encrypted2, std::move(pool)); - } - } - - /** - Squares a ciphertext. This functions computes the square of encrypted. Dynamic - memory allocations in the process are allocated from the memory pool pointed - to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to square - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if encrypted is not valid for the encryption - parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale - is too large for the encryption parameters - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - void square_inplace(Ciphertext &encrypted, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Squares a ciphertext. This functions computes the square of encrypted and - stores the result in the destination parameter. Dynamic memory allocations - in the process are allocated from the memory pool pointed to by the given - MemoryPoolHandle. - - @param[in] encrypted The ciphertext to square - @param[out] destination The ciphertext to overwrite with the square - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if encrypted is not valid for the encryption - parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale - is too large for the encryption parameters - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - inline void square(const Ciphertext &encrypted, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination = encrypted; - square_inplace(destination, std::move(pool)); - } - - /** - Relinearizes a ciphertext. This functions relinearizes encrypted, reducing - its size down to 2. If the size of encrypted is K+1, the given relinearization - keys need to have size at least K-1. Dynamic memory allocations in the - process are allocated from the memory pool pointed to by the given - MemoryPoolHandle. - - @param[in] encrypted The ciphertext to relinearize - @param[in] relin_keys The relinearization keys - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if encrypted or relin_keys is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if relin_keys do not correspond to the top level - parameters in the current context - @throws std::invalid_argument if the size of relin_keys is too small - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void relinearize_inplace(Ciphertext &encrypted, const RelinKeys &relin_keys, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - relinearize_internal(encrypted, relin_keys, 2, std::move(pool)); - } - - /** - Relinearizes a ciphertext. This functions relinearizes encrypted, reducing - its size down to 2, and stores the result in the destination parameter. - If the size of encrypted is K+1, the given relinearization keys need to - have size at least K-1. Dynamic memory allocations in the process are allocated - from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to relinearize - @param[in] relin_keys The relinearization keys - @param[out] destination The ciphertext to overwrite with the relinearized result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if encrypted or relin_keys is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if relin_keys do not correspond to the top level - parameters in the current context - @throws std::invalid_argument if the size of relin_keys is too small - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void relinearize(const Ciphertext &encrypted, - const RelinKeys &relin_keys, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination = encrypted; - relinearize_inplace(destination, relin_keys, std::move(pool)); - } - - /** - Given a ciphertext encrypted modulo q_1...q_k, this function switches the - modulus down to q_1...q_{k-1} and stores the result in the destination - parameter. Dynamic memory allocations in the process are allocated from - the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to be switched to a smaller modulus - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @param[out] destination The ciphertext to overwrite with the modulus switched result - @throws std::invalid_argument if encrypted is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if encrypted is already at lowest level - @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too - large for the new encryption parameters - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - void mod_switch_to_next(const Ciphertext &encrypted, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Given a ciphertext encrypted modulo q_1...q_k, this function switches the - modulus down to q_1...q_{k-1}. Dynamic memory allocations in the process - are allocated from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to be switched to a smaller modulus - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if encrypted is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if encrypted is already at lowest level - @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too - large for the new encryption parameters - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - inline void mod_switch_to_next_inplace(Ciphertext &encrypted, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - mod_switch_to_next(encrypted, encrypted, std::move(pool)); - } - - /** - Modulus switches an NTT transformed plaintext from modulo q_1...q_k down - to modulo q_1...q_{k-1}. - - @param[in] plain The plaintext to be switched to a smaller modulus - @throws std::invalid_argument if plain is not in NTT form - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if plain is already at lowest level - @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too - large for the new encryption parameters - */ - inline void mod_switch_to_next_inplace(Plaintext &plain) - { - // Verify parameters. - if (!is_valid_for(plain, context_)) - { - throw std::invalid_argument("plain is not valid for encryption parameters"); - } - mod_switch_drop_to_next(plain); - } - - /** - Modulus switches an NTT transformed plaintext from modulo q_1...q_k down - to modulo q_1...q_{k-1} and stores the result in the destination parameter. - - @param[in] plain The plaintext to be switched to a smaller modulus - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @param[out] destination The plaintext to overwrite with the modulus switched result - @throws std::invalid_argument if plain is not in NTT form - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if plain is already at lowest level - @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too - large for the new encryption parameters - @throws std::invalid_argument if pool is uninitialized - */ - inline void mod_switch_to_next(const Plaintext &plain, Plaintext &destination) - { - destination = plain; - mod_switch_to_next_inplace(destination); - } - - /** - Given a ciphertext encrypted modulo q_1...q_k, this function switches the - modulus down until the parameters reach the given parms_id. Dynamic memory - allocations in the process are allocated from the memory pool pointed to - by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to be switched to a smaller modulus - @param[in] parms_id The target parms_id - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if encrypted is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if parms_id is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is already at lower level in modulus chain - than the parameters corresponding to parms_id - @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too - large for the new encryption parameters - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - void mod_switch_to_inplace(Ciphertext &encrypted, parms_id_type parms_id, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Given a ciphertext encrypted modulo q_1...q_k, this function switches the - modulus down until the parameters reach the given parms_id and stores the - result in the destination parameter. Dynamic memory allocations in the process - are allocated from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to be switched to a smaller modulus - @param[in] parms_id The target parms_id - @param[out] destination The ciphertext to overwrite with the modulus switched result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if encrypted is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if parms_id is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is already at lower level in modulus chain - than the parameters corresponding to parms_id - @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too - large for the new encryption parameters - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - inline void mod_switch_to(const Ciphertext &encrypted, - parms_id_type parms_id, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination = encrypted; - mod_switch_to_inplace(destination, parms_id, std::move(pool)); - } - - /** - Given an NTT transformed plaintext modulo q_1...q_k, this function switches - the modulus down until the parameters reach the given parms_id. - - @param[in] plain The plaintext to be switched to a smaller modulus - @param[in] parms_id The target parms_id - @throws std::invalid_argument if plain is not in NTT form - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if parms_id is not valid for the encryption parameters - @throws std::invalid_argument if plain is already at lower level in modulus chain - than the parameters corresponding to parms_id - @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too - large for the new encryption parameters - */ - void mod_switch_to_inplace(Plaintext &plain, parms_id_type parms_id); - - /** - Given an NTT transformed plaintext modulo q_1...q_k, this function switches - the modulus down until the parameters reach the given parms_id and stores - the result in the destination parameter. - - @param[in] plain The plaintext to be switched to a smaller modulus - @param[in] parms_id The target parms_id - @param[out] destination The plaintext to overwrite with the modulus switched result - @throws std::invalid_argument if plain is not in NTT form - @throws std::invalid_argument if plain is not valid for the encryption parameters - @throws std::invalid_argument if parms_id is not valid for the encryption parameters - @throws std::invalid_argument if plain is already at lower level in modulus chain - than the parameters corresponding to parms_id - @throws std::invalid_argument if, when using scheme_type::CKKS, the scale is too - large for the new encryption parameters - */ - inline void mod_switch_to(const Plaintext &plain, parms_id_type parms_id, - Plaintext &destination) - { - destination = plain; - mod_switch_to_inplace(destination, parms_id); - } - - /** - Given a ciphertext encrypted modulo q_1...q_k, this function switches the - modulus down to q_1...q_{k-1}, scales the message down accordingly, and - stores the result in the destination parameter. Dynamic memory allocations - in the process are allocated from the memory pool pointed to by the given - MemoryPoolHandle. - - @param[in] encrypted The ciphertext to be switched to a smaller modulus - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @param[out] destination The ciphertext to overwrite with the modulus switched result - @throws std::invalid_argument if the scheme is invalid for rescaling - @throws std::invalid_argument if encrypted is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if encrypted is already at lowest level - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - void rescale_to_next(const Ciphertext &encrypted, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Given a ciphertext encrypted modulo q_1...q_k, this function switches the - modulus down to q_1...q_{k-1} and scales the message down accordingly. Dynamic - memory allocations in the process are allocated from the memory pool pointed - to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to be switched to a smaller modulus - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if the scheme is invalid for rescaling - @throws std::invalid_argument if encrypted is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if encrypted is already at lowest level - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - inline void rescale_to_next_inplace(Ciphertext &encrypted, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - rescale_to_next(encrypted, encrypted, std::move(pool)); - } - - /** - Given a ciphertext encrypted modulo q_1...q_k, this function switches the - modulus down until the parameters reach the given parms_id and scales the - message down accordingly. Dynamic memory allocations in the process are - allocated from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to be switched to a smaller modulus - @param[in] parms_id The target parms_id - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if the scheme is invalid for rescaling - @throws std::invalid_argument if encrypted is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if parms_id is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is already at lower level in modulus chain - than the parameters corresponding to parms_id - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - void rescale_to_inplace(Ciphertext &encrypted, parms_id_type parms_id, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Given a ciphertext encrypted modulo q_1...q_k, this function switches the - modulus down until the parameters reach the given parms_id, scales the message - down accordingly, and stores the result in the destination parameter. Dynamic - memory allocations in the process are allocated from the memory pool pointed - to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to be switched to a smaller modulus - @param[in] parms_id The target parms_id - @param[out] destination The ciphertext to overwrite with the modulus switched result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if the scheme is invalid for rescaling - @throws std::invalid_argument if encrypted is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if parms_id is not valid for the encryption parameters - @throws std::invalid_argument if encrypted is already at lower level in modulus chain - than the parameters corresponding to parms_id - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - inline void rescale_to(const Ciphertext &encrypted, - parms_id_type parms_id, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination = encrypted; - rescale_to_inplace(destination, parms_id, std::move(pool)); - } - - /** - Multiplies several ciphertexts together. This function computes the product - of several ciphertext given as an std::vector and stores the result in the - destination parameter. The multiplication is done in a depth-optimal order, - and relinearization is performed automatically after every multiplication - in the process. In relinearization the given relinearization keys are used. - Dynamic memory allocations in the process are allocated from the memory - pool pointed to by the given MemoryPoolHandle. - - @param[in] encrypteds The ciphertexts to multiply - @param[in] relin_keys The relinearization keys - @param[out] destination The ciphertext to overwrite with the multiplication result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::BFV - @throws std::invalid_argument if encrypteds is empty - @throws std::invalid_argument if the ciphertexts or relin_keys are not valid for - the encryption parameters - @throws std::invalid_argument if encrypteds are not in the default NTT form - @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale - is too large for the encryption parameters - @throws std::invalid_argument if the size of relin_keys is too small - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - void multiply_many(const std::vector &encrypteds, - const RelinKeys &relin_keys, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Exponentiates a ciphertext. This functions raises encrypted to a power. - Dynamic memory allocations in the process are allocated from the memory - pool pointed to by the given MemoryPoolHandle. The exponentiation is done - in a depth-optimal order, and relinearization is performed automatically - after every multiplication in the process. In relinearization the given - relinearization keys are used. - - @param[in] encrypted The ciphertext to exponentiate - @param[in] exponent The power to raise the ciphertext to - @param[in] relin_keys The relinearization keys - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::BFV - @throws std::invalid_argument if encrypted or relin_keys is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale - is too large for the encryption parameters - @throws std::invalid_argument if exponent is zero - @throws std::invalid_argument if the size of relin_keys is too small - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - void exponentiate_inplace(Ciphertext &encrypted, - std::uint64_t exponent, const RelinKeys &relin_keys, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Exponentiates a ciphertext. This functions raises encrypted to a power and - stores the result in the destination parameter. Dynamic memory allocations - in the process are allocated from the memory pool pointed to by the given - MemoryPoolHandle. The exponentiation is done in a depth-optimal order, and - relinearization is performed automatically after every multiplication in - the process. In relinearization the given relinearization keys are used. - - @param[in] encrypted The ciphertext to exponentiate - @param[in] exponent The power to raise the ciphertext to - @param[in] relin_keys The relinearization keys - @param[out] destination The ciphertext to overwrite with the power - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::BFV - @throws std::invalid_argument if encrypted or relin_keys is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if, when using scheme_type::CKKS, the output scale - is too large for the encryption parameters - @throws std::invalid_argument if exponent is zero - @throws std::invalid_argument if the size of relin_keys is too small - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void exponentiate(const Ciphertext &encrypted, std::uint64_t exponent, - const RelinKeys &relin_keys, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination = encrypted; - exponentiate_inplace(destination, exponent, relin_keys, std::move(pool)); - } - - /** - Adds a ciphertext and a plaintext. The plaintext must be valid for the current - encryption parameters. - - @param[in] encrypted The ciphertext to add - @param[in] plain The plaintext to add - @throws std::invalid_argument if encrypted or plain is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted or plain is in NTT form - @throws std::logic_error if result ciphertext is transparent - */ - void add_plain_inplace(Ciphertext &encrypted, const Plaintext &plain); - - /** - Adds a ciphertext and a plaintext. This function adds a ciphertext and - a plaintext and stores the result in the destination parameter. The plaintext - must be valid for the current encryption parameters. - - @param[in] encrypted The ciphertext to add - @param[in] plain The plaintext to add - @param[out] destination The ciphertext to overwrite with the addition result - @throws std::invalid_argument if encrypted or plain is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted or plain is in NTT form - @throws std::logic_error if result ciphertext is transparent - */ - inline void add_plain(const Ciphertext &encrypted, const Plaintext &plain, - Ciphertext &destination) - { - destination = encrypted; - add_plain_inplace(destination, plain); - } - - /** - Subtracts a plaintext from a ciphertext. The plaintext must be valid for the - current encryption parameters. - - @param[in] encrypted The ciphertext to subtract from - @param[in] plain The plaintext to subtract - @throws std::invalid_argument if encrypted or plain is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted or plain is in NTT form - @throws std::logic_error if result ciphertext is transparent - */ - void sub_plain_inplace(Ciphertext &encrypted, const Plaintext &plain); - - /** - Subtracts a plaintext from a ciphertext. This function subtracts a plaintext - from a ciphertext and stores the result in the destination parameter. The - plaintext must be valid for the current encryption parameters. - - @param[in] encrypted The ciphertext to subtract from - @param[in] plain The plaintext to subtract - @param[out] destination The ciphertext to overwrite with the subtraction result - @throws std::invalid_argument if encrypted or plain is not valid for the - encryption parameters - @throws std::invalid_argument if encrypted or plain is in NTT form - @throws std::logic_error if result ciphertext is transparent - */ - inline void sub_plain(const Ciphertext &encrypted, const Plaintext &plain, - Ciphertext &destination) - { - destination = encrypted; - sub_plain_inplace(destination, plain); - } - - /** - Multiplies a ciphertext with a plaintext. The plaintext must be valid for the - current encryption parameters, and cannot be identially 0. Dynamic memory - allocations in the process are allocated from the memory pool pointed to by - the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to multiply - @param[in] plain The plaintext to multiply - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if the encrypted or plain is not valid for - the encryption parameters - @throws std::invalid_argument if encrypted and plain are in different NTT forms - @throws std::invalid_argument if, when using scheme_type::CKKS, the output - scale is too large for the encryption parameters - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - void multiply_plain_inplace(Ciphertext &encrypted, const Plaintext &plain, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Multiplies a ciphertext with a plaintext. This function multiplies - a ciphertext with a plaintext and stores the result in the destination - parameter. The plaintext must be a valid for the current encryption parameters, - and cannot be identially 0. Dynamic memory allocations in the process are - allocated from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to multiply - @param[in] plain The plaintext to multiply - @param[out] destination The ciphertext to overwrite with the multiplication result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if the encrypted or plain is not valid for - the encryption parameters - @throws std::invalid_argument if encrypted and plain are in different NTT forms - @throws std::invalid_argument if plain is zero - @throws std::invalid_argument if, when using scheme_type::CKKS, the output - scale is too large for the encryption parameters - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if result ciphertext is transparent - */ - inline void multiply_plain(const Ciphertext &encrypted, - const Plaintext &plain, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination = encrypted; - multiply_plain_inplace(destination, plain, std::move(pool)); - } - - /** - Transforms a plaintext to NTT domain. This functions applies the Number - Theoretic Transform to a plaintext by first embedding integers modulo the - plaintext modulus to integers modulo the coefficient modulus and then - performing David Harvey's NTT on the resulting polynomial. The transformation - is done with respect to encryption parameters corresponding to a given parms_id. - For the operation to be valid, the plaintext must have degree less than - poly_modulus_degree and each coefficient must be less than the plaintext - modulus, i.e., the plaintext must be a valid plaintext under the current - encryption parameters. Dynamic memory allocations in the process are allocated - from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] plain The plaintext to transform - @param[in] parms_id The parms_id with respect to which the NTT is done - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is already in NTT form - @throws std::invalid_argument if plain or parms_id is not valid for the - encryption parameters - @throws std::invalid_argument if pool is uninitialized - */ - void transform_to_ntt_inplace(Plaintext &plain, parms_id_type parms_id, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Transforms a plaintext to NTT domain. This functions applies the Number - Theoretic Transform to a plaintext by first embedding integers modulo the - plaintext modulus to integers modulo the coefficient modulus and then - performing David Harvey's NTT on the resulting polynomial. The transformation - is done with respect to encryption parameters corresponding to a given - parms_id. The result is stored in the destination_ntt parameter. For the - operation to be valid, the plaintext must have degree less than poly_modulus_degree - and each coefficient must be less than the plaintext modulus, i.e., the plaintext - must be a valid plaintext under the current encryption parameters. Dynamic - memory allocations in the process are allocated from the memory pool pointed - to by the given MemoryPoolHandle. - - @param[in] plain The plaintext to transform - @param[in] parms_id The parms_id with respect to which the NTT is done - @param[out] destinationNTT The plaintext to overwrite with the transformed result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if plain is already in NTT form - @throws std::invalid_argument if plain or parms_id is not valid for the - encryption parameters - @throws std::invalid_argument if pool is uninitialized - */ - inline void transform_to_ntt(const Plaintext &plain, - parms_id_type parms_id, Plaintext &destination_ntt, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination_ntt = plain; - transform_to_ntt_inplace(destination_ntt, parms_id, std::move(pool)); - } - - /** - Transforms a ciphertext to NTT domain. This functions applies David Harvey's - Number Theoretic Transform separately to each polynomial of a ciphertext. - - @param[in] encrypted The ciphertext to transform - @throws std::invalid_argument if encrypted is not valid for the encryption - parameters - @throws std::invalid_argument if encrypted is already in NTT form - @throws std::logic_error if result ciphertext is transparent - */ - void transform_to_ntt_inplace(Ciphertext &encrypted); - - /** - Transforms a ciphertext to NTT domain. This functions applies David Harvey's - Number Theoretic Transform separately to each polynomial of a ciphertext. - The result is stored in the destination_ntt parameter. - - @param[in] encrypted The ciphertext to transform - @param[out] destination_ntt The ciphertext to overwrite with the transformed result - @throws std::invalid_argument if encrypted is not valid for the encryption - parameters - @throws std::invalid_argument if encrypted is already in NTT form - @throws std::logic_error if result ciphertext is transparent - */ - inline void transform_to_ntt(const Ciphertext &encrypted, - Ciphertext &destination_ntt) - { - destination_ntt = encrypted; - transform_to_ntt_inplace(destination_ntt); - } - - /** - Transforms a ciphertext back from NTT domain. This functions applies the - inverse of David Harvey's Number Theoretic Transform separately to each - polynomial of a ciphertext. - - @param[in] encrypted_ntt The ciphertext to transform - @throws std::invalid_argument if encrypted_ntt is not valid for the encryption - parameters - @throws std::invalid_argument if encrypted_ntt is not in NTT form - @throws std::logic_error if result ciphertext is transparent - */ - void transform_from_ntt_inplace(Ciphertext &encrypted_ntt); - - /** - Transforms a ciphertext back from NTT domain. This functions applies the - inverse of David Harvey's Number Theoretic Transform separately to each - polynomial of a ciphertext. The result is stored in the destination parameter. - - @param[in] encrypted_ntt The ciphertext to transform - @param[out] destination The ciphertext to overwrite with the transformed result - @throws std::invalid_argument if encrypted_ntt is not valid for the encryption - parameters - @throws std::invalid_argument if encrypted_ntt is not in NTT form - @throws std::logic_error if result ciphertext is transparent - */ - inline void transform_from_ntt(const Ciphertext &encrypted_ntt, - Ciphertext &destination) - { - destination = encrypted_ntt; - transform_from_ntt_inplace(destination); - } - - /** - Applies a Galois automorphism to a ciphertext. To evaluate the Galois - automorphism, an appropriate set of Galois keys must also be provided. - Dynamic memory allocations in the process are allocated from the memory - pool pointed to by the given MemoryPoolHandle. - - - The desired Galois automorphism is given as a Galois element, and must be - an odd integer in the interval [1, M-1], where M = 2*N, and N = poly_modulus_degree. - Used with batching, a Galois element 3^i % M corresponds to a cyclic row - rotation i steps to the left, and a Galois element 3^(N/2-i) % M corresponds - to a cyclic row rotation i steps to the right. The Galois element M-1 corresponds - to a column rotation (row swap) in BFV, and complex conjugation in CKKS. - In the polynomial view (not batching), a Galois automorphism by a Galois - element p changes Enc(plain(x)) to Enc(plain(x^p)). - - @param[in] encrypted The ciphertext to apply the Galois automorphism to - @param[in] galois_elt The Galois element - @param[in] galois_keys The Galois keys - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if encrypted or galois_keys is not valid for - the encryption parameters - @throws std::invalid_argument if galois_keys do not correspond to the top - level parameters in the current context - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if encrypted has size larger than 2 - @throws std::invalid_argument if the Galois element is not valid - @throws std::invalid_argument if necessary Galois keys are not present - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - void apply_galois_inplace(Ciphertext &encrypted, - std::uint64_t galois_elt, const GaloisKeys &galois_keys, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - /** - Applies a Galois automorphism to a ciphertext and writes the result to the - destination parameter. To evaluate the Galois automorphism, an appropriate - set of Galois keys must also be provided. Dynamic memory allocations in - the process are allocated from the memory pool pointed to by the given - MemoryPoolHandle. - - The desired Galois automorphism is given as a Galois element, and must be - an odd integer in the interval [1, M-1], where M = 2*N, and N = poly_modulus_degree. - Used with batching, a Galois element 3^i % M corresponds to a cyclic row - rotation i steps to the left, and a Galois element 3^(N/2-i) % M corresponds - to a cyclic row rotation i steps to the right. The Galois element M-1 corresponds - to a column rotation (row swap) in BFV, and complex conjugation in CKKS. - In the polynomial view (not batching), a Galois automorphism by a Galois - element p changes Enc(plain(x)) to Enc(plain(x^p)). - - @param[in] encrypted The ciphertext to apply the Galois automorphism to - @param[in] galois_elt The Galois element - @param[in] galois_keys The Galois keys - @param[out] destination The ciphertext to overwrite with the result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if encrypted or galois_keys is not valid for - the encryption parameters - @throws std::invalid_argument if galois_keys do not correspond to the top - level parameters in the current context - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if encrypted has size larger than 2 - @throws std::invalid_argument if the Galois element is not valid - @throws std::invalid_argument if necessary Galois keys are not present - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void apply_galois(const Ciphertext &encrypted, - std::uint64_t galois_elt, const GaloisKeys &galois_keys, - Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination = encrypted; - apply_galois_inplace(destination, galois_elt, galois_keys, std::move(pool)); - } - - /** - Rotates plaintext matrix rows cyclically. When batching is used with the - BFV scheme, this function rotates the encrypted plaintext matrix rows - cyclically to the left (steps > 0) or to the right (steps < 0). Since - the size of the batched matrix is 2-by-(N/2), where N is the degree of - the polynomial modulus, the number of steps to rotate must have absolute - value at most N/2-1. Dynamic memory allocations in the process are allocated - from the memory pool pointed to by the given MemoryPoolHandle. - - - @param[in] encrypted The ciphertext to rotate - @param[in] steps The number of steps to rotate (negative left, positive right) - @param[in] galois_keys The Galois keys - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::BFV - @throws std::logic_error if the encryption parameters do not support batching - @throws std::invalid_argument if encrypted or galois_keys is not valid for - the encryption parameters - @throws std::invalid_argument if galois_keys do not correspond to the top - level parameters in the current context - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if encrypted has size larger than 2 - @throws std::invalid_argument if steps has too big absolute value - @throws std::invalid_argument if necessary Galois keys are not present - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void rotate_rows_inplace(Ciphertext &encrypted, - int steps, const GaloisKeys &galois_keys, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - if (context_->key_context_data()->parms().scheme() != scheme_type::BFV) - { - throw std::logic_error("unsupported scheme"); - } - rotate_internal(encrypted, steps, galois_keys, std::move(pool)); - } - - /** - Rotates plaintext matrix rows cyclically. When batching is used with the - BFV scheme, this function rotates the encrypted plaintext matrix rows - cyclically to the left (steps > 0) or to the right (steps < 0) and writes - the result to the destination parameter. Since the size of the batched - matrix is 2-by-(N/2), where N is the degree of the polynomial modulus, - the number of steps to rotate must have absolute value at most N/2-1. Dynamic - memory allocations in the process are allocated from the memory pool pointed - to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to rotate - @param[in] steps The number of steps to rotate (negative left, positive right) - @param[in] galois_keys The Galois keys - @param[out] destination The ciphertext to overwrite with the rotated result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::BFV - @throws std::logic_error if the encryption parameters do not support batching - @throws std::invalid_argument if encrypted or galois_keys is not valid for - the encryption parameters - @throws std::invalid_argument if galois_keys do not correspond to the top - level parameters in the current context - @throws std::invalid_argument if encrypted is in NTT form - @throws std::invalid_argument if encrypted has size larger than 2 - @throws std::invalid_argument if steps has too big absolute value - @throws std::invalid_argument if necessary Galois keys are not present - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void rotate_rows(const Ciphertext &encrypted, int steps, - const GaloisKeys &galois_keys, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination = encrypted; - rotate_rows_inplace(destination, steps, galois_keys, std::move(pool)); - } - - /** - Rotates plaintext matrix columns cyclically. When batching is used with - the BFV scheme, this function rotates the encrypted plaintext matrix - columns cyclically. Since the size of the batched matrix is 2-by-(N/2), - where N is the degree of the polynomial modulus, this means simply swapping - the two rows. Dynamic memory allocations in the process are allocated from - the memory pool pointed to by the given MemoryPoolHandle. - - - @param[in] encrypted The ciphertext to rotate - @param[in] galois_keys The Galois keys - @param[out] destination The ciphertext to overwrite with the rotated result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::BFV - @throws std::logic_error if the encryption parameters do not support batching - @throws std::invalid_argument if encrypted or galois_keys is not valid for - the encryption parameters - @throws std::invalid_argument if galois_keys do not correspond to the top - level parameters in the current context - @throws std::invalid_argument if encrypted is in NTT form - @throws std::invalid_argument if encrypted has size larger than 2 - @throws std::invalid_argument if necessary Galois keys are not present - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void rotate_columns_inplace(Ciphertext &encrypted, - const GaloisKeys &galois_keys, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - if (context_->key_context_data()->parms().scheme() != scheme_type::BFV) - { - throw std::logic_error("unsupported scheme"); - } - conjugate_internal(encrypted, galois_keys, std::move(pool)); - } - - /** - Rotates plaintext matrix columns cyclically. When batching is used with - the BFV scheme, this function rotates the encrypted plaintext matrix columns - cyclically, and writes the result to the destination parameter. Since the - size of the batched matrix is 2-by-(N/2), where N is the degree of the - polynomial modulus, this means simply swapping the two rows. Dynamic memory - allocations in the process are allocated from the memory pool pointed to - by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to rotate - @param[in] galois_keys The Galois keys - @param[out] destination The ciphertext to overwrite with the rotated result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::BFV - @throws std::logic_error if the encryption parameters do not support batching - @throws std::invalid_argument if encrypted or galois_keys is not valid for - the encryption parameters - @throws std::invalid_argument if galois_keys do not correspond to the top - level parameters in the current context - @throws std::invalid_argument if encrypted is in NTT form - @throws std::invalid_argument if encrypted has size larger than 2 - @throws std::invalid_argument if necessary Galois keys are not present - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void rotate_columns(const Ciphertext &encrypted, - const GaloisKeys &galois_keys, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination = encrypted; - rotate_columns_inplace(destination, galois_keys, std::move(pool)); - } - - /** - Rotates plaintext vector cyclically. When using the CKKS scheme, this function - rotates the encrypted plaintext vector cyclically to the left (steps > 0) - or to the right (steps < 0). Since the size of the batched matrix is - 2-by-(N/2), where N is the degree of the polynomial modulus, the number - of steps to rotate must have absolute value at most N/2-1. Dynamic memory - allocations in the process are allocated from the memory pool pointed to - by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to rotate - @param[in] steps The number of steps to rotate (negative left, positive right) - @param[in] galois_keys The Galois keys - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::CKKS - @throws std::invalid_argument if encrypted or galois_keys is not valid for - the encryption parameters - @throws std::invalid_argument if galois_keys do not correspond to the top - level parameters in the current context - @throws std::invalid_argument if encrypted is not in the default NTT form - @throws std::invalid_argument if encrypted has size larger than 2 - @throws std::invalid_argument if steps has too big absolute value - @throws std::invalid_argument if necessary Galois keys are not present - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void rotate_vector_inplace(Ciphertext &encrypted, - int steps, const GaloisKeys &galois_keys, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - if (context_->key_context_data()->parms().scheme() != scheme_type::CKKS) - { - throw std::logic_error("unsupported scheme"); - } - rotate_internal(encrypted, steps, galois_keys, std::move(pool)); - } - - /** - Rotates plaintext vector cyclically. When using the CKKS scheme, this function - rotates the encrypted plaintext vector cyclically to the left (steps > 0) - or to the right (steps < 0) and writes the result to the destination parameter. - Since the size of the batched matrix is 2-by-(N/2), where N is the degree - of the polynomial modulus, the number of steps to rotate must have absolute - value at most N/2-1. Dynamic memory allocations in the process are allocated - from the memory pool pointed to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to rotate - @param[in] steps The number of steps to rotate (negative left, positive right) - @param[in] galois_keys The Galois keys - @param[out] destination The ciphertext to overwrite with the rotated result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::CKKS - @throws std::invalid_argument if encrypted or galois_keys is not valid for - the encryption parameters - @throws std::invalid_argument if galois_keys do not correspond to the top - level parameters in the current context - @throws std::invalid_argument if encrypted is in NTT form - @throws std::invalid_argument if encrypted has size larger than 2 - @throws std::invalid_argument if steps has too big absolute value - @throws std::invalid_argument if necessary Galois keys are not present - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void rotate_vector(const Ciphertext &encrypted, int steps, - const GaloisKeys &galois_keys, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination = encrypted; - rotate_vector_inplace(destination, steps, galois_keys, std::move(pool)); - } - - /** - Complex conjugates plaintext slot values. When using the CKKS scheme, this - function complex conjugates all values in the underlying plaintext. Dynamic - memory allocations in the process are allocated from the memory pool pointed - to by the given MemoryPoolHandle. - - @param[in] encrypted The ciphertext to rotate - @param[in] galois_keys The Galois keys - @param[out] destination The ciphertext to overwrite with the rotated result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::CKKS - @throws std::invalid_argument if encrypted or galois_keys is not valid for - the encryption parameters - @throws std::invalid_argument if galois_keys do not correspond to the top - level parameters in the current context - @throws std::invalid_argument if encrypted is in NTT form - @throws std::invalid_argument if encrypted has size larger than 2 - @throws std::invalid_argument if necessary Galois keys are not present - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void complex_conjugate_inplace(Ciphertext &encrypted, - const GaloisKeys &galois_keys, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - if (context_->key_context_data()->parms().scheme() != scheme_type::CKKS) - { - throw std::logic_error("unsupported scheme"); - } - conjugate_internal(encrypted, galois_keys, std::move(pool)); - } - - /** - Complex conjugates plaintext slot values. When using the CKKS scheme, this - function complex conjugates all values in the underlying plaintext, and - writes the result to the destination parameter. Dynamic memory allocations - in the process are allocated from the memory pool pointed to by the given - MemoryPoolHandle. - - @param[in] encrypted The ciphertext to rotate - @param[in] galois_keys The Galois keys - @param[out] destination The ciphertext to overwrite with the rotated result - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::logic_error if scheme is not scheme_type::CKKS - @throws std::invalid_argument if encrypted or galois_keys is not valid for - the encryption parameters - @throws std::invalid_argument if galois_keys do not correspond to the top - level parameters in the current context - @throws std::invalid_argument if encrypted is in NTT form - @throws std::invalid_argument if encrypted has size larger than 2 - @throws std::invalid_argument if necessary Galois keys are not present - @throws std::invalid_argument if pool is uninitialized - @throws std::logic_error if keyswitching is not supported by the context - @throws std::logic_error if result ciphertext is transparent - */ - inline void complex_conjugate(const Ciphertext &encrypted, - const GaloisKeys &galois_keys, Ciphertext &destination, - MemoryPoolHandle pool = MemoryManager::GetPool()) - { - destination = encrypted; - complex_conjugate_inplace(destination, galois_keys, std::move(pool)); - } - - /** - Enables access to private members of seal::Evaluator for .NET wrapper. - */ - struct EvaluatorPrivateHelper; - - private: - Evaluator(const Evaluator ©) = delete; - - Evaluator(Evaluator &&source) = delete; - - Evaluator &operator =(const Evaluator &assign) = delete; - - Evaluator &operator =(Evaluator &&assign) = delete; - - void bfv_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, - MemoryPoolHandle pool); - - void ckks_multiply(Ciphertext &encrypted1, const Ciphertext &encrypted2, - MemoryPoolHandle pool); - - void bfv_square(Ciphertext &encrypted, MemoryPoolHandle pool); - - void ckks_square(Ciphertext &encrypted, MemoryPoolHandle pool); - - void relinearize_internal(Ciphertext &encrypted, const RelinKeys &relin_keys, - std::size_t destination_size, MemoryPoolHandle pool); - - void mod_switch_scale_to_next(const Ciphertext &encrypted, Ciphertext &destination, - MemoryPoolHandle pool); - - void mod_switch_drop_to_next(const Ciphertext &encrypted, Ciphertext &destination, - MemoryPoolHandle pool); - - void mod_switch_drop_to_next(Plaintext &plain); - - void rotate_internal(Ciphertext &encrypted, int steps, - const GaloisKeys &galois_keys, MemoryPoolHandle pool); - - inline void conjugate_internal(Ciphertext &encrypted, - const GaloisKeys &galois_keys, MemoryPoolHandle pool) - { - // Verify parameters. - auto context_data_ptr = context_->get_context_data(encrypted.parms_id()); - if (!context_data_ptr) - { - throw std::invalid_argument("encrypted is not valid for encryption parameters"); - } - - // Extract encryption parameters. - auto &context_data = *context_data_ptr; - if (!context_data.qualifiers().using_batching) - { - throw std::logic_error("encryption parameters do not support batching"); - } - - auto &parms = context_data.parms(); - std::size_t coeff_count = parms.poly_modulus_degree(); - - // Perform rotation and key switching - apply_galois_inplace(encrypted, util::steps_to_galois_elt(0, coeff_count), - galois_keys, std::move(pool)); - } - - inline void decompose_single_coeff(const SEALContext::ContextData &context_data, - const std::uint64_t *value, std::uint64_t *destination, util::MemoryPool &pool) - { - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - std::size_t coeff_mod_count = coeff_modulus.size(); -#ifdef SEAL_DEBUG - if (value == nullptr) - { - throw std::invalid_argument("value cannot be null"); - } - if (destination == nullptr) - { - throw std::invalid_argument("destination cannot be null"); - } - if (destination == value) - { - throw std::invalid_argument("value cannot be the same as destination"); - } -#endif - if (coeff_mod_count == 1) - { - util::set_uint_uint(value, coeff_mod_count, destination); - return; - } - - auto value_copy(util::allocate_uint(coeff_mod_count, pool)); - for (std::size_t j = 0; j < coeff_mod_count; j++) - { - //destination[j] = util::modulo_uint( - // value, coeff_mod_count, coeff_modulus_[j], pool); - - // Manually inlined for efficiency - // Make a fresh copy of value - util::set_uint_uint(value, coeff_mod_count, value_copy.get()); - - // Starting from the top, reduce always 128-bit blocks - for (std::size_t k = coeff_mod_count - 1; k--; ) - { - value_copy[k] = util::barrett_reduce_128( - value_copy.get() + k, coeff_modulus[j]); - } - destination[j] = value_copy[0]; - } - } - - inline void decompose(const SEALContext::ContextData &context_data, - const std::uint64_t *value, std::uint64_t *destination, util::MemoryPool &pool) - { - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - std::size_t coeff_count = parms.poly_modulus_degree(); - std::size_t coeff_mod_count = coeff_modulus.size(); - std::size_t rns_poly_uint64_count = - util::mul_safe(coeff_mod_count, coeff_count); -#ifdef SEAL_DEBUG - if (value == nullptr) - { - throw std::invalid_argument("value cannot be null"); - } - if (destination == nullptr) - { - throw std::invalid_argument("destination cannot be null"); - } - if (destination == value) - { - throw std::invalid_argument("value cannot be the same as destination"); - } -#endif - if (coeff_mod_count == 1) - { - util::set_uint_uint(value, rns_poly_uint64_count, destination); - return; - } - - auto value_copy(util::allocate_uint(coeff_mod_count, pool)); - for (std::size_t i = 0; i < coeff_count; i++) - { - for (std::size_t j = 0; j < coeff_mod_count; j++) - { - //destination[i + (j * coeff_count)] = - // util::modulo_uint(value + (i * coeff_mod_count), - // coeff_mod_count, coeff_modulus_[j], pool); - - // Manually inlined for efficiency - // Make a fresh copy of value + (i * coeff_mod_count) - util::set_uint_uint( - value + (i * coeff_mod_count), coeff_mod_count, value_copy.get()); - - // Starting from the top, reduce always 128-bit blocks - for (std::size_t k = coeff_mod_count - 1; k--; ) - { - value_copy[k] = util::barrett_reduce_128( - value_copy.get() + k, coeff_modulus[j]); - } - destination[i + (j * coeff_count)] = value_copy[0]; - } - } - } - - void switch_key_inplace(Ciphertext &encrypted, - const std::uint64_t *target, - const KSwitchKeys &kswitch_keys, - std::size_t key_index, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - void multiply_plain_normal(Ciphertext &encrypted, const Plaintext &plain, - util::MemoryPool &pool); - - void multiply_plain_ntt(Ciphertext &encrypted_ntt, const Plaintext &plain_ntt); - - void populate_Zmstar_to_generator(); - - std::shared_ptr context_{ nullptr }; - - std::map> Zmstar_to_generator_{}; - }; -} diff --git a/SEAL/native/src/seal/galoiskeys.h b/SEAL/native/src/seal/galoiskeys.h deleted file mode 100644 index 4d972ad..0000000 --- a/SEAL/native/src/seal/galoiskeys.h +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include "seal/util/defines.h" -#include "seal/ciphertext.h" -#include "seal/memorymanager.h" -#include "seal/encryptionparams.h" -#include "seal/kswitchkeys.h" -#include "seal/util/common.h" - -namespace seal -{ - /** - Class to store Galois keys. - - @par Slot Rotations - Galois keys are used together with batching (BatchEncoder). If the polynomial modulus - is a polynomial of degree N, in batching the idea is to view a plaintext polynomial as - a 2-by-(N/2) matrix of integers modulo plaintext modulus. Normal homomorphic computations - operate on such encrypted matrices element (slot) wise. However, special rotation - operations allow us to also rotate the matrix rows cyclically in either direction, and - rotate the columns (swap the rows). These operations require the Galois keys. - - @par Thread Safety - In general, reading from GaloisKeys is thread-safe as long as no other thread is - concurrently mutating it. This is due to the underlying data structure storing the - Galois keys not being thread-safe. - - @see SecretKey for the class that stores the secret key. - @see PublicKey for the class that stores the public key. - @see RelinKeys for the class that stores the relinearization keys. - @see KeyGenerator for the class that generates the Galois keys. - */ - class GaloisKeys : public KSwitchKeys - { - public: - /** - Returns the index of a Galois key in the backing KSwitchKeys instance that - corresponds to the given Galois element, assuming that it exists in the - backing KSwitchKeys. - - @param[in] galois_elt The Galois element - @throws std::invalid_argument if galois_elt is not valid - */ - SEAL_NODISCARD inline static std::size_t get_index( - std::uint64_t galois_elt) - { - // Verify parameters - if (!(galois_elt & 1)) - { - throw std::invalid_argument("galois_elt is not valid"); - } - return util::safe_cast((galois_elt - 1) >> 1); - } - - /** - Returns whether a Galois key corresponding to a given Galois element exists. - - @param[in] galois_elt The Galois element - @throws std::invalid_argument if galois_elt is not valid - */ - SEAL_NODISCARD inline bool has_key(std::uint64_t galois_elt) const - { - std::size_t index = get_index(galois_elt); - return data().size() > index && !data()[index].empty(); - } - - /** - Returns a const reference to a Galois key. The returned Galois key corresponds - to the given Galois element. - - @param[in] galois_elt The Galois element - @throws std::invalid_argument if the key corresponding to galois_elt does not exist - */ - SEAL_NODISCARD inline const auto &key(std::uint64_t galois_elt) const - { - return KSwitchKeys::data(get_index(galois_elt)); - } - }; -} diff --git a/SEAL/native/src/seal/intarray.h b/SEAL/native/src/seal/intarray.h deleted file mode 100644 index 521b990..0000000 --- a/SEAL/native/src/seal/intarray.h +++ /dev/null @@ -1,477 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include "seal/memorymanager.h" -#include "seal/util/pointer.h" -#include "seal/util/defines.h" -#include "seal/util/common.h" -#include -#include -#include -#include -#include - -namespace seal -{ - /** - A resizable container for storing an array of integral data types. The - allocations are done from a memory pool. The IntArray class is mainly - intended for internal use and provides the underlying data structure for - Plaintext and Ciphertext classes. - - @par Size and Capacity - IntArray allows the user to pre-allocate memory (capacity) for the array - in cases where the array is known to be resized in the future and memory - moves are to be avoided at the time of resizing. The size of the IntArray - can never exceed its capacity. The capacity and size can be changed using - the reserve and resize functions, respectively. - - @par Thread Safety - In general, reading from IntArray is thread-safe as long as no other thread - is concurrently mutating it. - */ - template::value>> - class IntArray - { - public: - using size_type = std::size_t; - using T = typename std::decay::type; - - /** - Creates a new IntArray. No memory is allocated by this constructor. - - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if pool is uninitialized - */ - IntArray(MemoryPoolHandle pool = MemoryManager::GetPool()) : - pool_(std::move(pool)) - { - if (!pool_) - { - throw std::invalid_argument("pool is uninitialized"); - } - } - - /** - Creates a new IntArray with given size. - - @param[in] size The size of the array - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if size is less than zero - @throws std::invalid_argument if pool is uninitialized - */ - explicit IntArray(size_type size, - MemoryPoolHandle pool = MemoryManager::GetPool()) : - pool_(std::move(pool)) - { - if (!pool_) - { - throw std::invalid_argument("pool is uninitialized"); - } - - // Reserve memory, resize, and set to zero - resize(size); - } - - /** - Creates a new IntArray with given capacity and size. - - @param[in] capacity The capacity of the array - @param[in] size The size of the array - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if capacity is less than size - @throws std::invalid_argument if capacity is less than zero - @throws std::invalid_argument if size is less than zero - @throws std::invalid_argument if pool is uninitialized - */ - explicit IntArray(size_type capacity, size_type size, - MemoryPoolHandle pool = MemoryManager::GetPool()) : - pool_(std::move(pool)) - { - if (!pool_) - { - throw std::invalid_argument("pool is uninitialized"); - } - if (capacity < size) - { - throw std::invalid_argument("capacity cannot be smaller than size"); - } - - // Reserve memory, resize, and set to zero - reserve(capacity); - resize(size); - } - - /** - Constructs a new IntArray by copying a given one. - - @param[in] copy The IntArray to copy from - */ - IntArray(const IntArray ©) : - pool_(MemoryManager::GetPool()), - capacity_(copy.size_), - size_(copy.size_), - data_(util::allocate(copy.size_, pool_)) - { - // Copy over value - std::copy_n(copy.cbegin(), copy.size_, begin()); - } - - /** - Constructs a new IntArray by moving a given one. - - @param[in] source The IntArray to move from - */ - IntArray(IntArray &&source) noexcept : - pool_(std::move(source.pool_)), - capacity_(source.capacity_), - size_(source.size_), - data_(std::move(source.data_)) - { - } - - /** - Returns a pointer to the beginning of the array data. - */ - SEAL_NODISCARD inline T* begin() noexcept - { - return data_.get(); - } - - /** - Returns a constant pointer to the beginning of the array data. - */ - SEAL_NODISCARD inline const T* cbegin() const noexcept - { - return data_.get(); - } - - /** - Returns a pointer to the end of the array data. - */ - SEAL_NODISCARD inline T* end() noexcept - { - return size_ ? begin() + size_ : begin(); - } - - /** - Returns a constant pointer to the end of the array data. - */ - SEAL_NODISCARD inline const T* cend() const noexcept - { - return size_ ? cbegin() + size_ : cbegin(); - } -#ifdef SEAL_USE_MSGSL_SPAN - /** - Returns a span pointing to the beginning of the IntArray. - */ - SEAL_NODISCARD inline gsl::span span() - { - return gsl::span( - begin(), static_cast(size_)); - } - - /** - Returns a span pointing to the beginning of the IntArray. - */ - SEAL_NODISCARD inline gsl::span span() const - { - return gsl::span( - cbegin(), static_cast(size_)); - } -#endif - /** - Returns a constant reference to the array element at a given index. - This function performs bounds checking and will throw an error if - the index is out of range. - - @param[in] index The index of the array element - @throws std::out_of_range if index is out of range - */ - SEAL_NODISCARD inline const T &at(size_type index) const - { - if (index >= size_) - { - throw std::out_of_range("index must be within [0, size)"); - } - return data_[index]; - } - - /** - Returns a reference to the array element at a given index. This - function performs bounds checking and will throw an error if the - index is out of range. - - @param[in] index The index of the array element - @throws std::out_of_range if index is out of range - */ - SEAL_NODISCARD inline T &at(size_type index) - { - if (index >= size_) - { - throw std::out_of_range("index must be within [0, size)"); - } - return data_[index]; - } - - /** - Returns a constant reference to the array element at a given index. - This function does not perform bounds checking. - - @param[in] index The index of the array element - */ - SEAL_NODISCARD inline const T &operator [](size_type index) const - { - return data_[index]; - } - - /** - Returns a reference to the array element at a given index. This - function does not perform bounds checking. - - @param[in] index The index of the array element - */ - SEAL_NODISCARD inline T &operator [](size_type index) - { - return data_[index]; - } - - /** - Returns whether the array has size zero. - */ - SEAL_NODISCARD inline bool empty() const noexcept - { - return (size_ == 0); - } - - /** - Returns the largest possible array size. - */ - SEAL_NODISCARD inline size_type max_size() const noexcept - { - return std::numeric_limits::max(); - } - - /** - Returns the size of the array. - */ - SEAL_NODISCARD inline size_type size() const noexcept - { - return size_; - } - - /** - Returns the capacity of the array. - */ - SEAL_NODISCARD inline size_type capacity() const noexcept - { - return capacity_; - } - - /** - Returns the currently used MemoryPoolHandle. - */ - SEAL_NODISCARD inline MemoryPoolHandle pool() const noexcept - { - return pool_; - } - - /** - Releases any allocated memory to the memory pool and sets the size - and capacity of the array to zero. - */ - inline void release() noexcept - { - capacity_ = 0; - size_ = 0; - data_.release(); - } - - /** - Sets the size of the array to zero. The capacity is not changed. - */ - inline void clear() noexcept - { - size_ = 0; - } - - /** - Allocates enough memory for storing a given number of elements without - changing the size of the array. If the given capacity is smaller than - the current size, the size is automatically set to equal the new capacity. - - @param[in] capacity The capacity of the array - */ - inline void reserve(size_type capacity) - { - size_type copy_size = std::min(capacity, size_); - - // Create new allocation and copy over value - auto new_data(util::allocate(capacity, pool_)); - std::copy_n(cbegin(), copy_size, new_data.get()); - std::swap(data_, new_data); - - // Set the coeff_count and capacity - capacity_ = capacity; - size_ = copy_size; - } - - /** - Reallocates the array so that its capacity exactly matches its size. - */ - inline void shrink_to_fit() - { - reserve(size_); - } - - /** - Resizes the array to given size. When resizing to larger size the data - in the array remains unchanged and any new space is initialized to zero; - when resizing to smaller size the last elements of the array are dropped. - If the capacity is not already large enough to hold the new size, the - array is also reallocated. - - @param[in] size The size of the array - */ - inline void resize(size_type size) - { - if (size <= capacity_) - { - // Are we changing size to bigger within current capacity? - // If so, need to set top terms to zero - if (size > size_) - { - std::fill(end(), begin() + size, T{ 0 }); - } - - // Set the size - size_ = size; - - return; - } - - // At this point we know for sure that size_ <= capacity_ < size so need - // to reallocate to bigger - auto new_data(util::allocate(size, pool_)); - std::copy_n(cbegin(), size_, new_data.get()); - std::fill(new_data.get() + size_, new_data.get() + size, T{ 0 }); - std::swap(data_, new_data); - - // Set the coeff_count and capacity - capacity_ = size; - size_ = size; - } - - /** - Copies a given IntArray to the current one. - - @param[in] assign The IntArray to copy from - */ - inline IntArray &operator =(const IntArray &assign) - { - // Check for self-assignment - if (this == &assign) - { - return *this; - } - - // First resize to correct size - resize(assign.size_); - - // Size is guaranteed to be OK now so copy over - std::copy_n(assign.cbegin(), assign.size_, begin()); - - return *this; - } - - /** - Moves a given IntArray to the current one. - - @param[in] assign The IntArray to move from - */ - IntArray &operator =(IntArray &&assign) noexcept - { - pool_ = std::move(assign.pool_); - capacity_ = assign.capacity_; - size_ = assign.size_; - data_ = std::move(assign.data_); - - return *this; - } - - /** - Saves the IntArray to an output stream. The output is in binary format - and not human-readable. The output stream must have the "binary" flag set. - - @param[in] stream The stream to save the IntArray to - @throws std::exception if the IntArray could not be written to stream - */ - inline void save(std::ostream &stream) const - { - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - stream.exceptions(std::ios_base::badbit | std::ios_base::failbit); - - std::uint64_t size64 = size_; - stream.write(reinterpret_cast(&size64), sizeof(std::uint64_t)); - stream.write(reinterpret_cast(cbegin()), - util::safe_cast( - util::mul_safe(size_, util::safe_cast(sizeof(T))))); - } - catch (const std::exception &) - { - stream.exceptions(old_except_mask); - throw; - } - - stream.exceptions(old_except_mask); - } - - /** - Loads a IntArray from an input stream overwriting the current IntArray. - - @param[in] stream The stream to load the IntArray from - @throws std::exception if a valid IntArray could not be read from stream - */ - inline void load(std::istream &stream) - { - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - stream.exceptions(std::ios_base::badbit | std::ios_base::failbit); - - std::uint64_t size64 = 0; - stream.read(reinterpret_cast(&size64), sizeof(std::uint64_t)); - - // Set new size - resize(util::safe_cast(size64)); - - // Read data - stream.read(reinterpret_cast(begin()), - util::safe_cast( - util::mul_safe(size_, util::safe_cast(sizeof(T))))); - } - catch (const std::exception &) - { - stream.exceptions(old_except_mask); - throw; - } - - stream.exceptions(old_except_mask); - } - - private: - MemoryPoolHandle pool_; - - size_type capacity_ = 0; - - size_type size_ = 0; - - util::Pointer data_; - }; -} diff --git a/SEAL/native/src/seal/intencoder.cpp b/SEAL/native/src/seal/intencoder.cpp deleted file mode 100644 index e3fdc47..0000000 --- a/SEAL/native/src/seal/intencoder.cpp +++ /dev/null @@ -1,299 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include -#include -#include "seal/intencoder.h" -#include "seal/util/common.h" -#include "seal/util/polyarith.h" -#include "seal/util/pointer.h" -#include "seal/util/defines.h" -#include "seal/util/uintarithsmallmod.h" - -using namespace std; -using namespace seal::util; - -namespace seal -{ - IntegerEncoder::IntegerEncoder(std::shared_ptr context) : - context_(std::move(context)) - { - // Verify parameters - if (!context_) - { - throw invalid_argument("invalid context"); - } - - // Unlike in other classes, we do not check "context_->parameters_set()". - // The IntegerEncoder should function without valid encryption parameters - // as long as the scheme is BFV and the plaintext modulus is at least 2. - auto &context_data = *context_->first_context_data(); - if (context_data.parms().scheme() != scheme_type::BFV) - { - throw invalid_argument("unsupported scheme"); - } - if (plain_modulus().bit_count() <= 1) - { - throw invalid_argument("plain_modulus must be at least 2"); - } - - if (plain_modulus().value() == 2) - { - // In this case we don't allow any negative numbers - coeff_neg_threshold_ = 2; - } - else - { - // Normal negative threshold case - coeff_neg_threshold_ = (plain_modulus().value() + 1) >> 1; - } - neg_one_ = plain_modulus().value() - 1; - } - - Plaintext IntegerEncoder::encode(uint64_t value) - { - Plaintext result; - encode(value, result); - return result; - } - - void IntegerEncoder::encode(uint64_t value, Plaintext &destination) - { - size_t encode_coeff_count = safe_cast( - get_significant_bit_count(value)); - destination.resize(encode_coeff_count); - destination.set_zero(); - - size_t coeff_index = 0; - while (value != 0) - { - if ((value & 1) != 0) - { - destination[coeff_index] = 1; - } - value >>= 1; - coeff_index++; - } - } - - Plaintext IntegerEncoder::encode(int64_t value) - { - Plaintext result; - encode(value, result); - return result; - } - - void IntegerEncoder::encode(int64_t value, Plaintext &destination) - { - if (value < 0) - { - uint64_t pos_value = static_cast(-value); - size_t encode_coeff_count = safe_cast( - get_significant_bit_count(pos_value)); - destination.resize(encode_coeff_count); - destination.set_zero(); - - size_t coeff_index = 0; - while (pos_value != 0) - { - if ((pos_value & 1) != 0) - { - destination[coeff_index] = neg_one_; - } - pos_value >>= 1; - coeff_index++; - } - } - else - { - encode(static_cast(value), destination); - } - } - - Plaintext IntegerEncoder::encode(const BigUInt &value) - { - Plaintext result; - encode(value, result); - return result; - } - - void IntegerEncoder::encode(const BigUInt &value, Plaintext &destination) - { - size_t encode_coeff_count = safe_cast( - value.significant_bit_count()); - destination.resize(encode_coeff_count); - destination.set_zero(); - - size_t coeff_index = 0; - size_t coeff_count = safe_cast(value.significant_bit_count()); - size_t coeff_uint64_count = value.uint64_count(); - while (coeff_index < coeff_count) - { - if (is_bit_set_uint(value.data(), coeff_uint64_count, - safe_cast(coeff_index))) - { - destination[coeff_index] = 1; - } - coeff_index++; - } - } - - uint32_t IntegerEncoder::decode_uint32(const Plaintext &plain) - { - uint64_t value64 = decode_uint64(plain); - if (value64 > UINT32_MAX) - { - throw invalid_argument("output out of range"); - } - return static_cast(value64); - } - - uint64_t IntegerEncoder::decode_uint64(const Plaintext &plain) - { - BigUInt bigvalue = decode_biguint(plain); - int bit_count = bigvalue.significant_bit_count(); - if (bit_count > bits_per_uint64) - { - // Decoded value has more bits than fit in a 64-bit uint. - throw invalid_argument("output out of range"); - } - return bit_count > 0 ? bigvalue.data()[0] : 0; - } - - int32_t IntegerEncoder::decode_int32(const Plaintext &plain) - { - int64_t value64 = decode_int64(plain); - return safe_cast(value64); - } - - int64_t IntegerEncoder::decode_int64(const Plaintext &plain) - { - int64_t result = 0; - for (size_t bit_index = plain.significant_coeff_count(); bit_index--; ) - { - unsigned long long coeff = plain[bit_index]; - - // Left shift result. - int64_t next_result = result << 1; - if ((next_result < 0) != (result < 0)) - { - // Check for overflow. - throw invalid_argument("output out of range"); - } - - // Get sign/magnitude of coefficient. - int coeff_bit_count = get_significant_bit_count(coeff); - if (coeff >= plain_modulus().value()) - { - // Coefficient is bigger than plaintext modulus - throw invalid_argument("plain does not represent a valid plaintext polynomial"); - } - bool coeff_is_negative = coeff >= coeff_neg_threshold_; - unsigned long long pos_value = coeff; - if (coeff_is_negative) - { - pos_value = plain_modulus().value() - pos_value; - coeff_bit_count = get_significant_bit_count(pos_value); - } - - if (coeff_bit_count > bits_per_uint64 - 1) - { - // Absolute value of coefficient is too large to represent in a int64_t, so overflow. - throw invalid_argument("output out of range"); - } - - int64_t coeff_value = safe_cast(pos_value); - if (coeff_is_negative) - { - coeff_value = -coeff_value; - } - bool next_result_was_negative = next_result < 0; - next_result += coeff_value; - bool next_result_is_negative = next_result < 0; - if ((next_result_was_negative == coeff_is_negative) && - (next_result_was_negative != next_result_is_negative)) - { - // Accumulation and coefficient had same signs, but accumulator changed signs after addition, so must be overflow. - throw invalid_argument("output out of range"); - } - result = next_result; - } - return result; - } - - BigUInt IntegerEncoder::decode_biguint(const Plaintext &plain) - { - size_t result_uint64_count = 1; - size_t bits_per_uint64_sz = safe_cast(bits_per_uint64); - size_t result_bit_capacity = result_uint64_count * bits_per_uint64_sz; - BigUInt resultint(safe_cast(result_bit_capacity)); - bool result_is_negative = false; - uint64_t *result = resultint.data(); - for (size_t bit_index = plain.significant_coeff_count(); bit_index--; ) - { - unsigned long long coeff = plain[bit_index]; - - // Left shift result, resizing if highest bit set. - if (is_bit_set_uint(result, result_uint64_count, - safe_cast(result_bit_capacity) - 1)) - { - // Resize to make bigger. - result_uint64_count++; - result_bit_capacity = mul_safe(result_uint64_count, bits_per_uint64_sz); - resultint.resize(safe_cast(result_bit_capacity)); - result = resultint.data(); - } - left_shift_uint(result, 1, result_uint64_count, result); - - // Get sign/magnitude of coefficient. - if (coeff >= plain_modulus().value()) - { - // Coefficient is bigger than plaintext modulus - throw invalid_argument("plain does not represent a valid plaintext polynomial"); - } - bool coeff_is_negative = coeff >= coeff_neg_threshold_; - unsigned long long pos_value = coeff; - if (coeff_is_negative) - { - pos_value = plain_modulus().value() - pos_value; - } - - // Add or subtract-in coefficient. - if (result_is_negative == coeff_is_negative) - { - // Result and coefficient have same signs so add. - if (add_uint_uint64(result, pos_value, result_uint64_count, result)) - { - // Add produced a carry that didn't fit, so resize and put it in. - int carry_bit_index = safe_cast(mul_safe( - result_uint64_count, bits_per_uint64_sz)); - result_uint64_count++; - result_bit_capacity = mul_safe( - result_uint64_count, bits_per_uint64_sz); - resultint.resize(safe_cast(result_bit_capacity)); - result = resultint.data(); - set_bit_uint(result, result_uint64_count, carry_bit_index); - } - } - else - { - // Result and coefficient have opposite signs so subtract. - if (sub_uint_uint64(result, pos_value, result_uint64_count, result)) - { - // Subtraction produced a borrow so coefficient is larger (in magnitude) - // than result, so need to negate result. - negate_uint(result, result_uint64_count, result); - result_is_negative = !result_is_negative; - } - } - } - - // Verify result is non-negative. - if (result_is_negative && !resultint.is_zero()) - { - throw invalid_argument("poly must decode to positive value"); - } - return resultint; - } -} diff --git a/SEAL/native/src/seal/intencoder.h b/SEAL/native/src/seal/intencoder.h deleted file mode 100644 index 46d6c46..0000000 --- a/SEAL/native/src/seal/intencoder.h +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include "seal/context.h" -#include "seal/biguint.h" -#include "seal/plaintext.h" -#include "seal/smallmodulus.h" -#include "seal/memorymanager.h" - -namespace seal -{ - /** - Encodes integers into plaintext polynomials that Encryptor can encrypt. An instance of - the IntegerEncoder class converts an integer into a plaintext polynomial by placing its - binary digits as the coefficients of the polynomial. Decoding the integer amounts to - evaluating the plaintext polynomial at x=2. - - Addition and multiplication on the integer side translate into addition and multiplication - on the encoded plaintext polynomial side, provided that the length of the polynomial - never grows to be of the size of the polynomial modulus (poly_modulus), and that the - coefficients of the plaintext polynomials appearing throughout the computations never - experience coefficients larger than the plaintext modulus (plain_modulus). - - @par Negative Integers - Negative integers are represented by using -1 instead of 1 in the binary representation, - and the negative coefficients are stored in the plaintext polynomials as unsigned integers - that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 - would be stored as a polynomial coefficient plain_modulus-1. - */ - class SEAL_NODISCARD IntegerEncoder - { - public: - /** - Creates a IntegerEncoder object. The constructor takes as input a pointer to - a SEALContext object which contains the plaintext modulus. - - @param[in] context The SEALContext - @throws std::invalid_argument if the context is not set - @throws std::invalid_argument if the plain_modulus set in context is not - at least 2 - */ - IntegerEncoder(std::shared_ptr context); - - /** - Destroys the IntegerEncoder. - */ - ~IntegerEncoder() = default; - - /** - Encodes an unsigned integer (represented by std::uint64_t) into a plaintext polynomial. - - @param[in] value The unsigned integer to encode - */ - SEAL_NODISCARD Plaintext encode(std::uint64_t value); - - /** - Encodes an unsigned integer (represented by std::uint64_t) into a plaintext polynomial. - - @param[in] value The unsigned integer to encode - @param[out] destination The plaintext to overwrite with the encoding - */ - void encode(std::uint64_t value, Plaintext &destination); - - /** - Decodes a plaintext polynomial and returns the result as std::uint32_t. - Mathematically this amounts to evaluating the input polynomial at x=2. - - @param[in] plain The plaintext to be decoded - @throws std::invalid_argument if the output does not fit in std::uint32_t - */ - SEAL_NODISCARD std::uint32_t decode_uint32(const Plaintext &plain); - - /** - Decodes a plaintext polynomial and returns the result as std::uint64_t. - Mathematically this amounts to evaluating the input polynomial at x=2. - - @param[in] plain The plaintext to be decoded - @throws std::invalid_argument if the output does not fit in std::uint64_t - */ - SEAL_NODISCARD std::uint64_t decode_uint64(const Plaintext &plain); - - /** - Encodes a signed integer (represented by std::uint64_t) into a plaintext polynomial. - - @par Negative Integers - Negative integers are represented by using -1 instead of 1 in the binary representation, - and the negative coefficients are stored in the plaintext polynomials as unsigned integers - that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 - would be stored as a polynomial coefficient plain_modulus-1. - - @param[in] value The signed integer to encode - */ - SEAL_NODISCARD Plaintext encode(std::int64_t value); - - /** - Encodes a signed integer (represented by std::int64_t) into a plaintext polynomial. - - @par Negative Integers - Negative integers are represented by using -1 instead of 1 in the binary representation, - and the negative coefficients are stored in the plaintext polynomials as unsigned integers - that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 - would be stored as a polynomial coefficient plain_modulus-1. - - @param[in] value The signed integer to encode - @param[out] destination The plaintext to overwrite with the encoding - */ - void encode(std::int64_t value, Plaintext &destination); - - /** - Encodes an unsigned integer (represented by BigUInt) into a plaintext polynomial. - - @param[in] value The unsigned integer to encode - */ - SEAL_NODISCARD Plaintext encode(const BigUInt &value); - - /** - Encodes an unsigned integer (represented by BigUInt) into a plaintext polynomial. - - @param[in] value The unsigned integer to encode - @param[out] destination The plaintext to overwrite with the encoding - */ - void encode(const BigUInt &value, Plaintext &destination); - - /** - Decodes a plaintext polynomial and returns the result as std::int32_t. - Mathematically this amounts to evaluating the input polynomial at x=2. - - @param[in] plain The plaintext to be decoded - @throws std::invalid_argument if plain does not represent a valid plaintext polynomial - @throws std::invalid_argument if the output does not fit in std::int32_t - */ - SEAL_NODISCARD std::int32_t decode_int32(const Plaintext &plain); - - /** - Decodes a plaintext polynomial and returns the result as std::int64_t. - Mathematically this amounts to evaluating the input polynomial at x=2. - - @param[in] plain The plaintext to be decoded - @throws std::invalid_argument if plain does not represent a valid plaintext polynomial - @throws std::invalid_argument if the output does not fit in std::int64_t - */ - SEAL_NODISCARD std::int64_t decode_int64(const Plaintext &plain); - - /** - Decodes a plaintext polynomial and returns the result as BigUInt. - Mathematically this amounts to evaluating the input polynomial at x=2. - - @param[in] plain The plaintext to be decoded - @throws std::invalid_argument if plain does not represent a valid plaintext polynomial - @throws std::invalid_argument if the output is negative - */ - SEAL_NODISCARD BigUInt decode_biguint(const Plaintext &plain); - - /** - Encodes a signed integer (represented by std::int32_t) into a plaintext polynomial. - - @par Negative Integers - Negative integers are represented by using -1 instead of 1 in the binary representation, - and the negative coefficients are stored in the plaintext polynomials as unsigned integers - that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 - would be stored as a polynomial coefficient plain_modulus-1. - - @param[in] value The signed integer to encode - */ - SEAL_NODISCARD inline Plaintext encode(std::int32_t value) - { - return encode(static_cast(value)); - } - - /** - Encodes an unsigned integer (represented by std::uint32_t) into a plaintext polynomial. - - @param[in] value The unsigned integer to encode - */ - SEAL_NODISCARD inline Plaintext encode(std::uint32_t value) - { - return encode(static_cast(value)); - } - - /** - Encodes a signed integer (represented by std::int32_t) into a plaintext polynomial. - - @par Negative Integers - Negative integers are represented by using -1 instead of 1 in the binary representation, - and the negative coefficients are stored in the plaintext polynomials as unsigned integers - that represent them modulo the plaintext modulus. Thus, for example, a coefficient of -1 - would be stored as a polynomial coefficient plain_modulus-1. - - @param[in] value The signed integer to encode - @param[out] destination The plaintext to overwrite with the encoding - */ - void inline encode(std::int32_t value, Plaintext &destination) - { - encode(static_cast(value), destination); - } - - /** - Encodes an unsigned integer (represented by std::uint32_t) into a plaintext polynomial. - - @param[in] value The unsigned integer to encode - @param[out] destination The plaintext to overwrite with the encoding - */ - void inline encode(std::uint32_t value, Plaintext &destination) - { - encode(static_cast(value), destination); - } - - /** - Returns a reference to the plaintext modulus. - */ - SEAL_NODISCARD inline const SmallModulus &plain_modulus() const - { - auto &context_data = *context_->first_context_data(); - return context_data.parms().plain_modulus(); - } - - private: - IntegerEncoder(const IntegerEncoder ©) = delete; - - IntegerEncoder(IntegerEncoder &&source) = delete; - - IntegerEncoder &operator =(const IntegerEncoder &assign) = delete; - - IntegerEncoder &operator =(IntegerEncoder &&assign) = delete; - - std::shared_ptr context_{ nullptr }; - - std::uint64_t coeff_neg_threshold_; - - std::uint64_t neg_one_; - }; -} diff --git a/SEAL/native/src/seal/keygenerator.cpp b/SEAL/native/src/seal/keygenerator.cpp deleted file mode 100644 index eb9e6c6..0000000 --- a/SEAL/native/src/seal/keygenerator.cpp +++ /dev/null @@ -1,530 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include "seal/keygenerator.h" -#include "seal/randomtostd.h" -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithsmallmod.h" -#include "seal/util/polyarithsmallmod.h" -#include "seal/util/clipnormal.h" -#include "seal/util/polycore.h" -#include "seal/util/smallntt.h" -#include "seal/util/rlwe.h" - -using namespace std; -using namespace seal::util; - -namespace seal -{ - KeyGenerator::KeyGenerator(shared_ptr context) : - context_(move(context)) - { - // Verify parameters - if (!context_) - { - throw invalid_argument("invalid context"); - } - if (!context_->parameters_set()) - { - throw invalid_argument("encryption parameters are not set correctly"); - } - - // Secret key and public key have not been generated - sk_generated_ = false; - pk_generated_ = false; - - // Generate the secret and public key - generate_sk(); - generate_pk(); - } - - KeyGenerator::KeyGenerator(shared_ptr context, - const SecretKey &secret_key) : context_(move(context)) - { - // Verify parameters - if (!context_) - { - throw invalid_argument("invalid context"); - } - if (!context_->parameters_set()) - { - throw invalid_argument("encryption parameters are not set correctly"); - } - if (!is_valid_for(secret_key, context_)) - { - throw invalid_argument("secret key is not valid for encryption parameters"); - } - - // Set the secret key - secret_key_ = secret_key; - sk_generated_ = true; - - // Generate the public key - generate_sk(sk_generated_); - generate_pk(); - } - - KeyGenerator::KeyGenerator(shared_ptr context, - const SecretKey &secret_key, const PublicKey &public_key) : - context_(move(context)) - { - // Verify parameters - if (!context_) - { - throw invalid_argument("invalid context"); - } - if (!context_->parameters_set()) - { - throw invalid_argument("encryption parameters are not set correctly"); - } - if (!is_valid_for(secret_key, context_)) - { - throw invalid_argument("secret key is not valid for encryption parameters"); - } - if (!is_valid_for(public_key, context_)) - { - throw invalid_argument("public key is not valid for encryption parameters"); - } - - // Set the secret and public keys - secret_key_ = secret_key; - public_key_ = public_key; - - // Secret key and public key are generated - sk_generated_ = true; - pk_generated_ = true; - - generate_sk(sk_generated_); - } - - void KeyGenerator::generate_sk(bool is_initialized) - { - // Extract encryption parameters. - auto &context_data = *context_->key_context_data(); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - - if (!is_initialized) - { - // Initialize secret key. - secret_key_ = SecretKey(); - sk_generated_ = false; - secret_key_.data().resize(mul_safe(coeff_count, coeff_mod_count)); - - shared_ptr random(parms.random_generator()->create()); - - // Generate secret key - uint64_t *secret_key = secret_key_.data().data(); - sample_poly_ternary(secret_key, random, parms); - - auto &small_ntt_tables = context_data.small_ntt_tables(); - for (size_t i = 0; i < coeff_mod_count; i++) - { - // Transform the secret s into NTT representation. - ntt_negacyclic_harvey(secret_key + (i * coeff_count), small_ntt_tables[i]); - } - - // Set the parms_id for secret key - secret_key_.parms_id() = context_data.parms_id(); - } - - // Set the secret_key_array to have size 1 (first power of secret) - secret_key_array_ = allocate_poly(coeff_count, coeff_mod_count, pool_); - set_poly_poly(secret_key_.data().data(), coeff_count, coeff_mod_count, - secret_key_array_.get()); - secret_key_array_size_ = 1; - - // Secret key has been generated - sk_generated_ = true; - } - - void KeyGenerator::generate_pk() - { - if (!sk_generated_) - { - throw logic_error("cannot generate public key for unspecified secret key"); - } - - // Extract encryption parameters. - auto &context_data = *context_->key_context_data(); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - - // Size check - if (!product_fits_in(coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - // Initialize public key. - // PublicKey data allocated from pool given by MemoryManager::GetPool. - public_key_ = PublicKey(); - pk_generated_ = false; - - shared_ptr random( - parms.random_generator()->create()); - encrypt_zero_symmetric(secret_key_, context_, context_data.parms_id(), - random, true, public_key_.data(), pool_); - - // Set the parms_id for public key - public_key_.parms_id() = context_data.parms_id(); - - // Public key has been generated - pk_generated_ = true; - } - - RelinKeys KeyGenerator::relin_keys(size_t count) - { - // Check to see if secret key and public key have been generated - if (!sk_generated_) - { - throw logic_error("cannot generate relinearization keys for unspecified secret key"); - } - if (!count || count > SEAL_CIPHERTEXT_SIZE_MAX - 2) - { - throw invalid_argument("invalid count"); - } - - // Extract encryption parameters. - auto &context_data = *context_->key_context_data(); - auto &parms = context_data.parms(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = parms.coeff_modulus().size(); - - // Size check - if (!product_fits_in(coeff_count, coeff_mod_count)) - { - throw logic_error("invalid parameters"); - } - - shared_ptr random(parms.random_generator()->create()); - - // Make sure we have enough secret keys computed - compute_secret_key_array(context_data, count + 1); - - // Create the RelinKeys object to return - RelinKeys relin_keys; - - // Assume the secret key is already transformed into NTT form. - generate_kswitch_keys( - secret_key_array_.get() + coeff_mod_count * coeff_count, - count, - static_cast(relin_keys)); - - // Set the parms_id - relin_keys.parms_id() = context_data.parms_id(); - - return relin_keys; - } - - GaloisKeys KeyGenerator::galois_keys(const vector &galois_elts) - { - // Check to see if secret key and public key have been generated - if (!sk_generated_) - { - throw logic_error("cannot generate Galois keys for unspecified secret key"); - } - - // Extract encryption parameters. - auto &context_data = *context_->key_context_data(); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - int coeff_count_power = get_power_of_two(coeff_count); - - // Size check - if (!product_fits_in(coeff_count, coeff_mod_count, size_t(2))) - { - throw logic_error("invalid parameters"); - } - - // Create the GaloisKeys object to return - GaloisKeys galois_keys; - - // The max number of keys is equal to number of coefficients - galois_keys.data().resize(coeff_count); - - for (uint64_t galois_elt : galois_elts) - { - // Verify coprime conditions. - if (!(galois_elt & 1) || (galois_elt >= 2 * coeff_count)) - { - throw invalid_argument("Galois element is not valid"); - } - - // Do we already have the key? - if (galois_keys.has_key(galois_elt)) - { - continue; - } - - // Rotate secret key for each coeff_modulus - auto rotated_secret_key( - allocate_poly(coeff_count, coeff_mod_count, pool_)); - for (size_t i = 0; i < coeff_mod_count; i++) - { - apply_galois_ntt( - secret_key_.data().data() + i * coeff_count, - coeff_count_power, - galois_elt, - rotated_secret_key.get() + i * coeff_count); - } - - // Initialize Galois key - // This is the location in the galois_keys vector - uint64_t index = GaloisKeys::get_index(galois_elt); - shared_ptr random(parms.random_generator()->create()); - - // Create Galois keys. - generate_one_kswitch_key( - rotated_secret_key.get(), - galois_keys.data()[index]); - } - - // Set the parms_id - galois_keys.parms_id_ = context_data.parms_id(); - - return galois_keys; - } - - GaloisKeys KeyGenerator::galois_keys(const vector &steps) - { - // Check to see if secret key and public key have been generated - if (!sk_generated_) - { - throw logic_error("cannot generate Galois keys for unspecified secret key"); - } - - // Extract encryption parameters. - auto &context_data = *context_->key_context_data(); - if (!context_data.qualifiers().using_batching) - { - throw logic_error("encryption parameters do not support batching"); - } - - auto &parms = context_data.parms(); - size_t coeff_count = parms.poly_modulus_degree(); - - vector galois_elts; - transform(steps.begin(), steps.end(), back_inserter(galois_elts), - [&](auto s) { return steps_to_galois_elt(s, coeff_count); }); - - return galois_keys(galois_elts); - } - - GaloisKeys KeyGenerator::galois_keys() - { - // Check to see if secret key and public key have been generated - if (!sk_generated_) - { - throw logic_error("cannot generate Galois keys for unspecified secret key"); - } - - size_t coeff_count = context_->key_context_data()->parms().poly_modulus_degree(); - uint64_t m = coeff_count << 1; - int logn = get_power_of_two(static_cast(coeff_count)); - - vector logn_galois_keys{}; - - // Generate Galois keys for m - 1 (X -> X^{m-1}) - logn_galois_keys.push_back(m - 1); - - // Generate Galois key for power of 3 mod m (X -> X^{3^k}) and - // for negative power of 3 mod m (X -> X^{-3^k}) - uint64_t two_power_of_three = 3; - uint64_t neg_two_power_of_three = 0; - try_mod_inverse(3, m, neg_two_power_of_three); - for (int i = 0; i < logn - 1; i++) - { - logn_galois_keys.push_back(two_power_of_three); - two_power_of_three *= two_power_of_three; - two_power_of_three &= (m - 1); - - logn_galois_keys.push_back(neg_two_power_of_three); - neg_two_power_of_three *= neg_two_power_of_three; - neg_two_power_of_three &= (m - 1); - } - - return galois_keys(logn_galois_keys); - } - - const SecretKey &KeyGenerator::secret_key() const - { - if (!sk_generated_) - { - throw logic_error("secret key has not been generated"); - } - return secret_key_; - } - - const PublicKey &KeyGenerator::public_key() const - { - if (!pk_generated_) - { - throw logic_error("public key has not been generated"); - } - return public_key_; - } - - void KeyGenerator::compute_secret_key_array( - const SEALContext::ContextData &context_data, size_t max_power) - { -#ifdef SEAL_DEBUG - if (max_power < 1) - { - throw invalid_argument("max_power must be at least 1"); - } - if (!secret_key_array_size_ || !secret_key_array_) - { - throw logic_error("secret_key_array_ is uninitialized"); - } -#endif - // Extract encryption parameters. - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_count = parms.poly_modulus_degree(); - size_t coeff_mod_count = coeff_modulus.size(); - - // Size check - if (!product_fits_in(coeff_count, coeff_mod_count, max_power)) - { - throw logic_error("invalid parameters"); - } - - ReaderLock reader_lock(secret_key_array_locker_.acquire_read()); - - size_t old_size = secret_key_array_size_; - size_t new_size = max(max_power, old_size); - - if (old_size == new_size) - { - return; - } - - reader_lock.unlock(); - - // Need to extend the array - // Compute powers of secret key until max_power - auto new_secret_key_array(allocate_poly( - new_size * coeff_count, coeff_mod_count, pool_)); - set_poly_poly(secret_key_array_.get(), old_size * coeff_count, - coeff_mod_count, new_secret_key_array.get()); - - size_t poly_ptr_increment = coeff_count * coeff_mod_count; - uint64_t *prev_poly_ptr = new_secret_key_array.get() + - (old_size - 1) * poly_ptr_increment; - uint64_t *next_poly_ptr = prev_poly_ptr + poly_ptr_increment; - - // Since all of the key powers in secret_key_array_ are already - // NTT transformed, to get the next one we simply need to compute - // a dyadic product of the last one with the first one - // [which is equal to NTT(secret_key_)]. - for (size_t i = old_size; i < new_size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - dyadic_product_coeffmod( - prev_poly_ptr + (j * coeff_count), - new_secret_key_array.get() + (j * coeff_count), - coeff_count, coeff_modulus[j], - next_poly_ptr + (j * coeff_count)); - } - prev_poly_ptr = next_poly_ptr; - next_poly_ptr += poly_ptr_increment; - } - - // Take writer lock to update array - WriterLock writer_lock(secret_key_array_locker_.acquire_write()); - - // Do we still need to update size? - old_size = secret_key_array_size_; - new_size = max(max_power, secret_key_array_size_); - - if (old_size == new_size) - { - return; - } - - // Acquire new array - secret_key_array_size_ = new_size; - secret_key_array_.acquire(new_secret_key_array); - } - - void KeyGenerator::generate_one_kswitch_key( - const uint64_t *new_key, - std::vector &destination) - { - size_t coeff_count = context_->key_context_data()->parms().poly_modulus_degree(); - size_t decomp_mod_count = context_->first_context_data()->parms().coeff_modulus().size(); - auto &key_context_data = *context_->key_context_data(); - auto &key_parms = key_context_data.parms(); - auto &key_modulus = key_parms.coeff_modulus(); - shared_ptr random(key_parms.random_generator()->create()); - - // Size check - if (!product_fits_in(coeff_count, decomp_mod_count)) - { - throw logic_error("invalid parameters"); - } - - // KSwitchKeys data allocated from pool given by MemoryManager::GetPool. - destination.resize(decomp_mod_count); - - auto temp(allocate_uint(coeff_count, pool_)); - uint64_t factor = 0; - for (size_t j = 0; j < decomp_mod_count; j++) - { - encrypt_zero_symmetric(secret_key_, context_, - key_context_data.parms_id(), random, true, - destination[j].data(), pool_); - - factor = key_modulus.back().value() % key_modulus[j].value(); - multiply_poly_scalar_coeffmod( - new_key + j * coeff_count, - coeff_count, - factor, - key_modulus[j], - temp.get()); - add_poly_poly_coeffmod( - destination[j].data().data() + j * coeff_count, - temp.get(), - coeff_count, - key_modulus[j], - destination[j].data().data() + j * coeff_count); - } - } - - void KeyGenerator::generate_kswitch_keys( - const uint64_t *new_keys, - size_t num_keys, - KSwitchKeys &destination) - { - size_t coeff_count = context_->key_context_data()->parms().poly_modulus_degree(); - auto &key_context_data = *context_->key_context_data(); - auto &key_parms = key_context_data.parms(); - size_t coeff_mod_count = key_parms.coeff_modulus().size(); - shared_ptr random(key_parms.random_generator()->create()); - - // Size check - if (!product_fits_in(coeff_count, coeff_mod_count, num_keys)) - { - throw logic_error("invalid parameters"); - } - - destination.data().resize(num_keys); - auto temp(allocate_uint(coeff_count, pool_)); - for (size_t l = 0; l < num_keys; l++) - { - const uint64_t *new_key_ptr = new_keys + l * coeff_mod_count * coeff_count; - generate_one_kswitch_key(new_key_ptr, destination.data()[l]); - } - } -} \ No newline at end of file diff --git a/SEAL/native/src/seal/keygenerator.h b/SEAL/native/src/seal/keygenerator.h deleted file mode 100644 index dc93808..0000000 --- a/SEAL/native/src/seal/keygenerator.h +++ /dev/null @@ -1,208 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include "seal/util/defines.h" -#include "seal/context.h" -#include "seal/util/smallntt.h" -#include "seal/memorymanager.h" -#include "seal/publickey.h" -#include "seal/secretkey.h" -#include "seal/relinkeys.h" -#include "seal/galoiskeys.h" -#include "seal/randomgen.h" - -namespace seal -{ - /** - Generates matching secret key and public key. An existing KeyGenerator can - also at any time be used to generate relinearization keys and Galois keys. - Constructing a KeyGenerator requires only a SEALContext. - - @see EncryptionParameters for more details on encryption parameters. - @see SecretKey for more details on secret key. - @see PublicKey for more details on public key. - @see RelinKeys for more details on relinearization keys. - @see GaloisKeys for more details on Galois keys. - */ - class SEAL_NODISCARD KeyGenerator - { - public: - /** - Creates a KeyGenerator initialized with the specified SEALContext. - - @param[in] context The SEALContext - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - */ - KeyGenerator(std::shared_ptr context); - - /** - Creates an KeyGenerator instance initialized with the specified SEALContext - and specified previously secret key. This can e.g. be used to increase - the number of relinearization keys from what had earlier been generated, - or to generate Galois keys in case they had not been generated earlier. - - - @param[in] context The SEALContext - @param[in] secret_key A previously generated secret key - @throws std::invalid_argument if encryption parameters are not valid - @throws std::invalid_argument if secret_key or public_key is not valid - for encryption parameters - */ - KeyGenerator(std::shared_ptr context, - const SecretKey &secret_key); - - /** - Creates an KeyGenerator instance initialized with the specified SEALContext - and specified previously secret and public keys. This can e.g. be used - to increase the number of relinearization keys from what had earlier been - generated, or to generate Galois keys in case they had not been generated - earlier. - - @param[in] context The SEALContext - @param[in] secret_key A previously generated secret key - @param[in] public_key A previously generated public key - @throws std::invalid_argument if encryption parameters are not valid - @throws std::invalid_argument if secret_key or public_key is not valid - for encryption parameters - */ - KeyGenerator(std::shared_ptr context, - const SecretKey &secret_key, const PublicKey &public_key); - - /** - Returns a const reference to the secret key. - */ - SEAL_NODISCARD const SecretKey &secret_key() const; - - /** - Returns a const reference to the public key. - */ - SEAL_NODISCARD const PublicKey &public_key() const; - - /** - Generates and returns relinearization keys. - */ - SEAL_NODISCARD inline RelinKeys relin_keys() - { - return relin_keys(1); - } - - /** - Generates and returns Galois keys. This function creates specific Galois - keys that can be used to apply specific Galois automorphisms on encrypted - data. The user needs to give as input a vector of Galois elements - corresponding to the keys that are to be created. - - The Galois elements are odd integers in the interval [1, M-1], where - M = 2*N, and N = poly_modulus_degree. Used with batching, a Galois element - 3^i % M corresponds to a cyclic row rotation i steps to the left, and - a Galois element 3^(N/2-i) % M corresponds to a cyclic row rotation i - steps to the right. The Galois element M-1 corresponds to a column rotation - (row swap) in BFV, and complex conjugation in CKKS. In the polynomial view - (not batching), a Galois automorphism by a Galois element p changes Enc(plain(x)) - to Enc(plain(x^p)). - - @param[in] galois_elts The Galois elements for which to generate keys - @throws std::invalid_argument if the Galois elements are not valid - */ - SEAL_NODISCARD GaloisKeys galois_keys( - const std::vector &galois_elts); - - /** - Generates and returns Galois keys. This function creates specific Galois - keys that can be used to apply specific Galois automorphisms on encrypted - data. The user needs to give as input a vector of desired Galois rotation - step counts, where negative step counts correspond to rotations to the - right and positive step counts correspond to rotations to the left. - A step count of zero can be used to indicate a column rotation in the BFV - scheme complex conjugation in the CKKS scheme. - - @param[in] galois_elts The rotation step counts for which to generate keys - @throws std::logic_error if the encryption parameters do not support batching - and scheme is scheme_type::BFV - @throws std::invalid_argument if the step counts are not valid - */ - SEAL_NODISCARD GaloisKeys galois_keys(const std::vector &steps); - - /** - Generates and returns Galois keys. This function creates logarithmically - many (in degree of the polynomial modulus) Galois keys that is sufficient - to apply any Galois automorphism (e.g. rotations) on encrypted data. Most - users will want to use this overload of the function. - */ - SEAL_NODISCARD GaloisKeys galois_keys(); - - private: - KeyGenerator(const KeyGenerator ©) = delete; - - KeyGenerator &operator =(const KeyGenerator &assign) = delete; - - KeyGenerator(KeyGenerator &&source) = delete; - - KeyGenerator &operator =(KeyGenerator &&assign) = delete; - - void compute_secret_key_array( - const SEALContext::ContextData &context_data, - std::size_t max_power); - - /** - Generates new secret key. - - @param[in] is_initialized True if the secret_key_ has already been initialized so that only the - secret_key_array_ should be initialized (it may be the case, for instance, if the secret_key_ - was provided in the constructor - */ - void generate_sk(bool is_initialized = false); - - /** - Generates new public key matching to existing secret key. - */ - void generate_pk(); - - /** - Generates new key switching keys for an array of new keys. - */ - void generate_kswitch_keys( - const std::uint64_t *new_keys, - std::size_t num_keys, - KSwitchKeys &destination); - - /** - Generates one key switching key for a new key. - */ - void generate_one_kswitch_key( - const uint64_t *new_key, - std::vector &destination); - - /** - Generates and returns the specified number of relinearization keys. - - @param[in] count The number of relinearization keys to generate - @throws std::invalid_argument if count is zero or too large - */ - RelinKeys relin_keys(std::size_t count); - - // We use a fresh memory pool with `clear_on_destruction' enabled. - MemoryPoolHandle pool_ = MemoryManager::GetPool(mm_prof_opt::FORCE_NEW, true); - - std::shared_ptr context_{ nullptr }; - - PublicKey public_key_; - - SecretKey secret_key_; - - std::size_t secret_key_array_size_ = 0; - - util::Pointer secret_key_array_; - - mutable util::ReaderWriterLocker secret_key_array_locker_; - - bool sk_generated_ = false; - - bool pk_generated_ = false; - }; -} diff --git a/SEAL/native/src/seal/kswitchkeys.cpp b/SEAL/native/src/seal/kswitchkeys.cpp deleted file mode 100644 index 545575e..0000000 --- a/SEAL/native/src/seal/kswitchkeys.cpp +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/kswitchkeys.h" -#include "seal/util/defines.h" -#include - -using namespace std; -using namespace seal::util; - -namespace seal -{ - KSwitchKeys &KSwitchKeys::operator =(const KSwitchKeys &assign) - { - // Check for self-assignment - if (this == &assign) - { - return *this; - } - - // Copy over fields - parms_id_ = assign.parms_id_; - - // Then copy over keys - keys_.clear(); - size_t keys_dim1 = assign.keys_.size(); - keys_.reserve(keys_dim1); - for (size_t i = 0; i < keys_dim1; i++) - { - size_t keys_dim2 = assign.keys_[i].size(); - keys_.emplace_back(); - keys_[i].reserve(keys_dim2); - for (size_t j = 0; j < keys_dim2; j++) - { - keys_[i].emplace_back(PublicKey(pool_)); - keys_[i][j] = assign.keys_[i][j]; - } - } - - return *this; - } - - void KSwitchKeys::save(ostream &stream) const - { - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on ios_base::badbit and ios_base::failbit - stream.exceptions(ios_base::badbit | ios_base::failbit); - - uint64_t keys_dim1 = static_cast(keys_.size()); - - // Save the parms_id - stream.write(reinterpret_cast(&parms_id_), - sizeof(parms_id_type)); - - // Save the size of keys_ - stream.write(reinterpret_cast(&keys_dim1), sizeof(uint64_t)); - - // Now loop again over keys_dim1 - for (size_t index = 0; index < keys_dim1; index++) - { - // Save second dimension of keys_ - uint64_t keys_dim2 = static_cast(keys_[index].size()); - stream.write(reinterpret_cast(&keys_dim2), sizeof(uint64_t)); - - // Loop over keys_dim2 and save all (or none) - for (size_t j = 0; j < keys_dim2; j++) - { - // Save the key - keys_[index][j].save(stream); - } - } - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - - stream.exceptions(old_except_mask); - } - - void KSwitchKeys::unsafe_load(istream &stream) - { - // Create new keys - vector> new_keys; - - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on ios_base::badbit and ios_base::failbit - stream.exceptions(ios_base::badbit | ios_base::failbit); - - // Read the parms_id - stream.read(reinterpret_cast(&parms_id_), - sizeof(parms_id_type)); - - // Read in the size of keys_ - uint64_t keys_dim1 = 0; - stream.read(reinterpret_cast(&keys_dim1), sizeof(uint64_t)); - - // Reserve first for dimension of keys_ - new_keys.reserve(safe_cast(keys_dim1)); - - // Loop over the first dimension of keys_ - for (size_t index = 0; index < keys_dim1; index++) - { - // Read the size of the second dimension - uint64_t keys_dim2 = 0; - stream.read(reinterpret_cast(&keys_dim2), sizeof(uint64_t)); - - // Don't resize; only reserve - new_keys.emplace_back(); - new_keys.back().reserve(safe_cast(keys_dim2)); - for (size_t j = 0; j < keys_dim2; j++) - { - PublicKey key(pool_); - key.unsafe_load(stream); - new_keys[index].emplace_back(move(key)); - } - } - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - stream.exceptions(old_except_mask); - - swap(keys_, new_keys); - } -} diff --git a/SEAL/native/src/seal/kswitchkeys.h b/SEAL/native/src/seal/kswitchkeys.h deleted file mode 100644 index 4a3be02..0000000 --- a/SEAL/native/src/seal/kswitchkeys.h +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include "seal/publickey.h" -#include "seal/memorymanager.h" -#include "seal/encryptionparams.h" -#include "seal/valcheck.h" - -namespace seal -{ - /** - Class to store keyswitching keys. It should never be necessary for normal - users to create an instance of KSwitchKeys. This class is used strictly as - a base class for RelinKeys and GaloisKeys classes. - - @par Keyswitching - Concretely, keyswitching is used to change a ciphertext encrypted with one - key to be encrypted with another key. It is a general technique and is used - in relinearization and Galois rotations. A keyswitching key contains a sequence - (vector) of keys. In RelinKeys, each key is an encryption of a power of the - secret key. In GaloisKeys, each key corresponds to a type of rotation. - - @par Thread Safety - In general, reading from KSwitchKeys is thread-safe as long as no - other thread is concurrently mutating it. This is due to the underlying - data structure storing the keyswitching keys not being thread-safe. - - @see RelinKeys for the class that stores the relinearization keys. - @see GaloisKeys for the class that stores the Galois keys. - */ - class KSwitchKeys - { - friend class KeyGenerator; - friend class RelinKeys; - friend class GaloisKeys; - - public: - /** - Creates an empty KSwitchKeys. - */ - KSwitchKeys() = default; - - /** - Creates a new KSwitchKeys instance by copying a given instance. - - @param[in] copy The KSwitchKeys to copy from - */ - KSwitchKeys(const KSwitchKeys ©) = default; - - /** - Creates a new KSwitchKeys instance by moving a given instance. - - @param[in] source The RelinKeys to move from - */ - KSwitchKeys(KSwitchKeys &&source) = default; - - /** - Copies a given KSwitchKeys instance to the current one. - - @param[in] assign The KSwitchKeys to copy from - */ - KSwitchKeys &operator =(const KSwitchKeys &assign); - - /** - Moves a given KSwitchKeys instance to the current one. - - @param[in] assign The KSwitchKeys to move from - */ - KSwitchKeys &operator =(KSwitchKeys &&assign) = default; - - /** - Returns the current number of keyswitching keys. Only keys that are - non-empty are counted. - */ - SEAL_NODISCARD inline std::size_t size() const noexcept - { - return std::accumulate(keys_.cbegin(), keys_.cend(), std::size_t(0), - [](std::size_t res, auto &next_key) - { - return res + (next_key.empty() ? 0 : 1); - }); - } - - /** - Returns a reference to the KSwitchKeys data. - */ - SEAL_NODISCARD inline auto &data() noexcept - { - return keys_; - } - - /** - Returns a const reference to the KSwitchKeys data. - */ - SEAL_NODISCARD inline auto &data() const noexcept - { - return keys_; - } - - /** - Returns a reference to a keyswitching key at a given index. - - @param[in] index The index of the keyswitching key - @throws std::invalid_argument if the key at the given index does not exist - */ - SEAL_NODISCARD inline auto &data(std::size_t index) - { - if (index >= keys_.size() || keys_[index].empty()) - { - throw std::invalid_argument("keyswitching key does not exist"); - } - return keys_[index]; - } - - /** - Returns a const reference to a keyswitching key at a given index. - - @param[in] index The index of the keyswitching key - @throws std::invalid_argument if the key at the given index does not exist - */ - SEAL_NODISCARD inline const auto &data(std::size_t index) const - { - if (index >= keys_.size() || keys_[index].empty()) - { - throw std::invalid_argument("keyswitching key does not exist"); - } - return keys_[index]; - } - - /** - Returns a reference to parms_id. - - @see EncryptionParameters for more information about parms_id. - */ - SEAL_NODISCARD inline auto &parms_id() noexcept - { - return parms_id_; - } - - /** - Returns a const reference to parms_id. - - @see EncryptionParameters for more information about parms_id. - */ - SEAL_NODISCARD inline auto &parms_id() const noexcept - { - return parms_id_; - } - - /** - Saves the KSwitchKeys instance to an output stream. The output is - in binary format and not human-readable. The output stream must have - the "binary" flag set. - - @param[in] stream The stream to save the KSwitchKeys to - @throws std::exception if the KSwitchKeys could not be written to stream - */ - void save(std::ostream &stream) const; - - void python_save(std::string &path) const - { - try - { - std::ofstream out(path, std::ofstream::binary); - this->save(out); - out.close(); - } - catch (const std::exception &) - { - throw "KSwitchKeys write exception"; - } - } - - /** - Loads a KSwitchKeys from an input stream overwriting the current KSwitchKeys. - No checking of the validity of the KSwitchKeys data against encryption - parameters is performed. This function should not be used unless the - KSwitchKeys comes from a fully trusted source. - - @param[in] stream The stream to load the KSwitchKeys from - @throws std::exception if a valid KSwitchKeys could not be read from stream - */ - void unsafe_load(std::istream &stream); - - void python_load(std::shared_ptr context, - std::string &path) - { - try - { - std::ifstream in(path, std::ifstream::binary); - this->load(context, in); - in.close(); - } - catch (const std::exception &) - { - throw "KSwitchKeys read exception"; - } - } - - /** - Loads a KSwitchKeys from an input stream overwriting the current KSwitchKeys. - The loaded KSwitchKeys is verified to be valid for the given SEALContext. - - @param[in] context The SEALContext - @param[in] stream The stream to load the KSwitchKeys from - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::exception if a valid KSwitchKeys could not be read from stream - @throws std::invalid_argument if the loaded KSwitchKeys is invalid for the - context - */ - inline void load(std::shared_ptr context, - std::istream &stream) - { - KSwitchKeys new_keys; - new_keys.pool_ = pool_; - new_keys.unsafe_load(stream); - if (!is_valid_for(new_keys, std::move(context))) - { - throw std::invalid_argument("KSwitchKeys data is invalid"); - } - std::swap(*this, new_keys); - } - - /** - Returns the currently used MemoryPoolHandle. - */ - SEAL_NODISCARD inline MemoryPoolHandle pool() const noexcept - { - return pool_; - } - - private: - MemoryPoolHandle pool_ = MemoryManager::GetPool(); - - parms_id_type parms_id_ = parms_id_zero; - - /** - The vector of keyswitching keys. - */ - std::vector> keys_{}; - }; -} diff --git a/SEAL/native/src/seal/memorymanager.cpp b/SEAL/native/src/seal/memorymanager.cpp deleted file mode 100644 index e854155..0000000 --- a/SEAL/native/src/seal/memorymanager.cpp +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/memorymanager.h" - -namespace seal -{ - std::unique_ptr - MemoryManager::mm_prof_{ new MMProfGlobal }; -#ifndef _M_CEE - std::mutex MemoryManager::switch_mutex_; -#else -#pragma message("WARNING: MemoryManager compiled thread-unsafe and MMProfGuard disabled to support /clr") -#endif -} \ No newline at end of file diff --git a/SEAL/native/src/seal/memorymanager.h b/SEAL/native/src/seal/memorymanager.h deleted file mode 100644 index b5742ef..0000000 --- a/SEAL/native/src/seal/memorymanager.h +++ /dev/null @@ -1,821 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include "seal/util/defines.h" -#include "seal/util/mempool.h" -#include "seal/util/globals.h" - -/* -For .NET Framework wrapper support (C++/CLI) we need to - (1) compile the MemoryManager class as thread-unsafe because C++ - mutexes cannot be brought through C++/CLI layer; - (2) disable thread-safe memory pools. -*/ -#ifndef _M_CEE -#include -#include -#endif - -namespace seal -{ - /** - Manages a shared pointer to a memory pool. Microsoft SEAL uses memory pools - for improved performance due to the large number of memory allocations - needed by the homomorphic encryption operations, and the underlying polynomial - arithmetic. The library automatically creates a shared global memory pool - that is used for all dynamic allocations by default, and the user can - optionally create any number of custom memory pools to be used instead. - - @par Uses in Multi-Threaded Applications - Sometimes the user might want to use specific memory pools for dynamic - allocations in certain functions. For example, in heavily multi-threaded - applications allocating concurrently from a shared memory pool might lead - to significant performance issues due to thread contention. For these cases - Microsoft SEAL provides overloads of the functions that take a MemoryPoolHandle - as an additional argument, and uses the associated memory pool for all dynamic - allocations inside the function. Whenever these functions are called, the - user can then simply pass a thread-local MemoryPoolHandle to be used. - - @par Thread-Unsafe Memory Pools - While memory pools are by default thread-safe, in some cases it suffices - to have a memory pool be thread-unsafe. To get a little extra performance, - the user can optionally create such thread-unsafe memory pools and use them - just as they would use thread-safe memory pools. - - @par Initialized and Uninitialized Handles - A MemoryPoolHandle has to be set to point either to the global memory pool, - or to a new memory pool. If this is not done, the MemoryPoolHandle is - said to be uninitialized, and cannot be used. Initialization simple means - assigning MemoryPoolHandle::Global() or MemoryPoolHandle::New() to it. - - @par Managing Lifetime - Internally, the MemoryPoolHandle wraps an std::shared_ptr pointing to - a memory pool class. Thus, as long as a MemoryPoolHandle pointing to - a particular memory pool exists, the pool stays alive. Classes such as - Evaluator and Ciphertext store their own local copies of a MemoryPoolHandle - to guarantee that the pool stays alive as long as the managing object - itself stays alive. The global memory pool is implemented as a global - std::shared_ptr to a memory pool class, and is thus expected to stay - alive for the entire duration of the program execution. Note that it can - be problematic to create other global objects that use the memory pool - e.g. in their constructor, as one would have to ensure the initialization - order of these global variables to be correct (i.e. global memory pool - first). - */ - class MemoryPoolHandle - { - public: - /** - Creates a new uninitialized MemoryPoolHandle. - */ - MemoryPoolHandle() = default; - - /** - Creates a MemoryPoolHandle pointing to a given MemoryPool object. - */ - MemoryPoolHandle(std::shared_ptr pool) noexcept : - pool_(std::move(pool)) - { - } - - /** - Creates a copy of a given MemoryPoolHandle. As a result, the created - MemoryPoolHandle will point to the same underlying memory pool as the - copied instance. - - - @param[in] copy The MemoryPoolHandle to copy from - */ - MemoryPoolHandle(const MemoryPoolHandle ©) noexcept - { - operator =(copy); - } - - /** - Creates a new MemoryPoolHandle by moving a given one. As a result, the - moved MemoryPoolHandle will become uninitialized. - - - @param[in] source The MemoryPoolHandle to move from - */ - MemoryPoolHandle(MemoryPoolHandle &&source) noexcept - { - operator =(std::move(source)); - } - - /** - Overwrites the MemoryPoolHandle instance with the specified instance. As - a result, the current MemoryPoolHandle will point to the same underlying - memory pool as the assigned instance. - - @param[in] assign The MemoryPoolHandle instance to assign to the current - instance - */ - inline MemoryPoolHandle &operator =(const MemoryPoolHandle &assign) noexcept - { - pool_ = assign.pool_; - return *this; - } - - /** - Moves a specified MemoryPoolHandle instance to the current instance. As - a result, the assigned MemoryPoolHandle will become uninitialized. - - @param[in] assign The MemoryPoolHandle instance to assign to the current - instance - */ - inline MemoryPoolHandle &operator =(MemoryPoolHandle &&assign) noexcept - { - pool_ = std::move(assign.pool_); - return *this; - } - - /** - Returns a MemoryPoolHandle pointing to the global memory pool. - */ - SEAL_NODISCARD inline static MemoryPoolHandle Global() noexcept - { - return util::global_variables::global_memory_pool; - } -#ifndef _M_CEE - /** - Returns a MemoryPoolHandle pointing to the thread-local memory pool. - */ - SEAL_NODISCARD inline static MemoryPoolHandle ThreadLocal() noexcept - { - return util::global_variables::tls_memory_pool; - } -#endif - /** - Returns a MemoryPoolHandle pointing to a new thread-safe memory pool. - - @param[in] clear_on_destruction Indicates whether the memory pool data - should be cleared when destroyed. This can be important when memory pools - are used to store private data. - */ - SEAL_NODISCARD inline static MemoryPoolHandle New( - bool clear_on_destruction = false) - { - return MemoryPoolHandle( - std::make_shared(clear_on_destruction)); - } - - /** - Returns a reference to the internal memory pool that the MemoryPoolHandle - points to. This function is mainly for internal use. - - @throws std::logic_error if the MemoryPoolHandle is uninitialized - */ - SEAL_NODISCARD inline operator util::MemoryPool &() const - { - if (!pool_) - { - throw std::logic_error("pool not initialized"); - } - return *pool_.get(); - } - - /** - Returns the number of different allocation sizes. This function returns - the number of different allocation sizes the memory pool pointed to by - the current MemoryPoolHandle has made. For example, if the memory pool has - only allocated two allocations of sizes 128 KB, this function returns 1. - If it has instead allocated one allocation of size 64 KB and one of 128 KB, - this function returns 2. - */ - SEAL_NODISCARD inline std::size_t pool_count() const noexcept - { - return !pool_ ? std::size_t(0) : pool_->pool_count(); - } - - /** - Returns the size of allocated memory. This functions returns the total - amount of memory (in bytes) allocated by the memory pool pointed to by - the current MemoryPoolHandle. - */ - SEAL_NODISCARD inline std::size_t alloc_byte_count() const noexcept - { - return !pool_ ? std::size_t(0) : pool_->alloc_byte_count(); - } - - /** - Returns the number of MemoryPoolHandle objects sharing this memory pool. - */ - SEAL_NODISCARD inline long use_count() const noexcept - { - return !pool_ ? 0 : pool_.use_count(); - } - - /** - Returns whether the MemoryPoolHandle is initialized. - */ - SEAL_NODISCARD inline operator bool () const noexcept - { - return pool_.operator bool(); - } - - /** - Compares MemoryPoolHandles. This function returns whether the current - MemoryPoolHandle points to the same memory pool as a given MemoryPoolHandle. - */ - inline bool operator ==(const MemoryPoolHandle &compare) noexcept - { - return pool_ == compare.pool_; - } - - /** - Compares MemoryPoolHandles. This function returns whether the current - MemoryPoolHandle points to a different memory pool than a given - MemoryPoolHandle. - */ - inline bool operator !=(const MemoryPoolHandle &compare) noexcept - { - return pool_ != compare.pool_; - } - - private: - std::shared_ptr pool_ = nullptr; - }; - - using mm_prof_opt_t = std::uint64_t; - - /** - Control options for MemoryManager::GetPool function. These force the MemoryManager - to override the current MMProf and instead return a MemoryPoolHandle pointing - to a memory pool of the indicated type. - */ - enum mm_prof_opt : mm_prof_opt_t - { - DEFAULT = 0x0, - FORCE_GLOBAL = 0x1, - FORCE_NEW = 0x2, - FORCE_THREAD_LOCAL = 0x4 - }; - - /** - The MMProf is a pure virtual class that every profile for the MemoryManager - should inherit from. The only functionality this class implements is the - get_pool(mm_prof_opt_t) function that returns a MemoryPoolHandle pointing - to a pool selected by internal logic optionally using the input parameter - of type mm_prof_opt_t. The returned MemoryPoolHandle must point to a valid - memory pool. - */ - class MMProf - { - public: - /** - Creates a new MMProf. - */ - MMProf() = default; - - /** - Destroys the MMProf. - */ - virtual ~MMProf() noexcept - { - } - - /** - Returns a MemoryPoolHandle pointing to a pool selected by internal logic - in a derived class and by the mm_prof_opt_t input parameter. - - */ - virtual MemoryPoolHandle get_pool(mm_prof_opt_t) = 0; - - private: - }; - - /** - A memory manager profile that always returns a MemoryPoolHandle pointing to - the global memory pool. Microsoft SEAL uses this memory manager profile by default. - */ - class MMProfGlobal : public MMProf - { - public: - /** - Creates a new MMProfGlobal. - */ - MMProfGlobal() = default; - - /** - Destroys the MMProfGlobal. - */ - virtual ~MMProfGlobal() noexcept override - { - } - - /** - Returns a MemoryPoolHandle pointing to the global memory pool. The - mm_prof_opt_t input parameter has no effect. - */ - SEAL_NODISCARD inline virtual MemoryPoolHandle get_pool( - mm_prof_opt_t) override - { - return MemoryPoolHandle::Global(); - } - - private: - }; - - /** - A memory manager profile that always returns a MemoryPoolHandle pointing to - the new thread-safe memory pool. This profile should not be used except in - special circumstances, as it does not result in any reuse of allocated memory. - */ - class MMProfNew : public MMProf - { - public: - /** - Creates a new MMProfNew. - */ - MMProfNew() = default; - - /** - Destroys the MMProfNew. - */ - virtual ~MMProfNew() noexcept override - { - } - - /** - Returns a MemoryPoolHandle pointing to a new thread-safe memory pool. The - mm_prof_opt_t input parameter has no effect. - */ - SEAL_NODISCARD inline virtual MemoryPoolHandle get_pool( - mm_prof_opt_t) override - { - return MemoryPoolHandle::New(); - } - - private: - }; - - /** - A memory manager profile that always returns a MemoryPoolHandle pointing to - specific memory pool. - */ - class MMProfFixed : public MMProf - { - public: - /** - Creates a new MMProfFixed. The MemoryPoolHandle given as argument is returned - by every call to get_pool(mm_prof_opt_t). - - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if pool is uninitialized - */ - MMProfFixed(MemoryPoolHandle pool) : pool_(std::move(pool)) - { - if (!pool_) - { - throw std::invalid_argument("pool is uninitialized"); - } - } - - /** - Destroys the MMProfFixed. - */ - virtual ~MMProfFixed() noexcept override - { - } - - /** - Returns a MemoryPoolHandle pointing to the stored memory pool. The - mm_prof_opt_t input parameter has no effect. - */ - SEAL_NODISCARD inline virtual MemoryPoolHandle get_pool( - mm_prof_opt_t) override - { - return pool_; - } - - private: - MemoryPoolHandle pool_; - }; -#ifndef _M_CEE - /** - A memory manager profile that always returns a MemoryPoolHandle pointing to - the thread-local memory pool. This profile should be used with care, as any - memory allocated by it will be released once the thread exits. In other words, - the thread-local memory pool cannot be used to share memory across different - threads. On the other hand, this profile can be useful when a very high number - of threads doing simultaneous allocations would cause contention in the - global memory pool. - */ - class MMProfThreadLocal : public MMProf - { - public: - /** - Creates a new MMProfThreadLocal. - */ - MMProfThreadLocal() = default; - - /** - Destroys the MMProfThreadLocal. - */ - virtual ~MMProfThreadLocal() noexcept override - { - } - - /** - Returns a MemoryPoolHandle pointing to the thread-local memory pool. The - mm_prof_opt_t input parameter has no effect. - */ - SEAL_NODISCARD inline virtual MemoryPoolHandle get_pool( - mm_prof_opt_t) override - { - return MemoryPoolHandle::ThreadLocal(); - } - - private: - }; -#endif - /** - The MemoryManager class can be used to create instances of MemoryPoolHandle - based on a given "profile". A profile is implemented by inheriting from the - MMProf class (pure virtual) and encapsulates internal logic for deciding which - memory pool to use. - */ - class MemoryManager - { - friend class MMProfGuard; - - public: - MemoryManager() = delete; - - /** - Sets the current profile to a given one and returns a unique_ptr pointing - to the previously set profile. - - @param[in] mm_prof Pointer to a new memory manager profile - @throws std::invalid_argument if mm_prof is nullptr - */ - static inline std::unique_ptr - SwitchProfile(MMProf* &&mm_prof) noexcept - { -#ifndef _M_CEE - std::lock_guard switching_lock(switch_mutex_); -#endif - return SwitchProfileThreadUnsafe(std::move(mm_prof)); - } - - /** - Sets the current profile to a given one and returns a unique_ptr pointing - to the previously set profile. - - @param[in] mm_prof Pointer to a new memory manager profile - @throws std::invalid_argument if mm_prof is nullptr - */ - static inline std::unique_ptr SwitchProfile( - std::unique_ptr &&mm_prof) noexcept - { -#ifndef _M_CEE - std::lock_guard switch_lock(switch_mutex_); -#endif - return SwitchProfileThreadUnsafe(std::move(mm_prof)); - } - - /** - Returns a MemoryPoolHandle according to the currently set memory manager - profile and prof_opt. The following values for prof_opt have an effect - independent of the current profile: - - mm_prof_opt::FORCE_NEW: return MemoryPoolHandle::New() - mm_prof_opt::FORCE_GLOBAL: return MemoryPoolHandle::Global() - mm_prof_opt::FORCE_THREAD_LOCAL: return MemoryPoolHandle::ThreadLocal() - - Other values for prof_opt are forwarded to the current profile and, depending - on the profile, may or may not have an effect. The value mm_prof_opt::DEFAULT - will always invoke a default behavior for the current profile. - - @param[in] prof_opt A mm_prof_opt_t parameter used to provide additional - instructions to the memory manager profile for internal logic. - */ - template - SEAL_NODISCARD static inline MemoryPoolHandle GetPool( - mm_prof_opt_t prof_opt, Args &&...args) - { - switch (prof_opt) - { - case mm_prof_opt::FORCE_GLOBAL: - return MemoryPoolHandle::Global(); - - case mm_prof_opt::FORCE_NEW: - return MemoryPoolHandle::New(std::forward(args)...); -#ifndef _M_CEE - case mm_prof_opt::FORCE_THREAD_LOCAL: - return MemoryPoolHandle::ThreadLocal(); -#endif - default: -#ifdef SEAL_DEBUG - { - auto pool = mm_prof_->get_pool(prof_opt); - if (!pool) - { - throw std::logic_error("cannot return uninitialized pool"); - } - return pool; - } -#endif - return mm_prof_->get_pool(prof_opt); - } - } - - SEAL_NODISCARD static inline MemoryPoolHandle GetPool() - { - return GetPool(mm_prof_opt::DEFAULT); - } - - private: - SEAL_NODISCARD static inline std::unique_ptr - SwitchProfileThreadUnsafe( - MMProf* &&mm_prof) - { - if (!mm_prof) - { - throw std::invalid_argument("mm_prof cannot be nullptr"); - } - auto ret_mm_prof = std::move(mm_prof_); - mm_prof_.reset(mm_prof); - return ret_mm_prof; - } - - SEAL_NODISCARD static inline std::unique_ptr - SwitchProfileThreadUnsafe( - std::unique_ptr &&mm_prof) - { - if (!mm_prof) - { - throw std::invalid_argument("mm_prof cannot be nullptr"); - } - std::swap(mm_prof_, mm_prof); - return std::move(mm_prof); - } - - static std::unique_ptr mm_prof_; -#ifndef _M_CEE - static std::mutex switch_mutex_; -#endif - }; -#ifndef _M_CEE - /** - Class for a scoped switch of memory manager profile. This class acts as a scoped - "guard" for changing the memory manager profile so that the programmer does - not have to explicitly switch back afterwards and that other threads cannot - change the MMProf. It can also help with exception safety by guaranteeing that - the profile is switched back to the original if a function throws an exception - after changing the profile for local use. - */ - class MMProfGuard - { - public: - /** - Creates a new MMProfGuard. If start_locked is true, this function will - attempt to lock the MemoryManager for profile switch to mm_prof, perform - the switch, and keep the lock until unlocked or destroyed. If start_lock - is false, mm_prof will be stored but the switch will not be performed and - a lock will not be obtained until lock() is explicitly called. - - @param[in] mm_prof Pointer to a new memory manager profile - @param[in] start_locked Bool indicating whether the lock should be - immediately obtained (true by default) - */ - MMProfGuard(std::unique_ptr &&mm_prof, - bool start_locked = true) noexcept : - mm_switch_lock_(MemoryManager::switch_mutex_, std::defer_lock) - { - if (start_locked) - { - lock(std::move(mm_prof)); - } - else - { - old_prof_ = std::move(mm_prof); - } - } - - /** - Creates a new MMProfGuard. If start_locked is true, this function will - attempt to lock the MemoryManager for profile switch to mm_prof, perform - the switch, and keep the lock until unlocked or destroyed. If start_lock - is false, mm_prof will be stored but the switch will not be performed and - a lock will not be obtained until lock() is explicitly called. - - @param[in] mm_prof Pointer to a new memory manager profile - @param[in] start_locked Bool indicating whether the lock should be - immediately obtained (true by default) - */ - MMProfGuard(MMProf* &&mm_prof, - bool start_locked = true) noexcept : - mm_switch_lock_(MemoryManager::switch_mutex_, std::defer_lock) - { - if (start_locked) - { - lock(std::move(mm_prof)); - } - else - { - old_prof_.reset(std::move(mm_prof)); - } - } - - /** - Attempts to lock the MemoryManager for profile switch, perform the switch - to currently stored memory manager profile, store the previously held profile, - and keep the lock until unlocked or destroyed. If the lock cannot be obtained - on the first attempt, the function returns false; otherwise returns true. - - @throws std::runtime_error if the lock is already owned - */ - inline bool try_lock() - { - if (mm_switch_lock_.owns_lock()) - { - throw std::runtime_error("lock is already owned"); - } - if (!mm_switch_lock_.try_lock()) - { - return false; - } - old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( - std::move(old_prof_)); - return true; - } - - /** - Locks the MemoryManager for profile switch, performs the switch to currently - stored memory manager profile, stores the previously held profile, and - keep the lock until unlocked or destroyed. The calling thread will block - until the lock can be obtained. - - @throws std::runtime_error if the lock is already owned - */ -#ifdef _MSC_VER - _Acquires_lock_(mm_switch_lock_) -#endif - inline void lock() - { - if (mm_switch_lock_.owns_lock()) - { - throw std::runtime_error("lock is already owned"); - } - mm_switch_lock_.lock(); - old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( - std::move(old_prof_)); - } - - /** - Attempts to lock the MemoryManager for profile switch, perform the switch - to the given memory manager profile, store the previously held profile, - and keep the lock until unlocked or destroyed. If the lock cannot be - obtained on the first attempt, the function returns false; otherwise - returns true. - - @param[in] mm_prof Pointer to a new memory manager profile - @throws std::runtime_error if the lock is already owned - */ - inline bool try_lock( - std::unique_ptr &&mm_prof) - { - if (mm_switch_lock_.owns_lock()) - { - throw std::runtime_error("lock is already owned"); - } - if (!mm_switch_lock_.try_lock()) - { - return false; - } - old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( - std::move(mm_prof)); - return true; - } - - /** - Locks the MemoryManager for profile switch, performs the switch to the given - memory manager profile, stores the previously held profile, and keep the - lock until unlocked or destroyed. The calling thread will block until the - lock can be obtained. - - @param[in] mm_prof Pointer to a new memory manager profile - @throws std::runtime_error if the lock is already owned - */ -#ifdef _MSC_VER - _Acquires_lock_(mm_switch_lock_) -#endif - inline void lock( - std::unique_ptr &&mm_prof) - { - if (mm_switch_lock_.owns_lock()) - { - throw std::runtime_error("lock is already owned"); - } - mm_switch_lock_.lock(); - old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( - std::move(mm_prof)); - } - - /** - Attempts to lock the MemoryManager for profile switch, perform the switch - to the given memory manager profile, store the previously held profile, - and keep the lock until unlocked or destroyed. If the lock cannot be - obtained on the first attempt, the function returns false; otherwise returns - true. - - @param[in] mm_prof Pointer to a new memory manager profile - @throws std::runtime_error if the lock is already owned - */ - inline bool try_lock(MMProf* &&mm_prof) - { - if (mm_switch_lock_.owns_lock()) - { - throw std::runtime_error("lock is already owned"); - } - if (!mm_switch_lock_.try_lock()) - { - return false; - } - old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( - std::move(mm_prof)); - return true; - } - - /** - Locks the MemoryManager for profile switch, performs the switch to the - given memory manager profile, stores the previously held profile, and keep - the lock until unlocked or destroyed. The calling thread will block until - the lock can be obtained. - - @param[in] mm_prof Pointer to a new memory manager profile - @throws std::runtime_error if the lock is already owned - */ -#ifdef _MSC_VER - _Acquires_lock_(mm_switch_lock_) -#endif - inline void lock(MMProf* &&mm_prof) - { - if (mm_switch_lock_.owns_lock()) - { - throw std::runtime_error("lock is already owned"); - } - mm_switch_lock_.lock(); - old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( - std::move(mm_prof)); - } - - /** - Releases the memory manager profile switch lock for MemoryManager, stores - the current profile, and resets the profile to the one used before locking. - - @throws std::runtime_error if the lock is not owned - */ -#ifdef _MSC_VER - _Releases_lock_(mm_switch_lock_) -#endif - inline void unlock() - { - if (!mm_switch_lock_.owns_lock()) - { - throw std::runtime_error("lock is not owned"); - } - old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( - std::move(old_prof_)); - mm_switch_lock_.unlock(); - } - - /** - Destroys the MMProfGuard. If the memory manager profile switch lock is - owned, releases the lock, and resets the profile to the one used before - locking. - */ - ~MMProfGuard() - { - if (mm_switch_lock_.owns_lock()) - { - old_prof_ = MemoryManager::SwitchProfileThreadUnsafe( - std::move(old_prof_)); - mm_switch_lock_.unlock(); - } - } - - /** - Returns whether the current MMProfGuard owns the memory manager profile - switch lock. - */ - inline bool owns_lock() noexcept - { - return mm_switch_lock_.owns_lock(); - } - - private: - std::unique_ptr old_prof_; - - std::unique_lock mm_switch_lock_; - }; -#endif -} diff --git a/SEAL/native/src/seal/modulus.cpp b/SEAL/native/src/seal/modulus.cpp deleted file mode 100644 index 5dce6a7..0000000 --- a/SEAL/native/src/seal/modulus.cpp +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include -#include -#include "seal/modulus.h" -#include "seal/util/numth.h" -#include "seal/util/common.h" - -using namespace std; -using namespace seal::util; - -namespace seal -{ - vector CoeffModulus::BFVDefault( - size_t poly_modulus_degree, sec_level_type sec_level) - { - if (!MaxBitCount(poly_modulus_degree, sec_level)) - { - throw invalid_argument("non-standard poly_modulus_degree"); - } - if (sec_level == sec_level_type::none) - { - throw invalid_argument("invalid security level"); - } - - switch (sec_level) - { - case sec_level_type::tc128: - return global_variables::default_coeff_modulus_128. - at(poly_modulus_degree); - - case sec_level_type::tc192: - return global_variables::default_coeff_modulus_192. - at(poly_modulus_degree); - - case sec_level_type::tc256: - return global_variables::default_coeff_modulus_256. - at(poly_modulus_degree); - - default: - throw runtime_error("invalid security level"); - } - } - - vector CoeffModulus::Create( - size_t poly_modulus_degree, vector bit_sizes) - { - if (poly_modulus_degree > SEAL_POLY_MOD_DEGREE_MAX || - poly_modulus_degree < SEAL_POLY_MOD_DEGREE_MIN || - get_power_of_two(static_cast(poly_modulus_degree)) < 0) - { - throw invalid_argument("poly_modulus_degree is invalid"); - } - if (bit_sizes.size() > SEAL_COEFF_MOD_COUNT_MAX) - { - throw invalid_argument("bit_sizes is invalid"); - } - if (accumulate(bit_sizes.cbegin(), bit_sizes.cend(), - SEAL_USER_MOD_BIT_COUNT_MIN, [](int a, int b) { - return max(a, b); }) > SEAL_USER_MOD_BIT_COUNT_MAX || - accumulate(bit_sizes.cbegin(), bit_sizes.cend(), - SEAL_USER_MOD_BIT_COUNT_MAX, [](int a, int b) { - return min(a, b); }) < SEAL_USER_MOD_BIT_COUNT_MIN) - { - throw invalid_argument("bit_sizes is invalid"); - } - - unordered_map count_table; - unordered_map> prime_table; - for (int size : bit_sizes) - { - ++count_table[size]; - } - for (const auto &table_elt : count_table) - { - prime_table[table_elt.first] = get_primes( - poly_modulus_degree, table_elt.first, table_elt.second); - } - - vector result; - for (int size : bit_sizes) - { - result.emplace_back(prime_table[size].back()); - prime_table[size].pop_back(); - } - return result; - } -} \ No newline at end of file diff --git a/SEAL/native/src/seal/modulus.h b/SEAL/native/src/seal/modulus.h deleted file mode 100644 index bc7b25c..0000000 --- a/SEAL/native/src/seal/modulus.h +++ /dev/null @@ -1,185 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include "seal/smallmodulus.h" -#include "seal/util/hestdparms.h" - -namespace seal -{ - /** - Represents a standard security level according to the HomomorphicEncryption.org - security standard. The value sec_level_type::none signals that no standard - security level should be imposed. The value sec_level_type::tc128 provides - a very high level of security and is the default security level enforced by - Microsoft SEAL when constructing a SEALContext object. Normal users should not - have to specify the security level explicitly anywhere. - */ - enum class sec_level_type : int - { - /** - No security level specified. - */ - none = 0, - - /** - 128-bit security level according to HomomorphicEncryption.org standard. - */ - tc128 = 128, - - /** - 192-bit security level according to HomomorphicEncryption.org standard. - */ - tc192 = 192, - - /** - 256-bit security level according to HomomorphicEncryption.org standard. - */ - tc256 = 256 - }; - - /** - This class contains static methods for creating a coefficient modulus easily. - Note that while these functions take a sec_level_type argument, all security - guarantees are lost if the output is used with encryption parameters with - a mismatching value for the poly_modulus_degree. - - The default value sec_level_type::tc128 provides a very high level of security - and is the default security level enforced by Microsoft SEAL when constructing - a SEALContext object. Normal users should not have to specify the security - level explicitly anywhere. - */ - class CoeffModulus - { - public: - CoeffModulus() = delete; - - /** - Returns the largest bit-length of the coefficient modulus, i.e., bit-length - of the product of the primes in the coefficient modulus, that guarantees - a given security level when using a given poly_modulus_degree, according - to the HomomorphicEncryption.org security standard. - - @param[in] poly_modulus_degree The value of the poly_modulus_degree - encryption parameter - @param[in] sec_level The desired standard security level - */ - SEAL_NODISCARD static constexpr int MaxBitCount( - std::size_t poly_modulus_degree, - sec_level_type sec_level = sec_level_type::tc128) noexcept - { - switch (sec_level) - { - case sec_level_type::tc128: - return util::SEAL_HE_STD_PARMS_128_TC(poly_modulus_degree); - - case sec_level_type::tc192: - return util::SEAL_HE_STD_PARMS_192_TC(poly_modulus_degree); - - case sec_level_type::tc256: - return util::SEAL_HE_STD_PARMS_256_TC(poly_modulus_degree); - - case sec_level_type::none: - return std::numeric_limits::max(); - - default: - return 0; - } - } - - /** - Returns a default coefficient modulus for the BFV scheme that guarantees - a given security level when using a given poly_modulus_degree, according - to the HomomorphicEncryption.org security standard. Note that all security - guarantees are lost if the output is used with encryption parameters with - a mismatching value for the poly_modulus_degree. - - The coefficient modulus returned by this function will not perform well - if used with the CKKS scheme. - - @param[in] poly_modulus_degree The value of the poly_modulus_degree - encryption parameter - @param[in] sec_level The desired standard security level - @throws std::invalid_argument if poly_modulus_degree is not a power-of-two - or is too large - @throws std::invalid_argument if sec_level is sec_level_type::none - */ - SEAL_NODISCARD static std::vector BFVDefault( - std::size_t poly_modulus_degree, - sec_level_type sec_level = sec_level_type::tc128); - - /** - Returns a custom coefficient modulus suitable for use with the specified - poly_modulus_degree. The return value will be a vector consisting of - SmallModulus elements representing distinct prime numbers of bit-lengths - as given in the bit_sizes parameter. The bit sizes of the prime numbers - can be at most 60 bits. - - @param[in] poly_modulus_degree The value of the poly_modulus_degree - encryption parameter - @param[in] bit_sizes The bit-lengths of the primes to be generated - @throws std::invalid_argument if poly_modulus_degree is not a power-of-two - or is too large - @throws std::invalid_argument if bit_sizes is too large or if its elements - are out of bounds - @throws std::logic_error if not enough suitable primes could be found - */ - SEAL_NODISCARD static std::vector Create( - std::size_t poly_modulus_degree, - std::vector bit_sizes); - }; - - /** - This class contains static methods for creating a plaintext modulus easily. - */ - class PlainModulus - { - public: - PlainModulus() = delete; - - /** - Creates a prime number SmallModulus for use as plain_modulus encryption - parameter that supports batching with a given poly_modulus_degree. - - @param[in] poly_modulus_degree The value of the poly_modulus_degree - encryption parameter - @param[in] bit_size The bit-length of the prime to be generated - @throws std::invalid_argument if poly_modulus_degree is not a power-of-two - or is too large - @throws std::invalid_argument if bit_size is out of bounds - @throws std::logic_error if a suitable prime could not be found - */ - SEAL_NODISCARD static inline SmallModulus Batching( - std::size_t poly_modulus_degree, - int bit_size) - { - return CoeffModulus::Create(poly_modulus_degree, { bit_size })[0]; - } - - - /** - Creates several prime number SmallModulus elements that can be used as - plain_modulus encryption parameters, each supporting batching with a given - poly_modulus_degree. - - @param[in] poly_modulus_degree The value of the poly_modulus_degree - encryption parameter - @param[in] bit_sizes The bit-lengths of the primes to be generated - @throws std::invalid_argument if poly_modulus_degree is not a power-of-two - or is too large - @throws std::invalid_argument if bit_sizes is too large or if its elements - are out of bounds - @throws std::logic_error if not enough suitable primes could be found - */ - SEAL_NODISCARD static inline std::vector Batching( - std::size_t poly_modulus_degree, - std::vector bit_sizes) - { - return CoeffModulus::Create(poly_modulus_degree, bit_sizes); - } - }; -} diff --git a/SEAL/native/src/seal/plaintext.cpp b/SEAL/native/src/seal/plaintext.cpp deleted file mode 100644 index 42d6870..0000000 --- a/SEAL/native/src/seal/plaintext.cpp +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/plaintext.h" -#include "seal/util/common.h" - -using namespace std; -using namespace seal::util; - -namespace seal -{ - namespace - { - bool is_dec_char(char c) - { - return c >= '0' && c <= '9'; - } - - int get_dec_value(char c) - { - return c - '0'; - } - - int get_coeff_length(const char *poly) - { - int length = 0; - while (is_hex_char(*poly)) - { - length++; - poly++; - } - return length; - } - - int get_coeff_power(const char *poly, int *power_length) - { - int length = 0; - if (*poly == '\0') - { - *power_length = 0; - return 0; - } - if (*poly != 'x') - { - return -1; - } - poly++; - length++; - - if (*poly != '^') - { - return -1; - } - poly++; - length++; - - int power = 0; - while (is_dec_char(*poly)) - { - power *= 10; - power += get_dec_value(*poly); - poly++; - length++; - } - *power_length = length; - return power; - } - - int get_plus(const char *poly) - { - if (*poly == '\0') - { - return 0; - } - if (*poly++ != ' ') - { - return -1; - } - if (*poly++ != '+') - { - return -1; - } - if (*poly != ' ') - { - return -1; - } - return 3; - } - } - - Plaintext &Plaintext::operator =(const string &hex_poly) - { - if (is_ntt_form()) - { - throw logic_error("cannot set an NTT transformed Plaintext"); - } - if (unsigned_gt(hex_poly.size(), numeric_limits::max())) - { - throw invalid_argument("hex_poly too long"); - } - int length = safe_cast(hex_poly.size()); - - // Determine size needed to store string coefficient. - int assign_coeff_count = 0; - - int assign_coeff_bit_count = 0; - int pos = 0; - int last_power = safe_cast( - min(data_.max_size(), safe_cast(numeric_limits::max()))); - const char *hex_poly_ptr = hex_poly.data(); - while (pos < length) - { - // Determine length of coefficient starting at pos. - int coeff_length = get_coeff_length(hex_poly_ptr + pos); - if (coeff_length == 0) - { - throw invalid_argument("unable to parse hex_poly"); - } - - // Determine bit length of coefficient. - int coeff_bit_count = - get_hex_string_bit_count(hex_poly_ptr + pos, coeff_length); - if (coeff_bit_count > assign_coeff_bit_count) - { - assign_coeff_bit_count = coeff_bit_count; - } - pos += coeff_length; - - // Extract power-term. - int power_length = 0; - int power = get_coeff_power(hex_poly_ptr + pos, &power_length); - if (power == -1 || power >= last_power) - { - throw invalid_argument("unable to parse hex_poly"); - } - if (assign_coeff_count == 0) - { - assign_coeff_count = power + 1; - } - pos += power_length; - last_power = power; - - // Extract plus (unless it is the end). - int plus_length = get_plus(hex_poly_ptr + pos); - if (plus_length == -1) - { - throw invalid_argument("unable to parse hex_poly"); - } - pos += plus_length; - } - - // If string is empty, then done. - if (assign_coeff_count == 0 || assign_coeff_bit_count == 0) - { - set_zero(); - return *this; - } - - // Resize polynomial. - if (assign_coeff_bit_count > bits_per_uint64) - { - throw invalid_argument("hex_poly has too large coefficients"); - } - resize(safe_cast(assign_coeff_count)); - - // Populate polynomial from string. - pos = 0; - last_power = safe_cast(coeff_count()); - while (pos < length) - { - // Determine length of coefficient starting at pos. - const char *coeff_start = hex_poly_ptr + pos; - int coeff_length = get_coeff_length(coeff_start); - pos += coeff_length; - - // Extract power-term. - int power_length = 0; - int power = get_coeff_power(hex_poly_ptr + pos, &power_length); - pos += power_length; - - // Extract plus (unless it is the end). - int plus_length = get_plus(hex_poly_ptr + pos); - pos += plus_length; - - // Zero coefficients not set by string. - for (int zero_power = last_power - 1; zero_power > power; --zero_power) - { - data_[static_cast(zero_power)] = 0; - } - - // Populate coefficient. - uint64_t *coeff_ptr = data_.begin() + power; - hex_string_to_uint(coeff_start, coeff_length, size_t(1), coeff_ptr); - last_power = power; - } - - // Zero coefficients not set by string. - for (int zero_power = last_power - 1; zero_power >= 0; --zero_power) - { - data_[static_cast(zero_power)] = 0; - } - - return *this; - } - - void Plaintext::save(ostream &stream) const - { - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - stream.exceptions(ios_base::badbit | ios_base::failbit); - - stream.write(reinterpret_cast(&parms_id_), sizeof(parms_id_type)); - stream.write(reinterpret_cast(&scale_), sizeof(double)); - data_.save(stream); - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - - stream.exceptions(old_except_mask); - } - - void Plaintext::unsafe_load(istream &stream) - { - Plaintext new_data(data_.pool()); - - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - stream.exceptions(ios_base::badbit | ios_base::failbit); - - parms_id_type parms_id{}; - stream.read(reinterpret_cast(&parms_id), sizeof(parms_id_type)); - - double scale = 0; - stream.read(reinterpret_cast(&scale), sizeof(double)); - - // Load the data - new_data.data_.load(stream); - - // Set the parms_id - new_data.parms_id_ = parms_id; - - // Set the scale - new_data.scale_ = scale; - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - stream.exceptions(old_except_mask); - - swap(*this, new_data); - } -} diff --git a/SEAL/native/src/seal/plaintext.h b/SEAL/native/src/seal/plaintext.h deleted file mode 100644 index cf63826..0000000 --- a/SEAL/native/src/seal/plaintext.h +++ /dev/null @@ -1,680 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "seal/util/common.h" -#include "seal/util/polycore.h" -#include "seal/util/defines.h" -#include "seal/memorymanager.h" -#include "seal/encryptionparams.h" -#include "seal/intarray.h" -#include "seal/context.h" -#include "seal/valcheck.h" - -namespace seal -{ - /** - Class to store a plaintext element. The data for the plaintext is a polynomial - with coefficients modulo the plaintext modulus. The degree of the plaintext - polynomial must be one less than the degree of the polynomial modulus. The - backing array always allocates one 64-bit word per each coefficient of the - polynomial. - - @par Memory Management - The coefficient count of a plaintext refers to the number of word-size - coefficients in the plaintext, whereas its capacity refers to the number of - word-size coefficients that fit in the current memory allocation. In high- - performance applications unnecessary re-allocations should be avoided by - reserving enough memory for the plaintext to begin with either by providing - the desired capacity to the constructor as an extra argument, or by calling - the reserve function at any time. - - When the scheme is scheme_type::BFV each coefficient of a plaintext is a 64-bit - word, but when the scheme is scheme_type::CKKS the plaintext is by default - stored in an NTT transformed form with respect to each of the primes in the - coefficient modulus. Thus, the size of the allocation that is needed is the - size of the coefficient modulus (number of primes) times the degree of the - polynomial modulus. In addition, a valid CKKS plaintext also store the parms_id - for the corresponding encryption parameters. - - @par Thread Safety - In general, reading from plaintext is thread-safe as long as no other thread - is concurrently mutating it. This is due to the underlying data structure - storing the plaintext not being thread-safe. - - @see Ciphertext for the class that stores ciphertexts. - */ - class Plaintext - { - public: - using pt_coeff_type = std::uint64_t; - - using size_type = IntArray::size_type; - - /** - Constructs an empty plaintext allocating no memory. - - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if pool is uninitialized - */ - Plaintext(MemoryPoolHandle pool = MemoryManager::GetPool()) : - data_(std::move(pool)) - { - } - - /** - Constructs a plaintext representing a constant polynomial 0. The coefficient - count of the polynomial is set to the given value. The capacity is set to - the same value. - - @param[in] coeff_count The number of (zeroed) coefficients in the plaintext - polynomial - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if coeff_count is negative - @throws std::invalid_argument if pool is uninitialized - */ - explicit Plaintext(size_type coeff_count, - MemoryPoolHandle pool = MemoryManager::GetPool()) : - data_(coeff_count, std::move(pool)) - { - } - - /** - Constructs a plaintext representing a constant polynomial 0. The coefficient - count of the polynomial and the capacity are set to the given values. - - @param[in] capacity The capacity - @param[in] coeff_count The number of (zeroed) coefficients in the plaintext - polynomial - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if capacity is less than coeff_count - @throws std::invalid_argument if coeff_count is negative - @throws std::invalid_argument if pool is uninitialized - */ - explicit Plaintext(size_type capacity, size_type coeff_count, - MemoryPoolHandle pool = MemoryManager::GetPool()) : - data_(capacity, coeff_count, std::move(pool)) - { - } - - /** - Constructs a plaintext from a given hexadecimal string describing the - plaintext polynomial. - - The string description of the polynomial must adhere to the format returned - by to_string(), - which is of the form "7FFx^3 + 1x^1 + 3" and summarized by the following - rules: - 1. Terms are listed in order of strictly decreasing exponent - 2. Coefficient values are non-negative and in hexadecimal format (upper - and lower case letters are both supported) - 3. Exponents are positive and in decimal format - 4. Zero coefficient terms (including the constant term) may be (but do - not have to be) omitted - 5. Term with the exponent value of one must be exactly written as x^1 - 6. Term with the exponent value of zero (the constant term) must be written - as just a hexadecimal number without exponent - 7. Terms must be separated by exactly + and minus is not - allowed - 8. Other than the +, no other terms should have whitespace - - @param[in] hex_poly The formatted polynomial string specifying the plaintext - polynomial - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if hex_poly does not adhere to the expected - format - @throws std::invalid_argument if pool is uninitialized - */ - Plaintext(const std::string &hex_poly, - MemoryPoolHandle pool = MemoryManager::GetPool()) : - data_(std::move(pool)) - { - operator =(hex_poly); - } - - /** - Constructs a new plaintext by copying a given one. - - @param[in] copy The plaintext to copy from - */ - Plaintext(const Plaintext ©) = default; - - /** - Constructs a new plaintext by moving a given one. - - @param[in] source The plaintext to move from - */ - Plaintext(Plaintext &&source) = default; - - /** - Constructs a new plaintext by copying a given one. - - @param[in] copy The plaintext to copy from - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if pool is uninitialized - */ - Plaintext(const Plaintext ©, MemoryPoolHandle pool) : - Plaintext(std::move(pool)) - { - *this = copy; - } - - /** - Allocates enough memory to accommodate the backing array of a plaintext - with given capacity. - - @param[in] capacity The capacity - @throws std::invalid_argument if capacity is negative - @throws std::logic_error if the plaintext is NTT transformed - */ - void reserve(size_type capacity) - { - if (is_ntt_form()) - { - throw std::logic_error("cannot reserve for an NTT transformed Plaintext"); - } - data_.reserve(capacity); - } - - /** - Allocates enough memory to accommodate the backing array of the current - plaintext and copies it over to the new location. This function is meant - to reduce the memory use of the plaintext to smallest possible and can be - particularly important after modulus switching. - */ - inline void shrink_to_fit() - { - data_.shrink_to_fit(); - } - - /** - Resets the plaintext. This function releases any memory allocated by the - plaintext, returning it to the memory pool. - */ - inline void release() noexcept - { - parms_id_ = parms_id_zero; - scale_ = 1.0; - data_.release(); - } - - /** - Resizes the plaintext to have a given coefficient count. The plaintext - is automatically reallocated if the new coefficient count does not fit in - the current capacity. - - @param[in] coeff_count The number of coefficients in the plaintext polynomial - @throws std::invalid_argument if coeff_count is negative - @throws std::logic_error if the plaintext is NTT transformed - */ - inline void resize(size_type coeff_count) - { - if (is_ntt_form()) - { - throw std::logic_error("cannot reserve for an NTT transformed Plaintext"); - } - data_.resize(coeff_count); - } - - /** - Copies a given plaintext to the current one. - - @param[in] assign The plaintext to copy from - */ - Plaintext &operator =(const Plaintext &assign) = default; - - /** - Moves a given plaintext to the current one. - - @param[in] assign The plaintext to move from - */ - Plaintext &operator =(Plaintext &&assign) = default; - - /** - Sets the value of the current plaintext to the polynomial represented by - the a given hexadecimal string. - - The string description of the polynomial must adhere to the format returned - by to_string(), which is of the form "7FFx^3 + 1x^1 + 3" and summarized - by the following rules: - 1. Terms are listed in order of strictly decreasing exponent - 2. Coefficient values are non-negative and in hexadecimal format (upper - and lower case letters are both supported) - 3. Exponents are positive and in decimal format - 4. Zero coefficient terms (including the constant term) may be (but do - not have to be) omitted - 5. Term with the exponent value of one must be exactly written as x^1 - 6. Term with the exponent value of zero (the constant term) must be - written as just a hexadecimal number without exponent - 7. Terms must be separated by exactly + and minus is not - allowed - 8. Other than the +, no other terms should have whitespace - - @param[in] hex_poly The formatted polynomial string specifying the plaintext - polynomial - @throws std::invalid_argument if hex_poly does not adhere to the expected - format - @throws std::invalid_argument if the coefficients of hex_poly are too wide - */ - Plaintext &operator =(const std::string &hex_poly); - - /** - Sets the value of the current plaintext to a given constant polynomial. - The coefficient count is set to one. - - @param[in] const_coeff The constant coefficient - @throws std::logic_error if the plaintext is NTT transformed - */ - Plaintext &operator =(pt_coeff_type const_coeff) - { - data_.resize(1); - data_[0] = const_coeff; - return *this; - } - - /** - Sets a given range of coefficients of a plaintext polynomial to zero; does - nothing if length is zero. - - @param[in] start_coeff The index of the first coefficient to set to zero - @param[in] length The number of coefficients to set to zero - @throws std::out_of_range if start_coeff + length - 1 is not within [0, coeff_count) - */ - inline void set_zero(size_type start_coeff, size_type length) - { - if (!length) - { - return; - } - if (start_coeff + length - 1 >= coeff_count()) - { - throw std::out_of_range("length must be non-negative and start_coeff + length - 1 must be within [0, coeff_count)"); - } - std::fill_n(data_.begin() + start_coeff, length, pt_coeff_type(0)); - } - - /** - Sets the plaintext polynomial coefficients to zero starting at a given index. - - @param[in] start_coeff The index of the first coefficient to set to zero - @throws std::out_of_range if start_coeff is not within [0, coeff_count) - */ - inline void set_zero(size_type start_coeff) - { - if (start_coeff >= coeff_count()) - { - throw std::out_of_range("start_coeff must be within [0, coeff_count)"); - } - std::fill(data_.begin() + start_coeff, data_.end(), pt_coeff_type(0)); - } - - /** - Sets the plaintext polynomial to zero. - */ - inline void set_zero() - { - std::fill(data_.begin(), data_.end(), pt_coeff_type(0)); - } - - /** - Returns a pointer to the beginning of the plaintext polynomial. - */ - SEAL_NODISCARD inline pt_coeff_type *data() - { - return data_.begin(); - } - - /** - Returns a const pointer to the beginning of the plaintext polynomial. - */ - SEAL_NODISCARD inline const pt_coeff_type *data() const - { - return data_.cbegin(); - } -#ifdef SEAL_USE_MSGSL_SPAN - /** - Returns a span pointing to the beginning of the text polynomial. - */ - SEAL_NODISCARD inline gsl::span data_span() - { - return gsl::span(data_.begin(), - static_cast(coeff_count())); - } - - /** - Returns a span pointing to the beginning of the text polynomial. - */ - SEAL_NODISCARD inline gsl::span data_span() const - { - return gsl::span(data_.cbegin(), - static_cast(coeff_count())); - } -#endif - /** - Returns a pointer to a given coefficient of the plaintext polynomial. - - @param[in] coeff_index The index of the coefficient in the plaintext polynomial - @throws std::out_of_range if coeff_index is not within [0, coeff_count) - */ - SEAL_NODISCARD inline pt_coeff_type *data(size_type coeff_index) - { - if (coeff_count() == 0) - { - return nullptr; - } - if (coeff_index >= coeff_count()) - { - throw std::out_of_range("coeff_index must be within [0, coeff_count)"); - } - return data_.begin() + coeff_index; - } - - /** - Returns a const pointer to a given coefficient of the plaintext polynomial. - - @param[in] coeff_index The index of the coefficient in the plaintext polynomial - */ - SEAL_NODISCARD inline const pt_coeff_type *data( - size_type coeff_index) const - { - if (coeff_count() == 0) - { - return nullptr; - } - if (coeff_index >= coeff_count()) - { - throw std::out_of_range("coeff_index must be within [0, coeff_count)"); - } - return data_.cbegin() + coeff_index; - } - - /** - Returns a const reference to a given coefficient of the plaintext polynomial. - - @param[in] coeff_index The index of the coefficient in the plaintext polynomial - @throws std::out_of_range if coeff_index is not within [0, coeff_count) - */ - SEAL_NODISCARD inline const pt_coeff_type &operator []( - size_type coeff_index) const - { - return data_.at(coeff_index); - } - - /** - Returns a reference to a given coefficient of the plaintext polynomial. - - @param[in] coeff_index The index of the coefficient in the plaintext polynomial - @throws std::out_of_range if coeff_index is not within [0, coeff_count) - */ - SEAL_NODISCARD inline pt_coeff_type &operator []( - size_type coeff_index) - { - return data_.at(coeff_index); - } - - /** - Returns whether or not the plaintext has the same semantic value as a given - plaintext. Leading zero coefficients are ignored by the comparison. - - @param[in] compare The plaintext to compare against - */ - SEAL_NODISCARD inline bool operator ==(const Plaintext &compare) const - { - std::size_t sig_coeff_count = significant_coeff_count(); - std::size_t sig_coeff_count_compare = compare.significant_coeff_count(); - bool parms_id_compare = (is_ntt_form() && compare.is_ntt_form() - && (parms_id_ == compare.parms_id_)) || - (!is_ntt_form() && !compare.is_ntt_form()); - return parms_id_compare - && (sig_coeff_count == sig_coeff_count_compare) - && std::equal(data_.cbegin(), - data_.cbegin() + sig_coeff_count, - compare.data_.cbegin(), - compare.data_.cbegin() + sig_coeff_count) - && std::all_of(data_.cbegin() + sig_coeff_count, - data_.cend(), util::is_zero) - && std::all_of(compare.data_.cbegin() + sig_coeff_count, - compare.data_.cend(), util::is_zero) - && util::are_close(scale_, compare.scale_); - } - - /** - Returns whether or not the plaintext has a different semantic value than - a given plaintext. Leading zero coefficients are ignored by the comparison. - - @param[in] compare The plaintext to compare against - */ - SEAL_NODISCARD inline bool operator !=(const Plaintext &compare) const - { - return !operator ==(compare); - } - - /** - Returns whether the current plaintext polynomial has all zero coefficients. - */ - SEAL_NODISCARD inline bool is_zero() const - { - return (coeff_count() == 0) || - std::all_of(data_.cbegin(), data_.cend(), - util::is_zero); - } - - /** - Returns the capacity of the current allocation. - */ - SEAL_NODISCARD inline size_type capacity() const noexcept - { - return data_.capacity(); - } - - /** - Returns the coefficient count of the current plaintext polynomial. - */ - SEAL_NODISCARD inline size_type coeff_count() const noexcept - { - return data_.size(); - } - - /** - Returns the significant coefficient count of the current plaintext polynomial. - */ - SEAL_NODISCARD inline size_type significant_coeff_count() const - { - if (coeff_count() == 0) - { - return 0; - } - return util::get_significant_uint64_count_uint(data_.cbegin(), coeff_count()); - } - - /** - Returns the non-zero coefficient count of the current plaintext polynomial. - */ - SEAL_NODISCARD inline size_type nonzero_coeff_count() const - { - if (coeff_count() == 0) - { - return 0; - } - return util::get_nonzero_uint64_count_uint(data_.cbegin(), coeff_count()); - } - - /** - Returns a human-readable string description of the plaintext polynomial. - - The returned string is of the form "7FFx^3 + 1x^1 + 3" with a format - summarized by the following: - 1. Terms are listed in order of strictly decreasing exponent - 2. Coefficient values are non-negative and in hexadecimal format (hexadecimal - letters are in upper-case) - 3. Exponents are positive and in decimal format - 4. Zero coefficient terms (including the constant term) are omitted unless - the polynomial is exactly 0 (see rule 9) - 5. Term with the exponent value of one is written as x^1 - 6. Term with the exponent value of zero (the constant term) is written as - just a hexadecimal number without x or exponent - 7. Terms are separated exactly by + - 8. Other than the +, no other terms have whitespace - 9. If the polynomial is exactly 0, the string "0" is returned - - @throws std::invalid_argument if the plaintext is in NTT transformed form - */ - SEAL_NODISCARD inline std::string to_string() const - { - if (is_ntt_form()) - { - throw std::invalid_argument("cannot convert NTT transformed plaintext to string"); - } - return util::poly_to_hex_string(data_.cbegin(), coeff_count(), 1); - } - - /** - Saves the plaintext to an output stream. The output is in binary format - and not human-readable. The output stream must have the "binary" flag set. - - @param[in] stream The stream to save the plaintext to - @throws std::exception if the plaintext could not be written to stream - */ - void save(std::ostream &stream) const; - - void python_save(std::string &path) const - { - try - { - std::ofstream out(path, std::ofstream::binary); - this->save(out); - out.close(); - } - catch (const std::exception &) - { - throw "Plaintext write exception"; - } - } - - /** - Loads a plaintext from an input stream overwriting the current plaintext. - No checking of the validity of the plaintext data against encryption - parameters is performed. This function should not be used unless the - plaintext comes from a fully trusted source. - - @param[in] stream The stream to load the plaintext from - @throws std::exception if a valid plaintext could not be read from stream - */ - void unsafe_load(std::istream &stream); - - void python_load(std::shared_ptr context, - std::string &path) - { - try - { - std::ifstream in(path, std::ifstream::binary); - this->load(context, in); - in.close(); - } - catch (const std::exception &) - { - throw "Plaintext read exception"; - } - } - - /** - Loads a plaintext from an input stream overwriting the current plaintext. - The loaded plaintext is verified to be valid for the given SEALContext. - - @param[in] context The SEALContext - @param[in] stream The stream to load the plaintext from - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::exception if a valid plaintext could not be read from stream - @throws std::invalid_argument if the loaded plaintext is invalid for the - context - */ - inline void load(std::shared_ptr context, - std::istream &stream) - { - Plaintext new_data(pool()); - new_data.unsafe_load(stream); - if (!is_valid_for(new_data, std::move(context))) - { - throw std::invalid_argument("Plaintext data is invalid"); - } - std::swap(*this, new_data); - } - - /** - Returns whether the plaintext is in NTT form. - */ - SEAL_NODISCARD inline bool is_ntt_form() const noexcept - { - return (parms_id_ != parms_id_zero); - } - - /** - Returns a reference to parms_id. The parms_id must remain zero unless the - plaintext polynomial is in NTT form. - - @see EncryptionParameters for more information about parms_id. - */ - SEAL_NODISCARD inline auto &parms_id() noexcept - { - return parms_id_; - } - - /** - Returns a const reference to parms_id. The parms_id must remain zero unless - the plaintext polynomial is in NTT form. - - @see EncryptionParameters for more information about parms_id. - */ - SEAL_NODISCARD inline auto &parms_id() const noexcept - { - return parms_id_; - } - - /** - Returns a reference to the scale. This is only needed when using the CKKS - encryption scheme. The user should have little or no reason to ever change - the scale by hand. - */ - SEAL_NODISCARD inline auto &scale() noexcept - { - return scale_; - } - - /** - Returns a constant reference to the scale. This is only needed when using - the CKKS encryption scheme. - */ - SEAL_NODISCARD inline auto &scale() const noexcept - { - return scale_; - } - - /** - Returns the currently used MemoryPoolHandle. - */ - SEAL_NODISCARD inline MemoryPoolHandle pool() const noexcept - { - return data_.pool(); - } - - /** - Enables access to private members of seal::Plaintext for .NET wrapper. - */ - struct PlaintextPrivateHelper; - - private: - parms_id_type parms_id_ = parms_id_zero; - - double scale_ = 1.0; - - IntArray data_; - }; -} diff --git a/SEAL/native/src/seal/publickey.h b/SEAL/native/src/seal/publickey.h deleted file mode 100644 index d849d0e..0000000 --- a/SEAL/native/src/seal/publickey.h +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include "seal/ciphertext.h" -#include "seal/context.h" -#include "seal/valcheck.h" - -namespace seal -{ - /** - Class to store a public key. - - @par Thread Safety - In general, reading from PublicKey is thread-safe as long as no other thread - is concurrently mutating it. This is due to the underlying data structure - storing the public key not being thread-safe. - - @see KeyGenerator for the class that generates the public key. - @see SecretKey for the class that stores the secret key. - @see RelinKeys for the class that stores the relinearization keys. - @see GaloisKeys for the class that stores the Galois keys. - */ - class PublicKey - { - friend class KeyGenerator; - friend class KSwitchKeys; - - public: - /** - Creates an empty public key. - */ - PublicKey() = default; - - /** - Creates a new PublicKey by copying an old one. - - @param[in] copy The PublicKey to copy from - */ - PublicKey(const PublicKey ©) = default; - - /** - Creates a new PublicKey by moving an old one. - - @param[in] source The PublicKey to move from - */ - PublicKey(PublicKey &&source) = default; - - /** - Copies an old PublicKey to the current one. - - @param[in] assign The PublicKey to copy from - */ - PublicKey &operator =(const PublicKey &assign) = default; - - /** - Moves an old PublicKey to the current one. - - @param[in] assign The PublicKey to move from - */ - PublicKey &operator =(PublicKey &&assign) = default; - - /** - Returns a reference to the underlying data. - */ - SEAL_NODISCARD inline auto &data() noexcept - { - return pk_; - } - - /** - Returns a const reference to the underlying data. - */ - SEAL_NODISCARD inline auto &data() const noexcept - { - return pk_; - } - - /** - Saves the PublicKey to an output stream. The output is in binary format - and not human-readable. The output stream must have the "binary" flag set. - - @param[in] stream The stream to save the PublicKey to - @throws std::exception if the PublicKey could not be written to stream - */ - inline void save(std::ostream &stream) const - { - pk_.save(stream); - } - - void python_save(std::string &path) const - { - try - { - std::ofstream out(path, std::ofstream::binary); - this->save(out); - out.close(); - } - catch (const std::exception &) - { - throw "PublicKey write exception"; - } - } - - /** - Loads a PublicKey from an input stream overwriting the current PublicKey. - No checking of the validity of the PublicKey data against encryption - parameters is performed. This function should not be used unless the - PublicKey comes from a fully trusted source. - - @param[in] stream The stream to load the PublicKey from - @throws std::exception if a valid PublicKey could not be read from stream - */ - inline void unsafe_load(std::istream &stream) - { - Ciphertext new_pk(pk_.pool()); - new_pk.unsafe_load(stream); - std::swap(pk_, new_pk); - } - - void python_load(std::shared_ptr context, - std::string &path) - { - try - { - std::ifstream in(path, std::ifstream::binary); - this->load(context, in); - in.close(); - } - catch (const std::exception &) - { - throw "PublicKey read exception"; - } - } - - /** - Loads a PublicKey from an input stream overwriting the current PublicKey. - The loaded PublicKey is verified to be valid for the given SEALContext. - - @param[in] context The SEALContext - @param[in] stream The stream to load the PublicKey from - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::exception if a valid PublicKey could not be read from stream - @throws std::invalid_argument if the loaded PublicKey is invalid for the - context - */ - inline void load(std::shared_ptr context, - std::istream &stream) - { - PublicKey new_pk(pool()); - new_pk.unsafe_load(stream); - if (!is_valid_for(new_pk, std::move(context))) - { - throw std::invalid_argument("PublicKey data is invalid"); - } - std::swap(*this, new_pk); - } - - /** - Returns a reference to parms_id. - */ - SEAL_NODISCARD inline auto &parms_id() noexcept - { - return pk_.parms_id(); - } - - /** - Returns a const reference to parms_id. - */ - SEAL_NODISCARD inline auto &parms_id() const noexcept - { - return pk_.parms_id(); - } - - /** - Returns the currently used MemoryPoolHandle. - */ - SEAL_NODISCARD inline MemoryPoolHandle pool() const noexcept - { - return pk_.pool(); - } - - /** - Enables access to private members of seal::PublicKey for .NET wrapper. - */ - struct PublicKeyPrivateHelper; - - private: - /** - Creates an empty public key. - - @param[in] pool The MemoryPoolHandle pointing to a valid memory pool - @throws std::invalid_argument if pool is uninitialized - */ - PublicKey(MemoryPoolHandle pool) : - pk_(std::move(pool)) - { - } - - Ciphertext pk_; - }; -} diff --git a/SEAL/native/src/seal/randomgen.cpp b/SEAL/native/src/seal/randomgen.cpp deleted file mode 100644 index 0a53553..0000000 --- a/SEAL/native/src/seal/randomgen.cpp +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/randomgen.h" - -using namespace std; - -namespace seal -{ - /** - Returns the default random number generator factory. This instance should - not be destroyed. - */ - auto UniformRandomGeneratorFactory::default_factory() - -> const shared_ptr - { - static const shared_ptr - default_factory{ new SEAL_DEFAULT_RNG_FACTORY }; - return default_factory; - } -#ifdef SEAL_USE_AES_NI_PRNG - auto FastPRNGFactory::create() -> shared_ptr - { - if (!(seed_[0] | seed_[1])) - { - return make_shared(random_uint64(), random_uint64()); - } - else - { - return make_shared(seed_[0], seed_[1]); - } - } -#endif -} diff --git a/SEAL/native/src/seal/randomgen.h b/SEAL/native/src/seal/randomgen.h deleted file mode 100644 index 4794e8b..0000000 --- a/SEAL/native/src/seal/randomgen.h +++ /dev/null @@ -1,288 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include "seal/util/defines.h" -#include "seal/util/common.h" -#include "seal/util/aes.h" - -namespace seal -{ - /** - Provides the base class for a uniform random number generator. Instances of - this class are typically returned from the UniformRandomGeneratorFactory class. - This class is meant for users to sub-class to implement their own random number - generators. The implementation should provide a uniform random unsigned 32-bit - value for each call to generate(). Note that the library will never make - concurrent calls to generate() to the same instance (but individual instances - of the same class may have concurrent calls). The uniformity and unpredictability - of the numbers generated is essential for making a secure cryptographic system. - - @see UniformRandomGeneratorFactory for the base class of a factory class that - generates UniformRandomGenerator instances. - */ - class UniformRandomGenerator - { - public: - /** - Generates a new uniform unsigned 32-bit random number. Note that the - implementation does not need to be thread-safe. - */ - virtual std::uint32_t generate() = 0; - - /** - Destroys the random number generator. - */ - virtual ~UniformRandomGenerator() = default; - }; - - /** - Provides the base class for a factory instance that creates instances of - UniformRandomGenerator. This class is meant for users to sub-class to implement - their own random number generators. Note that each instance returned may be - used concurrently across separate threads, but each individual instance does - not need to be thread-safe. - - @see UniformRandomGenerator for details relating to the random number generator - instances. - @see StandardRandomAdapterFactory for an implementation of - UniformRandomGeneratorFactory that supports the standard C++ library's - random number generators. - */ - class UniformRandomGeneratorFactory - { - public: - /** - Creates a new uniform random number generator. - */ - virtual auto create() - -> std::shared_ptr = 0; - - /** - Destroys the random number generator factory. - */ - virtual ~UniformRandomGeneratorFactory() = default; - - /** - Returns the default random number generator factory. This instance should - not be destroyed. - */ - static auto default_factory() - -> const std::shared_ptr; - - private: - }; -#ifdef SEAL_USE_AES_NI_PRNG - /** - Provides an implementation of UniformRandomGenerator for using very fast - AES-NI randomness with given 128-bit seed. - */ - class FastPRNG : public UniformRandomGenerator - { - public: - /** - Creates a new FastPRNGFactory instance that initializes every FastPRNG - instance it creates with the given seed. - */ - FastPRNG(std::uint64_t seed_lw, std::uint64_t seed_hw) : - aes_enc_{ seed_lw, seed_hw } - { - refill_buffer(); - } - - /** - Generates a new uniform unsigned 32-bit random number. Note that the - implementation does not need to be thread-safe. - */ - SEAL_NODISCARD virtual std::uint32_t generate() override - { - std::uint32_t result; - std::copy_n(buffer_head_, util::bytes_per_uint32, - reinterpret_cast(&result)); - buffer_head_ += util::bytes_per_uint32; - if (buffer_head_ == buffer_.cend()) - { - refill_buffer(); - } - return result; - } - - /** - Destroys the random number generator. - */ - virtual ~FastPRNG() override = default; - - private: - AESEncryptor aes_enc_; - - static constexpr std::size_t bytes_per_block_ = - sizeof(aes_block) / sizeof(SEAL_BYTE); - - static constexpr std::size_t buffer_block_size_ = 8; - - static constexpr std::size_t buffer_size_ = - buffer_block_size_ * bytes_per_block_; - - std::array buffer_; - - std::size_t counter_ = 0; - - typename decltype(buffer_)::const_iterator buffer_head_; - - void refill_buffer() - { - // Fill the randomness buffer - aes_block *buffer_ptr = reinterpret_cast(&*buffer_.begin()); - aes_enc_.counter_encrypt(counter_, buffer_block_size_, buffer_ptr); - counter_ += buffer_block_size_; - buffer_head_ = buffer_.cbegin(); - } - }; - - class FastPRNGFactory : public UniformRandomGeneratorFactory - { - public: - /** - Creates a new FastPRNGFactory instance that initializes every FastPRNG - instance it creates with the given seed. A zero seed (default value) - signals that each random number generator created by the factory should - use a different random seed obtained from std::random_device. - - @param[in] seed_lw Low-word for seed for the PRNG - @param[in] seed_hw High-word for seed for the PRNG - */ - FastPRNGFactory(std::uint64_t seed_lw = 0, std::uint64_t seed_hw = 0) : - seed_{ seed_lw, seed_hw } - { - } - - /** - Creates a new uniform random number generator. The caller of create needs - to ensure the returned instance is destroyed once it is no longer in-use - to prevent a memory leak. - */ - SEAL_NODISCARD virtual auto create() - -> std::shared_ptr override; - - /** - Destroys the random number generator factory. - */ - virtual ~FastPRNGFactory() = default; - - private: - SEAL_NODISCARD std::uint64_t random_uint64() const noexcept - { - std::random_device rd; - return (static_cast(rd()) << 32) - + static_cast(rd()); - } - - std::uint64_t seed_[2]; - }; -#endif //SEAL_USE_AES_NI_PRNG - /** - Provides an implementation of UniformRandomGenerator for the standard C++ - library's uniform random number generators. - - @tparam RNG specifies the type of the standard C++ library's random number - generator (e.g., std::default_random_engine) - */ - template - class StandardRandomAdapter : public UniformRandomGenerator - { - public: - /** - Creates a new random number generator (of type RNG). - */ - StandardRandomAdapter() = default; - - /** - Returns a reference to the random number generator. - */ - SEAL_NODISCARD inline const RNG &generator() const noexcept - { - return generator_; - } - - /** - Returns a reference to the random number generator. - */ - SEAL_NODISCARD inline RNG &generator() noexcept - { - return generator_; - } - - /** - Generates a new uniform unsigned 32-bit random number. - */ - SEAL_NODISCARD std::uint32_t generate() noexcept override - { - SEAL_IF_CONSTEXPR (RNG::min() == 0 && RNG::max() >= UINT32_MAX) - { - return static_cast(generator_()); - } - else SEAL_IF_CONSTEXPR (RNG::max() - RNG::min() >= UINT32_MAX) - { - return static_cast(generator_() - RNG::min()); - } - else SEAL_IF_CONSTEXPR (RNG::min() == 0) - { - std::uint64_t max_value = RNG::max(); - std::uint64_t value = static_cast(generator_()); - std::uint64_t max = max_value; - while (max < UINT32_MAX) - { - value *= max_value; - max *= max_value; - value += static_cast(generator_()); - } - return static_cast(value); - } - else - { - std::uint64_t max_value = RNG::max() - RNG::min(); - std::uint64_t value = static_cast(generator_() - RNG::min()); - std::uint64_t max = max_value; - while (max < UINT32_MAX) - { - value *= max_value; - max *= max_value; - value += static_cast(generator_() - RNG::min()); - } - return static_cast(value); - } - } - - private: - RNG generator_; - }; - - /** - Provides an implementation of UniformRandomGeneratorFactory for the standard - C++ library's random number generators. - - @tparam RNG specifies the type of the standard C++ library's random number - generator (e.g., std::default_random_engine) - */ - template - class StandardRandomAdapterFactory : public UniformRandomGeneratorFactory - { - public: - /** - Creates a new uniform random number generator. - */ - SEAL_NODISCARD auto create() - -> std::shared_ptr override - { - return std::shared_ptr{ - new StandardRandomAdapter() }; - } - - private: - }; -} diff --git a/SEAL/native/src/seal/randomtostd.h b/SEAL/native/src/seal/randomtostd.h deleted file mode 100644 index 6549797..0000000 --- a/SEAL/native/src/seal/randomtostd.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include "seal/randomgen.h" - -namespace seal -{ - /** - A simple wrapper class to implement C++ UniformRandomBitGenerator type properties - for a given polymorphic UniformRandomGenerator instance. The resulting object can - be used as a randomness source in C++ standard random number distribution classes, - such as std::uniform_int_distribution, std::normal_distribution, or any of the - standard RandomNumberEngine classes. - */ - class RandomToStandardAdapter - { - public: - using result_type = std::uint32_t; - - /** - Creates a new RandomToStandardAdapter backed by a given UniformRandomGenerator. - - @param[in] generator A backing UniformRandomGenerator instance - @throws std::invalid_argument if generator is null - */ - RandomToStandardAdapter( - std::shared_ptr generator) : generator_(generator) - { - if (!generator_) - { - throw std::invalid_argument("generator cannot be null"); - } - } - - /** - Returns a new random number from the backing UniformRandomGenerator. - */ - SEAL_NODISCARD inline result_type operator()() - { - return generator_->generate(); - } - - /** - Returns the backing UniformRandomGenerator. - */ - SEAL_NODISCARD inline auto generator() const noexcept - { - return generator_; - } - - /** - Returns the smallest possible output value. - */ - SEAL_NODISCARD inline static constexpr result_type min() noexcept - { - return std::numeric_limits::min(); - } - - /** - Returns the largest possible output value. - */ - SEAL_NODISCARD static constexpr result_type max() noexcept - { - return std::numeric_limits::max(); - } - - private: - std::shared_ptr generator_; - }; -} diff --git a/SEAL/native/src/seal/relinkeys.h b/SEAL/native/src/seal/relinkeys.h deleted file mode 100644 index 2efe72e..0000000 --- a/SEAL/native/src/seal/relinkeys.h +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include "seal/util/defines.h" -#include "seal/ciphertext.h" -#include "seal/memorymanager.h" -#include "seal/encryptionparams.h" -#include "seal/kswitchkeys.h" - -namespace seal -{ - /** - Class to store relinearization keys. - - @par Relinearization - Freshly encrypted ciphertexts have a size of 2, and multiplying ciphertexts - of sizes K and L results in a ciphertext of size K+L-1. Unfortunately, this - growth in size slows down further multiplications and increases noise growth. - Relinearization is an operation that has no semantic meaning, but it reduces - the size of ciphertexts back to 2. Microsoft SEAL can only relinearize size 3 - ciphertexts back to size 2, so if the ciphertexts grow larger than size 3, - there is no way to reduce their size. Relinearization requires an instance of - RelinKeys to be created by the secret key owner and to be shared with the - evaluator. Note that plain multiplication is fundamentally different from - normal multiplication and does not result in ciphertext size growth. - - @par When to Relinearize - Typically, one should always relinearize after each multiplications. However, - in some cases relinearization should be postponed as late as possible due to - its computational cost. For example, suppose the computation involves several - homomorphic multiplications followed by a sum of the results. In this case it - makes sense to not relinearize each product, but instead add them first and - only then relinearize the sum. This is particularly important when using the - CKKS scheme, where relinearization is much more computationally costly than - multiplications and additions. - - @par Thread Safety - In general, reading from RelinKeys is thread-safe as long as no other thread - is concurrently mutating it. This is due to the underlying data structure - storing the relinearization keys not being thread-safe. - - @see SecretKey for the class that stores the secret key. - @see PublicKey for the class that stores the public key. - @see GaloisKeys for the class that stores the Galois keys. - @see KeyGenerator for the class that generates the relinearization keys. - */ - class RelinKeys : public KSwitchKeys - { - public: - /** - Returns the index of a relinearization key in the backing KSwitchKeys - instance that corresponds to the given secret key power, assuming that - it exists in the backing KSwitchKeys. - - @param[in] key_power The power of the secret key - @throws std::invalid_argument if key_power is less than 2 - */ - SEAL_NODISCARD inline static std::size_t get_index( - std::size_t key_power) - { - if (key_power < 2) - { - throw std::invalid_argument("key_power cannot be less than 2"); - } - return key_power - 2; - } - - /** - Returns whether a relinearization key corresponding to a given power of - the secret key exists. - - @param[in] key_power The power of the secret key - @throws std::invalid_argument if key_power is less than 2 - */ - SEAL_NODISCARD inline bool has_key(std::size_t key_power) const - { - std::size_t index = get_index(key_power); - return data().size() > index && !data()[index].empty(); - } - - /** - Returns a const reference to a relinearization key. The returned - relinearization key corresponds to the given power of the secret key. - - @param[in] key_power The power of the secret key - @throws std::invalid_argument if the key corresponding to key_power does not exist - */ - SEAL_NODISCARD inline auto &key(std::size_t key_power) const - { - return KSwitchKeys::data(get_index(key_power)); - } - }; -} diff --git a/SEAL/native/src/seal/seal.h b/SEAL/native/src/seal/seal.h deleted file mode 100644 index d78892f..0000000 --- a/SEAL/native/src/seal/seal.h +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include "seal/biguint.h" -#include "seal/ciphertext.h" -#include "seal/ckks.h" -#include "seal/modulus.h" -#include "seal/context.h" -#include "seal/decryptor.h" -#include "seal/intencoder.h" -#include "seal/encryptionparams.h" -#include "seal/encryptor.h" -#include "seal/evaluator.h" -#include "seal/intarray.h" -#include "seal/keygenerator.h" -#include "seal/memorymanager.h" -#include "seal/plaintext.h" -#include "seal/batchencoder.h" -#include "seal/publickey.h" -#include "seal/randomgen.h" -#include "seal/randomtostd.h" -#include "seal/relinkeys.h" -#include "seal/secretkey.h" -#include "seal/smallmodulus.h" -#include "seal/valcheck.h" \ No newline at end of file diff --git a/SEAL/native/src/seal/secretkey.h b/SEAL/native/src/seal/secretkey.h deleted file mode 100644 index 5aef7ff..0000000 --- a/SEAL/native/src/seal/secretkey.h +++ /dev/null @@ -1,221 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include "seal/util/defines.h" -#include "seal/randomgen.h" -#include "seal/plaintext.h" -#include "seal/memorymanager.h" -#include "seal/util/common.h" -#include "seal/valcheck.h" -#include -#include -#include -#include -#include - -namespace seal -{ - /** - Class to store a secret key. - - @par Thread Safety - In general, reading from SecretKey is thread-safe as long as no other thread - is concurrently mutating it. This is due to the underlying data structure - storing the secret key not being thread-safe. - - @see KeyGenerator for the class that generates the secret key. - @see PublicKey for the class that stores the public key. - @see RelinKeys for the class that stores the relinearization keys. - @see GaloisKeys for the class that stores the Galois keys. - */ - class SecretKey - { - friend class KeyGenerator; - - public: - /** - Creates an empty secret key. - */ - SecretKey() = default; - - /** - Creates a new SecretKey by copying an old one. - - @param[in] copy The SecretKey to copy from - */ - SecretKey(const SecretKey ©) - { - // Note: sk_ is at this point initialized to use a custom (new) - // memory pool with the `clear_on_destruction' property. Now use - // Plaintext::operator =(const Plaintext &) to copy over the data. - // This is very important to do right, otherwise newly created - // SecretKey may use a normal memory pool obtained from - // MemoryManager::GetPool() with currently active profile (MMProf). - sk_ = copy.sk_; - } - - /** - Destroys the SecretKey object. - */ - ~SecretKey() = default; - - /** - Creates a new SecretKey by moving an old one. - - @param[in] source The SecretKey to move from - */ - SecretKey(SecretKey &&source) = default; - - /** - Copies an old SecretKey to the current one. - - @param[in] assign The SecretKey to copy from - */ - SecretKey &operator =(const SecretKey &assign) - { - Plaintext new_sk(MemoryManager::GetPool(mm_prof_opt::FORCE_NEW, true)); - new_sk = assign.sk_; - std::swap(sk_, new_sk); - return *this; - } - - /** - Moves an old SecretKey to the current one. - - @param[in] assign The SecretKey to move from - */ - SecretKey &operator =(SecretKey &&assign) = default; - - /** - Returns a reference to the underlying polynomial. - */ - SEAL_NODISCARD inline auto &data() noexcept - { - return sk_; - } - - /** - Returns a const reference to the underlying polynomial. - */ - SEAL_NODISCARD inline auto &data() const noexcept - { - return sk_; - } - - /** - Saves the SecretKey to an output stream. The output is in binary format - and not human-readable. The output stream must have the "binary" flag set. - - @param[in] stream The stream to save the SecretKey to - @throws std::exception if the plaintext could not be written to stream - */ - inline void save(std::ostream &stream) const - { - sk_.save(stream); - } - - void python_save(std::string &path) const - { - try - { - std::ofstream out(path, std::ofstream::binary); - this->save(out); - out.close(); - } - catch (const std::exception &) - { - throw "SecretKey write exception"; - } - } - - /** - Loads a SecretKey from an input stream overwriting the current SecretKey. - No checking of the validity of the SecretKey data against encryption - parameters is performed. This function should not be used unless the - SecretKey comes from a fully trusted source. - - @param[in] stream The stream to load the SecretKey from - @throws std::exception if a valid SecretKey could not be read from stream - */ - inline void unsafe_load(std::istream &stream) - { - // We use a fresh memory pool with `clear_on_destruction' enabled. - Plaintext new_sk(MemoryManager::GetPool(mm_prof_opt::FORCE_NEW, true)); - new_sk.unsafe_load(stream); - std::swap(sk_, new_sk); - } - - void python_load(std::shared_ptr context, - std::string &path) - { - try - { - std::ifstream in(path, std::ifstream::binary); - this->load(context, in); - in.close(); - } - catch (const std::exception &) - { - throw "SecretKey read exception"; - } - } - - /** - Loads a SecretKey from an input stream overwriting the current SecretKey. - The loaded SecretKey is verified to be valid for the given SEALContext. - - @param[in] context The SEALContext - @param[in] stream The stream to load the SecretKey from - @throws std::invalid_argument if the context is not set or encryption - parameters are not valid - @throws std::exception if a valid SecretKey could not be read from stream - @throws std::invalid_argument if the loaded SecretKey is invalid for the - context - */ - inline void load(std::shared_ptr context, - std::istream &stream) - { - SecretKey new_sk; - new_sk.unsafe_load(stream); - if (!is_valid_for(new_sk, std::move(context))) - { - throw std::invalid_argument("SecretKey data is invalid"); - } - std::swap(*this, new_sk); - } - - /** - Returns a reference to parms_id. - - @see EncryptionParameters for more information about parms_id. - */ - SEAL_NODISCARD inline auto &parms_id() noexcept - { - return sk_.parms_id(); - } - - /** - Returns a const reference to parms_id. - - @see EncryptionParameters for more information about parms_id. - */ - SEAL_NODISCARD inline auto &parms_id() const noexcept - { - return sk_.parms_id(); - } - - /** - Returns the currently used MemoryPoolHandle. - */ - SEAL_NODISCARD inline MemoryPoolHandle pool() const noexcept - { - return sk_.pool(); - } - - private: - // We use a fresh memory pool with `clear_on_destruction' enabled. - Plaintext sk_{ MemoryManager::GetPool(mm_prof_opt::FORCE_NEW, true) }; - }; -} diff --git a/SEAL/native/src/seal/smallmodulus.cpp b/SEAL/native/src/seal/smallmodulus.cpp deleted file mode 100644 index 7cc13b4..0000000 --- a/SEAL/native/src/seal/smallmodulus.cpp +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/smallmodulus.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithsmallmod.h" -#include "seal/util/common.h" -#include "seal/util/numth.h" -#include - -using namespace seal::util; -using namespace std; - -namespace seal -{ - void SmallModulus::save(ostream &stream) const - { - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - stream.exceptions(ios_base::badbit | ios_base::failbit); - - stream.write(reinterpret_cast(&value_), sizeof(uint64_t)); - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - - stream.exceptions(old_except_mask); - } - - void SmallModulus::load(istream &stream) - { - auto old_except_mask = stream.exceptions(); - try - { - // Throw exceptions on std::ios_base::badbit and std::ios_base::failbit - stream.exceptions(ios_base::badbit | ios_base::failbit); - - uint64_t value; - stream.read(reinterpret_cast(&value), sizeof(uint64_t)); - set_value(value); - } - catch (const exception &) - { - stream.exceptions(old_except_mask); - throw; - } - - stream.exceptions(old_except_mask); - } - - void SmallModulus::set_value(uint64_t value) - { - if (value == 0) - { - // Zero settings - bit_count_ = 0; - uint64_count_ = 1; - value_ = 0; - const_ratio_ = { { 0, 0, 0 } }; - is_prime_ = false; - } - else if ((value >> 62 != 0) || (value == uint64_t(0x4000000000000000)) || - (value == 1)) - { - throw invalid_argument("value can be at most 62 bits and cannot be 1"); - } - else - { - // All normal, compute const_ratio and set everything - value_ = value; - bit_count_ = get_significant_bit_count(value_); - - // Compute Barrett ratios for 64-bit words (barrett_reduce_128) - uint64_t numerator[3]{ 0, 0, 1 }; - uint64_t quotient[3]{ 0, 0, 0 }; - - // Use a special method to avoid using memory pool - divide_uint192_uint64_inplace(numerator, value_, quotient); - - const_ratio_[0] = quotient[0]; - const_ratio_[1] = quotient[1]; - - // We store also the remainder - const_ratio_[2] = numerator[0]; - - uint64_count_ = 1; - - // Set the primality flag - is_prime_ = util::is_prime(*this); - } - } -} \ No newline at end of file diff --git a/SEAL/native/src/seal/smallmodulus.h b/SEAL/native/src/seal/smallmodulus.h deleted file mode 100644 index 8f7cf28..0000000 --- a/SEAL/native/src/seal/smallmodulus.h +++ /dev/null @@ -1,325 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include "seal/util/defines.h" -#include "seal/util/uintcore.h" -#include "seal/memorymanager.h" - -namespace seal -{ - /** - Represent an integer modulus of up to 62 bits. An instance of the SmallModulus - class represents a non-negative integer modulus up to 62 bits. In particular, - the encryption parameter plain_modulus, and the primes in coeff_modulus, are - represented by instances of SmallModulus. The purpose of this class is to - perform and store the pre-computation required by Barrett reduction. - - @par Thread Safety - In general, reading from SmallModulus is thread-safe as long as no other thread - is concurrently mutating it. - - @see EncryptionParameters for a description of the encryption parameters. - */ - class SmallModulus - { - public: - /** - Creates a SmallModulus instance. The value of the SmallModulus is set to - the given value, or to zero by default. - - @param[in] value The integer modulus - @throws std::invalid_argument if value is 1 or more than 62 bits - */ - SmallModulus(std::uint64_t value = 0) - { - set_value(value); - } - - /** - Creates a new SmallModulus by copying a given one. - - @param[in] copy The SmallModulus to copy from - */ - SmallModulus(const SmallModulus ©) = default; - - /** - Creates a new SmallModulus by copying a given one. - - @param[in] source The SmallModulus to move from - */ - SmallModulus(SmallModulus &&source) = default; - - /** - Copies a given SmallModulus to the current one. - - @param[in] assign The SmallModulus to copy from - */ - SmallModulus &operator =(const SmallModulus &assign) = default; - - /** - Moves a given SmallModulus to the current one. - - @param[in] assign The SmallModulus to move from - */ - SmallModulus &operator =(SmallModulus &&assign) = default; - - /** - Sets the value of the SmallModulus. - - @param[in] value The new integer modulus - @throws std::invalid_argument if value is 1 or more than 62 bits - */ - inline SmallModulus &operator =(std::uint64_t value) - { - set_value(value); - return *this; - } - - /** - Returns the significant bit count of the value of the current SmallModulus. - */ - SEAL_NODISCARD inline int bit_count() const noexcept - { - return bit_count_; - } - - /** - Returns the size (in 64-bit words) of the value of the current SmallModulus. - */ - SEAL_NODISCARD inline std::size_t uint64_count() const noexcept - { - return uint64_count_; - } - - /** - Returns a const pointer to the value of the current SmallModulus. - */ - SEAL_NODISCARD inline const uint64_t *data() const noexcept - { - return &value_; - } - - /** - Returns the value of the current SmallModulus. - */ - SEAL_NODISCARD inline std::uint64_t value() const noexcept - { - return value_; - } - - /** - Returns the Barrett ratio computed for the value of the current SmallModulus. - The first two components of the Barrett ratio are the floor of 2^128/value, - and the third component is the remainder. - */ - SEAL_NODISCARD inline auto &const_ratio() const noexcept - { - return const_ratio_; - } - - /** - Returns whether the value of the current SmallModulus is zero. - */ - SEAL_NODISCARD inline bool is_zero() const noexcept - { - return value_ == 0; - } - - /** - Returns whether the value of the current SmallModulus is a prime number. - */ - SEAL_NODISCARD inline bool is_prime() const noexcept - { - return is_prime_; - } - - /** - Compares two SmallModulus instances. - - @param[in] compare The SmallModulus to compare against - */ - SEAL_NODISCARD inline bool operator ==( - const SmallModulus &compare) const noexcept - { - return value_ == compare.value_; - } - - /** - Compares a SmallModulus value to an unsigned integer. - - @param[in] compare The unsigned integer to compare against - */ - SEAL_NODISCARD inline bool operator ==( - std::uint64_t compare) const noexcept - { - return value_ == compare; - } - - /** - Compares two SmallModulus instances. - - @param[in] compare The SmallModulus to compare against - */ - SEAL_NODISCARD inline bool operator !=( - const SmallModulus &compare) const noexcept - { - return !(value_ == compare.value_); - } - - /** - Compares a SmallModulus value to an unsigned integer. - - @param[in] compare The unsigned integer to compare against - */ - SEAL_NODISCARD inline bool operator !=( - std::uint64_t compare) const noexcept - { - return value_ != compare; - } - - /** - Compares two SmallModulus instances. - - @param[in] compare The SmallModulus to compare against - */ - SEAL_NODISCARD inline bool operator <( - const SmallModulus &compare) const noexcept - { - return value_ < compare.value_; - } - - /** - Compares a SmallModulus value to an unsigned integer. - - @param[in] compare The unsigned integer to compare against - */ - SEAL_NODISCARD inline bool operator <( - std::uint64_t compare) const noexcept - { - return value_ < compare; - } - - /** - Compares two SmallModulus instances. - - @param[in] compare The SmallModulus to compare against - */ - SEAL_NODISCARD inline bool operator <=( - const SmallModulus &compare) const noexcept - { - return value_ <= compare.value_; - } - - /** - Compares a SmallModulus value to an unsigned integer. - - @param[in] compare The unsigned integer to compare against - */ - SEAL_NODISCARD inline bool operator <=( - std::uint64_t compare) const noexcept - { - return value_ <= compare; - } - - /** - Compares two SmallModulus instances. - - @param[in] compare The SmallModulus to compare against - */ - SEAL_NODISCARD inline bool operator >( - const SmallModulus &compare) const noexcept - { - return value_ > compare.value_; - } - - /** - Compares a SmallModulus value to an unsigned integer. - - @param[in] compare The unsigned integer to compare against - */ - SEAL_NODISCARD inline bool operator >( - std::uint64_t compare) const noexcept - { - return value_ > compare; - } - - /** - Compares two SmallModulus instances. - - @param[in] compare The SmallModulus to compare against - */ - SEAL_NODISCARD inline bool operator >=( - const SmallModulus &compare) const noexcept - { - return value_ >= compare.value_; - } - - /** - Compares a SmallModulus value to an unsigned integer. - - @param[in] compare The unsigned integer to compare against - */ - SEAL_NODISCARD inline bool operator >=( - std::uint64_t compare) const noexcept - { - return value_ >= compare; - } - - /** - Saves the SmallModulus to an output stream. The full state of the modulus is - serialized. The output is in binary format and not human-readable. The output - stream must have the "binary" flag set. - - @param[in] stream The stream to save the SmallModulus to - @throws std::exception if the SmallModulus could not be written to stream - */ - void save(std::ostream &stream) const; - - /** - Loads a SmallModulus from an input stream overwriting the current SmallModulus. - - @param[in] stream The stream to load the SmallModulus from - @throws std::exception if a valid SmallModulus could not be read from stream - */ - void load(std::istream &stream); - - /** - Returns in decreasing order a vector of the largest prime numbers of a given - length that all support NTTs of a given size. More precisely, the generated - primes are all congruent to 1 modulo 2 * ntt_size. Typically, the user might - call this function by passing poly_modulus_degree as ntt_size if the primes - are to be used as a coefficient modulus primes for encryption parameters. - - @param[in] bit_size the bit-size of primes to be generated, no less than 2 and - no larger than 62 - @param[in] count The total number of primes to be generated - @param[in] ntt_size The size of NTT that should be supported - @throws std::invalid_argument if bit_size is less than 2 - @throws std::invalid_argument if count or ntt_size is zero - @throws std::logic_error if enough qualifying primes cannot be found - */ - //static std::vector GetPrimes(int bit_size, std::size_t count, - // std::size_t ntt_size); - - //static std::vector BuildCoeffModulus( - // std::size_t poly_modulus_degree, std::vector bit_sizes); - - private: - void set_value(std::uint64_t value); - - std::uint64_t value_ = 0; - - std::array const_ratio_{ { 0, 0, 0 } }; - - std::size_t uint64_count_ = 0; - - int bit_count_ = 0; - - bool is_prime_ = false; - }; -} diff --git a/SEAL/native/src/seal/util/CMakeLists.txt b/SEAL/native/src/seal/util/CMakeLists.txt deleted file mode 100644 index da4c1f0..0000000 --- a/SEAL/native/src/seal/util/CMakeLists.txt +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -target_sources(seal - PRIVATE - ${CMAKE_CURRENT_LIST_DIR}/aes.cpp - ${CMAKE_CURRENT_LIST_DIR}/baseconverter.cpp - ${CMAKE_CURRENT_LIST_DIR}/clipnormal.cpp - ${CMAKE_CURRENT_LIST_DIR}/globals.cpp - ${CMAKE_CURRENT_LIST_DIR}/hash.cpp - ${CMAKE_CURRENT_LIST_DIR}/mempool.cpp - ${CMAKE_CURRENT_LIST_DIR}/numth.cpp - ${CMAKE_CURRENT_LIST_DIR}/polyarith.cpp - ${CMAKE_CURRENT_LIST_DIR}/polyarithmod.cpp - ${CMAKE_CURRENT_LIST_DIR}/polyarithsmallmod.cpp - ${CMAKE_CURRENT_LIST_DIR}/rlwe.cpp - ${CMAKE_CURRENT_LIST_DIR}/smallntt.cpp - ${CMAKE_CURRENT_LIST_DIR}/uintarith.cpp - ${CMAKE_CURRENT_LIST_DIR}/uintarithmod.cpp - ${CMAKE_CURRENT_LIST_DIR}/uintarithsmallmod.cpp - ${CMAKE_CURRENT_LIST_DIR}/uintcore.cpp -) - -# Create the config file -configure_file(${CMAKE_CURRENT_LIST_DIR}/config.h.in ${CMAKE_CURRENT_LIST_DIR}/config.h) - -install( - FILES - ${CMAKE_CURRENT_LIST_DIR}/aes.h - ${CMAKE_CURRENT_LIST_DIR}/baseconverter.h - ${CMAKE_CURRENT_LIST_DIR}/clang.h - ${CMAKE_CURRENT_LIST_DIR}/clipnormal.h - ${CMAKE_CURRENT_LIST_DIR}/common.h - ${CMAKE_CURRENT_LIST_DIR}/config.h - ${CMAKE_CURRENT_LIST_DIR}/defines.h - ${CMAKE_CURRENT_LIST_DIR}/gcc.h - ${CMAKE_CURRENT_LIST_DIR}/globals.h - ${CMAKE_CURRENT_LIST_DIR}/hash.h - ${CMAKE_CURRENT_LIST_DIR}/hestdparms.h - ${CMAKE_CURRENT_LIST_DIR}/locks.h - ${CMAKE_CURRENT_LIST_DIR}/mempool.h - ${CMAKE_CURRENT_LIST_DIR}/msvc.h - ${CMAKE_CURRENT_LIST_DIR}/numth.h - ${CMAKE_CURRENT_LIST_DIR}/pointer.h - ${CMAKE_CURRENT_LIST_DIR}/polyarith.h - ${CMAKE_CURRENT_LIST_DIR}/polyarithmod.h - ${CMAKE_CURRENT_LIST_DIR}/polyarithsmallmod.h - ${CMAKE_CURRENT_LIST_DIR}/polycore.h - ${CMAKE_CURRENT_LIST_DIR}/rlwe.h - ${CMAKE_CURRENT_LIST_DIR}/smallntt.h - ${CMAKE_CURRENT_LIST_DIR}/uintarith.h - ${CMAKE_CURRENT_LIST_DIR}/uintarithmod.h - ${CMAKE_CURRENT_LIST_DIR}/uintarithsmallmod.h - ${CMAKE_CURRENT_LIST_DIR}/uintcore.h - DESTINATION - ${SEAL_INCLUDES_INSTALL_DIR}/seal/util -) diff --git a/SEAL/native/src/seal/util/aes.cpp b/SEAL/native/src/seal/util/aes.cpp deleted file mode 100644 index c654ed2..0000000 --- a/SEAL/native/src/seal/util/aes.cpp +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/util/aes.h" - -#ifdef SEAL_USE_AES_NI_PRNG - -namespace seal -{ - namespace - { - __m128i keygen_helper(__m128i key, __m128i key_rcon) - { - key_rcon = _mm_shuffle_epi32(key_rcon, _MM_SHUFFLE(3, 3, 3, 3)); - key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); - key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); - key = _mm_xor_si128(key, _mm_slli_si128(key, 4)); - return _mm_xor_si128(key, key_rcon); - } - } - - void AESEncryptor::set_key(const aes_block &key) - { - round_key_[0] = key.i128; - round_key_[1] = keygen_helper(round_key_[0], _mm_aeskeygenassist_si128(round_key_[0], 0x01)); - round_key_[2] = keygen_helper(round_key_[1], _mm_aeskeygenassist_si128(round_key_[1], 0x02)); - round_key_[3] = keygen_helper(round_key_[2], _mm_aeskeygenassist_si128(round_key_[2], 0x04)); - round_key_[4] = keygen_helper(round_key_[3], _mm_aeskeygenassist_si128(round_key_[3], 0x08)); - round_key_[5] = keygen_helper(round_key_[4], _mm_aeskeygenassist_si128(round_key_[4], 0x10)); - round_key_[6] = keygen_helper(round_key_[5], _mm_aeskeygenassist_si128(round_key_[5], 0x20)); - round_key_[7] = keygen_helper(round_key_[6], _mm_aeskeygenassist_si128(round_key_[6], 0x40)); - round_key_[8] = keygen_helper(round_key_[7], _mm_aeskeygenassist_si128(round_key_[7], 0x80)); - round_key_[9] = keygen_helper(round_key_[8], _mm_aeskeygenassist_si128(round_key_[8], 0x1B)); - round_key_[10] = keygen_helper(round_key_[9], _mm_aeskeygenassist_si128(round_key_[9], 0x36)); - } - - void AESEncryptor::ecb_encrypt(const aes_block &plaintext, aes_block &ciphertext) const - { - ciphertext.i128 = _mm_xor_si128(plaintext.i128, round_key_[0]); - ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[1]); - ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[2]); - ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[3]); - ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[4]); - ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[5]); - ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[6]); - ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[7]); - ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[8]); - ciphertext.i128 = _mm_aesenc_si128(ciphertext.i128, round_key_[9]); - ciphertext.i128 = _mm_aesenclast_si128(ciphertext.i128, round_key_[10]); - } - - void AESEncryptor::ecb_encrypt(const aes_block *plaintext, - size_t aes_block_count, aes_block *ciphertext) const - { - for (; aes_block_count--; ciphertext++, plaintext++) - { - ciphertext->i128 = _mm_xor_si128(plaintext->i128, round_key_[0]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[1]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[2]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[3]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[4]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[5]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[6]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[7]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[8]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[9]); - ciphertext->i128 = _mm_aesenclast_si128(ciphertext->i128, round_key_[10]); - } - } - - void AESEncryptor::counter_encrypt(size_t start_index, - size_t aes_block_count, aes_block *ciphertext) const - { - for (; aes_block_count--; start_index++, ciphertext++) - { - ciphertext->i128 = _mm_xor_si128( - _mm_set_epi64x(0, static_cast(start_index)), round_key_[0]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[1]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[2]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[3]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[4]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[5]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[6]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[7]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[8]); - ciphertext->i128 = _mm_aesenc_si128(ciphertext->i128, round_key_[9]); - ciphertext->i128 = _mm_aesenclast_si128(ciphertext->i128, round_key_[10]); - } - } - - AESDecryptor::AESDecryptor(const aes_block &key) - { - set_key(key); - } - - void AESDecryptor::set_key(const aes_block &key) - { - const __m128i &v0 = key.i128; - const __m128i v1 = keygen_helper(v0, _mm_aeskeygenassist_si128(v0, 0x01)); - const __m128i v2 = keygen_helper(v1, _mm_aeskeygenassist_si128(v1, 0x02)); - const __m128i v3 = keygen_helper(v2, _mm_aeskeygenassist_si128(v2, 0x04)); - const __m128i v4 = keygen_helper(v3, _mm_aeskeygenassist_si128(v3, 0x08)); - const __m128i v5 = keygen_helper(v4, _mm_aeskeygenassist_si128(v4, 0x10)); - const __m128i v6 = keygen_helper(v5, _mm_aeskeygenassist_si128(v5, 0x20)); - const __m128i v7 = keygen_helper(v6, _mm_aeskeygenassist_si128(v6, 0x40)); - const __m128i v8 = keygen_helper(v7, _mm_aeskeygenassist_si128(v7, 0x80)); - const __m128i v9 = keygen_helper(v8, _mm_aeskeygenassist_si128(v8, 0x1B)); - const __m128i v10 = keygen_helper(v9, _mm_aeskeygenassist_si128(v9, 0x36)); - - _mm_storeu_si128(round_key_, v10); - _mm_storeu_si128(round_key_ + 1, _mm_aesimc_si128(v9)); - _mm_storeu_si128(round_key_ + 2, _mm_aesimc_si128(v8)); - _mm_storeu_si128(round_key_ + 3, _mm_aesimc_si128(v7)); - _mm_storeu_si128(round_key_ + 4, _mm_aesimc_si128(v6)); - _mm_storeu_si128(round_key_ + 5, _mm_aesimc_si128(v5)); - _mm_storeu_si128(round_key_ + 6, _mm_aesimc_si128(v4)); - _mm_storeu_si128(round_key_ + 7, _mm_aesimc_si128(v3)); - _mm_storeu_si128(round_key_ + 8, _mm_aesimc_si128(v2)); - _mm_storeu_si128(round_key_ + 9, _mm_aesimc_si128(v1)); - _mm_storeu_si128(round_key_ + 10, v0); - } - - void AESDecryptor::ecb_decrypt(const aes_block &ciphertext, aes_block &plaintext) - { - plaintext.i128 = _mm_xor_si128(ciphertext.i128, round_key_[0]); - plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[1]); - plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[2]); - plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[3]); - plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[4]); - plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[5]); - plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[6]); - plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[7]); - plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[8]); - plaintext.i128 = _mm_aesdec_si128(plaintext.i128, round_key_[9]); - plaintext.i128 = _mm_aesdeclast_si128(plaintext.i128, round_key_[10]); - } -} - -#endif diff --git a/SEAL/native/src/seal/util/aes.h b/SEAL/native/src/seal/util/aes.h deleted file mode 100644 index e087824..0000000 --- a/SEAL/native/src/seal/util/aes.h +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include "seal/util/defines.h" - -#ifdef SEAL_USE_AES_NI_PRNG - -#include -#include -#include - -namespace seal -{ - union aes_block - { - std::uint32_t u32[4]; - std::uint64_t u64[2]; - __m128i i128; - }; - - class AESEncryptor - { - public: - AESEncryptor() = default; - - AESEncryptor(const aes_block &key) - { - set_key(key); - } - - AESEncryptor(std::uint64_t key_lw, std::uint64_t key_hw) - { - aes_block key; - key.u64[0] = key_lw; - key.u64[1] = key_hw; - set_key(key); - } - - void set_key(const aes_block &key); - - void ecb_encrypt(const aes_block &plaintext, aes_block &ciphertext) const; - - SEAL_NODISCARD inline aes_block ecb_encrypt(const aes_block &plaintext) const - { - aes_block ret; - ecb_encrypt(plaintext, ret); - return ret; - } - - // ECB mode encryption - void ecb_encrypt(const aes_block *plaintext, - std::size_t aes_block_count, aes_block *ciphertext) const; - - // Counter Mode encryption: encrypts the counter - void counter_encrypt(std::size_t start_index, - std::size_t aes_block_count, aes_block *ciphertext) const; - - private: - __m128i round_key_[11]; - }; - - class AESDecryptor - { - public: - AESDecryptor() = default; - - AESDecryptor(const aes_block &key); - - void set_key(const aes_block &key); - - void ecb_decrypt(const aes_block &ciphertext, aes_block &plaintext); - - SEAL_NODISCARD inline aes_block ecb_decrypt(const aes_block &ciphertext) - { - aes_block ret; - ecb_decrypt(ciphertext, ret); - return ret; - } - - private: - __m128i round_key_[11]; - }; -} - -#endif diff --git a/SEAL/native/src/seal/util/baseconverter.cpp b/SEAL/native/src/seal/util/baseconverter.cpp deleted file mode 100644 index e740ee3..0000000 --- a/SEAL/native/src/seal/util/baseconverter.cpp +++ /dev/null @@ -1,1153 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include -#include -#include "seal/util/defines.h" -#include "seal/util/pointer.h" -#include "seal/util/uintcore.h" -#include "seal/util/polycore.h" -#include "seal/util/baseconverter.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithsmallmod.h" -#include "seal/util/uintarithmod.h" -#include "seal/util/polyarithsmallmod.h" -#include "seal/util/smallntt.h" -#include "seal/util/globals.h" -#include "seal/smallmodulus.h" - -using namespace std; - -namespace seal -{ - namespace util - { - BaseConverter::BaseConverter(const std::vector &coeff_base, - size_t coeff_count, const SmallModulus &small_plain_mod, - MemoryPoolHandle pool) : pool_(move(pool)) - { -#ifdef SEAL_DEBUG - if (!pool) - { - throw std::invalid_argument("pool is uninitialized"); - } -#endif - generate(coeff_base, coeff_count, small_plain_mod); - } - - void BaseConverter::generate(const std::vector &coeff_base, - size_t coeff_count, const SmallModulus &small_plain_mod) - { -#ifdef SEAL_DEBUG - if (get_power_of_two(coeff_count) < 0) - { - throw invalid_argument("coeff_count must be a power of 2"); - } - if (coeff_base.size() < SEAL_COEFF_MOD_COUNT_MIN || - coeff_base.size() > SEAL_COEFF_MOD_COUNT_MAX) - { - throw invalid_argument("coeff_base has invalid size"); - } -#endif - int coeff_count_power = get_power_of_two(coeff_count); - - /** - Perform all the required pre-computations and populate the tables - */ - reset(); - - m_sk_ = global_variables::internal_mods::m_sk; - m_tilde_ = global_variables::internal_mods::m_tilde; - gamma_ = global_variables::internal_mods::gamma; - small_plain_mod_ = small_plain_mod; - coeff_count_ = coeff_count; - coeff_base_mod_count_ = coeff_base.size(); - aux_base_mod_count_ = coeff_base.size(); - - // In some cases we might need to increase the size of the aux base by one, namely - // we require K * n * t * q^2 < q * prod_i m_i * m_sk, where K takes into account - // cross terms when larger size ciphertexts are used, and n is the "delta factor" - // for the ring. We reserve 32 bits for K * n. Here the coeff modulus primes q_i - // are bounded to be 60 bits, and all m_i, m_sk are 61 bits. - int total_coeff_bit_count = accumulate(coeff_base.cbegin(), coeff_base.cend(), 0, - [](int result, auto &mod) { return result + mod.bit_count(); }); - - if (32 + small_plain_mod_.bit_count() + total_coeff_bit_count >= - 61 * safe_cast(coeff_base_mod_count_) + 61) - { - aux_base_mod_count_++; - } - - // Base sizes - bsk_base_mod_count_ = aux_base_mod_count_ + 1; - plain_gamma_count_ = 2; - - // Size check; should always pass - if (!product_fits_in(coeff_count_, coeff_base_mod_count_)) - { - throw logic_error("invalid parameters"); - } - if (!product_fits_in(coeff_count_, aux_base_mod_count_)) - { - throw logic_error("invalid parameters"); - } - if (!product_fits_in(coeff_count_, bsk_base_mod_count_)) - { - throw logic_error("invalid parameters"); - } - - // We use a reversed order here for performance reasons - coeff_base_products_mod_aux_bsk_array_ = - allocate>(bsk_base_mod_count_, pool_); - generate_n( - coeff_base_products_mod_aux_bsk_array_.get(), - bsk_base_mod_count_, - [&]() { return allocate_uint(coeff_base_mod_count_, pool_); }); - - // We use a reversed order here for performance reasons - aux_base_products_mod_coeff_array_ = - allocate>(coeff_base_mod_count_, pool_); - generate_n( - aux_base_products_mod_coeff_array_.get(), - coeff_base_mod_count_, - [&]() { return allocate_uint(aux_base_mod_count_, pool_); }); - - coeff_products_mod_plain_gamma_array_ = - allocate>(plain_gamma_count_, pool_); - generate_n( - coeff_products_mod_plain_gamma_array_.get(), - plain_gamma_count_, - [&]() { return allocate_uint(coeff_base_mod_count_, pool_); }); - - // Create moduli arrays - coeff_base_array_ = allocate(coeff_base_mod_count_, pool_); - aux_base_array_ = allocate(aux_base_mod_count_, pool_); - bsk_base_array_ = allocate(bsk_base_mod_count_, pool_); - - copy(coeff_base.cbegin(), coeff_base.cend(), coeff_base_array_.get()); - copy_n(global_variables::internal_mods::aux_small_mods.cbegin(), - aux_base_mod_count_, aux_base_array_.get()); - copy_n(aux_base_array_.get(), aux_base_mod_count_, bsk_base_array_.get()); - bsk_base_array_[bsk_base_mod_count_ - 1] = m_sk_; - - // Generate Bsk U {mtilde} small ntt tables which is used in Evaluator - bsk_small_ntt_tables_ = allocate(bsk_base_mod_count_, pool_); - for (size_t i = 0; i < bsk_base_mod_count_; i++) - { - if (!bsk_small_ntt_tables_[i].generate(coeff_count_power, bsk_base_array_[i])) - { - reset(); - return; - } - } - - size_t coeff_products_uint64_count = coeff_base_mod_count_; - size_t aux_products_uint64_count = aux_base_mod_count_; - - // Generate punctured products of coeff moduli - coeff_products_array_ = allocate_zero_uint( - coeff_products_uint64_count * coeff_base_mod_count_, pool_); - auto tmp_coeff(allocate_uint(coeff_products_uint64_count, pool_)); - - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - coeff_products_array_[i * coeff_products_uint64_count] = 1; - for (size_t j = 0; j < coeff_base_mod_count_; j++) - { - if (i != j) - { - multiply_uint_uint64(coeff_products_array_.get() + - (i * coeff_products_uint64_count), coeff_products_uint64_count, - coeff_base_array_[j].value(), coeff_products_uint64_count, - tmp_coeff.get()); - set_uint_uint(tmp_coeff.get(), coeff_products_uint64_count, - coeff_products_array_.get() + (i * coeff_products_uint64_count)); - } - } - } - - // Generate punctured products of aux moduli - auto aux_products_array(allocate_zero_uint( - aux_products_uint64_count * aux_base_mod_count_, pool_)); - auto tmp_aux(allocate_uint(aux_products_uint64_count, pool_)); - - for (size_t i = 0; i < aux_base_mod_count_; i++) - { - aux_products_array[i * aux_products_uint64_count] = 1; - for (size_t j = 0; j < aux_base_mod_count_; j++) - { - if (i != j) - { - multiply_uint_uint64(aux_products_array.get() + - (i * aux_products_uint64_count), aux_products_uint64_count, - aux_base_array_[j].value(), aux_products_uint64_count, - tmp_aux.get()); - set_uint_uint(tmp_aux.get(), aux_products_uint64_count, - aux_products_array.get() + (i * aux_products_uint64_count)); - } - } - } - - // Compute auxiliary base products mod m_sk - aux_base_products_mod_msk_array_ = allocate_uint(aux_base_mod_count_, pool_); - for (size_t i = 0; i < aux_base_mod_count_; i++) - { - aux_base_products_mod_msk_array_[i] = - modulo_uint(aux_products_array.get() + (i * aux_products_uint64_count), - aux_products_uint64_count, m_sk_, pool_); - } - - // Compute inverse coeff base mod coeff base array (qi^(-1)) mod qi and - // mtilde inv coeff products mod auxiliary moduli (m_tilda*qi^(-1)) mod qi - inv_coeff_base_products_mod_coeff_array_ = - allocate_uint(coeff_base_mod_count_, pool_); - mtilde_inv_coeff_base_products_mod_coeff_array_ = - allocate_uint(coeff_base_mod_count_, pool_); - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - inv_coeff_base_products_mod_coeff_array_[i] = - modulo_uint(coeff_products_array_.get() + (i * coeff_products_uint64_count), - coeff_products_uint64_count, coeff_base_array_[i], pool_); - if (!try_invert_uint_mod(inv_coeff_base_products_mod_coeff_array_[i], - coeff_base_array_[i], inv_coeff_base_products_mod_coeff_array_[i])) - { - reset(); - return; - } - mtilde_inv_coeff_base_products_mod_coeff_array_[i] = - multiply_uint_uint_mod(inv_coeff_base_products_mod_coeff_array_[i], - m_tilde_.value(), coeff_base_array_[i]); - } - - // Compute inverse auxiliary moduli mod auxiliary moduli (mi^(-1)) mod mi - inv_aux_base_products_mod_aux_array_ = allocate_uint(aux_base_mod_count_, pool_); - for (size_t i = 0; i < aux_base_mod_count_; i++) - { - inv_aux_base_products_mod_aux_array_[i] = - modulo_uint(aux_products_array.get() + (i * aux_products_uint64_count), - aux_products_uint64_count, aux_base_array_[i], pool_); - if (!try_invert_uint_mod(inv_aux_base_products_mod_aux_array_[i], - aux_base_array_[i], inv_aux_base_products_mod_aux_array_[i])) - { - reset(); - return; - } - } - - // Compute coeff modulus products mod mtilde (qi) mod m_tilde_ - coeff_base_products_mod_mtilde_array_ = allocate_uint(coeff_base_mod_count_, pool_); - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - coeff_base_products_mod_mtilde_array_[i] = - modulo_uint(coeff_products_array_.get() + (i * coeff_products_uint64_count), - coeff_products_uint64_count, m_tilde_, pool_); - } - - // Compute coeff modulus products mod auxiliary moduli (qi) mod mj U {msk} - coeff_base_products_mod_aux_bsk_array_ = - allocate>(bsk_base_mod_count_, pool_); - for (size_t i = 0; i < aux_base_mod_count_; i++) - { - coeff_base_products_mod_aux_bsk_array_[i] = - allocate_uint(coeff_base_mod_count_, pool_); - for (size_t j = 0; j < coeff_base_mod_count_; j++) - { - coeff_base_products_mod_aux_bsk_array_[i][j] = - modulo_uint(coeff_products_array_.get() + (j * coeff_products_uint64_count), - coeff_products_uint64_count, aux_base_array_[i], pool_); - } - } - - // Add qi mod msk at the end of the array - coeff_base_products_mod_aux_bsk_array_[bsk_base_mod_count_ - 1] = - allocate_uint(coeff_base_mod_count_, pool_); - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - coeff_base_products_mod_aux_bsk_array_[bsk_base_mod_count_ - 1][i] = - modulo_uint(coeff_products_array_.get() + (i * coeff_products_uint64_count), - coeff_products_uint64_count, m_sk_, pool_); - } - - // Compute auxiliary moduli products mod coeff moduli (mj) mod qi - aux_base_products_mod_coeff_array_ = - allocate>(coeff_base_mod_count_, pool_); - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - aux_base_products_mod_coeff_array_[i] = allocate_uint(aux_base_mod_count_, pool_); - for (size_t j = 0; j < aux_base_mod_count_; j++) - { - aux_base_products_mod_coeff_array_[i][j] = - modulo_uint(aux_products_array.get() + (j * aux_products_uint64_count), - aux_products_uint64_count, coeff_base_array_[i], pool_); - } - } - - // Compute coeff moduli products inverse mod auxiliary mods (qi^(-1)) mod mj U {msk} - auto coeff_products_all(allocate_uint(coeff_base_mod_count_, pool_)); - auto tmp_products_all(allocate_uint(coeff_base_mod_count_, pool_)); - set_uint(1, coeff_base_mod_count_, coeff_products_all.get()); - - // Compute the product of all coeff moduli - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - multiply_uint_uint64(coeff_products_all.get(), coeff_base_mod_count_, - coeff_base_array_[i].value(), coeff_base_mod_count_, tmp_products_all.get()); - set_uint_uint(tmp_products_all.get(), coeff_base_mod_count_, - coeff_products_all.get()); - } - - // Compute inverses of coeff_products_all modulo aux moduli - inv_coeff_products_all_mod_aux_bsk_array_ = allocate_uint(bsk_base_mod_count_, pool_); - for (size_t i = 0; i < aux_base_mod_count_; i++) - { - inv_coeff_products_all_mod_aux_bsk_array_[i] = modulo_uint(coeff_products_all.get(), - coeff_base_mod_count_, aux_base_array_[i], pool_); - if (!try_invert_uint_mod(inv_coeff_products_all_mod_aux_bsk_array_[i], - aux_base_array_[i], inv_coeff_products_all_mod_aux_bsk_array_[i])) - { - reset(); - return; - } - } - - // Add product of all coeffs mod msk at the end of the array - inv_coeff_products_all_mod_aux_bsk_array_[bsk_base_mod_count_ - 1] = - modulo_uint(coeff_products_all.get(), coeff_base_mod_count_, m_sk_, pool_); - if (!try_invert_uint_mod(inv_coeff_products_all_mod_aux_bsk_array_[bsk_base_mod_count_ - 1], - m_sk_, inv_coeff_products_all_mod_aux_bsk_array_[bsk_base_mod_count_ - 1])) - { - reset(); - return; - } - - // Compute the products of all aux moduli - auto aux_products_all(allocate_uint(aux_base_mod_count_, pool_)); - auto tmp_aux_products_all(allocate_uint(aux_base_mod_count_, pool_)); - set_uint(1, aux_base_mod_count_, aux_products_all.get()); - - for (size_t i = 0; i < aux_base_mod_count_; i++) - { - multiply_uint_uint64(aux_products_all.get(), aux_base_mod_count_, - aux_base_array_[i].value(), aux_base_mod_count_, tmp_aux_products_all.get()); - set_uint_uint(tmp_aux_products_all.get(), aux_base_mod_count_, - aux_products_all.get()); - } - - // Compute the auxiliary products inverse mod m_sk_ (M-1) mod m_sk_ - inv_aux_products_mod_msk_ = modulo_uint(aux_products_all.get(), - aux_base_mod_count_, m_sk_, pool_); - if (!try_invert_uint_mod(inv_aux_products_mod_msk_, m_sk_, - inv_aux_products_mod_msk_)) - { - reset(); - return; - } - - // Compute auxiliary products all mod coefficient moduli - aux_products_all_mod_coeff_array_ = allocate_uint(coeff_base_mod_count_, pool_); - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - aux_products_all_mod_coeff_array_[i] = modulo_uint(aux_products_all.get(), - aux_base_mod_count_, coeff_base_array_[i], pool_); - } - - // Compute m_tilde inverse mod bsk base - inv_mtilde_mod_bsk_array_ = allocate_uint(bsk_base_mod_count_, pool_); - for (size_t i = 0; i < aux_base_mod_count_; i++) - { - if (!try_invert_uint_mod(m_tilde_.value() % aux_base_array_[i].value(), - aux_base_array_[i], inv_mtilde_mod_bsk_array_[i])) - { - reset(); - return; - } - } - - // Add m_tilde inverse mod msk at the end of the array - if (!try_invert_uint_mod(m_tilde_.value() % m_sk_.value(), m_sk_, - inv_mtilde_mod_bsk_array_[bsk_base_mod_count_ - 1])) - { - reset(); - return; - } - - // Compute coeff moduli products inverse mod m_tilde - inv_coeff_products_mod_mtilde_ = modulo_uint(coeff_products_all.get(), - coeff_base_mod_count_, m_tilde_, pool_); - if (!try_invert_uint_mod(inv_coeff_products_mod_mtilde_, m_tilde_, - inv_coeff_products_mod_mtilde_)) - { - reset(); - return; - } - - // Compute coeff base products all mod Bsk - coeff_products_all_mod_bsk_array_ = allocate_uint(bsk_base_mod_count_, pool_); - for (size_t i = 0; i < aux_base_mod_count_; i++) - { - coeff_products_all_mod_bsk_array_[i] = - modulo_uint(coeff_products_all.get(), coeff_base_mod_count_, - aux_base_array_[i], pool_); - } - - // Add coeff base products all mod m_sk_ at the end of the array - coeff_products_all_mod_bsk_array_[bsk_base_mod_count_ - 1] = - modulo_uint(coeff_products_all.get(), coeff_base_mod_count_, m_sk_, pool_); - - // Compute inverses of last coeff base modulus modulo the first ones for - // modulus switching/rescaling. - inv_last_coeff_mod_array_ = allocate_uint(coeff_base_mod_count_ - 1, pool_); - for (size_t i = 0; i < coeff_base_mod_count_ - 1; i++) - { - if (!try_mod_inverse(coeff_base_array_[coeff_base_mod_count_ - 1].value(), - coeff_base_array_[i].value(), inv_last_coeff_mod_array_[i])) - { - reset(); - return; - } - } - - // Generate plain gamma array of small_plain_mod_ is set to non-zero. - // Otherwise assume we use CKKS and no plain_modulus is needed. - if (!small_plain_mod_.is_zero()) - { - plain_gamma_array_ = allocate(plain_gamma_count_, pool_); - plain_gamma_array_[0] = small_plain_mod_; - plain_gamma_array_[1] = gamma_; - - // Compute coeff moduli products mod plain gamma - coeff_products_mod_plain_gamma_array_ = - allocate>(plain_gamma_count_, pool_); - for (size_t i = 0; i < plain_gamma_count_; i++) - { - coeff_products_mod_plain_gamma_array_[i] = - allocate_uint(coeff_base_mod_count_, pool_); - for (size_t j = 0; j < coeff_base_mod_count_; j++) - { - coeff_products_mod_plain_gamma_array_[i][j] = - modulo_uint( - coeff_products_array_.get() + (j * coeff_products_uint64_count), - coeff_products_uint64_count, plain_gamma_array_[i], pool_ - ); - } - } - - // Compute inverse of all coeff moduli products mod plain gamma - neg_inv_coeff_products_all_mod_plain_gamma_array_ = - allocate_uint(plain_gamma_count_, pool_); - for (size_t i = 0; i < plain_gamma_count_; i++) - { - uint64_t temp = modulo_uint(coeff_products_all.get(), - coeff_base_mod_count_, plain_gamma_array_[i], pool_); - neg_inv_coeff_products_all_mod_plain_gamma_array_[i] = - negate_uint_mod(temp, plain_gamma_array_[i]); - if (!try_invert_uint_mod(neg_inv_coeff_products_all_mod_plain_gamma_array_[i], - plain_gamma_array_[i], neg_inv_coeff_products_all_mod_plain_gamma_array_[i])) - { - reset(); - return; - } - } - - // Compute inverse of gamma mod plain modulus - inv_gamma_mod_plain_ = modulo_uint(gamma_.data(), gamma_.uint64_count(), - small_plain_mod_, pool_); - if (!try_invert_uint_mod( - inv_gamma_mod_plain_, small_plain_mod_, inv_gamma_mod_plain_)) - { - reset(); - return; - } - - // Compute plain_gamma product mod coeff base moduli - plain_gamma_product_mod_coeff_array_ = - allocate_uint(coeff_base_mod_count_, pool_); - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - plain_gamma_product_mod_coeff_array_[i] = - multiply_uint_uint_mod(small_plain_mod_.value(), gamma_.value(), - coeff_base_array_[i]); - } - } - - // Everything went well - generated_ = true; - } - - void BaseConverter::reset() noexcept - { - generated_ = false; - coeff_base_array_.release(); - aux_base_array_.release(); - bsk_base_array_.release(); - plain_gamma_array_.release(); - coeff_products_array_.release(); - mtilde_inv_coeff_base_products_mod_coeff_array_.release(); - inv_aux_base_products_mod_aux_array_.release(); - inv_coeff_products_all_mod_aux_bsk_array_.release(); - inv_coeff_base_products_mod_coeff_array_.release(); - aux_base_products_mod_coeff_array_.release(); - coeff_base_products_mod_aux_bsk_array_.release(); - coeff_base_products_mod_mtilde_array_.release(); - aux_base_products_mod_msk_array_.release(); - aux_products_all_mod_coeff_array_.release(); - inv_mtilde_mod_bsk_array_.release(); - coeff_products_all_mod_bsk_array_.release(); - coeff_products_mod_plain_gamma_array_.release(); - neg_inv_coeff_products_all_mod_plain_gamma_array_.release(); - plain_gamma_product_mod_coeff_array_.release(); - bsk_small_ntt_tables_.release(); - inv_last_coeff_mod_array_.release(); - inv_coeff_products_mod_mtilde_ = 0; - m_tilde_ = 0; - m_sk_ = 0; - gamma_ = 0; - coeff_count_ = 0; - coeff_base_mod_count_ = 0; - aux_base_mod_count_ = 0; - plain_gamma_count_ = 0; - inv_gamma_mod_plain_ = 0; - } - - void BaseConverter::fastbconv(const uint64_t *input, - uint64_t *destination, MemoryPoolHandle pool) const - { -#ifdef SEAL_DEBUG - if (input == nullptr) - { - throw invalid_argument("input cannot be null"); - } - if (destination == nullptr) - { - throw invalid_argument("destination cannot be null"); - } - if (!pool) - { - throw invalid_argument("pool is not initialied"); - } - if (!generated_) - { - throw logic_error("BaseConverter is not generated"); - } -#endif - /** - Require: Input in q - Ensure: Output in Bsk = {m1,...,ml} U {msk} - */ - auto temp_coeff_transition(allocate_uint( - coeff_count_ * coeff_base_mod_count_, pool)); - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - uint64_t inv_coeff_base_products_mod_coeff_elt = - inv_coeff_base_products_mod_coeff_array_[i]; - SmallModulus coeff_base_array_elt = coeff_base_array_[i]; - for (size_t k = 0; k < coeff_count_; k++, input++) - { - temp_coeff_transition[i + (k * coeff_base_mod_count_)] = - multiply_uint_uint_mod( - *input, - inv_coeff_base_products_mod_coeff_elt, - coeff_base_array_elt - ); - } - } - - for (size_t j = 0; j < bsk_base_mod_count_; j++) - { - uint64_t *temp_coeff_transition_ptr = temp_coeff_transition.get(); - SmallModulus bsk_base_array_elt = bsk_base_array_[j]; - for (size_t k = 0; k < coeff_count_; k++, destination++) - { - const uint64_t *coeff_base_products_mod_aux_bsk_array_ptr = - coeff_base_products_mod_aux_bsk_array_[j].get(); - unsigned long long aux_transition[2]{ 0, 0 }; - for (size_t i = 0; i < coeff_base_mod_count_; - i++, temp_coeff_transition_ptr++, - coeff_base_products_mod_aux_bsk_array_ptr++) - { - // Lazy reduction - unsigned long long temp[2]; - - // Product is 60 bit + 61 bit = 121 bit, so can sum up to 127 of them with no reduction - // Thus need coeff_base_mod_count_ <= 127 to guarantee success - multiply_uint64(*temp_coeff_transition_ptr, - *coeff_base_products_mod_aux_bsk_array_ptr, temp); - unsigned char carry = add_uint64(aux_transition[0], - temp[0], aux_transition); - aux_transition[1] += temp[1] + carry; - } - *destination = barrett_reduce_128(aux_transition, bsk_base_array_elt); - } - } - } - - void BaseConverter::floor_last_coeff_modulus_inplace( - uint64_t *rns_poly, - MemoryPoolHandle pool) const - { - auto temp(allocate_uint(coeff_count_, pool)); - for (size_t i = 0; i < coeff_base_mod_count_ - 1; i++) - { - // (ct mod qk) mod qi - modulo_poly_coeffs_63( - rns_poly + (coeff_base_mod_count_ - 1) * coeff_count_, - coeff_count_, - coeff_base_array_[i], - temp.get()); - sub_poly_poly_coeffmod( - rns_poly + i * coeff_count_, - temp.get(), - coeff_count_, - coeff_base_array_[i], - rns_poly + i * coeff_count_); - // qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi - multiply_poly_scalar_coeffmod( - rns_poly + i * coeff_count_, - coeff_count_, - inv_last_coeff_mod_array_[i], - coeff_base_array_[i], - rns_poly + i * coeff_count_); - } - } - - void BaseConverter::floor_last_coeff_modulus_ntt_inplace( - std::uint64_t *rns_poly, - const Pointer &rns_ntt_tables, - MemoryPoolHandle pool) const - { - auto temp(allocate_uint(coeff_count_, pool)); - // Convert to non-NTT form - inverse_ntt_negacyclic_harvey( - rns_poly + (coeff_base_mod_count_ - 1) * coeff_count_, - rns_ntt_tables[coeff_base_mod_count_ - 1]); - for (size_t i = 0; i < coeff_base_mod_count_ - 1; i++) - { - // (ct mod qk) mod qi - modulo_poly_coeffs_63( - rns_poly + (coeff_base_mod_count_ - 1) * coeff_count_, - coeff_count_, - coeff_base_array_[i], - temp.get()); - // Convert to NTT form - ntt_negacyclic_harvey(temp.get(), rns_ntt_tables[i]); - // ((ct mod qi) - (ct mod qk)) mod qi - sub_poly_poly_coeffmod( - rns_poly + i * coeff_count_, - temp.get(), - coeff_count_, - coeff_base_array_[i], - rns_poly + i * coeff_count_); - // qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi - multiply_poly_scalar_coeffmod( - rns_poly + i * coeff_count_, - coeff_count_, - inv_last_coeff_mod_array_[i], - coeff_base_array_[i], - rns_poly + i * coeff_count_); - } - } - - void BaseConverter::round_last_coeff_modulus_inplace( - uint64_t *rns_poly, - MemoryPoolHandle pool) const - { - auto temp(allocate_uint(coeff_count_, pool)); - uint64_t *last_ptr = rns_poly + (coeff_base_mod_count_ - 1) * coeff_count_; - - // Add (p-1)/2 to change from flooring to rounding. - auto last_modulus = coeff_base_array_[coeff_base_mod_count_ - 1]; - uint64_t half = last_modulus.value() >> 1; - for (size_t j = 0; j < coeff_count_; j++) - { - last_ptr[j] = barrett_reduce_63(last_ptr[j] + half, last_modulus); - } - - for (size_t i = 0; i < coeff_base_mod_count_ - 1; i++) - { - // (ct mod qk) mod qi - modulo_poly_coeffs_63( - last_ptr, - coeff_count_, - coeff_base_array_[i], - temp.get()); - - uint64_t half_mod = barrett_reduce_63(half, coeff_base_array_[i]); - for (size_t j = 0; j < coeff_count_; j++) - { - temp.get()[j] = sub_uint_uint_mod(temp.get()[j], half_mod, coeff_base_array_[i]); - } - sub_poly_poly_coeffmod( - rns_poly + i * coeff_count_, - temp.get(), - coeff_count_, - coeff_base_array_[i], - rns_poly + i * coeff_count_); - // qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi - multiply_poly_scalar_coeffmod( - rns_poly + i * coeff_count_, - coeff_count_, - inv_last_coeff_mod_array_[i], - coeff_base_array_[i], - rns_poly + i * coeff_count_); - } - } - - void BaseConverter::round_last_coeff_modulus_ntt_inplace( - std::uint64_t *rns_poly, - const Pointer &rns_ntt_tables, - MemoryPoolHandle pool) const - { - auto temp(allocate_uint(coeff_count_, pool)); - uint64_t *last_ptr = rns_poly + (coeff_base_mod_count_ - 1) * coeff_count_; - // Convert to non-NTT form - inverse_ntt_negacyclic_harvey( - last_ptr, - rns_ntt_tables[coeff_base_mod_count_ - 1]); - - // Add (p-1)/2 to change from flooring to rounding. - auto last_modulus = coeff_base_array_[coeff_base_mod_count_ - 1]; - uint64_t half = last_modulus.value() >> 1; - for (size_t j = 0; j < coeff_count_; j++) - { - last_ptr[j] = barrett_reduce_63(last_ptr[j] + half, last_modulus); - } - - for (size_t i = 0; i < coeff_base_mod_count_ - 1; i++) - { - // (ct mod qk) mod qi - modulo_poly_coeffs_63( - last_ptr, - coeff_count_, - coeff_base_array_[i], - temp.get()); - - uint64_t half_mod = barrett_reduce_63(half, coeff_base_array_[i]); - for (size_t j = 0; j < coeff_count_; j++) { - temp.get()[j] = sub_uint_uint_mod(temp.get()[j], half_mod, coeff_base_array_[i]); - } - // Convert to NTT form - ntt_negacyclic_harvey(temp.get(), rns_ntt_tables[i]); - // ((ct mod qi) - (ct mod qk)) mod qi - sub_poly_poly_coeffmod( - rns_poly + i * coeff_count_, - temp.get(), - coeff_count_, - coeff_base_array_[i], - rns_poly + i * coeff_count_); - // qk^(-1) * ((ct mod qi) - (ct mod qk)) mod qi - multiply_poly_scalar_coeffmod( - rns_poly + i * coeff_count_, - coeff_count_, - inv_last_coeff_mod_array_[i], - coeff_base_array_[i], - rns_poly + i * coeff_count_); - } - } - - void BaseConverter::fastbconv_sk(const uint64_t *input, - uint64_t *destination, MemoryPoolHandle pool) const - { -#ifdef SEAL_DEBUG - if (input == nullptr) - { - throw invalid_argument("input cannot be null"); - } - if (destination == nullptr) - { - throw invalid_argument("destination cannot be null"); - } - if (!pool) - { - throw invalid_argument("pool is not initialied"); - } -#endif - /** - Require: Input in base Bsk = M U {msk} - Ensure: Output in base q - */ - - // Fast convert B -> q - auto temp_coeff_transition(allocate_uint( - coeff_count_ * aux_base_mod_count_, pool)); - const uint64_t *input_ptr = input; - for (size_t i = 0; i < aux_base_mod_count_; i++) - { - uint64_t inv_aux_base_products_mod_aux_array_elt = - inv_aux_base_products_mod_aux_array_[i]; - SmallModulus aux_base_array_elt = aux_base_array_[i]; - for (size_t k = 0; k < coeff_count_; k++) - { - temp_coeff_transition[i + (k * aux_base_mod_count_)] = - multiply_uint_uint_mod( - *input_ptr++, - inv_aux_base_products_mod_aux_array_elt, - aux_base_array_elt - ); - } - } - - uint64_t *destination_ptr = destination; - uint64_t *temp_ptr; - for (size_t j = 0; j < coeff_base_mod_count_; j++) - { - temp_ptr = temp_coeff_transition.get(); - SmallModulus coeff_base_array_elt = coeff_base_array_[j]; - for (size_t k = 0; k < coeff_count_; k++, destination_ptr++) - { - const uint64_t *aux_base_products_mod_coeff_array_ptr = - aux_base_products_mod_coeff_array_[j].get(); - unsigned long long aux_transition[2]{ 0, 0 }; - for (size_t i = 0; i < aux_base_mod_count_; i++, temp_ptr++, - aux_base_products_mod_coeff_array_ptr++) - { - // Lazy reduction - unsigned long long temp[2]; - - // Product is 61 bit + 60 bit = 121 bit, so can sum up to 127 of them with no reduction - // Thus need aux_base_mod_count_ <= 127, so coeff_base_mod_count_ <= 126 to guarantee success - multiply_uint64(*temp_ptr, *aux_base_products_mod_coeff_array_ptr, temp); - unsigned char carry = add_uint64(aux_transition[0], temp[0], aux_transition); - aux_transition[1] += temp[1] + carry; - } - *destination_ptr = barrett_reduce_128(aux_transition, coeff_base_array_elt); - } - } - - // Compute alpha_sk - // Require: Input is in Bsk - // we only use coefficient in B - // Fast convert B -> m_sk - auto tmp(allocate_uint(coeff_count_, pool)); - destination_ptr = tmp.get(); - temp_ptr = temp_coeff_transition.get(); - for (size_t k = 0; k < coeff_count_; k++, destination_ptr++) - { - unsigned long long msk_transition[2]{ 0, 0 }; - const uint64_t *aux_base_products_mod_msk_array_ptr = - aux_base_products_mod_msk_array_.get(); - for (size_t i = 0; i < aux_base_mod_count_; i++, temp_ptr++, - aux_base_products_mod_msk_array_ptr++) - { - // Lazy reduction - unsigned long long temp[2]; - - // Product is 61 bit + 61 bit = 122 bit, so can sum up to 63 of them with no reduction - // Thus need aux_base_mod_count_ <= 63, so coeff_base_mod_count_ <= 62 to guarantee success - // This gives the strongest restriction on the number of coeff modulus primes - multiply_uint64(*temp_ptr, *aux_base_products_mod_msk_array_ptr, temp); - unsigned char carry = add_uint64(msk_transition[0], temp[0], msk_transition); - msk_transition[1] += temp[1] + carry; - } - *destination_ptr = barrett_reduce_128(msk_transition, m_sk_); - } - - auto alpha_sk(allocate_uint(coeff_count_, pool)); - input_ptr = input + (aux_base_mod_count_ * coeff_count_); - destination_ptr = alpha_sk.get(); - temp_ptr = tmp.get(); - const uint64_t m_sk_value = m_sk_.value(); - // x_sk is allocated in input[aux_base_mod_count_] - for (size_t i = 0; i < coeff_count_; i++, input_ptr++, temp_ptr++, destination_ptr++) - { - // It is not necessary for the negation to be reduced modulo the small prime - uint64_t negated_input = m_sk_value - *input_ptr; - *destination_ptr = multiply_uint_uint_mod(*temp_ptr + negated_input, - inv_aux_products_mod_msk_, m_sk_); - } - - const uint64_t m_sk_div_2 = m_sk_value >> 1; - destination_ptr = destination; - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - uint64_t aux_products_all_mod_coeff_array_elt = - aux_products_all_mod_coeff_array_[i]; - temp_ptr = alpha_sk.get(); - SmallModulus coeff_base_array_elt = coeff_base_array_[i]; - uint64_t coeff_base_array_elt_value = coeff_base_array_elt.value(); - for (size_t k = 0; k < coeff_count_; k++, temp_ptr++, destination_ptr++) - { - unsigned long long m_alpha_sk[2]; - - // Correcting alpha_sk since it is a centered modulo - if (*temp_ptr > m_sk_div_2) - { - // Lazy reduction - multiply_uint64(aux_products_all_mod_coeff_array_elt, - m_sk_value - *temp_ptr, m_alpha_sk); - m_alpha_sk[1] += add_uint64(m_alpha_sk[0], *destination_ptr, m_alpha_sk); - *destination_ptr = barrett_reduce_128(m_alpha_sk, coeff_base_array_elt); - } - // No correction needed - else - { - // Lazy reduction - // It is not necessary for the negation to be reduced modulo the small prime - multiply_uint64( - coeff_base_array_elt_value - aux_products_all_mod_coeff_array_elt, - *temp_ptr, m_alpha_sk - ); - m_alpha_sk[1] += add_uint64(*destination_ptr, - m_alpha_sk[0], m_alpha_sk); - *destination_ptr = barrett_reduce_128(m_alpha_sk, coeff_base_array_elt); - } - } - } - } - - void BaseConverter::mont_rq(const uint64_t *input, uint64_t *destination) const - { -#ifdef SEAL_DEBUG - if (input == nullptr) - { - throw invalid_argument("input cannot be null"); - } - if (destination == nullptr) - { - throw invalid_argument("destination cannot be null"); - } -#endif - /** - Require: Input should in Bsk U {m_tilde} - Ensure: Destination array in Bsk = m U {msk} - */ - const uint64_t *input_m_tilde_ptr = - input + coeff_count_ * bsk_base_mod_count_; - for (size_t k = 0; k < bsk_base_mod_count_; k++) - { - uint64_t coeff_products_all_mod_bsk_array_elt = - coeff_products_all_mod_bsk_array_[k]; - uint64_t inv_mtilde_mod_bsk_array_elt = inv_mtilde_mod_bsk_array_[k]; - SmallModulus bsk_base_array_elt = bsk_base_array_[k]; - const uint64_t *input_m_tilde_ptr_copy = input_m_tilde_ptr; - - // Compute result for aux base - for (size_t i = 0; i < coeff_count_; i++, destination++, - input_m_tilde_ptr_copy++, input++) - { - // Compute r_mtilde - // Duplicate work here: - // This needs to be computed only once per coefficient, not per Bsk prime. - uint64_t r_mtilde = multiply_uint_uint_mod(*input_m_tilde_ptr_copy, - inv_coeff_products_mod_mtilde_, m_tilde_); - r_mtilde = negate_uint_mod(r_mtilde, m_tilde_); - - // Lazy reduction - unsigned long long tmp[2]; - multiply_uint64(coeff_products_all_mod_bsk_array_elt, r_mtilde, tmp); - tmp[1] += add_uint64(tmp[0], *input, tmp); - r_mtilde = barrett_reduce_128(tmp, bsk_base_array_elt); - *destination = multiply_uint_uint_mod( - r_mtilde, inv_mtilde_mod_bsk_array_elt, bsk_base_array_elt); - } - } - } - - void BaseConverter::fast_floor(const uint64_t *input, - uint64_t *destination, MemoryPoolHandle pool) const - { -#ifdef SEAL_DEBUG - if (input == nullptr) - { - throw invalid_argument("input cannot be null"); - } - if (destination == nullptr) - { - throw invalid_argument("destination cannot be null"); - } - if (!pool) - { - throw invalid_argument("pool is not initialied"); - } -#endif - /** - Require: Input in q U m U {msk} - Ensure: Destination array in Bsk - */ - fastbconv(input, destination, pool); //q -> Bsk - - size_t index_msk = coeff_base_mod_count_ * coeff_count_; - input += index_msk; - for (size_t i = 0; i < bsk_base_mod_count_; i++) - { - SmallModulus bsk_base_array_elt = bsk_base_array_[i]; - uint64_t bsk_base_array_value = bsk_base_array_elt.value(); - uint64_t inv_coeff_products_all_mod_aux_bsk_array_elt = - inv_coeff_products_all_mod_aux_bsk_array_[i]; - for (size_t k = 0; k < coeff_count_; k++, input++, destination++) - { - // It is not necessary for the negation to be reduced modulo the small prime - //negate_uint_smallmod(base_convert_Bsk.get() + k + (i * coeff_count_), - // bsk_base_array_[i], &negated_base_convert_Bsk); - *destination = multiply_uint_uint_mod( - *input + bsk_base_array_value - *destination, - inv_coeff_products_all_mod_aux_bsk_array_elt, - bsk_base_array_elt - ); - } - } - } - - void BaseConverter::fastbconv_mtilde(const uint64_t *input, - uint64_t *destination, MemoryPoolHandle pool) const - { -#ifdef SEAL_DEBUG - if (input == nullptr) - { - throw invalid_argument("input cannot be null"); - } - if (destination == nullptr) - { - throw invalid_argument("destination cannot be null"); - } - if (!pool) - { - throw invalid_argument("pool is not initialied"); - } -#endif - /** - Require: Input in q - Ensure: Output in Bsk U {m_tilde} - */ - - // Compute in Bsk first; we compute |m_tilde*q^-1i| mod qi - auto temp_coeff_transition(allocate_uint( - coeff_count_ * coeff_base_mod_count_, pool)); - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - SmallModulus coeff_base_array_elt = coeff_base_array_[i]; - uint64_t mtilde_inv_coeff_base_products_mod_coeff_elt = - mtilde_inv_coeff_base_products_mod_coeff_array_[i]; - for (size_t k = 0; k < coeff_count_; k++, input++) - { - temp_coeff_transition[i + (k * coeff_base_mod_count_)] = - multiply_uint_uint_mod( - *input, - mtilde_inv_coeff_base_products_mod_coeff_elt, - coeff_base_array_elt - ); - } - } - - uint64_t *destination_ptr = destination; - for (size_t j = 0; j < bsk_base_mod_count_; j++) - { - const uint64_t *coeff_base_products_mod_aux_bsk_array_ptr = - coeff_base_products_mod_aux_bsk_array_[j].get(); - uint64_t *temp_coeff_transition_ptr = temp_coeff_transition.get(); - SmallModulus bsk_base_array_elt = bsk_base_array_[j]; - for (size_t k = 0; k < coeff_count_; k++, destination_ptr++) - { - unsigned long long aux_transition[2]{ 0, 0 }; - const uint64_t *temp_ptr = coeff_base_products_mod_aux_bsk_array_ptr; - for (size_t i = 0; i < coeff_base_mod_count_; - i++, temp_ptr++, temp_coeff_transition_ptr++) - { - // Lazy reduction - unsigned long long temp[2]{ 0, 0 }; - - // Product is 60 bit + 61 bit = 121 bit, so can sum up to 127 of them with no reduction - // Thus need coeff_base_mod_count_ <= 127 - multiply_uint64(*temp_coeff_transition_ptr, *temp_ptr, temp); - unsigned char carry = add_uint64(aux_transition[0], - temp[0], aux_transition); - aux_transition[1] += temp[1] + carry; - } - *destination_ptr = barrett_reduce_128(aux_transition, bsk_base_array_elt); - } - } - - // Computing the last element (mod m_tilde) and add it at the end of destination array - uint64_t *temp_coeff_transition_ptr = temp_coeff_transition.get(); - destination += bsk_base_mod_count_ * coeff_count_; - for (size_t k = 0; k < coeff_count_; k++, destination++) - { - unsigned long long wide_result[2]{ 0, 0 }; - const uint64_t *coeff_base_products_mod_mtilde_array_ptr = - coeff_base_products_mod_mtilde_array_.get(); - for (size_t i = 0; i < coeff_base_mod_count_; i++, - temp_coeff_transition_ptr++, - coeff_base_products_mod_mtilde_array_ptr++) - { - // Lazy reduction - unsigned long long aux_transition[2]; - - // Product is 60 bit + 33 bit = 93 bit - multiply_uint64(*temp_coeff_transition_ptr, - *coeff_base_products_mod_mtilde_array_ptr, aux_transition); - unsigned char carry = add_uint64(aux_transition[0], - wide_result[0], wide_result); - wide_result[1] += aux_transition[1] + carry; - } - *destination = barrett_reduce_128(wide_result, m_tilde_); - } - } - - void BaseConverter::fastbconv_plain_gamma(const uint64_t *input, - uint64_t *destination, MemoryPoolHandle pool) const - { -#ifdef SEAL_DEBUG - if (small_plain_mod_.is_zero()) - { - throw logic_error("invalid operation"); - } - if (input == nullptr) - { - throw invalid_argument("input cannot be null"); - } - if (destination == nullptr) - { - throw invalid_argument("destination cannot be null"); - } -#endif - /** - Require: Input in q - Ensure: Output in t (plain modulus) U gamma - */ - auto temp_coeff_transition(allocate_uint( - coeff_count_ * coeff_base_mod_count_, pool)); - for (size_t i = 0; i < coeff_base_mod_count_; i++) - { - uint64_t inv_coeff_base_products_mod_coeff_elt = - inv_coeff_base_products_mod_coeff_array_[i]; - SmallModulus coeff_base_array_elt = coeff_base_array_[i]; - for (size_t k = 0; k < coeff_count_; k++, input++) - { - temp_coeff_transition[i + (k * coeff_base_mod_count_)] = - multiply_uint_uint_mod( - *input, - inv_coeff_base_products_mod_coeff_elt, - coeff_base_array_elt - ); - } - } - - for (size_t j = 0; j < plain_gamma_count_; j++) - { - SmallModulus plain_gamma_array_elt = plain_gamma_array_[j]; - uint64_t *temp_coeff_transition_ptr = temp_coeff_transition.get(); - const uint64_t *coeff_products_mod_plain_gamma_array_ptr = - coeff_products_mod_plain_gamma_array_[j].get(); - for (size_t k = 0; k < coeff_count_; k++, destination++) - { - unsigned long long wide_result[2]{ 0, 0 }; - const uint64_t *temp_ptr = coeff_products_mod_plain_gamma_array_ptr; - for (size_t i = 0; i < coeff_base_mod_count_; i++, - temp_coeff_transition_ptr++, temp_ptr++) - { - unsigned long long plain_transition[2]; - - // Lazy reduction - // Product is 60 bit + 61 bit = 121 bit, so can sum up to 127 of them with no reduction - // Thus need coeff_base_mod_count_ <= 127 - multiply_uint64(*temp_coeff_transition_ptr, *temp_ptr, plain_transition); - unsigned char carry = add_uint64(plain_transition[0], - wide_result[0], wide_result); - wide_result[1] += plain_transition[1] + carry; - } - *destination = barrett_reduce_128(wide_result, plain_gamma_array_elt); - } - } - } - } -} diff --git a/SEAL/native/src/seal/util/baseconverter.h b/SEAL/native/src/seal/util/baseconverter.h deleted file mode 100644 index 17ae6d3..0000000 --- a/SEAL/native/src/seal/util/baseconverter.h +++ /dev/null @@ -1,291 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include "seal/util/pointer.h" -#include "seal/memorymanager.h" -#include "seal/smallmodulus.h" -#include "seal/util/smallntt.h" -#include "seal/biguint.h" - -namespace seal -{ - namespace util - { - class BaseConverter - { - public: - BaseConverter(MemoryPoolHandle pool) : pool_(std::move(pool)) - { - if (!pool_) - { - throw std::invalid_argument("pool is uninitialized"); - } - } - - BaseConverter(const std::vector &coeff_base, - std::size_t coeff_count, const SmallModulus &small_plain_mod, - MemoryPoolHandle pool); - - /** - Generates the pre-computations for the given parameters. - */ - void generate(const std::vector &coeff_base, - std::size_t coeff_count, const SmallModulus &small_plain_mod); - - void floor_last_coeff_modulus_inplace( - std::uint64_t *rns_poly, - MemoryPoolHandle pool) const; - - void floor_last_coeff_modulus_ntt_inplace( - std::uint64_t *rns_poly, - const Pointer &rns_ntt_tables, - MemoryPoolHandle pool) const; - - void round_last_coeff_modulus_inplace( - std::uint64_t *rns_poly, - MemoryPoolHandle pool) const; - - void round_last_coeff_modulus_ntt_inplace( - std::uint64_t *rns_poly, - const Pointer &rns_ntt_tables, - MemoryPoolHandle pool) const; - - /** - Fast base converter from q to Bsk - */ - void fastbconv(const std::uint64_t *input, - std::uint64_t *destination, MemoryPoolHandle pool) const; - - /** - Fast base converter from Bsk to q - */ - void fastbconv_sk(const std::uint64_t *input, - std::uint64_t *destination, MemoryPoolHandle pool) const; - - /** - Reduction from Bsk U {m_tilde} to Bsk - */ - void mont_rq(const std::uint64_t *input, - std::uint64_t *destination) const; - - /** - Fast base converter from q U Bsk to Bsk - */ - void fast_floor(const std::uint64_t *input, - std::uint64_t *destination, MemoryPoolHandle pool) const; - - /** - Fast base converter from q to Bsk U {m_tilde} - */ - void fastbconv_mtilde(const std::uint64_t *input, - std::uint64_t *destination, MemoryPoolHandle pool) const; - - /** - Fast base converter from q to plain_modulus U {gamma} - */ - void fastbconv_plain_gamma(const std::uint64_t *input, - std::uint64_t *destination, MemoryPoolHandle pool) const; - - void reset() noexcept; - - SEAL_NODISCARD inline auto is_generated() const noexcept - { - return generated_; - } - - SEAL_NODISCARD inline auto coeff_base_mod_count() const noexcept - { - return coeff_base_mod_count_; - } - - SEAL_NODISCARD inline auto aux_base_mod_count() const noexcept - { - return aux_base_mod_count_; - } - - SEAL_NODISCARD inline auto &get_plain_gamma_product() const noexcept - { - return plain_gamma_product_mod_coeff_array_; - } - - SEAL_NODISCARD inline auto &get_neg_inv_coeff() const noexcept - { - return neg_inv_coeff_products_all_mod_plain_gamma_array_; - } - - SEAL_NODISCARD inline auto &get_plain_gamma_array() const noexcept - { - return plain_gamma_array_; - } - - SEAL_NODISCARD inline auto get_coeff_products_array() const noexcept - -> const std::uint64_t * - { - return coeff_products_array_.get(); - } - - SEAL_NODISCARD inline std::uint64_t get_inv_gamma() const noexcept - { - return inv_gamma_mod_plain_; - } - - SEAL_NODISCARD inline auto &get_bsk_small_ntt_tables() const noexcept - { - return bsk_small_ntt_tables_; - } - - SEAL_NODISCARD inline auto bsk_base_mod_count() const noexcept - { - return bsk_base_mod_count_; - } - - SEAL_NODISCARD inline auto &get_bsk_mod_array() const noexcept - { - return bsk_base_array_; - } - - SEAL_NODISCARD inline auto &get_msk() const noexcept - { - return m_sk_; - } - - SEAL_NODISCARD inline auto &get_m_tilde() const noexcept - { - return m_tilde_; - } - - SEAL_NODISCARD inline auto &get_mtilde_inv_coeff_products_mod_coeff() const noexcept - { - return mtilde_inv_coeff_base_products_mod_coeff_array_; - } - - SEAL_NODISCARD inline auto &get_inv_coeff_mod_mtilde() const noexcept - { - return inv_coeff_products_mod_mtilde_; - } - - SEAL_NODISCARD inline auto &get_inv_coeff_mod_coeff_array() const noexcept - { - return inv_coeff_base_products_mod_coeff_array_; - } - - SEAL_NODISCARD inline auto &get_inv_last_coeff_mod_array() const noexcept - { - return inv_last_coeff_mod_array_; - } - - SEAL_NODISCARD inline auto &get_coeff_base_products_mod_msk() const noexcept - { - return coeff_base_products_mod_aux_bsk_array_[bsk_base_mod_count_ - 1]; - } - - private: - BaseConverter(const BaseConverter ©) = delete; - - BaseConverter(BaseConverter &&source) = delete; - - BaseConverter &operator =(const BaseConverter &assign) = delete; - - BaseConverter &operator =(BaseConverter &&assign) = delete; - - MemoryPoolHandle pool_; - - bool generated_ = false; - - std::size_t coeff_count_ = 0; - - std::size_t coeff_base_mod_count_ = 0; - - std::size_t aux_base_mod_count_ = 0; - - std::size_t bsk_base_mod_count_ = 0; - - std::size_t plain_gamma_count_ = 0; - - // Array of coefficient small moduli - Pointer coeff_base_array_; - - // Array of auxiliary moduli - Pointer aux_base_array_; - - // Array of auxiliary U {m_sk_} moduli - Pointer bsk_base_array_; - - // Array of plain modulus U gamma - Pointer plain_gamma_array_; - - // Punctured products of the coeff moduli - Pointer coeff_products_array_; - - // Matrix which contains the products of coeff moduli mod aux - Pointer> coeff_base_products_mod_aux_bsk_array_; - - // Array of inverse coeff modulus products mod each small coeff mods - Pointer inv_coeff_base_products_mod_coeff_array_; - - // Array of coeff moduli products mod m_tilde - Pointer coeff_base_products_mod_mtilde_array_; - - // Array of coeff modulus products times m_tilda mod each coeff modulus - Pointer mtilde_inv_coeff_base_products_mod_coeff_array_; - - // Matrix of the inversion of coeff modulus products mod each auxiliary mods - Pointer inv_coeff_products_all_mod_aux_bsk_array_; - - // Matrix of auxiliary mods products mod each coeff modulus - Pointer> aux_base_products_mod_coeff_array_; - - // Array of inverse auxiliary mod products mod each auxiliary mods - Pointer inv_aux_base_products_mod_aux_array_; - - // Array of auxiliary bases products mod m_sk_ - Pointer aux_base_products_mod_msk_array_; - - // Coeff moduli products inverse mod m_tilde - std::uint64_t inv_coeff_products_mod_mtilde_ = 0; - - // Auxiliary base products mod m_sk_ (m1*m2*...*ml)-1 mod m_sk - std::uint64_t inv_aux_products_mod_msk_ = 0; - - // Gamma inverse mod plain modulus - std::uint64_t inv_gamma_mod_plain_ = 0; - - // Auxiliary base products mod coeff moduli (m1*m2*...*ml) mod qi - Pointer aux_products_all_mod_coeff_array_; - - // Array of m_tilde inverse mod Bsk = m U {msk} - Pointer inv_mtilde_mod_bsk_array_; - - // Array of all coeff base products mod Bsk - Pointer coeff_products_all_mod_bsk_array_; - - // Matrix of coeff base product mod plain modulus and gamma - Pointer> coeff_products_mod_plain_gamma_array_; - - // Array of negative inverse all coeff base product mod plain modulus and gamma - Pointer neg_inv_coeff_products_all_mod_plain_gamma_array_; - - // Array of plain_gamma_product mod coeff base moduli - Pointer plain_gamma_product_mod_coeff_array_; - - // Array of small NTT tables for moduli in Bsk - Pointer bsk_small_ntt_tables_; - - // For modulus switching: inverses of the last coeff base modulus - Pointer inv_last_coeff_mod_array_; - - SmallModulus m_tilde_; - - SmallModulus m_sk_; - - SmallModulus small_plain_mod_; - - SmallModulus gamma_; - }; - } -} diff --git a/SEAL/native/src/seal/util/clang.h b/SEAL/native/src/seal/util/clang.h deleted file mode 100644 index 47f2852..0000000 --- a/SEAL/native/src/seal/util/clang.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#if SEAL_COMPILER == SEAL_COMPILER_CLANG - -// We require clang >= 5 -#if (__clang_major__ < 5) || not defined(__cplusplus) -#error "SEAL requires __clang_major__ >= 5" -#endif - -// Read in config.h -#include "seal/util/config.h" - -// Are we using MSGSL? -#ifdef SEAL_USE_MSGSL -#include -#endif - -// Are intrinsics enabled? -#ifdef SEAL_USE_INTRIN -#include - -#ifdef SEAL_USE___BUILTIN_CLZLL -#define SEAL_MSB_INDEX_UINT64(result, value) { \ - *result = 63UL - static_cast(__builtin_clzll(value)); \ -} -#endif - -#ifdef SEAL_USE___INT128 -#define SEAL_MULTIPLY_UINT64_HW64(operand1, operand2, hw64) { \ - *hw64 = static_cast( \ - ((static_cast(operand1) \ - * static_cast(operand2)) >> 64)); \ -} - -#define SEAL_MULTIPLY_UINT64(operand1, operand2, result128) { \ - unsigned __int128 product = static_cast(operand1) * operand2;\ - result128[0] = static_cast(product); \ - result128[1] = static_cast(product >> 64); \ -} - -#define SEAL_DIVIDE_UINT128_UINT64(numerator, denominator, result) { \ - unsigned __int128 n, q; \ - n = (static_cast(numerator[1]) << 64) | \ - (static_cast(numerator[0])); \ - q = n / denominator; \ - n -= q * denominator; \ - numerator[0] = static_cast(n); \ - numerator[1] = static_cast(n >> 64); \ - quotient[0] = static_cast(q); \ - quotient[1] = static_cast(q >> 64); \ -} -#endif - -#ifdef SEAL_USE__ADDCARRY_U64 -#define SEAL_ADD_CARRY_UINT64(operand1, operand2, carry, result) _addcarry_u64( \ - carry, operand1, operand2, result) -#endif - -#ifdef SEAL_USE__SUBBORROW_U64 -#define SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result) _subborrow_u64( \ - borrow, operand1, operand2, result) -#endif - -#endif //SEAL_USE_INTRIN - -#endif diff --git a/SEAL/native/src/seal/util/clipnormal.cpp b/SEAL/native/src/seal/util/clipnormal.cpp deleted file mode 100644 index 665dbd6..0000000 --- a/SEAL/native/src/seal/util/clipnormal.cpp +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include "seal/util/clipnormal.h" - -using namespace std; - -namespace seal -{ - namespace util - { - ClippedNormalDistribution::ClippedNormalDistribution( - result_type mean, - result_type standard_deviation, - result_type max_deviation) : - normal_(mean, standard_deviation), - max_deviation_(max_deviation) - { - // Verify arguments. - if (standard_deviation < 0) - { - throw invalid_argument("standard_deviation"); - } - if (max_deviation < 0) - { - throw invalid_argument("max_deviation"); - } - } - } -} diff --git a/SEAL/native/src/seal/util/clipnormal.h b/SEAL/native/src/seal/util/clipnormal.h deleted file mode 100644 index 38dc032..0000000 --- a/SEAL/native/src/seal/util/clipnormal.h +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include "seal/util/defines.h" -#include -#include - -namespace seal -{ - namespace util - { - class ClippedNormalDistribution - { - public: - using result_type = double; - - using param_type = ClippedNormalDistribution; - - ClippedNormalDistribution( - result_type mean, result_type standard_deviation, - result_type max_deviation); - - template - SEAL_NODISCARD inline result_type operator()( - RNG &engine, const param_type &parm) noexcept - { - param(parm); - return operator()(engine); - } - - template - SEAL_NODISCARD inline result_type operator()(RNG &engine) noexcept - { - result_type mean = normal_.mean(); - while (true) - { - result_type value = normal_(engine); - result_type deviation = std::abs(value - mean); - if (deviation <= max_deviation_) - { - return value; - } - } - } - - SEAL_NODISCARD inline result_type mean() const noexcept - { - return normal_.mean(); - } - - SEAL_NODISCARD inline result_type standard_deviation() const noexcept - { - return normal_.stddev(); - } - - SEAL_NODISCARD inline result_type max_deviation() const noexcept - { - return max_deviation_; - } - - SEAL_NODISCARD inline result_type min() const noexcept - { - return normal_.mean() - max_deviation_; - } - - SEAL_NODISCARD inline result_type max() const noexcept - { - return normal_.mean() + max_deviation_; - } - - SEAL_NODISCARD inline param_type param() const noexcept - { - return *this; - } - - inline void param(const param_type &parm) noexcept - { - *this = parm; - } - - inline void reset() noexcept - { - normal_.reset(); - } - - private: - std::normal_distribution normal_; - - result_type max_deviation_; - }; - } -} diff --git a/SEAL/native/src/seal/util/common.h b/SEAL/native/src/seal/util/common.h deleted file mode 100644 index b5efff7..0000000 --- a/SEAL/native/src/seal/util/common.h +++ /dev/null @@ -1,615 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include "seal/util/defines.h" - -namespace seal -{ - namespace util - { - template - struct is_uint64 : std::conditional< - std::is_integral::value && - std::is_unsigned::value && - (sizeof(T) == sizeof(std::uint64_t)), - std::true_type, std::false_type>::type - { - }; - - template - struct is_uint64 : std::conditional< - is_uint64::value && - is_uint64::value, - std::true_type, std::false_type>::type - { - }; - - template - constexpr bool is_uint64_v = is_uint64::value; - - template - struct is_uint32 : std::conditional< - std::is_integral::value && - std::is_unsigned::value && - (sizeof(T) == sizeof(std::uint32_t)), - std::true_type, std::false_type>::type - { - }; - - template - struct is_uint32 : std::conditional< - is_uint32::value && - is_uint32::value, - std::true_type, std::false_type>::type - { - }; - - template - constexpr bool is_uint32_v = is_uint32::value; - - template::value>, - typename = std::enable_if_t::value>> - SEAL_NODISCARD inline constexpr bool unsigned_lt(T in1, S in2) noexcept - { - return static_cast(in1) < static_cast(in2); - } - - template::value>, - typename = std::enable_if_t::value>> - SEAL_NODISCARD inline constexpr bool unsigned_leq(T in1, S in2) noexcept - { - return static_cast(in1) <= static_cast(in2); - } - - template::value>, - typename = std::enable_if_t::value>> - SEAL_NODISCARD inline constexpr bool unsigned_gt(T in1, S in2) noexcept - { - return static_cast(in1) > static_cast(in2); - } - - template::value>, - typename = std::enable_if_t::value>> - SEAL_NODISCARD inline constexpr bool unsigned_geq(T in1, S in2) noexcept - { - return static_cast(in1) >= static_cast(in2); - } - - template::value>, - typename = std::enable_if_t::value>> - SEAL_NODISCARD inline constexpr bool unsigned_eq(T in1, S in2) noexcept - { - return static_cast(in1) == static_cast(in2); - } - - template::value>, - typename = std::enable_if_t::value>> - SEAL_NODISCARD inline constexpr bool unsigned_neq(T in1, S in2) noexcept - { - return static_cast(in1) != static_cast(in2); - } - - template::value>> - SEAL_NODISCARD inline constexpr T mul_safe(T in1) noexcept - { - return in1; - } - - template::value>> - SEAL_NODISCARD inline constexpr T mul_safe(T in1, T in2) - { - SEAL_IF_CONSTEXPR (std::is_unsigned::value) - { - if (in1 && (in2 > std::numeric_limits::max() / in1)) - { - throw std::out_of_range("unsigned overflow"); - } - } - else - { - // Positive inputs - if ((in1 > 0) && (in2 > 0) && - (in2 > std::numeric_limits::max() / in1)) - { - throw std::out_of_range("signed overflow"); - } -#if (SEAL_COMPILER == SEAL_COMPILER_MSVC) && !defined(SEAL_USE_IF_CONSTEXPR) -#pragma warning(push) -#pragma warning(disable: 4146) -#endif - // Negative inputs - else if ((in1 < 0) && (in2 < 0) && - ((-in2) > std::numeric_limits::max() / (-in1))) - { - throw std::out_of_range("signed overflow"); - } - // Negative in1; positive in2 - else if ((in1 < 0) && (in2 > 0) && - (in2 > std::numeric_limits::max() / (-in1))) - { - throw std::out_of_range("signed underflow"); - } -#if (SEAL_COMPILER == SEAL_COMPILER_MSVC) && !defined(SEAL_USE_IF_CONSTEXPR) -#pragma warning(pop) -#endif - // Positive in1; negative in2 - else if ((in1 > 0) && (in2 < 0) && - (in2 < std::numeric_limits::min() / in1)) - { - throw std::out_of_range("signed underflow"); - } - } - return in1 * in2; - } - - template::value>> - SEAL_NODISCARD inline constexpr T mul_safe(T in1, T in2, Args &&...args) - { - return mul_safe(mul_safe(in1, in2), mul_safe(std::forward(args)...)); - } - - template::value>> - SEAL_NODISCARD inline constexpr T add_safe(T in1) noexcept - { - return in1; - } - - template::value>> - SEAL_NODISCARD inline constexpr T add_safe(T in1, T in2) - { - SEAL_IF_CONSTEXPR (std::is_unsigned::value) - { - T result = in1 + in2; - if (result < in1) - { - throw std::out_of_range("unsigned overflow"); - } - return result; - } - else - { - if (in1 > 0 && (in2 > std::numeric_limits::max() - in1)) - { - throw std::out_of_range("signed overflow"); - } - else if (in1 < 0 && - (in2 < std::numeric_limits::min() - in1)) - { - throw std::out_of_range("signed underflow"); - } - return in1 + in2; - } - } - - template::value>> - SEAL_NODISCARD inline constexpr T add_safe(T in1, T in2, Args &&...args) - { - return add_safe(add_safe(in1, in2), add_safe(std::forward(args)...)); - } - - template::value>> - SEAL_NODISCARD inline T sub_safe(T in1, T in2) - { - SEAL_IF_CONSTEXPR (std::is_unsigned::value) - { - T result = in1 - in2; - if (result > in1) - { - throw std::out_of_range("unsigned underflow"); - } - return result; - } - else - { - if (in1 < 0 && (in2 > std::numeric_limits::max() + in1)) - { - throw std::out_of_range("signed underflow"); - } - else if (in1 > 0 && - (in2 < std::numeric_limits::min() + in1)) - { - throw std::out_of_range("signed overflow"); - } - return in1 - in2; - } - } - - template::value>, - typename = std::enable_if_t::value>> - SEAL_NODISCARD inline constexpr bool fits_in(S value SEAL_MAYBE_UNUSED) noexcept - { - SEAL_IF_CONSTEXPR (std::is_same::value) - { - // Same type - return true; - } - - SEAL_IF_CONSTEXPR (sizeof(S) <= sizeof(T)) - { - // Converting to bigger type - SEAL_IF_CONSTEXPR (std::is_integral::value && std::is_integral::value) - { - // Converting to at least equally big integer type - SEAL_IF_CONSTEXPR ((std::is_unsigned::value && std::is_unsigned::value) - || (!std::is_unsigned::value && !std::is_unsigned::value)) - { - // Both either signed or unsigned - return true; - } - else SEAL_IF_CONSTEXPR (std::is_unsigned::value - && std::is_signed::value) - { - // Converting from signed to at least equally big unsigned type - return value >= 0; - } - } - else SEAL_IF_CONSTEXPR (std::is_floating_point::value - && std::is_floating_point::value) - { - // Both floating-point - return true; - } - - // Still need to consider integer-float conversions and all - // unsigned to signed conversions - } - - SEAL_IF_CONSTEXPR (std::is_integral::value && std::is_integral::value) - { - // Both integer types - if (value >= 0) - { - // Non-negative number; compare as std::uint64_t - // Cannot use unsigned_leq with C++14 for lack of `if constexpr' - return static_cast(value) <= - static_cast(std::numeric_limits::max()); - } - else - { - // Negative number; compare as std::int64_t - return (static_cast(value) >= - static_cast(std::numeric_limits::min())); - } - } - else SEAL_IF_CONSTEXPR (std::is_floating_point::value) - { - // Converting to floating-point - return (static_cast(value) <= - static_cast(std::numeric_limits::max())) && - (static_cast(value) >= - -static_cast(std::numeric_limits::max())); - } - else - { - // Converting from floating-point - return (static_cast(value) <= - static_cast(std::numeric_limits::max())) && - (static_cast(value) >= - static_cast(std::numeric_limits::min())); - } - } - - template::value>> - SEAL_NODISCARD inline constexpr bool sum_fits_in(Args &&...args) - { - return fits_in(add_safe(std::forward(args)...)); - } - - template::value>> - SEAL_NODISCARD inline constexpr bool sum_fits_in(T in1, Args &&...args) - { - return fits_in(add_safe(in1, std::forward(args)...)); - } - - template::value>> - SEAL_NODISCARD inline constexpr bool product_fits_in(Args &&...args) - { - return fits_in(mul_safe(std::forward(args)...)); - } - - template::value>> - SEAL_NODISCARD inline constexpr bool product_fits_in(T in1, Args &&...args) - { - return fits_in(mul_safe(in1, std::forward(args)...)); - } - - template::value>, - typename = std::enable_if_t::value>> - SEAL_NODISCARD inline T safe_cast(S value) - { - SEAL_IF_CONSTEXPR (!std::is_same::value) - { - if(!fits_in(value)) - { - throw std::out_of_range("cast failed"); - } - } - return static_cast(value); - } - - constexpr int bytes_per_uint64 = sizeof(std::uint64_t); - - constexpr int bytes_per_uint32 = sizeof(std::uint32_t); - - constexpr int uint32_per_uint64 = 2; - - constexpr int bits_per_nibble = 4; - - constexpr int bits_per_byte = 8; - - constexpr int bits_per_uint64 = bytes_per_uint64 * bits_per_byte; - - constexpr int bits_per_uint32 = bytes_per_uint32 * bits_per_byte; - - constexpr int nibbles_per_byte = 2; - - constexpr int nibbles_per_uint64 = bytes_per_uint64 * nibbles_per_byte; - - constexpr std::uint64_t uint64_high_bit = std::uint64_t(1) << (bits_per_uint64 - 1); - - template || is_uint64_v>> - SEAL_NODISCARD inline constexpr T reverse_bits(T operand) noexcept - { - SEAL_IF_CONSTEXPR (is_uint32_v) - { - operand = (((operand & T(0xaaaaaaaa)) >> 1) | ((operand & T(0x55555555)) << 1)); - operand = (((operand & T(0xcccccccc)) >> 2) | ((operand & T(0x33333333)) << 2)); - operand = (((operand & T(0xf0f0f0f0)) >> 4) | ((operand & T(0x0f0f0f0f)) << 4)); - operand = (((operand & T(0xff00ff00)) >> 8) | ((operand & T(0x00ff00ff)) << 8)); - return static_cast(operand >> 16) | static_cast(operand << 16); - } - else SEAL_IF_CONSTEXPR (is_uint64_v) - { -// Temporarily disable UB warnings when `if constexpr` is not available. -#ifndef SEAL_USE_IF_CONSTEXPR -#if (SEAL_COMPILER == SEAL_COMPILER_MSVC) -#pragma warning(push) -#pragma warning(disable: 4293) -#elif (SEAL_COMPILER == SEAL_COMPILER_GCC) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wshift-count-overflow" -#elif (SEAL_COMPILER == SEAL_COMPILER_CLANG) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wshift-count-overflow" -#endif -#endif - return static_cast(reverse_bits(static_cast(operand >> 32))) | - (static_cast(reverse_bits(static_cast(operand & T(0xFFFFFFFF)))) << 32); -#ifndef SEAL_USE_IF_CONSTEXPR -#if (SEAL_COMPILER == SEAL_COMPILER_MSVC) -#pragma warning(pop) -#elif (SEAL_COMPILER == SEAL_COMPILER_GCC) -#pragma GCC diagnostic pop -#elif (SEAL_COMPILER == SEAL_COMPILER_CLANG) -#pragma clang diagnostic pop -#endif -#endif - } - } - - template || is_uint64_v>> - SEAL_NODISCARD inline T reverse_bits(T operand, int bit_count) - { -#ifdef SEAL_DEBUG - if (bit_count < 0 || - static_cast(bit_count) > - mul_safe(sizeof(T), static_cast(bits_per_byte))) - { - throw std::invalid_argument("bit_count"); - } -#endif - // Just return zero if bit_count is zero - return (bit_count == 0) ? T(0) : reverse_bits(operand) >> ( - sizeof(T) * static_cast(bits_per_byte) - - static_cast(bit_count)); - } - - inline void get_msb_index_generic(unsigned long *result, std::uint64_t value) - { -#ifdef SEAL_DEBUG - if (result == nullptr) - { - throw std::invalid_argument("result"); - } -#endif - static const unsigned long deBruijnTable64[64] = { - 63, 0, 58, 1, 59, 47, 53, 2, - 60, 39, 48, 27, 54, 33, 42, 3, - 61, 51, 37, 40, 49, 18, 28, 20, - 55, 30, 34, 11, 43, 14, 22, 4, - 62, 57, 46, 52, 38, 26, 32, 41, - 50, 36, 17, 19, 29, 10, 13, 21, - 56, 45, 25, 31, 35, 16, 9, 12, - 44, 24, 15, 8, 23, 7, 6, 5 - }; - - value |= value >> 1; - value |= value >> 2; - value |= value >> 4; - value |= value >> 8; - value |= value >> 16; - value |= value >> 32; - - *result = deBruijnTable64[((value - (value >> 1)) * std::uint64_t(0x07EDD5E59A4E28C2)) >> 58]; - } - - SEAL_NODISCARD inline int get_significant_bit_count(std::uint64_t value) - { - if (value == 0) - { - return 0; - } - - unsigned long result; - SEAL_MSB_INDEX_UINT64(&result, value); - return static_cast(result + 1); - } - - SEAL_NODISCARD inline bool is_hex_char(char hex) - { - if (hex >= '0' && hex <= '9') - { - return true; - } - if (hex >= 'A' && hex <= 'F') - { - return true; - } - if (hex >= 'a' && hex <= 'f') - { - return true; - } - return false; - } - - SEAL_NODISCARD inline char nibble_to_upper_hex(int nibble) - { -#ifdef SEAL_DEBUG - if (nibble < 0 || nibble > 15) - { - throw std::invalid_argument("nibble"); - } -#endif - if (nibble < 10) - { - return static_cast(nibble + static_cast('0')); - } - return static_cast(nibble + static_cast('A') - 10); - } - - SEAL_NODISCARD inline int hex_to_nibble(char hex) - { - if (hex >= '0' && hex <= '9') - { - return static_cast(hex) - static_cast('0'); - } - if (hex >= 'A' && hex <= 'F') - { - return static_cast(hex) - static_cast('A') + 10; - } - if (hex >= 'a' && hex <= 'f') - { - return static_cast(hex) - static_cast('a') + 10; - } -#ifdef SEAL_DEBUG - throw std::invalid_argument("hex"); -#endif - return -1; - } - - SEAL_NODISCARD inline SEAL_BYTE *get_uint64_byte( - std::uint64_t *value, std::size_t byte_index) - { -#ifdef SEAL_DEBUG - if (value == nullptr) - { - throw std::invalid_argument("value"); - } -#endif - return reinterpret_cast(value) + byte_index; - } - - SEAL_NODISCARD inline const SEAL_BYTE *get_uint64_byte( - const std::uint64_t *value, std::size_t byte_index) - { -#ifdef SEAL_DEBUG - if (value == nullptr) - { - throw std::invalid_argument("value"); - } -#endif - return reinterpret_cast(value) + byte_index; - } - - SEAL_NODISCARD inline int get_hex_string_bit_count( - const char *hex_string, int char_count) - { -#ifdef SEAL_DEBUG - if (hex_string == nullptr && char_count > 0) - { - throw std::invalid_argument("hex_string"); - } - if (char_count < 0) - { - throw std::invalid_argument("char_count"); - } -#endif - for (int i = 0; i < char_count; i++) - { - char hex = *hex_string++; - int nibble = hex_to_nibble(hex); - if (nibble != 0) - { - int nibble_bits = get_significant_bit_count( - static_cast(nibble)); - int remaining_nibbles = (char_count - i - 1) * bits_per_nibble; - return nibble_bits + remaining_nibbles; - } - } - return 0; - } - - template::value>> - SEAL_NODISCARD inline T divide_round_up(T value, T divisor) - { -#ifdef SEAL_DEBUG - if (value < 0) - { - throw std::invalid_argument("value"); - } - if (divisor <= 0) - { - throw std::invalid_argument("divisor"); - } -#endif - return (add_safe(value, divisor - 1)) / divisor; - } - - template - constexpr double epsilon = std::numeric_limits::epsilon(); - - template::value>> - SEAL_NODISCARD inline bool are_close(T value1, T value2) noexcept - { - double scale_factor = std::max( - { std::fabs(value1), std::fabs(value2), T{ 1.0 } }); - return std::fabs(value1 - value2) < epsilon * scale_factor; - } - - template::value>> - SEAL_NODISCARD inline constexpr bool is_zero(T value) noexcept - { - return value == T{ 0 }; - } - } -} diff --git a/SEAL/native/src/seal/util/config.h.in b/SEAL/native/src/seal/util/config.h.in deleted file mode 100644 index a80ab2f..0000000 --- a/SEAL/native/src/seal/util/config.h.in +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#define SEAL_VERSION "@SEAL_VERSION@" -#cmakedefine SEAL_DEBUG -#cmakedefine SEAL_USE_IF_CONSTEXPR -#cmakedefine SEAL_USE_MAYBE_UNUSED -#cmakedefine SEAL_USE_NODISCARD -#cmakedefine SEAL_USE_STD_BYTE -#cmakedefine SEAL_USE_SHARED_MUTEX -#cmakedefine SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT -#cmakedefine SEAL_USE_INTRIN -#cmakedefine SEAL_USE__UMUL128 -#cmakedefine SEAL_USE__BITSCANREVERSE64 -#cmakedefine SEAL_USE___BUILTIN_CLZLL -#cmakedefine SEAL_USE___INT128 -#cmakedefine SEAL_USE__ADDCARRY_U64 -#cmakedefine SEAL_USE__SUBBORROW_U64 -#cmakedefine SEAL_USE_AES_NI_PRNG -#cmakedefine SEAL_USE_MSGSL -#cmakedefine SEAL_USE_MSGSL_SPAN -#cmakedefine SEAL_USE_MSGSL_MULTISPAN diff --git a/SEAL/native/src/seal/util/defines.h b/SEAL/native/src/seal/util/defines.h deleted file mode 100644 index a811237..0000000 --- a/SEAL/native/src/seal/util/defines.h +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -// Debugging help -#define SEAL_ASSERT(condition) { if(!(condition)){ std::cerr << "ASSERT FAILED: " \ - << #condition << " @ " << __FILE__ << " (" << __LINE__ << ")" << std::endl; } } - -// String expansion -#define _SEAL_STRINGIZE(x) #x -#define SEAL_STRINGIZE(x) _SEAL_STRINGIZE(x) - -// Check that double is 64 bits -static_assert(sizeof(double) == 8, "Require sizeof(double) == 8"); - -// Check that int is 32 bits -static_assert(sizeof(int) == 4, "Require sizeof(int) == 4"); - -// Check that unsigned long long is 64 bits -static_assert(sizeof(unsigned long long) == 8, "Require sizeof(unsigned long long) == 8"); - -// Bounds for bit-length of user-defined coefficient moduli -#define SEAL_USER_MOD_BIT_COUNT_MAX 60 -#define SEAL_USER_MOD_BIT_COUNT_MIN 2 - -// Bounds for number of coefficient moduli -#define SEAL_COEFF_MOD_COUNT_MAX 62 -#define SEAL_COEFF_MOD_COUNT_MIN 1 - -// Bounds for polynomial modulus degree -#define SEAL_POLY_MOD_DEGREE_MAX 32768 -#define SEAL_POLY_MOD_DEGREE_MIN 2 - -// Bounds for the plaintext modulus -#define SEAL_PLAIN_MOD_MIN SEAL_USER_MOD_BIT_COUNT_MIN -#define SEAL_PLAIN_MOD_MAX SEAL_USER_MOD_BIT_COUNT_MAX - -// Upper bound on the size of a ciphertext -#define SEAL_CIPHERTEXT_SIZE_MIN 2 -#define SEAL_CIPHERTEXT_SIZE_MAX 16 - -// Detect compiler -#define SEAL_COMPILER_MSVC 1 -#define SEAL_COMPILER_CLANG 2 -#define SEAL_COMPILER_GCC 3 - -#if defined(_MSC_VER) -#define SEAL_COMPILER SEAL_COMPILER_MSVC -#elif defined(__clang__) -#define SEAL_COMPILER SEAL_COMPILER_CLANG -#elif defined(__GNUC__) && !defined(__clang__) -#define SEAL_COMPILER SEAL_COMPILER_GCC -#endif - -// MSVC support -#include "seal/util/msvc.h" - -// clang support -#include "seal/util/clang.h" - -// gcc support -#include "seal/util/gcc.h" - -// Create a true/false value for indicating debug mode -#ifdef SEAL_DEBUG -#define SEAL_DEBUG_V true -#else -#define SEAL_DEBUG_V false -#endif - -// Use std::byte as byte type -#ifdef SEAL_USE_STD_BYTE -#include -namespace seal -{ - using SEAL_BYTE = std::byte; -} -#else -namespace seal -{ - enum class SEAL_BYTE : unsigned char {}; -} -#endif - -// Use `if constexpr' from C++17 -#ifdef SEAL_USE_IF_CONSTEXPR -#define SEAL_IF_CONSTEXPR if constexpr -#else -#define SEAL_IF_CONSTEXPR if -#endif - -// Use [[maybe_unused]] from C++17 -#ifdef SEAL_USE_MAYBE_UNUSED -#define SEAL_MAYBE_UNUSED [[maybe_unused]] -#else -#define SEAL_MAYBE_UNUSED -#endif - -// Use [[nodiscard]] from C++17 -#ifdef SEAL_USE_NODISCARD -#define SEAL_NODISCARD [[nodiscard]] -#else -#define SEAL_NODISCARD -#endif - -// Which random number generator factory to use by default -#ifdef SEAL_USE_AES_NI_PRNG -// AES-PRNG with seed from std::random_device -#define SEAL_DEFAULT_RNG_FACTORY FastPRNGFactory() -#else -// std::random_device -#define SEAL_DEFAULT_RNG_FACTORY StandardRandomAdapterFactory -#endif - -// Use generic functions as (slower) fallback -#ifndef SEAL_ADD_CARRY_UINT64 -#define SEAL_ADD_CARRY_UINT64(operand1, operand2, carry, result) add_uint64_generic(operand1, operand2, carry, result) -#endif - -#ifndef SEAL_SUB_BORROW_UINT64 -#define SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result) sub_uint64_generic(operand1, operand2, borrow, result) -#endif - -#ifndef SEAL_MULTIPLY_UINT64 -#define SEAL_MULTIPLY_UINT64(operand1, operand2, result128) { \ - multiply_uint64_generic(operand1, operand2, result128); \ -} -#endif - -#ifndef SEAL_DIVIDE_UINT128_UINT64 -#define SEAL_DIVIDE_UINT128_UINT64(numerator, denominator, result) { \ - divide_uint128_uint64_inplace_generic(numerator, denominator, result); \ -} -#endif - -#ifndef SEAL_MULTIPLY_UINT64_HW64 -#define SEAL_MULTIPLY_UINT64_HW64(operand1, operand2, hw64) { \ - multiply_uint64_hw64_generic(operand1, operand2, hw64); \ -} -#endif - -#ifndef SEAL_MSB_INDEX_UINT64 -#define SEAL_MSB_INDEX_UINT64(result, value) get_msb_index_generic(result, value) -#endif diff --git a/SEAL/native/src/seal/util/gcc.h b/SEAL/native/src/seal/util/gcc.h deleted file mode 100644 index c418737..0000000 --- a/SEAL/native/src/seal/util/gcc.h +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#if SEAL_COMPILER == SEAL_COMPILER_GCC - -// We require GCC >= 6 -#if (__GNUC__ < 6) || not defined(__cplusplus) -#pragma GCC error "SEAL requires __GNUC__ >= 6" -#endif - -// Read in config.h -#include "seal/util/config.h" - -#if (__GNUC__ == 6) && defined(SEAL_USE_IF_CONSTEXPR) -#pragma GCC error "g++-6 cannot compile Microsoft SEAL as C++17; set CMake build option `SEAL_USE_CXX17' to OFF" -#endif - -// Are we using MSGSL? -#ifdef SEAL_USE_MSGSL -#include -#endif - -// Are intrinsics enabled? -#ifdef SEAL_USE_INTRIN -#include - -#ifdef SEAL_USE___BUILTIN_CLZLL -#define SEAL_MSB_INDEX_UINT64(result, value) { \ - *result = 63UL - static_cast(__builtin_clzll(value)); \ -} -#endif - -#ifdef SEAL_USE___INT128 -#define SEAL_MULTIPLY_UINT64_HW64(operand1, operand2, hw64) { \ - *hw64 = static_cast( \ - ((static_cast(operand1) \ - * static_cast(operand2)) >> 64)); \ -} - -#define SEAL_MULTIPLY_UINT64(operand1, operand2, result128) { \ - unsigned __int128 product = static_cast(operand1) * operand2;\ - result128[0] = static_cast(product); \ - result128[1] = static_cast(product >> 64); \ -} - -#define SEAL_DIVIDE_UINT128_UINT64(numerator, denominator, result) { \ - unsigned __int128 n, q; \ - n = (static_cast(numerator[1]) << 64) | \ - (static_cast(numerator[0])); \ - q = n / denominator; \ - n -= q * denominator; \ - numerator[0] = static_cast(n); \ - numerator[1] = static_cast(n >> 64); \ - quotient[0] = static_cast(q); \ - quotient[1] = static_cast(q >> 64); \ -} -#endif - -#ifdef SEAL_USE__ADDCARRY_U64 -#define SEAL_ADD_CARRY_UINT64(operand1, operand2, carry, result) _addcarry_u64( \ - carry, operand1, operand2, result) -#endif - -#ifdef SEAL_USE__SUBBORROW_U64 -#if ((__GNUC__ == 7) && (__GNUC_MINOR__ >= 2)) || (__GNUC__ >= 8) -// The inverted arguments problem was fixed in GCC-7.2 -// (https://patchwork.ozlabs.org/patch/784309/) -#define SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result) _subborrow_u64( \ - borrow, operand1, operand2, result) -#else -// Warning: Note the inverted order of operand1 and operand2 -#define SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result) _subborrow_u64( \ - borrow, operand2, operand1, result) -#endif //(__GNUC__ == 7) && (__GNUC_MINOR__ >= 2) -#endif - -#endif //SEAL_USE_INTRIN - -#endif diff --git a/SEAL/native/src/seal/util/globals.cpp b/SEAL/native/src/seal/util/globals.cpp deleted file mode 100644 index ca27f8a..0000000 --- a/SEAL/native/src/seal/util/globals.cpp +++ /dev/null @@ -1,254 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include "seal/util/globals.h" -#include "seal/smallmodulus.h" - -using namespace std; - -namespace seal -{ - namespace util - { - namespace global_variables - { - std::shared_ptr const global_memory_pool{ std::make_shared() }; -#ifndef _M_CEE - thread_local std::shared_ptr const tls_memory_pool{ std::make_shared() }; -#else -#pragma message("WARNING: Thread-local memory pools disabled to support /clr") -#endif - const map> default_coeff_modulus_128 - { - /* - Polynomial modulus: 1x^1024 + 1 - Modulus count: 1 - Total bit count: 27 - */ - { 1024,{ - 0x7e00001 - } }, - - /* - Polynomial modulus: 1x^2048 + 1 - Modulus count: 1 - Total bit count: 54 - */ - { 2048,{ - 0x3fffffff000001 - } }, - - /* - Polynomial modulus: 1x^4096 + 1 - Modulus count: 3 - Total bit count: 109 = 2 * 36 + 37 - */ - { 4096,{ - 0xffffee001, 0xffffc4001, 0x1ffffe0001 - } }, - - /* - Polynomial modulus: 1x^8192 + 1 - Modulus count: 5 - Total bit count: 218 = 2 * 43 + 3 * 44 - */ - { 8192,{ - 0x7fffffd8001, 0x7fffffc8001, - 0xfffffffc001, 0xffffff6c001, 0xfffffebc001 - } }, - - /* - Polynomial modulus: 1x^16384 + 1 - Modulus count: 9 - Total bit count: 438 = 3 * 48 + 6 * 49 - */ - { 16384,{ - 0xfffffffd8001, 0xfffffffa0001, 0xfffffff00001, - 0x1fffffff68001, 0x1fffffff50001, 0x1ffffffee8001, - 0x1ffffffea0001, 0x1ffffffe88001, 0x1ffffffe48001 - } }, - - /* - Polynomial modulus: 1x^32768 + 1 - Modulus count: 16 - Total bit count: 881 = 15 * 55 + 56 - */ - { 32768,{ - 0x7fffffffe90001, 0x7fffffffbf0001, 0x7fffffffbd0001, - 0x7fffffffba0001, 0x7fffffffaa0001, 0x7fffffffa50001, - 0x7fffffff9f0001, 0x7fffffff7e0001, 0x7fffffff770001, - 0x7fffffff380001, 0x7fffffff330001, 0x7fffffff2d0001, - 0x7fffffff170001, 0x7fffffff150001, 0x7ffffffef00001, - 0xfffffffff70001 - } } - }; - - const map> default_coeff_modulus_192 - { - /* - Polynomial modulus: 1x^1024 + 1 - Modulus count: 1 - Total bit count: 19 - */ - { 1024,{ - 0x7f001 - } }, - - /* - Polynomial modulus: 1x^2048 + 1 - Modulus count: 1 - Total bit count: 37 - */ - { 2048,{ - 0x1ffffc0001 - } }, - - /* - Polynomial modulus: 1x^4096 + 1 - Modulus count: 3 - Total bit count: 75 = 3 * 25 - */ - { 4096,{ - 0x1ffc001, 0x1fce001, 0x1fc0001 - } }, - - /* - Polynomial modulus: 1x^8192 + 1 - Modulus count: 4 - Total bit count: 152 = 4 * 38 - */ - { 8192,{ - 0x3ffffac001, 0x3ffff54001, - 0x3ffff48001, 0x3ffff28001 - } }, - - /* - Polynomial modulus: 1x^16384 + 1 - Modulus count: 6 - Total bit count: 300 = 6 * 50 - */ - { 16384,{ - 0x3ffffffdf0001, 0x3ffffffd48001, 0x3ffffffd20001, - 0x3ffffffd18001, 0x3ffffffcd0001, 0x3ffffffc70001 - } }, - - /* - Polynomial modulus: 1x^32768 + 1 - Modulus count: 11 - Total bit count: 600 = 5 * 54 + 6 * 55 - */ - { 32768,{ - 0x3fffffffd60001, 0x3fffffffca0001, 0x3fffffff6d0001, - 0x3fffffff5d0001, 0x3fffffff550001, 0x7fffffffe90001, - 0x7fffffffbf0001, 0x7fffffffbd0001, 0x7fffffffba0001, - 0x7fffffffaa0001, 0x7fffffffa50001 - } } - }; - - const map> default_coeff_modulus_256 - { - /* - Polynomial modulus: 1x^1024 + 1 - Modulus count: 1 - Total bit count: 14 - */ - { 1024,{ - 0x3001 - } }, - - /* - Polynomial modulus: 1x^2048 + 1 - Modulus count: 1 - Total bit count: 29 - */ - { 2048,{ - 0x1ffc0001 - } }, - - /* - Polynomial modulus: 1x^4096 + 1 - Modulus count: 1 - Total bit count: 58 - */ - { 4096,{ - 0x3ffffffff040001 - } }, - - /* - Polynomial modulus: 1x^8192 + 1 - Modulus count: 3 - Total bit count: 118 = 2 * 39 + 40 - */ - { 8192,{ - 0x7ffffec001, 0x7ffffb0001, 0xfffffdc001 - } }, - - /* - Polynomial modulus: 1x^16384 + 1 - Modulus count: 5 - Total bit count: 237 = 3 * 47 + 2 * 48 - */ - { 16384,{ - 0x7ffffffc8001, 0x7ffffff00001, 0x7fffffe70001, - 0xfffffffd8001, 0xfffffffa0001 - } }, - - /* - Polynomial modulus: 1x^32768 + 1 - Modulus count: 9 - Total bit count: 476 = 52 + 8 * 53 - */ - { 32768,{ - 0xffffffff00001, 0x1fffffffe30001, 0x1fffffffd80001, - 0x1fffffffd10001, 0x1fffffffc50001, 0x1fffffffbf0001, - 0x1fffffffb90001, 0x1fffffffb60001, 0x1fffffffa50001 - } } - }; - - namespace internal_mods - { - const SmallModulus m_sk(0x1fffffffffe00001); - - const SmallModulus m_tilde(uint64_t(1) << 32); - - const SmallModulus gamma(0x1fffffffffc80001); - - const vector aux_small_mods{ - 0x1fffffffffb40001, 0x1fffffffff500001, 0x1fffffffff380001, 0x1fffffffff000001, - 0x1ffffffffef00001, 0x1ffffffffee80001, 0x1ffffffffeb40001, 0x1ffffffffe780001, - 0x1ffffffffe600001, 0x1ffffffffe4c0001, 0x1ffffffffdf40001, 0x1ffffffffdac0001, - 0x1ffffffffda40001, 0x1ffffffffc680001, 0x1ffffffffc000001, 0x1ffffffffb880001, - 0x1ffffffffb7c0001, 0x1ffffffffb300001, 0x1ffffffffb1c0001, 0x1ffffffffadc0001, - 0x1ffffffffa400001, 0x1ffffffffa140001, 0x1ffffffff9d80001, 0x1ffffffff9140001, - 0x1ffffffff8ac0001, 0x1ffffffff8a80001, 0x1ffffffff81c0001, 0x1ffffffff7800001, - 0x1ffffffff7680001, 0x1ffffffff7080001, 0x1ffffffff6c80001, 0x1ffffffff6140001, - 0x1ffffffff5f40001, 0x1ffffffff5700001, 0x1ffffffff4bc0001, 0x1ffffffff4380001, - 0x1ffffffff3240001, 0x1ffffffff2dc0001, 0x1ffffffff1a40001, 0x1ffffffff11c0001, - 0x1ffffffff0fc0001, 0x1ffffffff0d80001, 0x1ffffffff0c80001, 0x1ffffffff08c0001, - 0x1fffffffefd00001, 0x1fffffffef9c0001, 0x1fffffffef600001, 0x1fffffffeef40001, - 0x1fffffffeed40001, 0x1fffffffeed00001, 0x1fffffffeebc0001, 0x1fffffffed540001, - 0x1fffffffed440001, 0x1fffffffed2c0001, 0x1fffffffed200001, 0x1fffffffec940001, - 0x1fffffffec6c0001, 0x1fffffffebe80001, 0x1fffffffebac0001, 0x1fffffffeba40001, - 0x1fffffffeb4c0001, 0x1fffffffeb280001, 0x1fffffffea780001, 0x1fffffffea440001, - 0x1fffffffe9f40001, 0x1fffffffe97c0001, 0x1fffffffe9300001, 0x1fffffffe8d00001, - 0x1fffffffe8400001, 0x1fffffffe7cc0001, 0x1fffffffe7bc0001, 0x1fffffffe7a80001, - 0x1fffffffe7600001, 0x1fffffffe7500001, 0x1fffffffe6fc0001, 0x1fffffffe6d80001, - 0x1fffffffe6ac0001, 0x1fffffffe6000001, 0x1fffffffe5d40001, 0x1fffffffe5a00001, - 0x1fffffffe5940001, 0x1fffffffe54c0001, 0x1fffffffe5340001, 0x1fffffffe4bc0001, - 0x1fffffffe4a40001, 0x1fffffffe3fc0001, 0x1fffffffe3540001, 0x1fffffffe2b00001, - 0x1fffffffe2680001, 0x1fffffffe0480001, 0x1fffffffe00c0001, 0x1fffffffdfd00001, - 0x1fffffffdfc40001, 0x1fffffffdf700001, 0x1fffffffdf340001, 0x1fffffffdef80001, - 0x1fffffffdea80001, 0x1fffffffde680001, 0x1fffffffde000001, 0x1fffffffdde40001, - 0x1fffffffddd80001, 0x1fffffffddd00001, 0x1fffffffddb40001, 0x1fffffffdd780001, - 0x1fffffffdd4c0001, 0x1fffffffdcb80001, 0x1fffffffdca40001, 0x1fffffffdc380001, - 0x1fffffffdc040001, 0x1fffffffdbb40001, 0x1fffffffdba80001, 0x1fffffffdb9c0001, - 0x1fffffffdb740001, 0x1fffffffdb380001, 0x1fffffffda600001, 0x1fffffffda340001, - 0x1fffffffda180001, 0x1fffffffd9700001, 0x1fffffffd9680001, 0x1fffffffd9440001, - 0x1fffffffd9080001, 0x1fffffffd8c80001, 0x1fffffffd8800001, 0x1fffffffd82c0001, - 0x1fffffffd7cc0001, 0x1fffffffd7b80001, 0x1fffffffd7840001, 0x1fffffffd73c0001 - }; - } - } - } -} \ No newline at end of file diff --git a/SEAL/native/src/seal/util/globals.h b/SEAL/native/src/seal/util/globals.h deleted file mode 100644 index 479dfe2..0000000 --- a/SEAL/native/src/seal/util/globals.h +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include -#include "seal/util/defines.h" -#include "seal/util/hestdparms.h" - -namespace seal -{ - class SmallModulus; - - namespace util - { - class MemoryPool; - - namespace global_variables - { - extern std::shared_ptr const global_memory_pool; - -/* -For .NET Framework wrapper support (C++/CLI) we need to - (1) compile the MemoryManager class as thread-unsafe because C++ - mutexes cannot be brought through C++/CLI layer; - (2) disable thread-safe memory pools. -*/ -#ifndef _M_CEE - extern thread_local std::shared_ptr const tls_memory_pool; -#endif - /** - Default value for the standard deviation of the noise (error) distribution. - */ - constexpr double noise_standard_deviation = SEAL_HE_STD_PARMS_ERROR_STD_DEV; - - constexpr double noise_distribution_width_multiplier = 6; - - constexpr double noise_max_deviation = noise_standard_deviation * - noise_distribution_width_multiplier; - - /** - This data structure is a key-value storage that maps degrees of the polynomial modulus - to vectors of SmallModulus elements so that when used with the default value for the - standard deviation of the noise distribution (noise_standard_deviation), the security - level is at least 128 bits according to http://HomomorphicEncryption.org. This makes - it easy for non-expert users to select secure parameters. - */ - extern const std::map> default_coeff_modulus_128; - - /** - This data structure is a key-value storage that maps degrees of the polynomial modulus - to vectors of SmallModulus elements so that when used with the default value for the - standard deviation of the noise distribution (noise_standard_deviation), the security - level is at least 192 bits according to http://HomomorphicEncryption.org. This makes - it easy for non-expert users to select secure parameters. - */ - extern const std::map> default_coeff_modulus_192; - - /** - This data structure is a key-value storage that maps degrees of the polynomial modulus - to vectors of SmallModulus elements so that when used with the default value for the - standard deviation of the noise distribution (noise_standard_deviation), the security - level is at least 256 bits according to http://HomomorphicEncryption.org. This makes - it easy for non-expert users to select secure parameters. - */ - extern const std::map> default_coeff_modulus_256; - - // For internal use only, do not modify - namespace internal_mods - { - // Prime, 61 bits, and congruent to 1 mod 2^18 - extern const SmallModulus m_sk; - - // Non-prime; 2^32 - extern const SmallModulus m_tilde; - - // Prime, 61 bits, and congruent to 1 mod 2^18 - extern const SmallModulus gamma; - - // For internal use only, all primes 61 bits and congruent to 1 mod 2^18 - extern const std::vector aux_small_mods; - } - } - } -} diff --git a/SEAL/native/src/seal/util/hash.cpp b/SEAL/native/src/seal/util/hash.cpp deleted file mode 100644 index 95445d5..0000000 --- a/SEAL/native/src/seal/util/hash.cpp +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/util/hash.h" -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/pointer.h" -#include "seal/util/globals.h" -#include "seal/memorymanager.h" -#include - -using namespace std; - -namespace seal -{ - namespace util - { - // For C++14 compatibility need to define static constexpr - // member variables with no initialization here. - constexpr std::uint64_t HashFunction::sha3_round_consts[sha3_round_count]; - - constexpr std::uint8_t HashFunction::sha3_rho[24]; - - constexpr HashFunction::sha3_block_type HashFunction::sha3_zero_block; - - void HashFunction::keccak_1600(sha3_state_type &state) noexcept - { - for (uint8_t round = 0; round < sha3_round_count; round++) - { - // theta - uint64_t C[5]; - uint64_t D[5]; - for (uint8_t x = 0; x < 5; x++) - { - C[x] = state[x][0]; - for (uint8_t y = 1; y < 5; y++) - { - C[x] ^= state[x][y]; - } - } - for (uint8_t x = 0; x < 5; x++) - { - D[x] = C[(x + 4) % 5] ^ rot(C[(x + 1) % 5], 1); - for (uint8_t y = 0; y < 5; y++) - { - state[x][y] ^= D[x]; - } - } - - // rho and pi - uint64_t ind_x = 1; - uint64_t ind_y = 0; - uint64_t curr = state[ind_x][ind_y]; - for (uint8_t i = 0; i < 24; i++) - { - uint64_t ind_X = ind_y; - uint64_t ind_Y = (2 * ind_x + 3 * ind_y) % 5; - uint64_t temp = state[ind_X][ind_Y]; - state[ind_X][ind_Y] = rot(curr, sha3_rho[i]); - curr = temp; - ind_x = ind_X; - ind_y = ind_Y; - } - - // xi - for (uint8_t y = 0; y < 5; y++) - { - for (uint8_t x = 0; x < 5; x++) - { - C[x] = state[x][y]; - } - for (uint8_t x = 0; x < 5; x++) - { - state[x][y] = C[x] ^ ((~C[(x + 1) % 5]) & C[(x + 2) % 5]); - } - } - - // iota - state[0][0] ^= sha3_round_consts[round]; - } - } - - void HashFunction::sha3_hash(const uint64_t *input, size_t uint64_count, - sha3_block_type &sha3_block) - { -#ifdef SEAL_DEBUG - if (input == nullptr) - { - throw invalid_argument("input cannot be null"); - } -#endif - // Padding - auto pool = MemoryManager::GetPool(); - size_t padded_uint64_count = sha3_rate_uint64_count * ((uint64_count / sha3_rate_uint64_count) + 1); - auto padded_input(allocate_uint(padded_uint64_count, pool)); - set_uint_uint(input, uint64_count, padded_input.get()); - for (size_t i = uint64_count; i < padded_uint64_count; i++) - { - padded_input[i] = 0; - if (i == uint64_count) - { - padded_input[i] |= 0x6; - } - if (i == padded_uint64_count - 1) - { - padded_input[i] |= uint64_t(1) << 63; - } - } - - // Absorb - sha3_state_type state; - memset(state, 0, sha3_state_uint64_count * static_cast(bytes_per_uint64)); - for (size_t i = 0; i < padded_uint64_count; i += sha3_rate_uint64_count) - { - sponge_absorb(padded_input.get() + i, state); - } - - sha3_block = sha3_zero_block; - sponge_squeeze(state, sha3_block); - } - } -} diff --git a/SEAL/native/src/seal/util/hash.h b/SEAL/native/src/seal/util/hash.h deleted file mode 100644 index 770f3ad..0000000 --- a/SEAL/native/src/seal/util/hash.h +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include "seal/util/defines.h" - -namespace seal -{ - namespace util - { - class HashFunction - { - public: - HashFunction() = delete; - - static constexpr std::size_t sha3_block_uint64_count = 4; - - using sha3_block_type = std::array; - - static constexpr sha3_block_type sha3_zero_block{ { 0, 0, 0, 0 } }; - - static void sha3_hash(const std::uint64_t *input, std::size_t uint64_count, - sha3_block_type &destination); - - inline static void sha3_hash(std::uint64_t input, sha3_block_type &destination) - { - sha3_hash(&input, 1, destination); - } - - private: - static constexpr std::uint8_t sha3_round_count = 24; - - // Rate 1088 = 17 * 64 bits - static constexpr std::uint8_t sha3_rate_uint64_count = 17; - - // Capacity 512 = 8 * 64 bits - static constexpr std::uint8_t sha3_capacity_uint64_count = 8; - - // State size = 1600 = 25 * 64 bits - static constexpr std::uint8_t sha3_state_uint64_count = 25; - - using sha3_state_type = std::uint64_t[5][5]; - - static constexpr std::uint8_t sha3_rho[24]{ - 1, 3, 6, 10, 15, 21, - 28, 36, 45, 55, 2, 14, - 27, 41, 56, 8, 25, 43, - 62, 18, 39, 61, 20, 44 - }; - - static constexpr std::uint64_t sha3_round_consts[sha3_round_count]{ - 0x0000000000000001, 0x0000000000008082, 0x800000000000808a, - 0x8000000080008000, 0x000000000000808b, 0x0000000080000001, - 0x8000000080008081, 0x8000000000008009, 0x000000000000008a, - 0x0000000000000088, 0x0000000080008009, 0x000000008000000a, - 0x000000008000808b, 0x800000000000008b, 0x8000000000008089, - 0x8000000000008003, 0x8000000000008002, 0x8000000000000080, - 0x000000000000800a, 0x800000008000000a, 0x8000000080008081, - 0x8000000000008080, 0x0000000080000001, 0x8000000080008008 - }; - - SEAL_NODISCARD inline static std::uint64_t rot( - std::uint64_t input, std::uint8_t s) noexcept - { - return (input << s) | (input >> (64 - s)); - } - - static void keccak_1600(sha3_state_type &state) noexcept; - - inline static void sponge_absorb( - const std::uint64_t sha3_block[sha3_rate_uint64_count], - sha3_state_type &state) noexcept - { - //for (std::uint8_t x = 0; x < 5; x++) - //{ - // for (std::uint8_t y = 0; y < 5; y++) - // { - // std::uint8_t index = 5 * y + x; - // state[x][y] ^= index < sha3_rate_uint64_count ? sha3_block[index] : std::uint64_t(0); - // } - //} - - state[0][0] ^= 0 < sha3_rate_uint64_count ? sha3_block[0] : std::uint64_t(0); - state[0][1] ^= 5 < sha3_rate_uint64_count ? sha3_block[5] : std::uint64_t(0); - state[0][2] ^= 10 < sha3_rate_uint64_count ? sha3_block[10] : std::uint64_t(0); - state[0][3] ^= 15 < sha3_rate_uint64_count ? sha3_block[15] : std::uint64_t(0); - state[0][4] ^= 20 < sha3_rate_uint64_count ? sha3_block[20] : std::uint64_t(0); - - state[1][0] ^= 1 < sha3_rate_uint64_count ? sha3_block[1] : std::uint64_t(0); - state[1][1] ^= 6 < sha3_rate_uint64_count ? sha3_block[6] : std::uint64_t(0); - state[1][2] ^= 11 < sha3_rate_uint64_count ? sha3_block[11] : std::uint64_t(0); - state[1][3] ^= 16 < sha3_rate_uint64_count ? sha3_block[16] : std::uint64_t(0); - state[1][4] ^= 21 < sha3_rate_uint64_count ? sha3_block[21] : std::uint64_t(0); - - state[2][0] ^= 2 < sha3_rate_uint64_count ? sha3_block[2] : std::uint64_t(0); - state[2][1] ^= 7 < sha3_rate_uint64_count ? sha3_block[7] : std::uint64_t(0); - state[2][2] ^= 12 < sha3_rate_uint64_count ? sha3_block[12] : std::uint64_t(0); - state[2][3] ^= 17 < sha3_rate_uint64_count ? sha3_block[17] : std::uint64_t(0); - state[2][4] ^= 22 < sha3_rate_uint64_count ? sha3_block[22] : std::uint64_t(0); - - state[3][0] ^= 3 < sha3_rate_uint64_count ? sha3_block[3] : std::uint64_t(0); - state[3][1] ^= 8 < sha3_rate_uint64_count ? sha3_block[8] : std::uint64_t(0); - state[3][2] ^= 13 < sha3_rate_uint64_count ? sha3_block[13] : std::uint64_t(0); - state[3][3] ^= 18 < sha3_rate_uint64_count ? sha3_block[18] : std::uint64_t(0); - state[3][4] ^= 23 < sha3_rate_uint64_count ? sha3_block[23] : std::uint64_t(0); - - state[4][0] ^= 4 < sha3_rate_uint64_count ? sha3_block[4] : std::uint64_t(0); - state[4][1] ^= 9 < sha3_rate_uint64_count ? sha3_block[9] : std::uint64_t(0); - state[4][2] ^= 14 < sha3_rate_uint64_count ? sha3_block[14] : std::uint64_t(0); - state[4][3] ^= 19 < sha3_rate_uint64_count ? sha3_block[19] : std::uint64_t(0); - state[4][4] ^= 24 < sha3_rate_uint64_count ? sha3_block[24] : std::uint64_t(0); - - keccak_1600(state); - } - - inline static void sponge_squeeze(const sha3_state_type &sha3_state, - sha3_block_type &sha3_block) noexcept - { - // Trivial in this case: we simply output the first blocks of the state - static_assert(sha3_block_uint64_count == 4, "sha3_block_uint64_count must equal 4"); - - sha3_block[0] = sha3_state[0][0]; - sha3_block[1] = sha3_state[1][0]; - sha3_block[2] = sha3_state[2][0]; - sha3_block[3] = sha3_state[3][0]; - } - }; - } -} diff --git a/SEAL/native/src/seal/util/hestdparms.h b/SEAL/native/src/seal/util/hestdparms.h deleted file mode 100644 index 55cfdf6..0000000 --- a/SEAL/native/src/seal/util/hestdparms.h +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include "seal/util/defines.h" - -namespace seal -{ - namespace util - { - /** - Largest allowed bit counts for coeff_modulus based on the security estimates from - HomomorphicEncryption.org security standard. Microsoft SEAL samples the secret key - from a ternary {-1, 0, 1} distribution. - */ - // Ternary secret; 128 bits classical security - SEAL_NODISCARD constexpr int SEAL_HE_STD_PARMS_128_TC( - std::size_t poly_modulus_degree) noexcept - { - switch (poly_modulus_degree) - { - case std::size_t(1024): return 27; - case std::size_t(2048): return 54; - case std::size_t(4096): return 109; - case std::size_t(8192): return 218; - case std::size_t(16384): return 438; - case std::size_t(32768): return 881; - } - return 0; - } - - // Ternary secret; 192 bits classical security - SEAL_NODISCARD constexpr int SEAL_HE_STD_PARMS_192_TC( - std::size_t poly_modulus_degree) noexcept - { - switch (poly_modulus_degree) - { - case std::size_t(1024): return 19; - case std::size_t(2048): return 37; - case std::size_t(4096): return 75; - case std::size_t(8192): return 152; - case std::size_t(16384): return 305; - case std::size_t(32768): return 611; - } - return 0; - } - - // Ternary secret; 256 bits classical security - SEAL_NODISCARD constexpr int SEAL_HE_STD_PARMS_256_TC( - std::size_t poly_modulus_degree) noexcept - { - switch (poly_modulus_degree) - { - case std::size_t(1024): return 14; - case std::size_t(2048): return 29; - case std::size_t(4096): return 58; - case std::size_t(8192): return 118; - case std::size_t(16384): return 237; - case std::size_t(32768): return 476; - } - return 0; - } - - // Ternary secret; 128 bits quantum security - SEAL_NODISCARD constexpr int SEAL_HE_STD_PARMS_128_TQ( - std::size_t poly_modulus_degree) noexcept - { - switch (poly_modulus_degree) - { - case std::size_t(1024): return 25; - case std::size_t(2048): return 51; - case std::size_t(4096): return 101; - case std::size_t(8192): return 202; - case std::size_t(16384): return 411; - case std::size_t(32768): return 827; - } - return 0; - } - - // Ternary secret; 192 bits quantum security - SEAL_NODISCARD constexpr int SEAL_HE_STD_PARMS_192_TQ( - std::size_t poly_modulus_degree) noexcept - { - switch (poly_modulus_degree) - { - case std::size_t(1024): return 17; - case std::size_t(2048): return 35; - case std::size_t(4096): return 70; - case std::size_t(8192): return 141; - case std::size_t(16384): return 284; - case std::size_t(32768): return 571; - } - return 0; - } - - // Ternary secret; 256 bits quantum security - SEAL_NODISCARD constexpr int SEAL_HE_STD_PARMS_256_TQ( - std::size_t poly_modulus_degree) noexcept - { - switch (poly_modulus_degree) - { - case std::size_t(1024): return 13; - case std::size_t(2048): return 27; - case std::size_t(4096): return 54; - case std::size_t(8192): return 109; - case std::size_t(16384): return 220; - case std::size_t(32768): return 443; - } - return 0; - } - - // Standard deviation for error distribution - constexpr double SEAL_HE_STD_PARMS_ERROR_STD_DEV = 3.20; - } -} diff --git a/SEAL/native/src/seal/util/locks.h b/SEAL/native/src/seal/util/locks.h deleted file mode 100644 index e0bd1c9..0000000 --- a/SEAL/native/src/seal/util/locks.h +++ /dev/null @@ -1,301 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include "seal/util/defines.h" - -#ifdef SEAL_USE_SHARED_MUTEX -#include - -namespace seal -{ - namespace util - { - using ReaderLock = std::shared_lock; - - using WriterLock = std::unique_lock; - - class SEAL_NODISCARD ReaderWriterLocker - { - public: - ReaderWriterLocker() = default; - - SEAL_NODISCARD inline ReaderLock acquire_read() - { - return ReaderLock(rw_lock_mutex_); - } - - SEAL_NODISCARD inline WriterLock acquire_write() - { - return WriterLock(rw_lock_mutex_); - } - - SEAL_NODISCARD inline ReaderLock try_acquire_read() noexcept - { - return ReaderLock(rw_lock_mutex_, std::try_to_lock); - } - - SEAL_NODISCARD inline WriterLock try_acquire_write() noexcept - { - return WriterLock(rw_lock_mutex_, std::try_to_lock); - } - - private: - ReaderWriterLocker(const ReaderWriterLocker ©) = delete; - - ReaderWriterLocker &operator =(const ReaderWriterLocker &assign) = delete; - - std::shared_mutex rw_lock_mutex_{}; - }; - } -} -#else -#include - -namespace seal -{ - namespace util - { - struct try_to_lock_t - { - }; - - constexpr try_to_lock_t try_to_lock{}; - - class ReaderWriterLocker; - - class SEAL_NODISCARD ReaderLock - { - public: - ReaderLock() noexcept : locker_(nullptr) - { - } - - ReaderLock(ReaderLock &&move) noexcept : locker_(move.locker_) - { - move.locker_ = nullptr; - } - - ReaderLock(ReaderWriterLocker &locker) noexcept : locker_(nullptr) - { - acquire(locker); - } - - ReaderLock(ReaderWriterLocker &locker, try_to_lock_t) noexcept : - locker_(nullptr) - { - try_acquire(locker); - } - - ~ReaderLock() noexcept - { - unlock(); - } - - SEAL_NODISCARD inline bool owns_lock() const noexcept - { - return locker_ != nullptr; - } - - void unlock() noexcept; - - inline void swap_with(ReaderLock &lock) noexcept - { - std::swap(locker_, lock.locker_); - } - - inline ReaderLock &operator =(ReaderLock &&lock) noexcept - { - swap_with(lock); - lock.unlock(); - return *this; - } - - private: - void acquire(ReaderWriterLocker &locker) noexcept; - - bool try_acquire(ReaderWriterLocker &locker) noexcept; - - ReaderWriterLocker *locker_; - }; - - class SEAL_NODISCARD WriterLock - { - public: - WriterLock() noexcept : locker_(nullptr) - { - } - - WriterLock(WriterLock &&move) noexcept : locker_(move.locker_) - { - move.locker_ = nullptr; - } - - WriterLock(ReaderWriterLocker &locker) noexcept : locker_(nullptr) - { - acquire(locker); - } - - WriterLock(ReaderWriterLocker &locker, try_to_lock_t) noexcept : - locker_(nullptr) - { - try_acquire(locker); - } - - ~WriterLock() noexcept - { - unlock(); - } - - SEAL_NODISCARD inline bool owns_lock() const noexcept - { - return locker_ != nullptr; - } - - void unlock() noexcept; - - inline void swap_with(WriterLock &lock) noexcept - { - std::swap(locker_, lock.locker_); - } - - inline WriterLock &operator =(WriterLock &&lock) noexcept - { - swap_with(lock); - lock.unlock(); - return *this; - } - - private: - void acquire(ReaderWriterLocker &locker) noexcept; - - bool try_acquire(ReaderWriterLocker &locker) noexcept; - - ReaderWriterLocker *locker_; - }; - - class SEAL_NODISCARD ReaderWriterLocker - { - friend class ReaderLock; - - friend class WriterLock; - - public: - ReaderWriterLocker() noexcept : reader_locks_(0), writer_locked_(false) - { - } - - SEAL_NODISCARD inline ReaderLock acquire_read() noexcept - { - return ReaderLock(*this); - } - - SEAL_NODISCARD inline WriterLock acquire_write() noexcept - { - return WriterLock(*this); - } - - SEAL_NODISCARD inline ReaderLock try_acquire_read() noexcept - { - return ReaderLock(*this, try_to_lock); - } - - SEAL_NODISCARD inline WriterLock try_acquire_write() noexcept - { - return WriterLock(*this, try_to_lock); - } - - private: - ReaderWriterLocker(const ReaderWriterLocker ©) = delete; - - ReaderWriterLocker &operator =(const ReaderWriterLocker &assign) = delete; - - std::atomic reader_locks_; - - std::atomic writer_locked_; - }; - - inline void ReaderLock::unlock() noexcept - { - if (locker_ == nullptr) - { - return; - } - locker_->reader_locks_.fetch_sub(1, std::memory_order_release); - locker_ = nullptr; - } - - inline void ReaderLock::acquire(ReaderWriterLocker &locker) noexcept - { - unlock(); - do - { - locker.reader_locks_.fetch_add(1, std::memory_order_acquire); - locker_ = &locker; - if (locker.writer_locked_.load(std::memory_order_acquire)) - { - unlock(); - while (locker.writer_locked_.load(std::memory_order_acquire)); - } - } while (locker_ == nullptr); - } - - SEAL_NODISCARD inline bool ReaderLock::try_acquire( - ReaderWriterLocker &locker) noexcept - { - unlock(); - locker.reader_locks_.fetch_add(1, std::memory_order_acquire); - locker_ = &locker; - if (locker.writer_locked_.load(std::memory_order_acquire)) - { - unlock(); - return false; - } - return true; - } - - inline void WriterLock::acquire(ReaderWriterLocker &locker) noexcept - { - unlock(); - bool expected = false; - while (!locker.writer_locked_.compare_exchange_strong( - expected, true, std::memory_order_acquire)) - { - expected = false; - } - locker_ = &locker; - while (locker.reader_locks_.load(std::memory_order_acquire) != 0); - } - - SEAL_NODISCARD inline bool WriterLock::try_acquire( - ReaderWriterLocker &locker) noexcept - { - unlock(); - bool expected = false; - if (!locker.writer_locked_.compare_exchange_strong( - expected, true, std::memory_order_acquire)) - { - return false; - } - locker_ = &locker; - if (locker.reader_locks_.load(std::memory_order_acquire) != 0) - { - unlock(); - return false; - } - return true; - } - - inline void WriterLock::unlock() noexcept - { - if (locker_ == nullptr) - { - return; - } - locker_->writer_locked_.store(false, std::memory_order_release); - locker_ = nullptr; - } - } -} -#endif diff --git a/SEAL/native/src/seal/util/mempool.cpp b/SEAL/native/src/seal/util/mempool.cpp deleted file mode 100644 index 37df0d2..0000000 --- a/SEAL/native/src/seal/util/mempool.cpp +++ /dev/null @@ -1,502 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include -#include -#include -#include -#include "seal/util/mempool.h" -#include "seal/util/common.h" -#include "seal/util/uintarith.h" - -using namespace std; - -namespace seal -{ - namespace util - { - MemoryPoolHeadMT::MemoryPoolHeadMT(size_t item_byte_count, - bool clear_on_destruction) : - clear_on_destruction_(clear_on_destruction), - locked_(false), item_byte_count_(item_byte_count), - item_count_(MemoryPool::first_alloc_count), - first_item_(nullptr) - { - if ((item_byte_count_ == 0) || - (item_byte_count_ > MemoryPool::max_batch_alloc_byte_count) || - (mul_safe(item_byte_count_, MemoryPool::first_alloc_count) > - MemoryPool::max_batch_alloc_byte_count)) - { - throw invalid_argument("invalid allocation size"); - } - - // Initial allocation - allocation new_alloc; - try - { - new_alloc.data_ptr = new SEAL_BYTE[ - mul_safe(MemoryPool::first_alloc_count, item_byte_count_)]; - } - catch (const bad_alloc &) - { - // Allocation failed; rethrow - throw; - } - - new_alloc.size = MemoryPool::first_alloc_count; - new_alloc.free = MemoryPool::first_alloc_count; - new_alloc.head_ptr = new_alloc.data_ptr; - allocs_.clear(); - allocs_.push_back(new_alloc); - } - - MemoryPoolHeadMT::~MemoryPoolHeadMT() noexcept - { - bool expected = false; - while (!locked_.compare_exchange_strong( - expected, true, memory_order_acquire)) - { - expected = false; - } - - // Delete the items (but not the memory) - MemoryPoolItem *curr_item = first_item_; - while (curr_item) - { - MemoryPoolItem *next_item = curr_item->next(); - delete curr_item; - curr_item = next_item; - } - first_item_ = nullptr; - - // Do we need to clear the memory? - if (clear_on_destruction_) - { - // Delete the memory - for (auto &alloc : allocs_) - { - size_t curr_alloc_byte_count = mul_safe(item_byte_count_, alloc.size); - volatile SEAL_BYTE *data_ptr = reinterpret_cast(alloc.data_ptr); - while (curr_alloc_byte_count--) - { - *data_ptr++ = static_cast(0); - } - - // Delete this allocation - delete[] alloc.data_ptr; - } - } - else - { - // Delete the memory - for (auto &alloc : allocs_) - { - // Delete this allocation - delete[] alloc.data_ptr; - } - } - - allocs_.clear(); - } - - MemoryPoolItem *MemoryPoolHeadMT::get() - { - bool expected = false; - while (!locked_.compare_exchange_strong( - expected, true, memory_order_acquire)) - { - expected = false; - } - MemoryPoolItem *old_first = first_item_; - - // Is pool empty? - if (old_first == nullptr) - { - allocation &last_alloc = allocs_.back(); - MemoryPoolItem *new_item = nullptr; - if (last_alloc.free > 0) - { - // Pool is empty; there is memory - new_item = new MemoryPoolItem(last_alloc.head_ptr); - last_alloc.free--; - last_alloc.head_ptr += item_byte_count_; - } - else - { - // Pool is empty; there is no memory - allocation new_alloc; - - // Increase allocation size unless we are already at max - size_t new_size = safe_cast( - ceil(MemoryPool::alloc_size_multiplier * - static_cast(last_alloc.size))); - size_t new_alloc_byte_count = mul_safe(new_size, item_byte_count_); - if (new_alloc_byte_count > - MemoryPool::max_batch_alloc_byte_count) - { - new_size = last_alloc.size; - new_alloc_byte_count = new_size * item_byte_count_; - } - - try - { - new_alloc.data_ptr = new SEAL_BYTE[new_alloc_byte_count]; - } - catch (const bad_alloc &) - { - // Allocation failed; rethrow - throw; - } - - new_alloc.size = new_size; - new_alloc.free = new_size - 1; - new_alloc.head_ptr = new_alloc.data_ptr + item_byte_count_; - allocs_.push_back(new_alloc); - item_count_ += new_size; - new_item = new MemoryPoolItem(new_alloc.data_ptr); - } - - locked_.store(false, memory_order_release); - return new_item; - } - - // Pool is not empty - first_item_ = old_first->next(); - old_first->next() = nullptr; - locked_.store(false, memory_order_release); - return old_first; - } - - MemoryPoolHeadST::MemoryPoolHeadST(size_t item_byte_count, - bool clear_on_destruction) : - clear_on_destruction_(clear_on_destruction), - item_byte_count_(item_byte_count), - item_count_(MemoryPool::first_alloc_count), - first_item_(nullptr) - { - if ((item_byte_count_ == 0) || - (item_byte_count_ > MemoryPool::max_batch_alloc_byte_count) || - (mul_safe(item_byte_count_, MemoryPool::first_alloc_count) > - MemoryPool::max_batch_alloc_byte_count)) - { - throw invalid_argument("invalid allocation size"); - } - - // Initial allocation - allocation new_alloc; - try - { - new_alloc.data_ptr = new SEAL_BYTE[ - mul_safe(MemoryPool::first_alloc_count, item_byte_count_)]; - } - catch (const bad_alloc &) - { - // Allocation failed; rethrow - throw; - } - - new_alloc.size = MemoryPool::first_alloc_count; - new_alloc.free = MemoryPool::first_alloc_count; - new_alloc.head_ptr = new_alloc.data_ptr; - allocs_.clear(); - allocs_.push_back(new_alloc); - } - - MemoryPoolHeadST::~MemoryPoolHeadST() noexcept - { - // Delete the items (but not the memory) - MemoryPoolItem *curr_item = first_item_; - while(curr_item) - { - MemoryPoolItem *next_item = curr_item->next(); - delete curr_item; - curr_item = next_item; - } - first_item_ = nullptr; - - // Do we need to clear the memory? - if (clear_on_destruction_) - { - // Delete the memory - for (auto &alloc : allocs_) - { - size_t curr_alloc_byte_count = mul_safe(item_byte_count_, alloc.size); - volatile SEAL_BYTE *data_ptr = reinterpret_cast(alloc.data_ptr); - while (curr_alloc_byte_count--) - { - *data_ptr++ = static_cast(0); - } - - // Delete this allocation - delete[] alloc.data_ptr; - } - } - else - { - // Delete the memory - for (auto &alloc : allocs_) - { - // Delete this allocation - delete[] alloc.data_ptr; - } - } - - allocs_.clear(); - } - - MemoryPoolItem *MemoryPoolHeadST::get() - { - MemoryPoolItem *old_first = first_item_; - - // Is pool empty? - if (old_first == nullptr) - { - allocation &last_alloc = allocs_.back(); - MemoryPoolItem *new_item = nullptr; - if (last_alloc.free > 0) - { - // Pool is empty; there is memory - new_item = new MemoryPoolItem(last_alloc.head_ptr); - last_alloc.free--; - last_alloc.head_ptr += item_byte_count_; - } - else - { - // Pool is empty; there is no memory - allocation new_alloc; - - // Increase allocation size unless we are already at max - size_t new_size = safe_cast( - ceil(MemoryPool::alloc_size_multiplier * - static_cast(last_alloc.size))); - size_t new_alloc_byte_count = mul_safe(new_size, item_byte_count_); - if (new_alloc_byte_count > - MemoryPool::max_batch_alloc_byte_count) - { - new_size = last_alloc.size; - new_alloc_byte_count = new_size * item_byte_count_; - } - - try - { - new_alloc.data_ptr = new SEAL_BYTE[new_alloc_byte_count]; - } - catch (const bad_alloc &) - { - // Allocation failed; rethrow - throw; - } - - new_alloc.size = new_size; - new_alloc.free = new_size - 1; - new_alloc.head_ptr = new_alloc.data_ptr + item_byte_count_; - allocs_.push_back(new_alloc); - item_count_ += new_size; - new_item = new MemoryPoolItem(new_alloc.data_ptr); - } - - return new_item; - } - - // Pool is not empty - first_item_ = old_first->next(); - old_first->next() = nullptr; - return old_first; - } - - const size_t MemoryPool::max_single_alloc_byte_count = - []() -> size_t { - int bit_shift = static_cast( - ceil(log2(MemoryPool::alloc_size_multiplier))); - if (bit_shift < 0 || unsigned_geq(bit_shift, - sizeof(size_t) * static_cast(bits_per_byte))) - { - throw logic_error("alloc_size_multiplier too large"); - } - return numeric_limits::max() >> bit_shift; - }(); - - const size_t MemoryPool::max_batch_alloc_byte_count = - []() -> size_t { - int bit_shift = static_cast( - ceil(log2(MemoryPool::alloc_size_multiplier))); - if (bit_shift < 0 || unsigned_geq(bit_shift, - sizeof(size_t) * static_cast(bits_per_byte))) - { - throw logic_error("alloc_size_multiplier too large"); - } - return numeric_limits::max() >> bit_shift; - }(); - - MemoryPoolMT::~MemoryPoolMT() noexcept - { - WriterLock lock(pools_locker_.acquire_write()); - for (MemoryPoolHead *head : pools_) - { - delete head; - } - pools_.clear(); - } - - Pointer MemoryPoolMT::get_for_byte_count(size_t byte_count) - { - if (byte_count > max_single_alloc_byte_count) - { - throw invalid_argument("invalid allocation size"); - } - else if (byte_count == 0) - { - return Pointer(); - } - - // Attempt to find size. - ReaderLock reader_lock(pools_locker_.acquire_read()); - size_t start = 0; - size_t end = pools_.size(); - while (start < end) - { - size_t mid = (start + end) / 2; - MemoryPoolHead *mid_head = pools_[mid]; - size_t mid_byte_count = mid_head->item_byte_count(); - if (byte_count < mid_byte_count) - { - start = mid + 1; - } - else if (byte_count > mid_byte_count) - { - end = mid; - } - else - { - return Pointer(mid_head); - } - } - reader_lock.unlock(); - - // Size was not found, so obtain an exclusive lock and search again. - WriterLock writer_lock(pools_locker_.acquire_write()); - start = 0; - end = pools_.size(); - while (start < end) - { - size_t mid = (start + end) / 2; - MemoryPoolHead *mid_head = pools_[mid]; - size_t mid_byte_count = mid_head->item_byte_count(); - if (byte_count < mid_byte_count) - { - start = mid + 1; - } - else if (byte_count > mid_byte_count) - { - end = mid; - } - else - { - return Pointer(mid_head); - } - } - - // Size was still not found, but we own an exclusive lock so just add it, - // but first check if we are at maximum pool head count already. - if (pools_.size() >= max_pool_head_count) - { - throw runtime_error("maximum pool head count reached"); - } - - MemoryPoolHead *new_head = new MemoryPoolHeadMT(byte_count, clear_on_destruction_); - if (!pools_.empty()) - { - pools_.insert(pools_.begin() + static_cast(start), new_head); - } - else - { - pools_.emplace_back(new_head); - } - - return Pointer(new_head); - } - - size_t MemoryPoolMT::alloc_byte_count() const - { - ReaderLock lock(pools_locker_.acquire_read()); - - return accumulate(pools_.cbegin(), pools_.cend(), size_t(0), - [](size_t byte_count, MemoryPoolHead *head) { - return add_safe(byte_count, - mul_safe(head->item_count(), head->item_byte_count())); - }); - } - - MemoryPoolST::~MemoryPoolST() noexcept - { - for (MemoryPoolHead *head : pools_) - { - delete head; - } - pools_.clear(); - } - - Pointer MemoryPoolST::get_for_byte_count(size_t byte_count) - { - if (byte_count > MemoryPool::max_single_alloc_byte_count) - { - throw invalid_argument("invalid allocation size"); - } - else if (byte_count == 0) - { - return Pointer(); - } - - // Attempt to find size. - size_t start = 0; - size_t end = pools_.size(); - while (start < end) - { - size_t mid = (start + end) / 2; - MemoryPoolHead *mid_head = pools_[mid]; - size_t mid_byte_count = mid_head->item_byte_count(); - if (byte_count < mid_byte_count) - { - start = mid + 1; - } - else if (byte_count > mid_byte_count) - { - end = mid; - } - else - { - return Pointer(mid_head); - } - } - - // Size was not found so just add it, but first check if we are at - // maximum pool head count already. - if (pools_.size() >= max_pool_head_count) - { - throw runtime_error("maximum pool head count reached"); - } - - MemoryPoolHead *new_head = new MemoryPoolHeadST(byte_count, clear_on_destruction_); - if (!pools_.empty()) - { - pools_.insert(pools_.begin() + static_cast(start), new_head); - } - else - { - pools_.emplace_back(new_head); - } - - return Pointer(new_head); - } - - size_t MemoryPoolST::alloc_byte_count() const - { - return accumulate(pools_.cbegin(), pools_.cend(), size_t(0), - [](size_t byte_count, MemoryPoolHead *head) { - return add_safe(byte_count, - mul_safe(head->item_count(), head->item_byte_count())); - }); - } - } -} diff --git a/SEAL/native/src/seal/util/mempool.h b/SEAL/native/src/seal/util/mempool.h deleted file mode 100644 index 7f3615e..0000000 --- a/SEAL/native/src/seal/util/mempool.h +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "seal/util/defines.h" -#include "seal/util/globals.h" -#include "seal/util/common.h" -#include "seal/util/locks.h" - -namespace seal -{ - namespace util - { - template::value>> - class ConstPointer; - - template<> - class ConstPointer; - - template::value>> - class Pointer; - - class MemoryPoolItem - { - public: - MemoryPoolItem(SEAL_BYTE *data) noexcept : data_(data) - { - } - - SEAL_NODISCARD inline SEAL_BYTE *data() noexcept - { - return data_; - } - - SEAL_NODISCARD inline const SEAL_BYTE *data() const noexcept - { - return data_; - } - - SEAL_NODISCARD inline MemoryPoolItem* &next() noexcept - { - return next_; - } - - SEAL_NODISCARD inline const MemoryPoolItem *next() const noexcept - { - return next_; - } - - private: - MemoryPoolItem(const MemoryPoolItem ©) = delete; - - MemoryPoolItem &operator =(const MemoryPoolItem &assign) = delete; - - SEAL_BYTE *data_ = nullptr; - - MemoryPoolItem *next_ = nullptr; - }; - - class MemoryPoolHead - { - public: - struct allocation - { - allocation() : - size(0), data_ptr(nullptr), free(0), head_ptr(nullptr) - { - } - - // Size of the allocation (number of items it can hold) - std::size_t size; - - // Pointer to start of the allocation - SEAL_BYTE *data_ptr; - - // How much free space is left (number of items that still fit) - std::size_t free; - - // Pointer to current head of allocation - SEAL_BYTE *head_ptr; - }; - - // The overriding functions are noexcept(false) - virtual ~MemoryPoolHead() = default; - - // Byte size of the allocations (items) owned by this pool - virtual std::size_t item_byte_count() const noexcept = 0; - - // Total number of items allocated - virtual std::size_t item_count() const noexcept = 0; - - virtual MemoryPoolItem *get() = 0; - - // Return item back to this pool - virtual void add(MemoryPoolItem *new_first) noexcept = 0; - }; - - class MemoryPoolHeadMT : public MemoryPoolHead - { - public: - // Creates a new MemoryPoolHeadMT with allocation for one single item. - MemoryPoolHeadMT(std::size_t item_byte_count, - bool clear_on_destruction = false); - - ~MemoryPoolHeadMT() noexcept override; - - // Byte size of the allocations (items) owned by this pool - SEAL_NODISCARD inline std::size_t item_byte_count() const noexcept override - { - return item_byte_count_; - } - - // Returns the total number of items allocated - SEAL_NODISCARD inline std::size_t item_count() const noexcept override - { - return item_count_; - } - - MemoryPoolItem *get() override; - - inline void add(MemoryPoolItem *new_first) noexcept override - { - bool expected = false; - while (!locked_.compare_exchange_strong( - expected, true, std::memory_order_acquire)) - { - expected = false; - } - MemoryPoolItem *old_first = first_item_; - new_first->next() = old_first; - first_item_ = new_first; - locked_.store(false, std::memory_order_release); - } - - private: - MemoryPoolHeadMT(const MemoryPoolHeadMT ©) = delete; - - MemoryPoolHeadMT &operator =(const MemoryPoolHeadMT &assign) = delete; - - const bool clear_on_destruction_; - - mutable std::atomic locked_; - - const std::size_t item_byte_count_; - - volatile std::size_t item_count_; - - std::vector allocs_; - - MemoryPoolItem* volatile first_item_; - }; - - class MemoryPoolHeadST : public MemoryPoolHead - { - public: - // Creates a new MemoryPoolHeadST with allocation for one single item. - MemoryPoolHeadST(std::size_t item_byte_count, - bool clear_on_destruction = false); - - ~MemoryPoolHeadST() noexcept override; - - // Byte size of the allocations (items) owned by this pool - SEAL_NODISCARD inline std::size_t item_byte_count() const noexcept override - { - return item_byte_count_; - } - - // Returns the total number of items allocated - SEAL_NODISCARD inline std::size_t item_count() const noexcept override - { - return item_count_; - } - - SEAL_NODISCARD MemoryPoolItem *get() override; - - inline void add(MemoryPoolItem *new_first) noexcept override - { - new_first->next() = first_item_; - first_item_ = new_first; - } - - private: - MemoryPoolHeadST(const MemoryPoolHeadST ©) = delete; - - MemoryPoolHeadST &operator =(const MemoryPoolHeadST &assign) = delete; - - const bool clear_on_destruction_; - - std::size_t item_byte_count_; - - std::size_t item_count_; - - std::vector allocs_; - - MemoryPoolItem *first_item_; - }; - - class MemoryPool - { - public: - static constexpr double alloc_size_multiplier = 1.05; - - // Largest size of single allocation that can be requested from memory pool - static const std::size_t max_single_alloc_byte_count; - - // Number of different size allocations allowed by a single memory pool - static constexpr std::size_t max_pool_head_count = - std::numeric_limits::max(); - - // Largest allowed size of batch allocation - static const std::size_t max_batch_alloc_byte_count; - - static constexpr std::size_t first_alloc_count = 1; - - virtual ~MemoryPool() = default; - - virtual Pointer get_for_byte_count(std::size_t byte_count) = 0; - - virtual std::size_t pool_count() const = 0; - - virtual std::size_t alloc_byte_count() const = 0; - }; - - class MemoryPoolMT : public MemoryPool - { - public: - MemoryPoolMT(bool clear_on_destruction = false) : - clear_on_destruction_(clear_on_destruction) - { - }; - - ~MemoryPoolMT() noexcept override; - - SEAL_NODISCARD Pointer get_for_byte_count( - std::size_t byte_count) override; - - SEAL_NODISCARD inline std::size_t pool_count() const override - { - ReaderLock lock(pools_locker_.acquire_read()); - return pools_.size(); - } - - SEAL_NODISCARD std::size_t alloc_byte_count() const override; - - protected: - MemoryPoolMT(const MemoryPoolMT ©) = delete; - - MemoryPoolMT &operator =(const MemoryPoolMT &assign) = delete; - - const bool clear_on_destruction_; - - mutable ReaderWriterLocker pools_locker_; - - std::vector pools_; - }; - - class MemoryPoolST : public MemoryPool - { - public: - MemoryPoolST(bool clear_on_destruction = false) : - clear_on_destruction_(clear_on_destruction) - { - }; - - ~MemoryPoolST() noexcept override; - - SEAL_NODISCARD Pointer get_for_byte_count( - std::size_t byte_count) override; - - SEAL_NODISCARD inline std::size_t pool_count() const override - { - return pools_.size(); - } - - std::size_t alloc_byte_count() const override; - - protected: - MemoryPoolST(const MemoryPoolST ©) = delete; - - MemoryPoolST &operator =(const MemoryPoolST &assign) = delete; - - const bool clear_on_destruction_; - - std::vector pools_; - }; - } -} diff --git a/SEAL/native/src/seal/util/msvc.h b/SEAL/native/src/seal/util/msvc.h deleted file mode 100644 index 51484b6..0000000 --- a/SEAL/native/src/seal/util/msvc.h +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#if SEAL_COMPILER == SEAL_COMPILER_MSVC - -// Require Visual Studio 2017 version 15.3 or newer -#if (_MSC_VER < 1911) -#error "Microsoft Visual Studio 2017 version 15.3 or newer required" -#endif - -// Read in config.h -#include "seal/util/config.h" - -// Do not throw when Evaluator produces transparent ciphertexts -//#undef SEAL_THROW_ON_TRANSPARENT_CIPHERTEXT - -// Try to check presence of additional headers using __has_include -#ifdef __has_include - -// Check for MSGSL -#if __has_include() -#include -#define SEAL_USE_MSGSL -#else -#undef SEAL_USE_MSGSL -#endif //__has_include() - -#endif - -// In Visual Studio redefine std::byte (SEAL_BYTE) -#undef SEAL_USE_STD_BYTE - -// In Visual Studio for now we disable the use of std::shared_mutex -#undef SEAL_USE_SHARED_MUTEX - -// Are we compiling with C++17 or newer -#if (__cplusplus >= 201703L) - -// Use `if constexpr' -#define SEAL_USE_IF_CONSTEXPR - -// Use [[maybe_unused]] -#define SEAL_USE_MAYBE_UNUSED - -// Use [[nodiscard]] -#define SEAL_USE_NODISCARD - -#else -#undef SEAL_USE_IF_CONSTEXPR -#undef SEAL_USE_MAYBE_UNUSED -#undef SEAL_USE_NODISCARD -#endif - -// X64 -#ifdef _M_X64 - -#ifdef SEAL_USE_INTRIN -#include - -#ifdef SEAL_USE__UMUL128 -#pragma intrinsic(_umul128) -#define SEAL_MULTIPLY_UINT64_HW64(operand1, operand2, hw64) { \ - _umul128(operand1, operand2, hw64); \ -} - -#define SEAL_MULTIPLY_UINT64(operand1, operand2, result128) { \ - result128[0] = _umul128(operand1, operand2, result128 + 1); \ -} -#endif - -#ifdef SEAL_USE__BITSCANREVERSE64 -#pragma intrinsic(_BitScanReverse64) -#define SEAL_MSB_INDEX_UINT64(result, value) _BitScanReverse64(result, value) -#endif - -#ifdef SEAL_USE__ADDCARRY_U64 -#pragma intrinsic(_addcarry_u64) -#define SEAL_ADD_CARRY_UINT64(operand1, operand2, carry, result) _addcarry_u64( \ - carry, operand1, operand2, result) -#endif - -#ifdef SEAL_USE__SUBBORROW_U64 -#pragma intrinsic(_subborrow_u64) -#define SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result) _subborrow_u64( \ - borrow, operand1, operand2, result) -#endif - -#endif -#else -#undef SEAL_USE_INTRIN - -#endif //_M_X64 - -#endif diff --git a/SEAL/native/src/seal/util/numth.cpp b/SEAL/native/src/seal/util/numth.cpp deleted file mode 100644 index 7b63435..0000000 --- a/SEAL/native/src/seal/util/numth.cpp +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include "seal/util/numth.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarithsmallmod.h" - -using namespace std; - -namespace seal -{ - namespace util - { - vector conjugate_classes(uint64_t modulus, - uint64_t subgroup_generator) - { - if (!product_fits_in(modulus, subgroup_generator)) - { - throw invalid_argument("inputs too large"); - } - - vector classes{}; - for (uint64_t i = 0; i < modulus; i++) - { - if (gcd(i, modulus) > 1) - { - classes.push_back(0); - } - else - { - classes.push_back(i); - } - } - for (uint64_t i = 0; i < modulus; i++) - { - if (classes[i] == 0) - { - continue; - } - if (classes[i] < i) - { - // i is not a pivot, updated its pivot - classes[i] = classes[classes[i]]; - continue; - } - // If i is a pivot, update other pivots to point to it - uint64_t j = (i * subgroup_generator) % modulus; - while (classes[j] != i) - { - // Merge the equivalence classes of j and i - // Note: if classes[j] != j then classes[j] will be updated later, - // when we get to i = j and use the code for "i not pivot". - classes[classes[j]] = i; - j = (j * subgroup_generator) % modulus; - } - } - return classes; - } - - vector multiplicative_orders( - vector conjugate_classes, uint64_t modulus) - { - if (!product_fits_in(modulus, modulus)) - { - throw invalid_argument("inputs too large"); - } - - vector orders{}; - orders.push_back(0); - orders.push_back(1); - - for (uint64_t i = 2; i < modulus; i++) - { - if (conjugate_classes[i] <= 1) - { - orders.push_back(conjugate_classes[i]); - continue; - } - if (conjugate_classes[i] < i) - { - orders.push_back(orders[conjugate_classes[i]]); - continue; - } - uint64_t j = (i * i) % modulus; - uint64_t order = 2; - while (conjugate_classes[j] != 1) - { - j = (j * i) % modulus; - order++; - } - orders.push_back(order); - } - return orders; - } - - void babystep_giantstep(uint64_t modulus, - vector &baby_steps, vector &giant_steps) - { - int exponent = get_power_of_two(modulus); - if (exponent < 0) - { - throw invalid_argument("modulus must be a power of 2"); - } - - // Compute square root of modulus (k stores the baby steps) - uint64_t k = uint64_t(1) << (exponent / 2); - uint64_t l = modulus / k; - - baby_steps.clear(); - giant_steps.clear(); - - uint64_t m = mul_safe(modulus, uint64_t(2)); - uint64_t g = 3; // the generator - uint64_t kprime = k >> 1; - uint64_t value = 1; - for (uint64_t i = 0; i < kprime; i++) - { - baby_steps.push_back(value); - baby_steps.push_back(m - value); - value = mul_safe(value, g) % m; - } - - // now value should equal to g**kprime - uint64_t value2 = value; - for (uint64_t j = 0; j < l; j++) - { - giant_steps.push_back(value2); - value2 = mul_safe(value2, value) % m; - } - } - - pair decompose_babystep_giantstep( - uint64_t modulus, uint64_t input, - const vector &baby_steps, - const vector &giant_steps) - { - for (size_t i = 0; i < giant_steps.size(); i++) - { - uint64_t gs = giant_steps[i]; - for (size_t j = 0; j < baby_steps.size(); j++) - { - uint64_t bs = baby_steps[j]; - if (mul_safe(gs, bs) % modulus == input) - { - return { i, j }; - } - } - } - throw logic_error("failed to decompose input"); - } - - bool is_prime(const SmallModulus &modulus, size_t num_rounds) - { - uint64_t value = modulus.value(); - // First check the simplest cases. - if (value < 2) - { - return false; - } - if (2 == value) - { - return true; - } - if (0 == (value & 0x1)) - { - return false; - } - if (3 == value) - { - return true; - } - if (0 == (value % 3)) - { - return false; - } - if (5 == value) - { - return true; - } - if (0 == (value % 5)) - { - return false; - } - if (7 == value) - { - return true; - } - if (0 == (value % 7)) - { - return false; - } - if (11 == value) - { - return true; - } - if (0 == (value % 11)) - { - return false; - } - if (13 == value) - { - return true; - } - if (0 == (value % 13)) - { - return false; - } - - // Second, Miller-Rabin test. - // Find r and odd d that satisfy value = 2^r * d + 1. - uint64_t d = value - 1; - uint64_t r = 0; - while (0 == (d & 0x1)) - { - d >>= 1; - r++; - } - if (r == 0) - { - return false; - } - - // 1) Pick a = 2, check a^(value - 1). - // 2) Pick a randomly from [3, value - 1], check a^(value - 1). - // 3) Repeat 2) for another num_rounds - 2 times. - random_device rand; - uniform_int_distribution dist(3, value - 1); - for (size_t i = 0; i < num_rounds; i++) - { - uint64_t a = i ? dist(rand) : 2; - uint64_t x = exponentiate_uint_mod(a, d, modulus); - if (x == 1 || x == value - 1) - { - continue; - } - uint64_t count = 0; - do - { - x = multiply_uint_uint_mod(x, x, modulus); - count++; - } while (x != value - 1 && count < r - 1); - if (x != value - 1) - { - return false; - } - } - return true; - } - - vector get_primes(size_t ntt_size, int bit_size, size_t count) - { - if (!count) - { - throw invalid_argument("count must be positive"); - } - if (!ntt_size) - { - throw invalid_argument("ntt_size must be positive"); - } - if (bit_size >= 63 || bit_size <= 1) - { - throw invalid_argument("bit_size is invalid"); - } - - vector destination; - uint64_t factor = mul_safe(uint64_t(2), safe_cast(ntt_size)); - - // Start with 2^bit_size - 2 * ntt_size + 1 - uint64_t value = uint64_t(0x1) << bit_size; - try - { - value = sub_safe(value, factor) + 1; - } - catch (const out_of_range &) - { - throw logic_error("failed to find enough qualifying primes"); - } - - uint64_t lower_bound = uint64_t(0x1) << (bit_size - 1); - while (count > 0 && value > lower_bound) - { - SmallModulus new_mod(value); - if (new_mod.is_prime()) - { - destination.emplace_back(move(new_mod)); - count--; - } - value -= factor; - } - if (count > 0) - { - throw logic_error("failed to find enough qualifying primes"); - } - return destination; - } - } -} \ No newline at end of file diff --git a/SEAL/native/src/seal/util/numth.h b/SEAL/native/src/seal/util/numth.h deleted file mode 100644 index 7ac27a5..0000000 --- a/SEAL/native/src/seal/util/numth.h +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include "seal/util/defines.h" -#include "seal/smallmodulus.h" -#include "seal/util/common.h" -#include -#include -#include -#include - -namespace seal -{ - namespace util - { - SEAL_NODISCARD inline std::uint64_t gcd( - std::uint64_t x, std::uint64_t y) - { -#ifdef SEAL_DEBUG - if (x == 0) - { - std::invalid_argument("x cannot be zero"); - } - if (y == 0) - { - std::invalid_argument("y cannot be zero"); - } -#endif - if (x < y) - { - return gcd(y, x); - } - else if (y == 0) - { - return x; - } - else - { - std::uint64_t f = x % y; - if (f == 0) - { - return y; - } - else - { - return gcd(y, f); - } - } - } - - SEAL_NODISCARD inline auto xgcd(std::uint64_t x, std::uint64_t y) - -> std::tuple - { - /* Extended GCD: - Returns (gcd, x, y) where gcd is the greatest common divisor of a and b. - The numbers x, y are such that gcd = ax + by. - */ -#ifdef SEAL_DEBUG - if (x == 0) - { - std::invalid_argument("x cannot be zero"); - } - if (y == 0) - { - std::invalid_argument("y cannot be zero"); - } -#endif - std::int64_t prev_a = 1; - std::int64_t a = 0; - std::int64_t prev_b = 0; - std::int64_t b = 1; - - while (y != 0) - { - std::int64_t q = util::safe_cast(x / y); - std::int64_t temp = util::safe_cast(x % y); - x = y; - y = util::safe_cast(temp); - - temp = a; - a = util::sub_safe(prev_a, mul_safe(q, a)); - prev_a = temp; - - temp = b; - b = util::sub_safe(prev_b, mul_safe(q, b)); - prev_b = temp; - } - return std::make_tuple(x, prev_a, prev_b); - } - - inline bool try_mod_inverse(std::uint64_t value, - std::uint64_t modulus, std::uint64_t &result) - { -#ifdef SEAL_DEBUG - if (value == 0) - { - std::invalid_argument("value cannot be zero"); - } - if (modulus <= 1) - { - std::invalid_argument("modulus must be at least 2"); - } -#endif - auto gcd_tuple = xgcd(value, modulus); - if (std::get<0>(gcd_tuple) != 1) - { - return false; - } - else if (std::get<1>(gcd_tuple) < 0) - { - result = static_cast(std::get<1>(gcd_tuple)) + modulus; - return true; - } - else - { - result = static_cast(std::get<1>(gcd_tuple)); - return true; - } - } - - SEAL_NODISCARD std::vector multiplicative_orders( - std::vector conjugate_classes, - std::uint64_t modulus); - - SEAL_NODISCARD std::vector conjugate_classes( - std::uint64_t modulus, std::uint64_t subgroup_generator); - - void babystep_giantstep(std::uint64_t modulus, - std::vector &baby_steps, - std::vector &giant_steps); - - SEAL_NODISCARD auto decompose_babystep_giantstep( - std::uint64_t modulus, - std::uint64_t input, - const std::vector &baby_steps, - const std::vector &giant_steps) - -> std::pair; - - SEAL_NODISCARD bool is_prime( - const SmallModulus &modulus, std::size_t num_rounds = 40); - - SEAL_NODISCARD std::vector get_primes( - std::size_t ntt_size, int bit_size, std::size_t count); - - SEAL_NODISCARD inline SmallModulus get_prime( - std::size_t ntt_size, int bit_size) - { - return get_primes(ntt_size, bit_size, 1)[0]; - } - } -} diff --git a/SEAL/native/src/seal/util/pointer.h b/SEAL/native/src/seal/util/pointer.h deleted file mode 100644 index 48b00b8..0000000 --- a/SEAL/native/src/seal/util/pointer.h +++ /dev/null @@ -1,1206 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include "seal/util/defines.h" -#include "seal/util/common.h" -#include "seal/util/mempool.h" -#include -#include -#include - -namespace seal -{ - namespace util - { - // Specialization for SEAL_BYTE - template<> - class SEAL_NODISCARD Pointer - { - friend class MemoryPoolST; - friend class MemoryPoolMT; - - public: - template friend class Pointer; - template friend class ConstPointer; - - Pointer() = default; - - // Move of the same type - Pointer(Pointer &&source) noexcept : - data_(source.data_), head_(source.head_), - item_(source.item_), alias_(source.alias_) - { - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - // Move of the same type - Pointer(Pointer &&source, SEAL_BYTE value) : - Pointer(std::move(source)) - { - std::fill_n(data_, head_->item_byte_count(), value); - } - - SEAL_NODISCARD inline SEAL_BYTE &operator [](std::size_t index) - { - return data_[index]; - } - - SEAL_NODISCARD inline const SEAL_BYTE &operator []( - std::size_t index) const - { - return data_[index]; - } - - inline auto &operator =(Pointer &&assign) noexcept - { - acquire(std::move(assign)); - return *this; - } - - SEAL_NODISCARD inline bool is_set() const noexcept - { - return data_ != nullptr; - } - - SEAL_NODISCARD inline SEAL_BYTE *get() noexcept - { - return data_; - } - - SEAL_NODISCARD inline const SEAL_BYTE *get() const noexcept - { - return data_; - } - - SEAL_NODISCARD inline SEAL_BYTE *operator ->() noexcept - { - return data_; - } - - SEAL_NODISCARD inline const SEAL_BYTE *operator ->() const noexcept - { - return data_; - } - - SEAL_NODISCARD inline SEAL_BYTE &operator *() - { - return *data_; - } - - SEAL_NODISCARD inline const SEAL_BYTE &operator *() const - { - return *data_; - } - - SEAL_NODISCARD inline bool is_alias() const noexcept - { - return alias_; - } - - inline void release() noexcept - { - if (head_) - { - // Return the memory to pool - head_->add(item_); - } - else if (data_ && !alias_) - { - // Free the memory - delete[] data_; - } - - data_ = nullptr; - head_ = nullptr; - item_ = nullptr; - alias_ = false; - } - - void acquire(Pointer &other) noexcept - { - if (this == &other) - { - return; - } - - release(); - - data_ = other.data_; - head_ = other.head_; - item_ = other.item_; - alias_ = other.alias_; - other.data_ = nullptr; - other.head_ = nullptr; - other.item_ = nullptr; - other.alias_ = false; - } - - inline void acquire(Pointer &&other) noexcept - { - acquire(other); - } - - ~Pointer() noexcept - { - release(); - } - - SEAL_NODISCARD operator bool() const noexcept - { - return (data_ != nullptr); - } - - SEAL_NODISCARD inline static Pointer Owning( - SEAL_BYTE *pointer) noexcept - { - return {pointer, false}; - } - - SEAL_NODISCARD inline static auto Aliasing( - SEAL_BYTE *pointer) noexcept -> Pointer - { - return {pointer, true}; - } - - private: - Pointer(const Pointer ©) = delete; - - Pointer &operator =(const Pointer &assign) = delete; - - Pointer(SEAL_BYTE *pointer, bool alias) noexcept : - data_(pointer), alias_(alias) - { - } - - Pointer(class MemoryPoolHead *head) - { -#ifdef SEAL_DEBUG - if (!head) - { - throw std::invalid_argument("head cannot be nullptr"); - } -#endif - head_ = head; - item_ = head->get(); - data_ = item_->data(); - } - - SEAL_BYTE *data_ = nullptr; - - MemoryPoolHead *head_ = nullptr; - - MemoryPoolItem *item_ = nullptr; - - bool alias_ = false; - }; - - template - class SEAL_NODISCARD Pointer - { - friend class MemoryPoolST; - friend class MemoryPoolMT; - - public: - friend class Pointer; - friend class ConstPointer; - friend class ConstPointer; - - Pointer() = default; - - // Move of the same type - Pointer(Pointer &&source) noexcept : - data_(source.data_), head_(source.head_), - item_(source.item_), alias_(source.alias_) - { - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - // Move when T is not SEAL_BYTE - Pointer(Pointer &&source) - { - // Cannot acquire a non-pool pointer of different type - if (!source.head_ && source.data_) - { - throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); - } - - head_ = source.head_; - item_ = source.item_; - if (head_) - { - data_ = reinterpret_cast(item_->data()); - SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) - { - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T; - } - } - } - alias_ = source.alias_; - - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - // Move when T is not SEAL_BYTE - template - Pointer(Pointer &&source, Args &&...args) - { - // Cannot acquire a non-pool pointer of different type - if (!source.head_ && source.data_) - { - throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); - } - - head_ = source.head_; - item_ = source.item_; - if (head_) - { - data_ = reinterpret_cast(item_->data()); - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T(std::forward(args)...); - } - } - alias_ = source.alias_; - - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - SEAL_NODISCARD inline T &operator [](std::size_t index) - { - return data_[index]; - } - - SEAL_NODISCARD inline const T &operator [](std::size_t index) const - { - return data_[index]; - } - - inline auto &operator =(Pointer &&assign) noexcept - { - acquire(std::move(assign)); - return *this; - } - - inline auto &operator =(Pointer &&assign) - { - acquire(std::move(assign)); - return *this; - } - - SEAL_NODISCARD inline bool is_set() const noexcept - { - return data_ != nullptr; - } - - SEAL_NODISCARD inline T *get() noexcept - { - return data_; - } - - SEAL_NODISCARD inline const T *get() const noexcept - { - return data_; - } - - SEAL_NODISCARD inline T *operator ->() noexcept - { - return data_; - } - - SEAL_NODISCARD inline const T *operator ->() const noexcept - { - return data_; - } - - SEAL_NODISCARD inline T &operator *() - { - return *data_; - } - - SEAL_NODISCARD inline const T &operator *() const - { - return *data_; - } - - SEAL_NODISCARD inline bool is_alias() const noexcept - { - return alias_; - } - - inline void release() noexcept - { - if (head_) - { - SEAL_IF_CONSTEXPR (!std::is_trivially_destructible::value) - { - // Manual destructor calls - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - alloc_ptr->~T(); - } - } - - // Return the memory to pool - head_->add(item_); - } - else if (data_ && !alias_) - { - // Free the memory - delete[] data_; - } - - data_ = nullptr; - head_ = nullptr; - item_ = nullptr; - alias_ = false; - } - - void acquire(Pointer &other) noexcept - { - if (this == &other) - { - return; - } - - release(); - - data_ = other.data_; - head_ = other.head_; - item_ = other.item_; - alias_ = other.alias_; - other.data_ = nullptr; - other.head_ = nullptr; - other.item_ = nullptr; - other.alias_ = false; - } - - inline void acquire(Pointer &&other) noexcept - { - acquire(other); - } - - void acquire(Pointer &other) - { - // Cannot acquire a non-pool pointer of different type - if (!other.head_ && other.data_) - { - throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); - } - - release(); - - head_ = other.head_; - item_ = other.item_; - if (head_) - { - data_ = reinterpret_cast(item_->data()); - SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) - { - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T; - } - } - } - alias_ = other.alias_; - other.data_ = nullptr; - other.head_ = nullptr; - other.item_ = nullptr; - other.alias_ = false; - } - - inline void acquire(Pointer &&other) - { - acquire(other); - } - - ~Pointer() noexcept - { - release(); - } - - SEAL_NODISCARD operator bool() const noexcept - { - return (data_ != nullptr); - } - - SEAL_NODISCARD inline static Pointer Owning(T *pointer) noexcept - { - return {pointer, false}; - } - - SEAL_NODISCARD inline static auto Aliasing( - T *pointer) noexcept -> Pointer - { - return {pointer, true}; - } - - private: - Pointer(const Pointer ©) = delete; - - Pointer &operator =(const Pointer &assign) = delete; - - Pointer(T *pointer, bool alias) noexcept : - data_(pointer), alias_(alias) - { - } - - Pointer(class MemoryPoolHead *head) - { -#ifdef SEAL_DEBUG - if (!head) - { - throw std::invalid_argument("head cannot be nullptr"); - } -#endif - head_ = head; - item_ = head->get(); - data_ = reinterpret_cast(item_->data()); - SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) - { - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T; - } - } - } - - template - Pointer(class MemoryPoolHead *head, Args &&...args) - { -#ifdef SEAL_DEBUG - if (!head) - { - throw std::invalid_argument("head cannot be nullptr"); - } -#endif - head_ = head; - item_ = head->get(); - data_ = reinterpret_cast(item_->data()); - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T(std::forward(args)...); - } - } - - T *data_ = nullptr; - - MemoryPoolHead *head_ = nullptr; - - MemoryPoolItem *item_ = nullptr; - - bool alias_ = false; - }; - - // Specialization for SEAL_BYTE - template<> - class SEAL_NODISCARD ConstPointer - { - friend class MemoryPoolST; - friend class MemoryPoolMT; - - public: - template friend class ConstPointer; - - ConstPointer() = default; - - // Move of the same type - ConstPointer(Pointer &&source) noexcept : - data_(source.data_), head_(source.head_), - item_(source.item_), alias_(source.alias_) - { - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - // Move of the same type - ConstPointer(Pointer &&source, SEAL_BYTE value) noexcept : - ConstPointer(std::move(source)) - { - std::fill_n(data_, head_->item_byte_count(), value); - } - - // Move of the same type - ConstPointer(ConstPointer &&source) noexcept : - data_(source.data_), head_(source.head_), - item_(source.item_), alias_(source.alias_) - { - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - // Move of the same type - ConstPointer(ConstPointer &&source, SEAL_BYTE value) noexcept : - ConstPointer(std::move(source)) - { - std::fill_n(data_, head_->item_byte_count(), value); - } - - inline auto &operator =(ConstPointer &&assign) noexcept - { - acquire(std::move(assign)); - return *this; - } - - inline auto &operator =(Pointer &&assign) noexcept - { - acquire(std::move(assign)); - return *this; - } - - SEAL_NODISCARD inline const SEAL_BYTE &operator []( - std::size_t index) const - { - return data_[index]; - } - - SEAL_NODISCARD inline bool is_set() const noexcept - { - return data_ != nullptr; - } - - SEAL_NODISCARD inline const SEAL_BYTE *get() const noexcept - { - return data_; - } - - SEAL_NODISCARD inline const SEAL_BYTE *operator ->() const noexcept - { - return data_; - } - - SEAL_NODISCARD inline const SEAL_BYTE &operator *() const - { - return *data_; - } - - inline void release() noexcept - { - if (head_) - { - // Return the memory to pool - head_->add(item_); - } - else if (data_ && !alias_) - { - // Free the memory - delete[] data_; - } - - data_ = nullptr; - head_ = nullptr; - item_ = nullptr; - alias_ = false; - } - - void acquire(Pointer &other) noexcept - { - release(); - - data_ = other.data_; - head_ = other.head_; - item_ = other.item_; - alias_ = other.alias_; - other.data_ = nullptr; - other.head_ = nullptr; - other.item_ = nullptr; - other.alias_ = false; - } - - inline void acquire(Pointer &&other) noexcept - { - acquire(other); - } - - void acquire(ConstPointer &other) noexcept - { - if (this == &other) - { - return; - } - - release(); - - data_ = other.data_; - head_ = other.head_; - item_ = other.item_; - alias_ = other.alias_; - other.data_ = nullptr; - other.head_ = nullptr; - other.item_ = nullptr; - other.alias_ = false; - } - - inline void acquire(ConstPointer &&other) noexcept - { - acquire(other); - } - - ~ConstPointer() noexcept - { - release(); - } - - SEAL_NODISCARD operator bool() const - { - return (data_ != nullptr); - } - - SEAL_NODISCARD inline static auto Owning(SEAL_BYTE *pointer) noexcept - -> ConstPointer - { - return {pointer, false}; - } - - SEAL_NODISCARD inline static auto Aliasing( - const SEAL_BYTE *pointer) noexcept -> ConstPointer - { - return {const_cast(pointer), true}; - } - - private: - ConstPointer(const ConstPointer ©) = delete; - - ConstPointer &operator =(const ConstPointer &assign) = delete; - - ConstPointer(SEAL_BYTE *pointer, bool alias) noexcept : - data_(pointer), alias_(alias) - { - } - - ConstPointer(class MemoryPoolHead *head) - { -#ifdef SEAL_DEBUG - if (!head) - { - throw std::invalid_argument("head cannot be nullptr"); - } -#endif - head_ = head; - item_ = head->get(); - data_ = item_->data(); - } - - SEAL_BYTE *data_ = nullptr; - - MemoryPoolHead *head_ = nullptr; - - MemoryPoolItem *item_ = nullptr; - - bool alias_ = false; - }; - - template - class SEAL_NODISCARD ConstPointer - { - friend class MemoryPoolST; - friend class MemoryPoolMT; - - public: - ConstPointer() = default; - - // Move of the same type - ConstPointer(Pointer &&source) noexcept : - data_(source.data_), head_(source.head_), - item_(source.item_), alias_(source.alias_) - { - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - // Move when T is not SEAL_BYTE - ConstPointer(Pointer &&source) - { - // Cannot acquire a non-pool pointer of different type - if (!source.head_ && source.data_) - { - throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); - } - - head_ = source.head_; - item_ = source.item_; - if (head_) - { - data_ = reinterpret_cast(item_->data()); - SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) - { - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T; - } - } - } - alias_ = source.alias_; - - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - // Move when T is not SEAL_BYTE - template - ConstPointer(Pointer &&source, Args &&...args) - { - // Cannot acquire a non-pool pointer of different type - if (!source.head_ && source.data_) - { - throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); - } - - head_ = source.head_; - item_ = source.item_; - if (head_) - { - data_ = reinterpret_cast(item_->data()); - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T(std::forward(args)...); - } - } - alias_ = source.alias_; - - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - // Move of the same type - ConstPointer(ConstPointer &&source) noexcept : - data_(source.data_), head_(source.head_), - item_(source.item_), alias_(source.alias_) - { - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - // Move when T is not SEAL_BYTE - ConstPointer(ConstPointer &&source) - { - // Cannot acquire a non-pool pointer of different type - if (!source.head_ && source.data_) - { - throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); - } - - head_ = source.head_; - item_ = source.item_; - if (head_) - { - data_ = reinterpret_cast(item_->data()); - SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) - { - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T; - } - } - } - alias_ = source.alias_; - - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - // Move when T is not SEAL_BYTE - template - ConstPointer(ConstPointer &&source, Args &&...args) - { - // Cannot acquire a non-pool pointer of different type - if (!source.head_ && source.data_) - { - throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); - } - - head_ = source.head_; - item_ = source.item_; - if (head_) - { - data_ = reinterpret_cast(item_->data()); - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T(std::forward(args)...); - } - } - alias_ = source.alias_; - - source.data_ = nullptr; - source.head_ = nullptr; - source.item_ = nullptr; - source.alias_ = false; - } - - inline auto &operator =(ConstPointer &&assign) noexcept - { - acquire(std::move(assign)); - return *this; - } - - inline auto &operator =(ConstPointer &&assign) - { - acquire(std::move(assign)); - return *this; - } - - inline auto &operator =(Pointer &&assign) noexcept - { - acquire(std::move(assign)); - return *this; - } - - inline auto &operator =(Pointer &&assign) - { - acquire(std::move(assign)); - return *this; - } - - SEAL_NODISCARD inline const T &operator [](std::size_t index) const - { - return data_[index]; - } - - SEAL_NODISCARD inline bool is_set() const noexcept - { - return data_ != nullptr; - } - - SEAL_NODISCARD inline const T *get() const noexcept - { - return data_; - } - - SEAL_NODISCARD inline const T *operator ->() const noexcept - { - return data_; - } - - SEAL_NODISCARD inline const T &operator *() const - { - return *data_; - } - - inline void release() noexcept - { - if (head_) - { - SEAL_IF_CONSTEXPR (!std::is_trivially_destructible::value) - { - // Manual destructor calls - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - alloc_ptr->~T(); - } - } - - // Return the memory to pool - head_->add(item_); - } - else if (data_ && !alias_) - { - // Free the memory - delete[] data_; - } - - data_ = nullptr; - head_ = nullptr; - item_ = nullptr; - alias_ = false; - } - - void acquire(ConstPointer &other) noexcept - { - if (this == &other) - { - return; - } - - release(); - - data_ = other.data_; - head_ = other.head_; - item_ = other.item_; - alias_ = other.alias_; - other.data_ = nullptr; - other.head_ = nullptr; - other.item_ = nullptr; - other.alias_ = false; - } - - inline void acquire(ConstPointer &&other) noexcept - { - acquire(other); - } - - void acquire(ConstPointer &other) - { - // Cannot acquire a non-pool pointer of different type - if (!other.head_ && other.data_) - { - throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); - } - - release(); - - head_ = other.head_; - item_ = other.item_; - if (head_) - { - data_ = reinterpret_cast(item_->data()); - SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) - { - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T; - } - } - } - alias_ = other.alias_; - other.data_ = nullptr; - other.head_ = nullptr; - other.item_ = nullptr; - other.alias_ = false; - } - - inline void acquire(ConstPointer &&other) - { - acquire(other); - } - - void acquire(Pointer &other) noexcept - { - release(); - - data_ = other.data_; - head_ = other.head_; - item_ = other.item_; - alias_ = other.alias_; - other.data_ = nullptr; - other.head_ = nullptr; - other.item_ = nullptr; - other.alias_ = false; - } - - inline void acquire(Pointer &&other) noexcept - { - acquire(other); - } - - void acquire(Pointer &other) - { - // Cannot acquire a non-pool pointer of different type - if (!other.head_ && other.data_) - { - throw std::invalid_argument("cannot acquire a non-pool pointer of different type"); - } - - release(); - - head_ = other.head_; - item_ = other.item_; - if (head_) - { - data_ = reinterpret_cast(item_->data()); - SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) - { - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T; - } - } - } - alias_ = other.alias_; - other.data_ = nullptr; - other.head_ = nullptr; - other.item_ = nullptr; - other.alias_ = false; - } - - inline void acquire(Pointer &&other) - { - acquire(other); - } - - ~ConstPointer() noexcept - { - release(); - } - - SEAL_NODISCARD operator bool() const noexcept - { - return (data_ != nullptr); - } - - SEAL_NODISCARD inline static ConstPointer Owning(T *pointer) noexcept - { - return {pointer, false}; - } - - SEAL_NODISCARD inline static auto Aliasing( - const T *pointer) noexcept -> ConstPointer - { - return {const_cast(pointer), true}; - } - - private: - ConstPointer(const ConstPointer ©) = delete; - - ConstPointer &operator =(const ConstPointer &assign) = delete; - - ConstPointer(T *pointer, bool alias) noexcept : data_(pointer), alias_(alias) - { - } - - ConstPointer(class MemoryPoolHead *head) - { -#ifdef SEAL_DEBUG - if (!head) - { - throw std::invalid_argument("head cannot be nullptr"); - } -#endif - head_ = head; - item_ = head->get(); - data_ = reinterpret_cast(item_->data()); - SEAL_IF_CONSTEXPR (!std::is_trivially_constructible::value) - { - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T; - } - } - } - - template - ConstPointer(class MemoryPoolHead *head, Args &&...args) - { -#ifdef SEAL_DEBUG - if (!head) - { - throw std::invalid_argument("head cannot be nullptr"); - } -#endif - head_ = head; - item_ = head->get(); - data_ = reinterpret_cast(item_->data()); - auto count = head_->item_byte_count() / sizeof(T); - for (auto alloc_ptr = data_; count--; alloc_ptr++) - { - new(alloc_ptr) T(std::forward(args)...); - } - } - - T *data_ = nullptr; - - MemoryPoolHead *head_ = nullptr; - - MemoryPoolItem *item_ = nullptr; - - bool alias_ = false; - }; - - // Allocate single element - template::value>> - SEAL_NODISCARD inline auto allocate( - MemoryPool &pool, Args &&...args) - { - using T = typename std::remove_cv::type>::type; - return Pointer(pool.get_for_byte_count(sizeof(T)), - std::forward(args)...); - } - - // Allocate array of elements - template::value>> - SEAL_NODISCARD inline auto allocate( - std::size_t count, MemoryPool &pool, Args &&...args) - { - using T = typename std::remove_cv::type>::type; - return Pointer(pool.get_for_byte_count(util::mul_safe(count, sizeof(T))), - std::forward(args)...); - } - - template::value>> - SEAL_NODISCARD inline auto duplicate_if_needed( - T_ *original, std::size_t count, bool condition, MemoryPool &pool) - { - using T = typename std::remove_cv::type>::type; -#ifdef SEAL_DEBUG - if (original == nullptr && count > 0) - { - throw std::invalid_argument("original"); - } -#endif - if (condition == false) - { - return Pointer::Aliasing(original); - } - auto allocation(allocate(count, pool)); - std::copy_n(original, count, allocation.get()); - return allocation; - } - - template::value>> - SEAL_NODISCARD inline auto duplicate_if_needed( - const T_ *original, std::size_t count, bool condition, MemoryPool &pool) - { - using T = typename std::remove_cv::type>::type; -#ifdef SEAL_DEBUG - if (original == nullptr && count > 0) - { - throw std::invalid_argument("original"); - } -#endif - if (condition == false) - { - return ConstPointer::Aliasing(original); - } - auto allocation(allocate(count, pool)); - std::copy_n(original, count, allocation.get()); - return ConstPointer(std::move(allocation)); - } - } -} diff --git a/SEAL/native/src/seal/util/polyarith.cpp b/SEAL/native/src/seal/util/polyarith.cpp deleted file mode 100644 index da9be1e..0000000 --- a/SEAL/native/src/seal/util/polyarith.cpp +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithmod.h" -#include "seal/util/polycore.h" -#include "seal/util/polyarith.h" - -using namespace std; - -namespace seal -{ - namespace util - { - void multiply_poly_poly(const uint64_t *operand1, - size_t operand1_coeff_count, size_t operand1_coeff_uint64_count, - const uint64_t *operand2, size_t operand2_coeff_count, - size_t operand2_coeff_uint64_count, size_t result_coeff_count, - size_t result_coeff_uint64_count, uint64_t *result, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (operand1 == nullptr && operand1_coeff_count > 0 && - operand1_coeff_uint64_count > 0) - { - throw invalid_argument("operand1"); - } - if (operand2 == nullptr && operand2_coeff_count > 0 && - operand2_coeff_uint64_count > 0) - { - throw invalid_argument("operand2"); - } - if (result == nullptr && result_coeff_count > 0 && - result_coeff_uint64_count > 0) - { - throw invalid_argument("result"); - } - if (result != nullptr && - (operand1 == result || operand2 == result)) - { - throw invalid_argument("result cannot point to the same value as operand1 or operand2"); - } - if (!sum_fits_in(operand1_coeff_count, operand2_coeff_count)) - { - throw invalid_argument("operand1 and operand2 too large"); - } -#endif - auto intermediate(allocate_uint(result_coeff_uint64_count, pool)); - - // Clear product. - set_zero_poly(result_coeff_count, result_coeff_uint64_count, result); - - operand1_coeff_count = get_significant_coeff_count_poly( - operand1, operand1_coeff_count, operand1_coeff_uint64_count); - operand2_coeff_count = get_significant_coeff_count_poly( - operand2, operand2_coeff_count, operand2_coeff_uint64_count); - for (size_t operand1_index = 0; - operand1_index < operand1_coeff_count; operand1_index++) - { - const uint64_t *operand1_coeff = get_poly_coeff( - operand1, operand1_index, operand1_coeff_uint64_count); - for (size_t operand2_index = 0; - operand2_index < operand2_coeff_count; operand2_index++) - { - size_t product_coeff_index = operand1_index + operand2_index; - if (product_coeff_index >= result_coeff_count) - { - break; - } - - const uint64_t *operand2_coeff = get_poly_coeff( - operand2, operand2_index, operand2_coeff_uint64_count); - multiply_uint_uint(operand1_coeff, operand1_coeff_uint64_count, - operand2_coeff, operand2_coeff_uint64_count, - result_coeff_uint64_count, intermediate.get()); - uint64_t *result_coeff = get_poly_coeff( - result, product_coeff_index, result_coeff_uint64_count); - add_uint_uint(result_coeff, intermediate.get(), - result_coeff_uint64_count, result_coeff); - } - } - } - - void poly_eval_poly(const uint64_t *poly_to_eval, - size_t poly_to_eval_coeff_count, - size_t poly_to_eval_coeff_uint64_count, - const uint64_t *value, size_t value_coeff_count, - size_t value_coeff_uint64_count, size_t result_coeff_count, - size_t result_coeff_uint64_count, uint64_t *result, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (poly_to_eval == nullptr) - { - throw invalid_argument("poly_to_eval"); - } - if (value == nullptr) - { - throw invalid_argument("value"); - } - if (result == nullptr) - { - throw invalid_argument("result"); - } - if (poly_to_eval_coeff_count == 0) - { - throw invalid_argument("poly_to_eval_coeff_count"); - } - if (poly_to_eval_coeff_uint64_count == 0) - { - throw invalid_argument("poly_to_eval_coeff_uint64_count"); - } - if (value_coeff_count == 0) - { - throw invalid_argument("value_coeff_count"); - } - if (value_coeff_uint64_count == 0) - { - throw invalid_argument("value_coeff_uint64_count"); - } - if (result_coeff_count == 0) - { - throw invalid_argument("result_coeff_count"); - } - if (result_coeff_uint64_count == 0) - { - throw invalid_argument("result_coeff_uint64_count"); - } -#endif - // Evaluate poly at value using Horner's method - auto temp1(allocate_poly(result_coeff_count, result_coeff_uint64_count, pool)); - auto temp2(allocate_zero_poly(result_coeff_count, result_coeff_uint64_count, pool)); - uint64_t *productptr = temp1.get(); - uint64_t *intermediateptr = temp2.get(); - - while (poly_to_eval_coeff_count--) - { - multiply_poly_poly(intermediateptr, result_coeff_count, - result_coeff_uint64_count, value, value_coeff_count, - value_coeff_uint64_count, result_coeff_count, - result_coeff_uint64_count, productptr, pool); - const uint64_t *curr_coeff = get_poly_coeff( - poly_to_eval, poly_to_eval_coeff_count, - poly_to_eval_coeff_uint64_count); - add_uint_uint(productptr, result_coeff_uint64_count, curr_coeff, - poly_to_eval_coeff_uint64_count, false, - result_coeff_uint64_count, productptr); - swap(productptr, intermediateptr); - } - set_poly_poly(intermediateptr, result_coeff_count, - result_coeff_uint64_count, result); - } - - void exponentiate_poly(const std::uint64_t *poly, size_t poly_coeff_count, - size_t poly_coeff_uint64_count, const uint64_t *exponent, - size_t exponent_uint64_count, size_t result_coeff_count, - size_t result_coeff_uint64_count, std::uint64_t *result, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (poly == nullptr) - { - throw invalid_argument("poly"); - } - if (poly_coeff_count == 0) - { - throw invalid_argument("poly_coeff_count"); - } - if (poly_coeff_uint64_count == 0) - { - throw invalid_argument("poly_coeff_uint64_count"); - } - if (exponent == nullptr) - { - throw invalid_argument("exponent"); - } - if (exponent_uint64_count == 0) - { - throw invalid_argument("exponent_uint64_count"); - } - if (result == nullptr) - { - throw invalid_argument("result"); - } - if (result_coeff_count == 0) - { - throw invalid_argument("result_coeff_count"); - } - if (result_coeff_uint64_count == 0) - { - throw invalid_argument("result_coeff_uint64_count"); - } -#endif - // Fast cases - if (is_zero_uint(exponent, exponent_uint64_count)) - { - set_zero_poly(result_coeff_count, result_coeff_uint64_count, result); - *result = 1; - return; - } - if (is_equal_uint(exponent, exponent_uint64_count, 1)) - { - set_poly_poly(poly, poly_coeff_count, poly_coeff_uint64_count, - result_coeff_count, result_coeff_uint64_count, result); - return; - } - - // Need to make a copy of exponent - auto exponent_copy(allocate_uint(exponent_uint64_count, pool)); - set_uint_uint(exponent, exponent_uint64_count, exponent_copy.get()); - - // Perform binary exponentiation. - auto big_alloc(allocate_uint(mul_safe( - add_safe(result_coeff_count, result_coeff_count, result_coeff_count), - result_coeff_uint64_count), pool)); - - uint64_t *powerptr = big_alloc.get(); - uint64_t *productptr = get_poly_coeff( - powerptr, result_coeff_count, result_coeff_uint64_count); - uint64_t *intermediateptr = get_poly_coeff( - productptr, result_coeff_count, result_coeff_uint64_count); - - set_poly_poly(poly, poly_coeff_count, poly_coeff_uint64_count, result_coeff_count, - result_coeff_uint64_count, powerptr); - set_zero_poly(result_coeff_count, result_coeff_uint64_count, intermediateptr); - *intermediateptr = 1; - - // Initially: power = operand and intermediate = 1, product is not initialized. - while (true) - { - if ((*exponent_copy.get() % 2) == 1) - { - multiply_poly_poly(powerptr, result_coeff_count, result_coeff_uint64_count, - intermediateptr, result_coeff_count, result_coeff_uint64_count, - result_coeff_count, result_coeff_uint64_count, productptr, pool); - swap(productptr, intermediateptr); - } - right_shift_uint(exponent_copy.get(), 1, exponent_uint64_count, exponent_copy.get()); - if (is_zero_uint(exponent_copy.get(), exponent_uint64_count)) - { - break; - } - multiply_poly_poly(powerptr, result_coeff_count, result_coeff_uint64_count, - powerptr, result_coeff_count, result_coeff_uint64_count, - result_coeff_count, result_coeff_uint64_count, productptr, pool); - swap(productptr, powerptr); - } - set_poly_poly(intermediateptr, result_coeff_count, result_coeff_uint64_count, result); - } - } -} diff --git a/SEAL/native/src/seal/util/polyarith.h b/SEAL/native/src/seal/util/polyarith.h deleted file mode 100644 index caa899f..0000000 --- a/SEAL/native/src/seal/util/polyarith.h +++ /dev/null @@ -1,147 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithmod.h" -#include "seal/util/polycore.h" -#include "seal/util/pointer.h" - -namespace seal -{ - namespace util - { - inline void right_shift_poly_coeffs( - const std::uint64_t *poly, std::size_t coeff_count, - std::size_t coeff_uint64_count, int shift_amount, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (poly == nullptr && coeff_count > 0 && coeff_uint64_count > 0) - { - throw std::invalid_argument("poly"); - } -#endif - while (coeff_count--) - { - right_shift_uint(poly, shift_amount, coeff_uint64_count, result); - poly += coeff_uint64_count; - result += coeff_uint64_count; - } - } - - inline void negate_poly(const std::uint64_t *poly, - std::size_t coeff_count, std::size_t coeff_uint64_count, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (poly == nullptr && coeff_count > 0 && coeff_uint64_count > 0) - { - throw std::invalid_argument("poly"); - } - if (result == nullptr && coeff_count > 0 && coeff_uint64_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - while(coeff_count--) - { - negate_uint(poly, coeff_uint64_count, result); - poly += coeff_uint64_count; - result += coeff_uint64_count; - } - } - - inline void add_poly_poly(const std::uint64_t *operand1, - const std::uint64_t *operand2, std::size_t coeff_count, - std::size_t coeff_uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (operand1 == nullptr && coeff_count > 0 && coeff_uint64_count > 0) - { - throw std::invalid_argument("operand1"); - } - if (operand2 == nullptr && coeff_count > 0 && coeff_uint64_count > 0) - { - throw std::invalid_argument("operand2"); - } - if (result == nullptr && coeff_count > 0 && coeff_uint64_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - while(coeff_count--) - { - add_uint_uint(operand1, operand2, coeff_uint64_count, result); - operand1 += coeff_uint64_count; - operand2 += coeff_uint64_count; - result += coeff_uint64_count; - } - } - - inline void sub_poly_poly(const std::uint64_t *operand1, - const std::uint64_t *operand2, std::size_t coeff_count, - std::size_t coeff_uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (operand1 == nullptr && coeff_count > 0 && coeff_uint64_count > 0) - { - throw std::invalid_argument("operand1"); - } - if (operand2 == nullptr && coeff_count > 0 && coeff_uint64_count > 0) - { - throw std::invalid_argument("operand2"); - } - if (result == nullptr && coeff_count > 0 && coeff_uint64_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - while(coeff_count--) - { - sub_uint_uint(operand1, operand2, coeff_uint64_count, result); - operand1 += coeff_uint64_count; - operand2 += coeff_uint64_count; - result += coeff_uint64_count; - } - } - - void multiply_poly_poly( - const std::uint64_t *operand1, std::size_t operand1_coeff_count, - std::size_t operand1_coeff_uint64_count, const std::uint64_t *operand2, - std::size_t operand2_coeff_count, std::size_t operand2_coeff_uint64_count, - std::size_t result_coeff_count, std::size_t result_coeff_uint64_count, - std::uint64_t *result, MemoryPool &pool); - - inline void poly_infty_norm(const std::uint64_t *poly, - std::size_t coeff_count, std::size_t coeff_uint64_count, - std::uint64_t *result) - { - set_zero_uint(coeff_uint64_count, result); - while(coeff_count--) - { - if (is_greater_than_uint_uint(poly, result, coeff_uint64_count)) - { - set_uint_uint(poly, coeff_uint64_count, result); - } - - poly += coeff_uint64_count; - } - } - - void poly_eval_poly(const std::uint64_t *poly_to_eval, - std::size_t poly_to_eval_coeff_count, - std::size_t poly_to_eval_coeff_uint64_count, const std::uint64_t *value, - std::size_t value_coeff_count, std::size_t value_coeff_uint64_count, - std::size_t result_coeff_count, std::size_t result_coeff_uint64_count, - std::uint64_t *result, MemoryPool &pool); - - void exponentiate_poly(const std::uint64_t *poly, std::size_t poly_coeff_count, - std::size_t poly_coeff_uint64_count, const std::uint64_t *exponent, - std::size_t exponent_uint64_count, std::size_t result_coeff_count, - std::size_t result_coeff_uint64_count, std::uint64_t *result, MemoryPool &pool); - } -} diff --git a/SEAL/native/src/seal/util/polyarithmod.cpp b/SEAL/native/src/seal/util/polyarithmod.cpp deleted file mode 100644 index 630e354..0000000 --- a/SEAL/native/src/seal/util/polyarithmod.cpp +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithmod.h" -#include "seal/util/polycore.h" -#include "seal/util/polyarith.h" -#include "seal/util/polyarithmod.h" - -using namespace std; - -namespace seal -{ - namespace util - { - void poly_infty_norm_coeffmod(const uint64_t *poly, size_t coeff_count, - size_t coeff_uint64_count, const uint64_t *modulus, uint64_t *result, - MemoryPool &pool) - { - // Construct negative threshold (first negative modulus value) to compute - // absolute values of coeffs. - auto modulus_neg_threshold(allocate_uint(coeff_uint64_count, pool)); - - // Set to value of (modulus + 1) / 2. To prevent overflowing with the +1, just - // add 1 to the result if modulus was odd. - half_round_up_uint(modulus, coeff_uint64_count, modulus_neg_threshold.get()); - - // Mod out the poly coefficients and choose a symmetric representative from - // [-modulus,modulus). Keep track of the max. - set_zero_uint(coeff_uint64_count, result); - auto coeff_abs_value(allocate_uint(coeff_uint64_count, pool)); - for (size_t i = 0; i < coeff_count; i++, poly += coeff_uint64_count) - { - if (is_greater_than_or_equal_uint_uint( - poly, modulus_neg_threshold.get(), coeff_uint64_count)) - { - sub_uint_uint(modulus, poly, coeff_uint64_count, coeff_abs_value.get()); - } - else - { - set_uint_uint(poly, coeff_uint64_count, coeff_abs_value.get()); - } - if (is_greater_than_uint_uint(coeff_abs_value.get(), result, - coeff_uint64_count)) - { - set_uint_uint(coeff_abs_value.get(), coeff_uint64_count, result); - } - } - } - } -} diff --git a/SEAL/native/src/seal/util/polyarithmod.h b/SEAL/native/src/seal/util/polyarithmod.h deleted file mode 100644 index db6c77d..0000000 --- a/SEAL/native/src/seal/util/polyarithmod.h +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include "seal/util/pointer.h" -#include "seal/util/polycore.h" -#include "seal/util/uintarithmod.h" - -namespace seal -{ - namespace util - { - inline void negate_poly_coeffmod(const std::uint64_t *poly, - std::size_t coeff_count, const std::uint64_t *coeff_modulus, - std::size_t coeff_uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (poly == nullptr && coeff_count > 0) - { - throw std::invalid_argument("poly"); - } - if (coeff_modulus == nullptr) - { - throw std::invalid_argument("coeff_modulus"); - } - if (coeff_uint64_count == 0) - { - throw std::invalid_argument("coeff_uint64_count"); - } - if (result == nullptr && coeff_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - for (std::size_t i = 0; i < coeff_count; i++) - { - negate_uint_mod(poly, coeff_modulus, coeff_uint64_count, result); - poly += coeff_uint64_count; - result += coeff_uint64_count; - } - } - - inline void add_poly_poly_coeffmod(const std::uint64_t *operand1, - const std::uint64_t *operand2, std::size_t coeff_count, - const std::uint64_t *coeff_modulus, std::size_t coeff_uint64_count, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (operand1 == nullptr && coeff_count > 0) - { - throw std::invalid_argument("operand1"); - } - if (operand2 == nullptr && coeff_count > 0) - { - throw std::invalid_argument("operand2"); - } - if (coeff_modulus == nullptr) - { - throw std::invalid_argument("coeff_modulus"); - } - if (coeff_uint64_count == 0) - { - throw std::invalid_argument("coeff_uint64_count"); - } - if (result == nullptr && coeff_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - for (std::size_t i = 0; i < coeff_count; i++) - { - add_uint_uint_mod(operand1, operand2, coeff_modulus, - coeff_uint64_count, result); - operand1 += coeff_uint64_count; - operand2 += coeff_uint64_count; - result += coeff_uint64_count; - } - } - - inline void sub_poly_poly_coeffmod(const std::uint64_t *operand1, - const std::uint64_t *operand2, std::size_t coeff_count, - const std::uint64_t *coeff_modulus, std::size_t coeff_uint64_count, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (operand1 == nullptr && coeff_count > 0) - { - throw std::invalid_argument("operand1"); - } - if (operand2 == nullptr && coeff_count > 0) - { - throw std::invalid_argument("operand2"); - } - if (coeff_modulus == nullptr) - { - throw std::invalid_argument("coeff_modulus"); - } - if (coeff_uint64_count == 0) - { - throw std::invalid_argument("coeff_uint64_count"); - } - if (result == nullptr && coeff_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - for (std::size_t i = 0; i < coeff_count; i++) - { - sub_uint_uint_mod(operand1, operand2, coeff_modulus, - coeff_uint64_count, result); - operand1 += coeff_uint64_count; - operand2 += coeff_uint64_count; - result += coeff_uint64_count; - } - } - - void poly_infty_norm_coeffmod(const std::uint64_t *poly, - std::size_t coeff_count, std::size_t coeff_uint64_count, - const std::uint64_t *modulus, std::uint64_t *result, MemoryPool &pool); - } -} diff --git a/SEAL/native/src/seal/util/polyarithsmallmod.cpp b/SEAL/native/src/seal/util/polyarithsmallmod.cpp deleted file mode 100644 index 58698f0..0000000 --- a/SEAL/native/src/seal/util/polyarithsmallmod.cpp +++ /dev/null @@ -1,725 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarithsmallmod.h" -#include "seal/util/uintarith.h" -#include "seal/util/polycore.h" -#include "seal/util/polyarith.h" -#include "seal/util/polyarithsmallmod.h" -#include "seal/util/defines.h" - -using namespace std; - -namespace seal -{ - namespace util - { - void multiply_poly_scalar_coeffmod(const uint64_t *poly, - size_t coeff_count, uint64_t scalar, const SmallModulus &modulus, - uint64_t *result) - { -#ifdef SEAL_DEBUG - if (poly == nullptr && coeff_count > 0) - { - throw invalid_argument("poly"); - } - if (result == nullptr && coeff_count > 0) - { - throw invalid_argument("result"); - } - if (modulus.is_zero()) - { - throw invalid_argument("modulus"); - } -#endif - // Explicit inline - //for (int i = 0; i < coeff_count; i++) - //{ - // *result++ = multiply_uint_uint_mod(*poly++, scalar, modulus); - //} - const uint64_t modulus_value = modulus.value(); - const uint64_t const_ratio_0 = modulus.const_ratio()[0]; - const uint64_t const_ratio_1 = modulus.const_ratio()[1]; - for (; coeff_count--; poly++, result++) - { - unsigned long long z[2], tmp1, tmp2[2], tmp3, carry; - multiply_uint64(*poly, scalar, z); - - // Reduces z using base 2^64 Barrett reduction - - // Multiply input and const_ratio - // Round 1 - multiply_uint64_hw64(z[0], const_ratio_0, &carry); - multiply_uint64(z[0], const_ratio_1, tmp2); - tmp3 = tmp2[1] + add_uint64(tmp2[0], carry, &tmp1); - - // Round 2 - multiply_uint64(z[1], const_ratio_0, tmp2); - carry = tmp2[1] + add_uint64(tmp1, tmp2[0], &tmp1); - - // This is all we care about - tmp1 = z[1] * const_ratio_1 + tmp3 + carry; - - // Barrett subtraction - tmp3 = z[0] - tmp1 * modulus_value; - - // Claim: One more subtraction is enough - *result = tmp3 - (modulus_value & static_cast( - -static_cast(tmp3 >= modulus_value))); - } - } - - void multiply_poly_poly_coeffmod(const uint64_t *operand1, - size_t operand1_coeff_count, const uint64_t *operand2, - size_t operand2_coeff_count, const SmallModulus &modulus, - size_t result_coeff_count, uint64_t *result) - { -#ifdef SEAL_DEBUG - if (operand1 == nullptr && operand1_coeff_count > 0) - { - throw invalid_argument("operand1"); - } - if (operand2 == nullptr && operand2_coeff_count > 0) - { - throw invalid_argument("operand2"); - } - if (result == nullptr && result_coeff_count > 0) - { - throw invalid_argument("result"); - } - if (result != nullptr && (operand1 == result || operand2 == result)) - { - throw invalid_argument("result cannot point to the same value as operand1, operand2, or modulus"); - } - if (modulus.is_zero()) - { - throw invalid_argument("modulus"); - } - if (!sum_fits_in(operand1_coeff_count, operand2_coeff_count)) - { - throw invalid_argument("operand1 and operand2 too large"); - } -#endif - // Clear product. - set_zero_uint(result_coeff_count, result); - - operand1_coeff_count = get_significant_coeff_count_poly( - operand1, operand1_coeff_count, 1); - operand2_coeff_count = get_significant_coeff_count_poly( - operand2, operand2_coeff_count, 1); - for (size_t operand1_index = 0; - operand1_index < operand1_coeff_count; operand1_index++) - { - if (operand1[operand1_index] == 0) - { - // If coefficient is 0, then move on to next coefficient. - continue; - } - // Do expensive add - for (size_t operand2_index = 0; - operand2_index < operand2_coeff_count; operand2_index++) - { - size_t product_coeff_index = operand1_index + operand2_index; - if (product_coeff_index >= result_coeff_count) - { - break; - } - - if (operand2[operand2_index] == 0) - { - // If coefficient is 0, then move on to next coefficient. - continue; - } - - // Lazy reduction - unsigned long long temp[2]; - multiply_uint64(operand1[operand1_index], operand2[operand2_index], temp); - temp[1] += add_uint64(temp[0], result[product_coeff_index], 0, temp); - result[product_coeff_index] = barrett_reduce_128(temp, modulus); - } - } - } - - void multiply_poly_poly_coeffmod(const uint64_t *operand1, - const uint64_t *operand2, size_t coeff_count, - const SmallModulus &modulus, uint64_t *result) - { -#ifdef SEAL_DEBUG - if (operand1 == nullptr && coeff_count > 0) - { - throw invalid_argument("operand1"); - } - if (operand2 == nullptr && coeff_count > 0) - { - throw invalid_argument("operand2"); - } - if (result == nullptr && coeff_count > 0) - { - throw invalid_argument("result"); - } - if (result != nullptr && (operand1 == result || operand2 == result)) - { - throw invalid_argument("result cannot point to the same value as operand1, operand2, or modulus"); - } - if (modulus.is_zero()) - { - throw invalid_argument("modulus"); - } -#endif - size_t result_coeff_count = coeff_count + coeff_count - 1; - - // Clear product. - set_zero_uint(result_coeff_count, result); - - for (size_t operand1_index = 0; operand1_index < coeff_count; operand1_index++) - { - if (operand1[operand1_index] == 0) - { - // If coefficient is 0, then move on to next coefficient. - continue; - } - // Lastly, do more expensive add if other cases don't handle it. - for (size_t operand2_index = 0; operand2_index < coeff_count; operand2_index++) - { - uint64_t operand2_coeff = operand2[operand2_index]; - if (operand2_coeff == 0) - { - // If coefficient is 0, then move on to next coefficient. - continue; - } - - // Lazy reduction - unsigned long long temp[2]; - multiply_uint64(operand1[operand1_index], operand2_coeff, temp); - temp[1] += add_uint64(temp[0], result[operand1_index + operand2_index], 0, temp); - - result[operand1_index + operand2_index] = barrett_reduce_128(temp, modulus); - } - } - } - - void divide_poly_poly_coeffmod_inplace(uint64_t *numerator, - const uint64_t *denominator, size_t coeff_count, - const SmallModulus &modulus, uint64_t *quotient) - { -#ifdef SEAL_DEBUG - if (numerator == nullptr) - { - throw invalid_argument("numerator"); - } - if (denominator == nullptr) - { - throw invalid_argument("denominator"); - } - if (is_zero_poly(denominator, coeff_count, modulus.uint64_count())) - { - throw invalid_argument("denominator"); - } - if (quotient == nullptr) - { - throw invalid_argument("quotient"); - } - if (numerator == quotient || denominator == quotient) - { - throw invalid_argument("quotient cannot point to same value as numerator or denominator"); - } - if (numerator == denominator) - { - throw invalid_argument("numerator cannot point to same value as denominator"); - } - if (modulus.is_zero()) - { - throw invalid_argument("modulus"); - } -#endif - // Clear quotient. - set_zero_uint(coeff_count, quotient); - - // Determine most significant coefficients of numerator and denominator. - size_t numerator_coeffs = get_significant_uint64_count_uint( - numerator, coeff_count); - size_t denominator_coeffs = get_significant_uint64_count_uint( - denominator, coeff_count); - - // If numerator has lesser degree than denominator, then done. - if (numerator_coeffs < denominator_coeffs) - { - return; - } - - // Create scalar to store value that makes denominator monic. - uint64_t monic_denominator_scalar; - - // Create temporary scalars used during calculation of quotient. - // Both are purposely twice as wide to store intermediate product prior to modulo operation. - uint64_t temp_quotient; - uint64_t subtrahend; - - // Determine scalar necessary to make denominator monic. - uint64_t leading_denominator_coeff = denominator[denominator_coeffs - 1]; - if (!try_invert_uint_mod(leading_denominator_coeff, modulus, monic_denominator_scalar)) - { - throw invalid_argument("modulus is not coprime with leading denominator coefficient"); - } - - // Perform coefficient-wise division algorithm. - while (numerator_coeffs >= denominator_coeffs) - { - // Determine leading numerator coefficient. - uint64_t leading_numerator_coeff = numerator[numerator_coeffs - 1]; - - // If leading numerator coefficient is not zero, then need to make zero by subtraction. - if (leading_numerator_coeff) - { - // Determine shift necesarry to bring significant coefficients in alignment. - size_t denominator_shift = numerator_coeffs - denominator_coeffs; - - // Determine quotient's coefficient, which is scalar that makes - // denominator's leading coefficient one multiplied by leading - // coefficient of denominator (which when subtracted will zero - // out the topmost denominator coefficient). - uint64_t "ient_coeff = quotient[denominator_shift]; - temp_quotient = multiply_uint_uint_mod( - monic_denominator_scalar, leading_numerator_coeff, modulus); - quotient_coeff = temp_quotient; - - // Subtract numerator and quotient*denominator (shifted by denominator_shift). - for (size_t denominator_coeff_index = 0; - denominator_coeff_index < denominator_coeffs; denominator_coeff_index++) - { - // Multiply denominator's coefficient by quotient. - uint64_t denominator_coeff = denominator[denominator_coeff_index]; - subtrahend = multiply_uint_uint_mod(temp_quotient, denominator_coeff, modulus); - - // Subtract numerator with resulting product, appropriately shifted by denominator shift. - uint64_t &numerator_coeff = numerator[denominator_coeff_index + denominator_shift]; - numerator_coeff = sub_uint_uint_mod(numerator_coeff, subtrahend, modulus); - } - } - - // Top numerator coefficient must now be zero, so adjust coefficient count. - numerator_coeffs--; - } - } - - void apply_galois(const uint64_t *input, int coeff_count_power, - uint64_t galois_elt, const SmallModulus &modulus, uint64_t *result) - { -#ifdef SEAL_DEBUG - if (input == nullptr) - { - throw invalid_argument("input"); - } - if (result == nullptr) - { - throw invalid_argument("result"); - } - if (input == result) - { - throw invalid_argument("result cannot point to the same value as input"); - } - if (coeff_count_power < get_power_of_two(SEAL_POLY_MOD_DEGREE_MIN) || - coeff_count_power > get_power_of_two(SEAL_POLY_MOD_DEGREE_MAX)) - { - throw invalid_argument("coeff_count_power"); - } - // Verify coprime conditions. - if (!(galois_elt & 1) || - (galois_elt >= 2 * (uint64_t(1) << coeff_count_power))) - { - throw invalid_argument("Galois element is not valid"); - } - if (modulus.is_zero()) - { - throw invalid_argument("modulus"); - } -#endif - const uint64_t modulus_value = modulus.value(); - uint64_t coeff_count_minus_one = (uint64_t(1) << coeff_count_power) - 1; - for (uint64_t i = 0; i <= coeff_count_minus_one; i++) - { - uint64_t index_raw = i * galois_elt; - uint64_t index = index_raw & coeff_count_minus_one; - uint64_t result_value = *input++; - if ((index_raw >> coeff_count_power) & 1) - { - // Explicit inline - //result[index] = negate_uint_mod(result[index], modulus); - int64_t non_zero = (result_value != 0); - result_value = (modulus_value - result_value) & - static_cast(-non_zero); - } - result[index] = result_value; - } - } - - void apply_galois_ntt(const uint64_t *input, int coeff_count_power, - uint64_t galois_elt, uint64_t *result) - { -#ifdef SEAL_DEBUG - if (input == nullptr) - { - throw invalid_argument("input"); - } - if (result == nullptr) - { - throw invalid_argument("result"); - } - if (input == result) - { - throw invalid_argument("result cannot point to the same value as input"); - } - if (coeff_count_power <= 0) - { - throw invalid_argument("coeff_count_power"); - } - // Verify coprime conditions. - if (!(galois_elt & 1) || - (galois_elt >= 2 * (uint64_t(1) << coeff_count_power))) - { - throw invalid_argument("Galois element is not valid"); - } -#endif - size_t coeff_count = size_t(1) << coeff_count_power; - uint64_t m_minus_one = 2 * coeff_count - 1; - for (size_t i = 0; i < coeff_count; i++) - { - uint64_t reversed = reverse_bits(i, coeff_count_power); - uint64_t index_raw = galois_elt * (2 * reversed + 1); - index_raw &= m_minus_one; - uint64_t index = reverse_bits((index_raw - 1) >> 1, coeff_count_power); - result[i] = input[index]; - } - } - - void dyadic_product_coeffmod(const uint64_t *operand1, - const uint64_t *operand2, size_t coeff_count, - const SmallModulus &modulus, uint64_t *result) - { -#ifdef SEAL_DEBUG - if (operand1 == nullptr) - { - throw invalid_argument("operand1"); - } - if (operand2 == nullptr) - { - throw invalid_argument("operand2"); - } - if (result == nullptr) - { - throw invalid_argument("result"); - } - if (coeff_count == 0) - { - throw invalid_argument("coeff_count"); - } - if (modulus.is_zero()) - { - throw invalid_argument("modulus"); - } -#endif - // Explicit inline - //for (int i = 0; i < coeff_count; i++) - //{ - // *result++ = multiply_uint_uint_mod(*operand1++, *operand2++, modulus); - //} - const uint64_t modulus_value = modulus.value(); - const uint64_t const_ratio_0 = modulus.const_ratio()[0]; - const uint64_t const_ratio_1 = modulus.const_ratio()[1]; - for (; coeff_count--; operand1++, operand2++, result++) - { - // Reduces z using base 2^64 Barrett reduction - unsigned long long z[2], tmp1, tmp2[2], tmp3, carry; - multiply_uint64(*operand1, *operand2, z); - - // Multiply input and const_ratio - // Round 1 - multiply_uint64_hw64(z[0], const_ratio_0, &carry); - multiply_uint64(z[0], const_ratio_1, tmp2); - tmp3 = tmp2[1] + add_uint64(tmp2[0], carry, &tmp1); - - // Round 2 - multiply_uint64(z[1], const_ratio_0, tmp2); - carry = tmp2[1] + add_uint64(tmp1, tmp2[0], &tmp1); - - // This is all we care about - tmp1 = z[1] * const_ratio_1 + tmp3 + carry; - - // Barrett subtraction - tmp3 = z[0] - tmp1 * modulus_value; - - // Claim: One more subtraction is enough - *result = tmp3 - (modulus_value & static_cast( - -static_cast(tmp3 >= modulus_value))); - } - } - - uint64_t poly_infty_norm_coeffmod(const uint64_t *operand, - size_t coeff_count, const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (operand == nullptr && coeff_count > 0) - { - throw invalid_argument("operand"); - } - if (modulus.is_zero()) - { - throw invalid_argument("modulus"); - } -#endif - // Construct negative threshold (first negative modulus value) to compute absolute values of coeffs. - uint64_t modulus_neg_threshold = (modulus.value() + 1) >> 1; - - // Mod out the poly coefficients and choose a symmetric representative from - // [-modulus,modulus). Keep track of the max. - uint64_t result = 0; - for (size_t coeff_index = 0; coeff_index < coeff_count; coeff_index++) - { - uint64_t poly_coeff = operand[coeff_index] % modulus.value(); - if (poly_coeff >= modulus_neg_threshold) - { - poly_coeff = modulus.value() - poly_coeff; - } - if (poly_coeff > result) - { - result = poly_coeff; - } - } - return result; - } - - bool try_invert_poly_coeffmod(const uint64_t *operand, const uint64_t *poly_modulus, - size_t coeff_count, const SmallModulus &modulus, uint64_t *result, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (operand == nullptr) - { - throw invalid_argument("operand"); - } - if (poly_modulus == nullptr) - { - throw invalid_argument("poly_modulus"); - } - if (coeff_count == 0) - { - throw invalid_argument("coeff_count"); - } - if (result == nullptr) - { - throw invalid_argument("result"); - } - if (get_significant_uint64_count_uint(operand, coeff_count) >= - get_significant_uint64_count_uint(poly_modulus, coeff_count)) - { - throw out_of_range("operand"); - } - if (modulus.is_zero()) - { - throw invalid_argument("modulus"); - } -#endif - // Cannot invert 0 poly. - if (is_zero_poly(operand, coeff_count, size_t(1))) - { - return false; - } - - // Construct a mutable copy of operand and modulus, with numerator being modulus - // and operand being denominator. Notice that degree(numerator) >= degree(denominator). - auto numerator_anchor(allocate_uint(coeff_count, pool)); - uint64_t *numerator = numerator_anchor.get(); - set_uint_uint(poly_modulus, coeff_count, numerator); - auto denominator_anchor(allocate_uint(coeff_count, pool)); - uint64_t *denominator = denominator_anchor.get(); - set_uint_uint(operand, coeff_count, denominator); - - // Determine most significant coefficients of each. - size_t numerator_coeffs = get_significant_coeff_count_poly( - numerator, coeff_count, size_t(1)); - size_t denominator_coeffs = get_significant_coeff_count_poly( - denominator, coeff_count, size_t(1)); - - // Create poly to store quotient. - auto quotient(allocate_uint(coeff_count, pool)); - - // Create scalar to store value that makes denominator monic. - uint64_t monic_denominator_scalar; - - // Create temporary scalars used during calculation of quotient. - // Both are purposely twice as wide to store intermediate product prior to modulo operation. - uint64_t temp_quotient; - uint64_t subtrahend; - - // Create three polynomials to store inverse. - // Initialize invert_prior to 0 and invert_curr to 1. - auto invert_prior_anchor(allocate_uint(coeff_count, pool)); - uint64_t *invert_prior = invert_prior_anchor.get(); - set_zero_uint(coeff_count, invert_prior); - auto invert_curr_anchor(allocate_uint(coeff_count, pool)); - uint64_t *invert_curr = invert_curr_anchor.get(); - set_zero_uint(coeff_count, invert_curr); - invert_curr[0] = 1; - auto invert_next_anchor(allocate_uint(coeff_count, pool)); - uint64_t *invert_next = invert_next_anchor.get(); - - // Perform extended Euclidean algorithm. - while (true) - { - // NOTE: degree(numerator) >= degree(denominator). - - // Determine scalar necessary to make denominator monic. - uint64_t leading_denominator_coeff = - denominator[denominator_coeffs - 1]; - if (!try_invert_uint_mod(leading_denominator_coeff, modulus, - monic_denominator_scalar)) - { - throw invalid_argument("modulus is not coprime with leading denominator coefficient"); - } - - // Clear quotient. - set_zero_uint(coeff_count, quotient.get()); - - // Perform coefficient-wise division algorithm. - while (numerator_coeffs >= denominator_coeffs) - { - // Determine leading numerator coefficient. - uint64_t leading_numerator_coeff = numerator[numerator_coeffs - 1]; - - // If leading numerator coefficient is not zero, then need to make zero by subtraction. - if (leading_numerator_coeff) - { - // Determine shift necessary to bring significant coefficients in alignment. - size_t denominator_shift = numerator_coeffs - denominator_coeffs; - - // Determine quotient's coefficient, which is scalar that makes - // denominator's leading coefficient one multiplied by leading - // coefficient of denominator (which when subtracted will zero - // out the topmost denominator coefficient). - uint64_t "ient_coeff = quotient[denominator_shift]; - temp_quotient = multiply_uint_uint_mod( - monic_denominator_scalar, leading_numerator_coeff, modulus); - quotient_coeff = temp_quotient; - - // Subtract numerator and quotient*denominator (shifted by denominator_shift). - for (size_t denominator_coeff_index = 0; - denominator_coeff_index < denominator_coeffs; - denominator_coeff_index++) - { - // Multiply denominator's coefficient by quotient. - uint64_t denominator_coeff = denominator[denominator_coeff_index]; - subtrahend = multiply_uint_uint_mod(temp_quotient, denominator_coeff, modulus); - - // Subtract numerator with resulting product, appropriately shifted by - // denominator shift. - uint64_t &numerator_coeff = numerator[denominator_coeff_index + denominator_shift]; - numerator_coeff = sub_uint_uint_mod(numerator_coeff, subtrahend, modulus); - } - } - - // Top numerator coefficient must now be zero, so adjust coefficient count. - numerator_coeffs--; - } - - // Double check that numerator coefficients is correct because possible - // other coefficients are zero. - numerator_coeffs = get_significant_coeff_count_poly( - numerator, coeff_count, size_t(1)); - - // We are done if numerator is zero. - if (numerator_coeffs == 0) - { - break; - } - - // Integrate quotient with invert coefficients. - // Calculate: invert_next = invert_prior + -quotient * invert_curr - multiply_truncate_poly_poly_coeffmod(quotient.get(), invert_curr, - coeff_count, modulus, invert_next); - sub_poly_poly_coeffmod(invert_prior, invert_next, coeff_count, - modulus, invert_next); - - // Swap prior and curr, and then curr and next. - swap(invert_prior, invert_curr); - swap(invert_curr, invert_next); - - // Swap numerator and denominator. - swap(numerator, denominator); - swap(numerator_coeffs, denominator_coeffs); - } - - // Polynomial is invertible only if denominator is just a scalar. - if (denominator_coeffs != 1) - { - return false; - } - - // Determine scalar necessary to make denominator monic. - uint64_t leading_denominator_coeff = denominator[0]; - if (!try_invert_uint_mod(leading_denominator_coeff, modulus, - monic_denominator_scalar)) - { - throw invalid_argument("modulus is not coprime with leading denominator coefficient"); - } - - // Multiply inverse by scalar and done. - multiply_poly_scalar_coeffmod(invert_curr, coeff_count, - monic_denominator_scalar, modulus, result); - return true; - } - - void negacyclic_shift_poly_coeffmod(const uint64_t *operand, - size_t coeff_count, size_t shift, const SmallModulus &modulus, - uint64_t *result) - { -#ifdef SEAL_DEBUG - if (operand == nullptr) - { - throw invalid_argument("operand"); - } - if (result == nullptr) - { - throw invalid_argument("result"); - } - if (operand == result) - { - throw invalid_argument("result cannot point to the same value as operand"); - } - if (modulus.is_zero()) - { - throw invalid_argument("modulus"); - } - if (util::get_power_of_two(static_cast(coeff_count)) < 0) - { - throw invalid_argument("coeff_count"); - } - if (shift >= coeff_count) - { - throw invalid_argument("shift"); - } -#endif - // Nothing to do - if (shift == 0) - { - set_uint_uint(operand, coeff_count, result); - return; - } - - uint64_t index_raw = shift; - uint64_t coeff_count_mod_mask = static_cast(coeff_count) - 1; - for (size_t i = 0; i < coeff_count; i++, operand++, index_raw++) - { - uint64_t index = index_raw & coeff_count_mod_mask; - if (!(index_raw & static_cast(coeff_count)) || !*operand) - { - result[index] = *operand; - } - else - { - result[index] = modulus.value() - *operand; - } - } - } - } -} diff --git a/SEAL/native/src/seal/util/polyarithsmallmod.h b/SEAL/native/src/seal/util/polyarithsmallmod.h deleted file mode 100644 index fab78e5..0000000 --- a/SEAL/native/src/seal/util/polyarithsmallmod.h +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include "seal/smallmodulus.h" -#include "seal/util/common.h" -#include "seal/util/polycore.h" -#include "seal/util/uintarithsmallmod.h" -#include "seal/util/pointer.h" - -namespace seal -{ - namespace util - { - inline void modulo_poly_coeffs(const std::uint64_t *poly, - std::size_t coeff_count, const SmallModulus &modulus, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (poly == nullptr && coeff_count > 0) - { - throw std::invalid_argument("poly"); - } - if (result == nullptr && coeff_count > 0) - { - throw std::invalid_argument("result"); - } - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } -#endif - std::transform(poly, poly + coeff_count, result, - [&](auto coeff) { - uint64_t temp[2]{ coeff, 0 }; - return barrett_reduce_128(temp, modulus); }); - } - - inline void modulo_poly_coeffs_63(const std::uint64_t *poly, - std::size_t coeff_count, const SmallModulus &modulus, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (poly == nullptr && coeff_count > 0) - { - throw std::invalid_argument("poly"); - } - if (result == nullptr && coeff_count > 0) - { - throw std::invalid_argument("result"); - } - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } -#endif - // This function is the fastest for reducing polynomial coefficients, - // but requires that the input coefficients are at most 63 bits, unlike - // modulo_poly_coeffs that allows also 64-bit coefficients. - std::transform(poly, poly + coeff_count, result, - [&](auto coeff) { - return barrett_reduce_63(coeff, modulus); }); - } - - inline void negate_poly_coeffmod(const std::uint64_t *poly, - std::size_t coeff_count, const SmallModulus &modulus, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (poly == nullptr && coeff_count > 0) - { - throw std::invalid_argument("poly"); - } - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } - if (result == nullptr && coeff_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - const uint64_t modulus_value = modulus.value(); - for (; coeff_count--; poly++, result++) - { - // Explicit inline - //*result = negate_uint_mod(*poly, modulus); -#ifdef SEAL_DEBUG - if (*poly >= modulus_value) - { - throw std::out_of_range("poly"); - } -#endif - std::int64_t non_zero = (*poly != 0); - *result = (modulus_value - *poly) & - static_cast(-non_zero); - } - } - - inline void add_poly_poly_coeffmod(const std::uint64_t *operand1, - const std::uint64_t *operand2, std::size_t coeff_count, - const SmallModulus &modulus, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (operand1 == nullptr && coeff_count > 0) - { - throw std::invalid_argument("operand1"); - } - if (operand2 == nullptr && coeff_count > 0) - { - throw std::invalid_argument("operand2"); - } - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } - if (result == nullptr && coeff_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - const uint64_t modulus_value = modulus.value(); - for (; coeff_count--; result++, operand1++, operand2++) - { - // Explicit inline - //result[i] = add_uint_uint_mod(operand1[i], operand2[i], modulus); -#ifdef SEAL_DEBUG - if (*operand1 >= modulus_value) - { - throw std::invalid_argument("operand1"); - } - if (*operand2 >= modulus_value) - { - throw std::invalid_argument("operand2"); - } -#endif - std::uint64_t sum = *operand1 + *operand2; - *result = sum - (modulus_value & static_cast( - -static_cast(sum >= modulus_value))); - } - } - - inline void sub_poly_poly_coeffmod(const std::uint64_t *operand1, - const std::uint64_t *operand2, std::size_t coeff_count, - const SmallModulus &modulus, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (operand1 == nullptr && coeff_count > 0) - { - throw std::invalid_argument("operand1"); - } - if (operand2 == nullptr && coeff_count > 0) - { - throw std::invalid_argument("operand2"); - } - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } - if (result == nullptr && coeff_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - const uint64_t modulus_value = modulus.value(); - for (; coeff_count--; result++, operand1++, operand2++) - { -#ifdef SEAL_DEBUG - if (*operand1 >= modulus_value) - { - throw std::out_of_range("operand1"); - } - if (*operand2 >= modulus_value) - { - throw std::out_of_range("operand2"); - } -#endif - unsigned long long temp_result; - std::int64_t borrow = sub_uint64(*operand1, *operand2, &temp_result); - *result = temp_result + (modulus_value & static_cast(-borrow)); - } - } - - void multiply_poly_scalar_coeffmod(const std::uint64_t *poly, - std::size_t coeff_count, std::uint64_t scalar, const SmallModulus &modulus, - std::uint64_t *result); - - void multiply_poly_poly_coeffmod(const std::uint64_t *operand1, - std::size_t operand1_coeff_count, const std::uint64_t *operand2, - std::size_t operand2_coeff_count, const SmallModulus &modulus, - std::size_t result_coeff_count, std::uint64_t *result); - - void multiply_poly_poly_coeffmod(const std::uint64_t *operand1, - const std::uint64_t *operand2, std::size_t coeff_count, - const SmallModulus &modulus, std::uint64_t *result); - - inline void multiply_truncate_poly_poly_coeffmod( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t coeff_count, const SmallModulus &modulus, std::uint64_t *result) - { - multiply_poly_poly_coeffmod(operand1, coeff_count, operand2, coeff_count, - modulus, coeff_count, result); - } - - void divide_poly_poly_coeffmod_inplace(std::uint64_t *numerator, - const std::uint64_t *denominator, std::size_t coeff_count, - const SmallModulus &modulus, std::uint64_t *quotient); - - inline void divide_poly_poly_coeffmod(const std::uint64_t *numerator, - const std::uint64_t *denominator, std::size_t coeff_count, - const SmallModulus &modulus, std::uint64_t *quotient, - std::uint64_t *remainder) - { - set_uint_uint(numerator, coeff_count, remainder); - divide_poly_poly_coeffmod_inplace(remainder, denominator, coeff_count, - modulus, quotient); - } - - void apply_galois(const std::uint64_t *input, int coeff_count_power, - std::uint64_t galois_elt, const SmallModulus &modulus, std::uint64_t *result); - - void apply_galois_ntt(const std::uint64_t *input, int coeff_count_power, - std::uint64_t galois_elt, std::uint64_t *result); - - void dyadic_product_coeffmod(const std::uint64_t *operand1, - const std::uint64_t *operand2, std::size_t coeff_count, - const SmallModulus &modulus, std::uint64_t *result); - - std::uint64_t poly_infty_norm_coeffmod(const std::uint64_t *operand, - std::size_t coeff_count, const SmallModulus &modulus); - - bool try_invert_poly_coeffmod(const std::uint64_t *operand, - const std::uint64_t *poly_modulus, std::size_t coeff_count, - const SmallModulus &modulus, std::uint64_t *result, MemoryPool &pool); - - void negacyclic_shift_poly_coeffmod(const std::uint64_t *operand, - std::size_t coeff_count, std::size_t shift, const SmallModulus &modulus, - std::uint64_t *result); - - inline void negacyclic_multiply_poly_mono_coeffmod( - const std::uint64_t *operand, std::size_t coeff_count, - std::uint64_t mono_coeff, std::size_t mono_exponent, - const SmallModulus &modulus, std::uint64_t *result, MemoryPool &pool) - { - auto temp(util::allocate_uint(coeff_count, pool)); - multiply_poly_scalar_coeffmod( - operand, coeff_count, mono_coeff, modulus, temp.get()); - negacyclic_shift_poly_coeffmod(temp.get(), coeff_count, mono_exponent, - modulus, result); - } - } -} diff --git a/SEAL/native/src/seal/util/polycore.h b/SEAL/native/src/seal/util/polycore.h deleted file mode 100644 index db180fc..0000000 --- a/SEAL/native/src/seal/util/polycore.h +++ /dev/null @@ -1,375 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/pointer.h" - -namespace seal -{ - namespace util - { - SEAL_NODISCARD inline std::string poly_to_hex_string( - const std::uint64_t *value, std::size_t coeff_count, - std::size_t coeff_uint64_count) - { -#ifdef SEAL_DEBUG - if (coeff_uint64_count && coeff_count && !value) - { - throw std::invalid_argument("value"); - } -#endif - std::ostringstream result; - bool empty = true; - value += util::mul_safe(coeff_count - 1, coeff_uint64_count); - while (coeff_count--) - { - if (is_zero_uint(value, coeff_uint64_count)) - { - value -= coeff_uint64_count; - continue; - } - if (!empty) - { - result << " + "; - } - result << uint_to_hex_string(value, coeff_uint64_count); - if (coeff_count) - { - result << "x^" << coeff_count; - } - empty = false; - value -= coeff_uint64_count; - } - if (empty) - { - result << "0"; - } - return result.str(); - } - - SEAL_NODISCARD inline std::string poly_to_dec_string( - const std::uint64_t *value, std::size_t coeff_count, - std::size_t coeff_uint64_count, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (coeff_uint64_count && coeff_count && !value) - { - throw std::invalid_argument("value"); - } -#endif - std::ostringstream result; - bool empty = true; - value += coeff_count - 1; - while (coeff_count--) - { - if (is_zero_uint(value, coeff_uint64_count)) - { - value -= coeff_uint64_count; - continue; - } - if (!empty) - { - result << " + "; - } - result << uint_to_dec_string(value, coeff_uint64_count, pool); - if (coeff_count) - { - result << "x^" << coeff_count; - } - empty = false; - value -= coeff_uint64_count; - } - if (empty) - { - result << "0"; - } - return result.str(); - } - - SEAL_NODISCARD inline auto allocate_poly(std::size_t coeff_count, - std::size_t coeff_uint64_count, MemoryPool &pool) - { - return allocate_uint( - util::mul_safe(coeff_count, coeff_uint64_count), pool); - } - - inline void set_zero_poly(std::size_t coeff_count, - std::size_t coeff_uint64_count, std::uint64_t* result) - { -#ifdef SEAL_DEBUG - if (!result && coeff_count && coeff_uint64_count) - { - throw std::invalid_argument("result"); - } -#endif - set_zero_uint(util::mul_safe(coeff_count, coeff_uint64_count), result); - } - - SEAL_NODISCARD inline auto allocate_zero_poly( - std::size_t coeff_count, std::size_t coeff_uint64_count, - MemoryPool &pool) - { - return allocate_zero_uint( - util::mul_safe(coeff_count, coeff_uint64_count), pool); - } - - SEAL_NODISCARD inline std::uint64_t *get_poly_coeff( - std::uint64_t *poly, std::size_t coeff_index, - std::size_t coeff_uint64_count) - { -#ifdef SEAL_DEBUG - if (!poly) - { - throw std::invalid_argument("poly"); - } -#endif - return poly + util::mul_safe(coeff_index, coeff_uint64_count); - } - - SEAL_NODISCARD inline const std::uint64_t *get_poly_coeff( - const std::uint64_t *poly, std::size_t coeff_index, - std::size_t coeff_uint64_count) - { -#ifdef SEAL_DEBUG - if (!poly) - { - throw std::invalid_argument("poly"); - } -#endif - return poly + util::mul_safe(coeff_index, coeff_uint64_count); - } - - inline void set_poly_poly(const std::uint64_t *poly, - std::size_t coeff_count, std::size_t coeff_uint64_count, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!poly && coeff_count && coeff_uint64_count) - { - throw std::invalid_argument("poly"); - } - if (!result && coeff_count && coeff_uint64_count) - { - throw std::invalid_argument("result"); - } -#endif - set_uint_uint(poly, - util::mul_safe(coeff_count, coeff_uint64_count), result); - } - - SEAL_NODISCARD inline bool is_zero_poly( - const std::uint64_t *poly, std::size_t coeff_count, - std::size_t coeff_uint64_count) - { -#ifdef SEAL_DEBUG - if (!poly && coeff_count && coeff_uint64_count) - { - throw std::invalid_argument("poly"); - } -#endif - return is_zero_uint(poly, - util::mul_safe(coeff_count, coeff_uint64_count)); - } - - SEAL_NODISCARD inline bool is_equal_poly_poly( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t coeff_count, std::size_t coeff_uint64_count) - { -#ifdef SEAL_DEBUG - if (!operand1 && coeff_count && coeff_uint64_count) - { - throw std::invalid_argument("operand1"); - } - if (!operand2 && coeff_count && coeff_uint64_count) - { - throw std::invalid_argument("operand2"); - } -#endif - return is_equal_uint_uint(operand1, operand2, - util::mul_safe(coeff_count, coeff_uint64_count)); - } - - inline void set_poly_poly(const std::uint64_t *poly, std::size_t poly_coeff_count, - std::size_t poly_coeff_uint64_count, std::size_t result_coeff_count, - std::size_t result_coeff_uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!poly && poly_coeff_count && poly_coeff_uint64_count) - { - throw std::invalid_argument("poly"); - } - if (!result && result_coeff_count && result_coeff_uint64_count) - { - throw std::invalid_argument("result"); - } -#endif - if (!result_coeff_uint64_count || !result_coeff_count) - { - return; - } - - std::size_t min_coeff_count = std::min(poly_coeff_count, result_coeff_count); - for (std::size_t i = 0; i < min_coeff_count; i++, - poly += poly_coeff_uint64_count, result += result_coeff_uint64_count) - { - set_uint_uint(poly, poly_coeff_uint64_count, result_coeff_uint64_count, result); - } - set_zero_uint(util::mul_safe( - result_coeff_count - min_coeff_count, result_coeff_uint64_count), result); - } - - SEAL_NODISCARD inline bool is_one_zero_one_poly( - const std::uint64_t *poly, std::size_t coeff_count, - std::size_t coeff_uint64_count) - { -#ifdef SEAL_DEBUG - if (!poly && coeff_count && coeff_uint64_count) - { - throw std::invalid_argument("poly"); - } -#endif - if (coeff_count == 0 || coeff_uint64_count == 0) - { - return false; - } - if (!is_equal_uint(get_poly_coeff(poly, 0, coeff_uint64_count), - coeff_uint64_count, 1)) - { - return false; - } - if (!is_equal_uint(get_poly_coeff(poly, coeff_count - 1, coeff_uint64_count), - coeff_uint64_count, 1)) - { - return false; - } - if (coeff_count > 2 && - !is_zero_poly(poly + coeff_uint64_count, - coeff_count - 2, coeff_uint64_count)) - { - return false; - } - return true; - } - - SEAL_NODISCARD inline std::size_t get_significant_coeff_count_poly( - const std::uint64_t *poly, std::size_t coeff_count, - std::size_t coeff_uint64_count) - { -#ifdef SEAL_DEBUG - if (!poly && coeff_count && coeff_uint64_count) - { - throw std::invalid_argument("poly"); - } -#endif - if (coeff_count == 0) - { - return 0; - } - - poly += util::mul_safe(coeff_count - 1, coeff_uint64_count); - for (std::size_t i = coeff_count; i; i--) - { - if (!is_zero_uint(poly, coeff_uint64_count)) - { - return i; - } - poly -= coeff_uint64_count; - } - return 0; - } - - SEAL_NODISCARD inline auto duplicate_poly_if_needed( - const std::uint64_t *poly, std::size_t coeff_count, - std::size_t coeff_uint64_count, std::size_t new_coeff_count, - std::size_t new_coeff_uint64_count, bool force, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (!poly && coeff_count && coeff_uint64_count) - { - throw std::invalid_argument("poly"); - } -#endif - if (!force && coeff_count >= new_coeff_count && - coeff_uint64_count == new_coeff_uint64_count) - { - return ConstPointer::Aliasing(poly); - } - auto allocation(allocate_poly( - new_coeff_count, new_coeff_uint64_count, pool)); - set_poly_poly(poly, coeff_count, coeff_uint64_count, new_coeff_count, - new_coeff_uint64_count, allocation.get()); - return ConstPointer(std::move(allocation)); - } - - SEAL_NODISCARD inline bool are_poly_coefficients_less_than( - const std::uint64_t *poly, std::size_t coeff_count, - std::size_t coeff_uint64_count, const std::uint64_t *compare, - std::size_t compare_uint64_count) - { -#ifdef SEAL_DEBUG - if (!poly && coeff_count && coeff_uint64_count) - { - throw std::invalid_argument("poly"); - } - if (!compare && compare_uint64_count > 0) - { - throw std::invalid_argument("compare"); - } -#endif - if (coeff_count == 0) - { - return true; - } - if (compare_uint64_count == 0) - { - return false; - } - if (coeff_uint64_count == 0) - { - return true; - } - for (; coeff_count--; poly += coeff_uint64_count) - { - if (compare_uint_uint(poly, coeff_uint64_count, compare, - compare_uint64_count) >= 0) - { - return false; - } - } - return true; - } - - SEAL_NODISCARD inline bool are_poly_coefficients_less_than( - const std::uint64_t *poly, std::size_t coeff_count, - std::uint64_t compare) - { -#ifdef SEAL_DEBUG - if (!poly && coeff_count) - { - throw std::invalid_argument("poly"); - } -#endif - if (coeff_count == 0) - { - return true; - } - for (; coeff_count--; poly++) - { - if (*poly >= compare) - { - return false; - } - } - return true; - } - } -} diff --git a/SEAL/native/src/seal/util/rlwe.cpp b/SEAL/native/src/seal/util/rlwe.cpp deleted file mode 100644 index 088142e..0000000 --- a/SEAL/native/src/seal/util/rlwe.cpp +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/randomtostd.h" -#include "seal/util/rlwe.h" -#include "seal/util/common.h" -#include "seal/util/clipnormal.h" -#include "seal/util/polycore.h" -#include "seal/util/smallntt.h" -#include "seal/util/polyarithsmallmod.h" -#include "seal/util/globals.h" - -using namespace std; - -namespace seal -{ - namespace util - { - void sample_poly_ternary( - uint64_t *poly, - shared_ptr random, - const EncryptionParameters &parms) - { - auto coeff_modulus = parms.coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t coeff_count = parms.poly_modulus_degree(); - RandomToStandardAdapter engine(random); - uniform_int_distribution dist(-1, 1); - - for (size_t i = 0; i < coeff_count; i++) - { - int rand_index = dist(engine); - if (rand_index == 1) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - poly[i + j * coeff_count] = 1; - } - } - else if (rand_index == -1) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - poly[i + j * coeff_count] = coeff_modulus[j].value() - 1; - } - } - else - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - poly[i + j * coeff_count] = 0; - } - } - } - } - - void sample_poly_normal( - uint64_t *poly, - shared_ptr random, - const EncryptionParameters &parms) - { - auto coeff_modulus = parms.coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t coeff_count = parms.poly_modulus_degree(); - - if (are_close(global_variables::noise_max_deviation, 0.0)) - { - set_zero_poly(coeff_count, coeff_mod_count, poly); - return; - } - - RandomToStandardAdapter engine(random); - ClippedNormalDistribution dist( - 0, global_variables::noise_standard_deviation, - global_variables::noise_max_deviation); - for (size_t i = 0; i < coeff_count; i++) - { - int64_t noise = static_cast(dist(engine)); - if (noise > 0) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - poly[i + j * coeff_count] = static_cast(noise); - } - } - else if (noise < 0) - { - noise = -noise; - for (size_t j = 0; j < coeff_mod_count; j++) - { - poly[i + j * coeff_count] = coeff_modulus[j].value() - - static_cast(noise); - } - } - else - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - poly[i + j * coeff_count] = 0; - } - } - } - } - - void sample_poly_uniform( - uint64_t *poly, - shared_ptr random, - const EncryptionParameters &parms) - { - // Extract encryption parameters. - auto coeff_modulus = parms.coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t coeff_count = parms.poly_modulus_degree(); - - // Set up source of randomness that produces 32 bit random things. - RandomToStandardAdapter engine(random); - - // We sample numbers up to 2^63-1 to use barrett_reduce_63 - constexpr uint64_t max_random = - numeric_limits::max() & uint64_t(0x7FFFFFFFFFFFFFFF); - for (size_t j = 0; j < coeff_mod_count; j++) - { - auto &modulus = coeff_modulus[j]; - uint64_t max_multiple = max_random - max_random % modulus.value(); - for (size_t i = 0; i < coeff_count; i++) - { - // This ensures uniform distribution. - uint64_t rand; - do - { - rand = (static_cast(engine()) << 31) | - (static_cast(engine() >> 1)); - } - while (rand >= max_multiple); - poly[i + j * coeff_count] = barrett_reduce_63(rand, modulus); - } - } - } - - void encrypt_zero_asymmetric( - const PublicKey &public_key, - shared_ptr context, - parms_id_type parms_id, - shared_ptr random, - bool is_ntt_form, - Ciphertext &destination, - MemoryPoolHandle pool) - { - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - if (public_key.parms_id() != context->key_parms_id()) - { - throw invalid_argument("key_parms_id mismatch"); - } - - auto &context_data = *context->get_context_data(parms_id); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t coeff_count = parms.poly_modulus_degree(); - auto &small_ntt_tables = context_data.small_ntt_tables(); - size_t encrypted_size = public_key.data().size(); - - if (encrypted_size < 2) - { - throw invalid_argument("public_key has less than 2 parts"); - } - - // Make destination have right size and parms_id - // Ciphertext (c_0,c_1, ...) - destination.resize(context, parms_id, encrypted_size); - destination.is_ntt_form() = is_ntt_form; - destination.scale() = 1.0; - - // c[j] = public_key[j] * u + e[j] where e[j] <-- chi, u <-- R_3. - - // Generate u <-- R_3 - auto u(allocate_poly(coeff_count, coeff_mod_count, pool)); - sample_poly_ternary(u.get(), random, parms); - - // c[j] = u * public_key[j] - for (size_t i = 0; i < coeff_mod_count; i++) - { - ntt_negacyclic_harvey( - u.get() + i * coeff_count, - small_ntt_tables[i]); - for (size_t j = 0; j < encrypted_size; j++) - { - dyadic_product_coeffmod( - u.get() + i * coeff_count, - public_key.data().data(j) + i * coeff_count, - coeff_count, - coeff_modulus[i], - destination.data(j) + i * coeff_count); - - // addition with e_0, e_1 is in non-NTT form. - if (!is_ntt_form) - { - inverse_ntt_negacyclic_harvey( - destination.data(j) + i * coeff_count, - small_ntt_tables[i]); - } - } - } - - // Generate e_j <-- chi. - // c[j] = public_key[j] * u + e[j] - for (size_t j = 0; j < encrypted_size; j++) - { - sample_poly_normal(u.get(), random, parms); - for (size_t i = 0; i < coeff_mod_count; i++) - { - // addition with e_0, e_1 is in NTT form. - if (is_ntt_form) - { - ntt_negacyclic_harvey( - u.get() + i * coeff_count, - small_ntt_tables[i]); - } - add_poly_poly_coeffmod( - u.get() + i * coeff_count, - destination.data(j) + i * coeff_count, - coeff_count, - coeff_modulus[i], - destination.data(j) + i * coeff_count); - } - } - } - - void encrypt_zero_symmetric( - const SecretKey &secret_key, - shared_ptr context, - parms_id_type parms_id, - shared_ptr random, - bool is_ntt_form, - Ciphertext &destination, - MemoryPoolHandle pool) - { - if (!pool) - { - throw invalid_argument("pool is uninitialized"); - } - if (secret_key.parms_id() != context->key_parms_id()) - { - throw invalid_argument("key_parms_id mismatch"); - } - auto &context_data = *context->get_context_data(parms_id); - auto &parms = context_data.parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - size_t coeff_count = parms.poly_modulus_degree(); - auto &small_ntt_tables = context_data.small_ntt_tables(); - size_t encrypted_size = 2; - - destination.resize(context, parms_id, encrypted_size); - destination.is_ntt_form() = is_ntt_form; - destination.scale() = 1.0; - - // Generate ciphertext: (c[0], c[1]) = ([-(as+e)]_q, a) - - // Sample a uniformly at random - // Set c[1] = a (we sample the NTT form directly) - sample_poly_uniform(destination.data(1), random, parms); - - // Sample e <-- chi - auto noise(allocate_poly(coeff_count, coeff_mod_count, pool)); - sample_poly_normal(noise.get(), random, parms); - - // calculate -(a*s + e) (mod q) and store in c[0] - for (size_t i = 0; i < coeff_mod_count; i++) - { - // Transform the noise e into NTT representation. - ntt_negacyclic_harvey( - noise.get() + i * coeff_count, - small_ntt_tables[i]); - dyadic_product_coeffmod( - secret_key.data().data() + i * coeff_count, - destination.data(1) + i * coeff_count, - coeff_count, - coeff_modulus[i], - destination.data() + i * coeff_count); - add_poly_poly_coeffmod( - noise.get() + i * coeff_count, - destination.data() + i * coeff_count, - coeff_count, - coeff_modulus[i], - destination.data() + i * coeff_count); - negate_poly_coeffmod( - destination.data() + i * coeff_count, - coeff_count, - coeff_modulus[i], - destination.data() + i * coeff_count); - } - } - } -} diff --git a/SEAL/native/src/seal/util/rlwe.h b/SEAL/native/src/seal/util/rlwe.h deleted file mode 100644 index fb9f0d7..0000000 --- a/SEAL/native/src/seal/util/rlwe.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include "seal/randomgen.h" -#include "seal/encryptionparams.h" -#include "seal/publickey.h" -#include "seal/secretkey.h" -#include "seal/publickey.h" -#include "seal/ciphertext.h" -#include "seal/context.h" - -namespace seal -{ - namespace util - { - void sample_poly_ternary( - std::uint64_t *poly, - std::shared_ptr random, - const EncryptionParameters &parms); - - void sample_poly_normal( - std::uint64_t *poly, - std::shared_ptr random, - const EncryptionParameters &parms); - - void sample_poly_uniform( - std::uint64_t *poly, - std::shared_ptr random, - const EncryptionParameters &parms); - - void encrypt_zero_asymmetric( - const PublicKey &public_key, - std::shared_ptr context, - parms_id_type parms_id, - std::shared_ptr random, - bool is_ntt_form, - Ciphertext &destination, - MemoryPoolHandle pool); - - void encrypt_zero_symmetric( - const SecretKey &secret_key, - std::shared_ptr context, - parms_id_type parms_id, - std::shared_ptr random, - bool is_ntt_form, - Ciphertext &destination, - MemoryPoolHandle pool); - } -} \ No newline at end of file diff --git a/SEAL/native/src/seal/util/smallntt.cpp b/SEAL/native/src/seal/util/smallntt.cpp deleted file mode 100644 index 43a0110..0000000 --- a/SEAL/native/src/seal/util/smallntt.cpp +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/util/smallntt.h" -#include "seal/util/polyarith.h" -#include "seal/util/uintarith.h" -#include "seal/smallmodulus.h" -#include "seal/util/uintarithsmallmod.h" -#include "seal/util/defines.h" -#include - -using namespace std; - -namespace seal -{ - namespace util - { - SmallNTTTables::SmallNTTTables(int coeff_count_power, - const SmallModulus &modulus, MemoryPoolHandle pool) : - pool_(move(pool)) - { -#ifdef SEAL_DEBUG - if (!pool_) - { - throw invalid_argument("pool is uninitialized"); - } -#endif - if (!generate(coeff_count_power, modulus)) - { - // Generation failed; probably modulus wasn't prime. - // It is necessary to check generated() after creating - // this class. - } - } - - void SmallNTTTables::reset() - { - generated_ = false; - modulus_ = SmallModulus(); - root_ = 0; - root_powers_.release(); - scaled_root_powers_.release(); - inv_root_powers_.release(); - scaled_inv_root_powers_.release(); - inv_root_powers_div_two_.release(); - scaled_inv_root_powers_div_two_.release(); - inv_degree_modulo_ = 0; - coeff_count_power_ = 0; - coeff_count_ = 0; - } - - bool SmallNTTTables::generate(int coeff_count_power, - const SmallModulus &modulus) - { - reset(); - - if ((coeff_count_power < get_power_of_two(SEAL_POLY_MOD_DEGREE_MIN)) || - coeff_count_power > get_power_of_two(SEAL_POLY_MOD_DEGREE_MAX)) - { - throw invalid_argument("coeff_count_power out of range"); - } - - coeff_count_power_ = coeff_count_power; - coeff_count_ = size_t(1) << coeff_count_power_; - - // Allocate memory for the tables - root_powers_ = allocate_uint(coeff_count_, pool_); - inv_root_powers_ = allocate_uint(coeff_count_, pool_); - scaled_root_powers_ = allocate_uint(coeff_count_, pool_); - scaled_inv_root_powers_ = allocate_uint(coeff_count_, pool_); - inv_root_powers_div_two_ = allocate_uint(coeff_count_, pool_); - scaled_inv_root_powers_div_two_ = allocate_uint(coeff_count_, pool_); - modulus_ = modulus; - - // We defer parameter checking to try_minimal_primitive_root(...) - if (!try_minimal_primitive_root(2 * coeff_count_, modulus_, root_)) - { - reset(); - return false; - } - - uint64_t inverse_root; - if (!try_invert_uint_mod(root_, modulus_, inverse_root)) - { - reset(); - return false; - } - - // Populate the tables storing (scaled version of) powers of root - // mod q in bit-scrambled order. - ntt_powers_of_primitive_root(root_, root_powers_.get()); - ntt_scale_powers_of_primitive_root(root_powers_.get(), - scaled_root_powers_.get()); - - // Populate the tables storing (scaled version of) powers of - // (root)^{-1} mod q in bit-scrambled order. - ntt_powers_of_primitive_root(inverse_root, inv_root_powers_.get()); - ntt_scale_powers_of_primitive_root(inv_root_powers_.get(), - scaled_inv_root_powers_.get()); - - // Populate the tables storing (scaled version of ) 2 times - // powers of roots^-1 mod q in bit-scrambled order. - for (size_t i = 0; i < coeff_count_; i++) - { - inv_root_powers_div_two_[i] = - div2_uint_mod(inv_root_powers_[i], modulus_); - } - ntt_scale_powers_of_primitive_root(inv_root_powers_div_two_.get(), - scaled_inv_root_powers_div_two_.get()); - - // Last compute n^(-1) modulo q. - uint64_t degree_uint = static_cast(coeff_count_); - generated_ = try_invert_uint_mod(degree_uint, modulus_, inv_degree_modulo_); - - if (!generated_) - { - reset(); - return false; - } - return true; - } - - void SmallNTTTables::ntt_powers_of_primitive_root(uint64_t root, - uint64_t *destination) const - { - uint64_t *destination_start = destination; - *destination_start = 1; - for (size_t i = 1; i < coeff_count_; i++) - { - uint64_t *next_destination = - destination_start + reverse_bits(i, coeff_count_power_); - *next_destination = - multiply_uint_uint_mod(*destination, root, modulus_); - destination = next_destination; - } - } - - // compute floor ( input * beta /q ), where beta is a 64k power of 2 - // and 0 < q < beta. - void SmallNTTTables::ntt_scale_powers_of_primitive_root( - const uint64_t *input, uint64_t *destination) const - { - for (size_t i = 0; i < coeff_count_; i++, input++, destination++) - { - uint64_t wide_quotient[2]{ 0, 0 }; - uint64_t wide_coeff[2]{ 0, *input }; - divide_uint128_uint64_inplace(wide_coeff, modulus_.value(), wide_quotient); - *destination = wide_quotient[0]; - } - } - - /** - This function computes in-place the negacyclic NTT. The input is - a polynomial a of degree n in R_q, where n is assumed to be a power of - 2 and q is a prime such that q = 1 (mod 2n). - - The output is a vector A such that the following hold: - A[j] = a(psi**(2*bit_reverse(j) + 1)), 0 <= j < n. - - For details, see Michael Naehrig and Patrick Longa. - */ - void ntt_negacyclic_harvey_lazy(uint64_t *operand, - const SmallNTTTables &tables) - { - uint64_t modulus = tables.modulus().value(); - uint64_t two_times_modulus = modulus * 2; - - // Return the NTT in scrambled order - size_t n = size_t(1) << tables.coeff_count_power(); - size_t t = n >> 1; - for (size_t m = 1; m < n; m <<= 1) - { - if (t >= 4) - { - for (size_t i = 0; i < m; i++) - { - size_t j1 = 2 * i * t; - size_t j2 = j1 + t; - const uint64_t W = tables.get_from_root_powers(m + i); - const uint64_t Wprime = tables.get_from_scaled_root_powers(m + i); - - uint64_t *X = operand + j1; - uint64_t *Y = X + t; - uint64_t currX; - unsigned long long Q; - for (size_t j = j1; j < j2; j += 4) - { - currX = *X - (two_times_modulus & static_cast(-static_cast(*X >= two_times_modulus))); - multiply_uint64_hw64(Wprime, *Y, &Q); - Q = *Y * W - Q * modulus; - *X++ = currX + Q; - *Y++ = currX + (two_times_modulus - Q); - - currX = *X - (two_times_modulus & static_cast(-static_cast(*X >= two_times_modulus))); - multiply_uint64_hw64(Wprime, *Y, &Q); - Q = *Y * W - Q * modulus; - *X++ = currX + Q; - *Y++ = currX + (two_times_modulus - Q); - - currX = *X - (two_times_modulus & static_cast(-static_cast(*X >= two_times_modulus))); - multiply_uint64_hw64(Wprime, *Y, &Q); - Q = *Y * W - Q * modulus; - *X++ = currX + Q; - *Y++ = currX + (two_times_modulus - Q); - - currX = *X - (two_times_modulus & static_cast(-static_cast(*X >= two_times_modulus))); - multiply_uint64_hw64(Wprime, *Y, &Q); - Q = *Y * W - Q * modulus; - *X++ = currX + Q; - *Y++ = currX + (two_times_modulus - Q); - } - } - } - else - { - for (size_t i = 0; i < m; i++) - { - size_t j1 = 2 * i * t; - size_t j2 = j1 + t; - const uint64_t W = tables.get_from_root_powers(m + i); - const uint64_t Wprime = tables.get_from_scaled_root_powers(m + i); - - uint64_t *X = operand + j1; - uint64_t *Y = X + t; - uint64_t currX; - unsigned long long Q; - for (size_t j = j1; j < j2; j++) - { - // The Harvey butterfly: assume X, Y in [0, 2p), and return X', Y' in [0, 2p). - // X', Y' = X + WY, X - WY (mod p). - currX = *X - (two_times_modulus & static_cast(-static_cast(*X >= two_times_modulus))); - multiply_uint64_hw64(Wprime, *Y, &Q); - Q = W * *Y - Q * modulus; - *X++ = currX + Q; - *Y++ = currX + (two_times_modulus - Q); - } - } - } - t >>= 1; - } - } - - // Inverse negacyclic NTT using Harvey's butterfly. (See Patrick Longa and Michael Naehrig). - void inverse_ntt_negacyclic_harvey_lazy(uint64_t *operand, const SmallNTTTables &tables) - { - uint64_t modulus = tables.modulus().value(); - uint64_t two_times_modulus = modulus * 2; - - // return the bit-reversed order of NTT. - size_t n = size_t(1) << tables.coeff_count_power(); - size_t t = 1; - - for (size_t m = n; m > 1; m >>= 1) - { - size_t j1 = 0; - size_t h = m >> 1; - if (t >= 4) - { - for (size_t i = 0; i < h; i++) - { - size_t j2 = j1 + t; - // Need the powers of phi^{-1} in bit-reversed order - const uint64_t W = tables.get_from_inv_root_powers_div_two(h + i); - const uint64_t Wprime = tables.get_from_scaled_inv_root_powers_div_two(h + i); - - uint64_t *U = operand + j1; - uint64_t *V = U + t; - uint64_t currU; - uint64_t T; - unsigned long long H; - for (size_t j = j1; j < j2; j += 4) - { - T = two_times_modulus - *V + *U; - currU = *U + *V - (two_times_modulus & static_cast(-static_cast((*U << 1) >= T))); - *U++ = (currU + (modulus & static_cast(-static_cast(T & 1)))) >> 1; - multiply_uint64_hw64(Wprime, T, &H); - *V++ = T * W - H * modulus; - - T = two_times_modulus - *V + *U; - currU = *U + *V - (two_times_modulus & static_cast(-static_cast((*U << 1) >= T))); - *U++ = (currU + (modulus & static_cast(-static_cast(T & 1)))) >> 1; - multiply_uint64_hw64(Wprime, T, &H); - *V++ = T * W - H * modulus; - - T = two_times_modulus - *V + *U; - currU = *U + *V - (two_times_modulus & static_cast(-static_cast((*U << 1) >= T))); - *U++ = (currU + (modulus & static_cast(-static_cast(T & 1)))) >> 1; - multiply_uint64_hw64(Wprime, T, &H); - *V++ = T * W - H * modulus; - - T = two_times_modulus - *V + *U; - currU = *U + *V - (two_times_modulus & static_cast(-static_cast((*U << 1) >= T))); - *U++ = (currU + (modulus & static_cast(-static_cast(T & 1)))) >> 1; - multiply_uint64_hw64(Wprime, T, &H); - *V++ = T * W - H * modulus; - } - j1 += (t << 1); - } - } - else - { - for (size_t i = 0; i < h; i++) - { - size_t j2 = j1 + t; - // Need the powers of phi^{-1} in bit-reversed order - const uint64_t W = tables.get_from_inv_root_powers_div_two(h + i); - const uint64_t Wprime = tables.get_from_scaled_inv_root_powers_div_two(h + i); - - uint64_t *U = operand + j1; - uint64_t *V = U + t; - uint64_t currU; - uint64_t T; - unsigned long long H; - for (size_t j = j1; j < j2; j++) - { - // U = x[i], V = x[i+m] - - // Compute U - V + 2q - T = two_times_modulus - *V + *U; - - // Cleverly check whether currU + currV >= two_times_modulus - currU = *U + *V - (two_times_modulus & static_cast(-static_cast((*U << 1) >= T))); - - // Need to make it so that div2_uint_mod takes values that are > q. - //div2_uint_mod(U, modulusptr, coeff_uint64_count, U); - // We use also the fact that parity of currU is same as parity of T. - // Since our modulus is always so small that currU + masked_modulus < 2^64, - // we never need to worry about wrapping around when adding masked_modulus. - //uint64_t masked_modulus = modulus & static_cast(-static_cast(T & 1)); - //uint64_t carry = add_uint64(currU, masked_modulus, 0, &currU); - //currU += modulus & static_cast(-static_cast(T & 1)); - *U++ = (currU + (modulus & static_cast(-static_cast(T & 1)))) >> 1; - - multiply_uint64_hw64(Wprime, T, &H); - // effectively, the next two multiply perform multiply modulo beta = 2**wordsize. - *V++ = W * T - H * modulus; - } - j1 += (t << 1); - } - } - t <<= 1; - } - } - } -} diff --git a/SEAL/native/src/seal/util/smallntt.h b/SEAL/native/src/seal/util/smallntt.h deleted file mode 100644 index 422672c..0000000 --- a/SEAL/native/src/seal/util/smallntt.h +++ /dev/null @@ -1,279 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include "seal/util/pointer.h" -#include "seal/memorymanager.h" -#include "seal/smallmodulus.h" - -namespace seal -{ - namespace util - { - class SmallNTTTables - { - public: - SmallNTTTables(MemoryPoolHandle pool = MemoryManager::GetPool()) : - pool_(std::move(pool)) - { -#ifdef SEAL_DEBUG - if (!pool_) - { - throw std::invalid_argument("pool is uninitialized"); - } -#endif - } - - SmallNTTTables(int coeff_count_power, const SmallModulus &modulus, - MemoryPoolHandle pool = MemoryManager::GetPool()); - - SEAL_NODISCARD inline bool is_generated() const - { - return generated_; - } - - bool generate(int coeff_count_power, const SmallModulus &modulus); - - void reset(); - - SEAL_NODISCARD inline std::uint64_t get_root() const - { -#ifdef SEAL_DEBUG - if (!generated_) - { - throw std::logic_error("tables are not generated"); - } -#endif - return root_; - } - - SEAL_NODISCARD inline auto get_from_root_powers( - std::size_t index) const -> std::uint64_t - { -#ifdef SEAL_DEBUG - if (index >= coeff_count_) - { - throw std::out_of_range("index"); - } - if (!generated_) - { - throw std::logic_error("tables are not generated"); - } -#endif - return root_powers_[index]; - } - - SEAL_NODISCARD inline auto get_from_scaled_root_powers( - std::size_t index) const -> std::uint64_t - { -#ifdef SEAL_DEBUG - if (index >= coeff_count_) - { - throw std::out_of_range("index"); - } - if (!generated_) - { - throw std::logic_error("tables are not generated"); - } -#endif - return scaled_root_powers_[index]; - } - - SEAL_NODISCARD inline auto get_from_inv_root_powers( - std::size_t index) const -> std::uint64_t - { -#ifdef SEAL_DEBUG - if (index >= coeff_count_) - { - throw std::out_of_range("index"); - } - if (!generated_) - { - throw std::logic_error("tables are not generated"); - } -#endif - return inv_root_powers_[index]; - } - - SEAL_NODISCARD inline auto get_from_scaled_inv_root_powers( - std::size_t index) const -> std::uint64_t - { -#ifdef SEAL_DEBUG - if (index >= coeff_count_) - { - throw std::out_of_range("index"); - } - if (!generated_) - { - throw std::logic_error("tables are not generated"); - } -#endif - return scaled_inv_root_powers_[index]; - } - - SEAL_NODISCARD inline auto get_from_inv_root_powers_div_two( - std::size_t index) const -> std::uint64_t - { -#ifdef SEAL_DEBUG - if (index >= coeff_count_) - { - throw std::out_of_range("index"); - } - if (!generated_) - { - throw std::logic_error("tables are not generated"); - } -#endif - return inv_root_powers_div_two_[index]; - } - - SEAL_NODISCARD inline auto get_from_scaled_inv_root_powers_div_two( - std::size_t index) const -> std::uint64_t - { -#ifdef SEAL_DEBUG - if (index >= coeff_count_) - { - throw std::out_of_range("index"); - } - if (!generated_) - { - throw std::logic_error("tables are not generated"); - } -#endif - return scaled_inv_root_powers_div_two_[index]; - } - - SEAL_NODISCARD inline auto get_inv_degree_modulo() const - -> const std::uint64_t* - { -#ifdef SEAL_DEBUG - if (!generated_) - { - throw std::logic_error("tables are not generated"); - } -#endif - return &inv_degree_modulo_; - } - - SEAL_NODISCARD inline const SmallModulus &modulus() const - { - return modulus_; - } - - SEAL_NODISCARD inline int coeff_count_power() const - { - return coeff_count_power_; - } - - SEAL_NODISCARD inline std::size_t coeff_count() const - { - return coeff_count_; - } - - private: - SmallNTTTables(const SmallNTTTables ©) = delete; - - SmallNTTTables(SmallNTTTables &&source) = delete; - - SmallNTTTables &operator =(const SmallNTTTables &assign) = delete; - - SmallNTTTables &operator =(SmallNTTTables &&assign) = delete; - - // Computed bit-scrambled vector of first 1 << coeff_count_power powers - // of a primitive root. - void ntt_powers_of_primitive_root(std::uint64_t root, - std::uint64_t *destination) const; - - // Scales the elements of a vector returned by powers_of_primitive_root(...) - // by word_size/modulus and rounds down. - void ntt_scale_powers_of_primitive_root(const std::uint64_t *input, - std::uint64_t *destination) const; - - MemoryPoolHandle pool_; - - bool generated_ = false; - - std::uint64_t root_ = 0; - - // Size coeff_count_ - Pointer root_powers_; - - // Size coeff_count_ - Pointer scaled_root_powers_; - - // Size coeff_count_ - Pointer inv_root_powers_div_two_; - - // Size coeff_count_ - Pointer scaled_inv_root_powers_div_two_; - - int coeff_count_power_ = 0; - - std::size_t coeff_count_ = 0; - - SmallModulus modulus_; - - // Size coeff_count_ - Pointer inv_root_powers_; - - // Size coeff_count_ - Pointer scaled_inv_root_powers_; - - std::uint64_t inv_degree_modulo_ = 0; - - }; - - void ntt_negacyclic_harvey_lazy(std::uint64_t *operand, - const SmallNTTTables &tables); - - inline void ntt_negacyclic_harvey(std::uint64_t *operand, - const SmallNTTTables &tables) - { - ntt_negacyclic_harvey_lazy(operand, tables); - - // Finally maybe we need to reduce every coefficient modulo q, but we - // know that they are in the range [0, 4q). - // Since word size is controlled this is fast. - std::uint64_t modulus = tables.modulus().value(); - std::uint64_t two_times_modulus = modulus * 2; - std::size_t n = std::size_t(1) << tables.coeff_count_power(); - - for (; n--; operand++) - { - if (*operand >= two_times_modulus) - { - *operand -= two_times_modulus; - } - if (*operand >= modulus) - { - *operand -= modulus; - } - } - } - - void inverse_ntt_negacyclic_harvey_lazy(std::uint64_t *operand, - const SmallNTTTables &tables); - - inline void inverse_ntt_negacyclic_harvey(std::uint64_t *operand, - const SmallNTTTables &tables) - { - inverse_ntt_negacyclic_harvey_lazy(operand, tables); - - std::uint64_t modulus = tables.modulus().value(); - std::size_t n = std::size_t(1) << tables.coeff_count_power(); - - // Final adjustments; compute a[j] = a[j] * n^{-1} mod q. - // We incorporated the final adjustment in the butterfly. Only need - // to reduce here. - for (; n--; operand++) - { - if (*operand >= modulus) - { - *operand -= modulus; - } - } - } - } -} diff --git a/SEAL/native/src/seal/util/uintarith.cpp b/SEAL/native/src/seal/util/uintarith.cpp deleted file mode 100644 index aacaf52..0000000 --- a/SEAL/native/src/seal/util/uintarith.cpp +++ /dev/null @@ -1,725 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include "seal/util/common.h" -#include -#include -#include - -using namespace std; - -namespace seal -{ - namespace util - { - void multiply_uint_uint(const uint64_t *operand1, - size_t operand1_uint64_count, const uint64_t *operand2, - size_t operand2_uint64_count, size_t result_uint64_count, - uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand1 && operand1_uint64_count > 0) - { - throw invalid_argument("operand1"); - } - if (!operand2 && operand2_uint64_count > 0) - { - throw invalid_argument("operand2"); - } - if (!result_uint64_count) - { - throw invalid_argument("result_uint64_count"); - } - if (!result) - { - throw invalid_argument("result"); - } - if (result != nullptr && (operand1 == result || operand2 == result)) - { - throw invalid_argument("result cannot point to the same value as operand1 or operand2"); - } -#endif - // Handle fast cases. - if (!operand1_uint64_count || !operand2_uint64_count) - { - // If either operand is 0, then result is 0. - set_zero_uint(result_uint64_count, result); - return; - } - if (result_uint64_count == 1) - { - *result = *operand1 * *operand2; - return; - } - - // In some cases these improve performance. - operand1_uint64_count = get_significant_uint64_count_uint( - operand1, operand1_uint64_count); - operand2_uint64_count = get_significant_uint64_count_uint( - operand2, operand2_uint64_count); - - // More fast cases - if (operand1_uint64_count == 1) - { - multiply_uint_uint64(operand2, operand2_uint64_count, - *operand1, result_uint64_count, result); - return; - } - if (operand2_uint64_count == 1) - { - multiply_uint_uint64(operand1, operand1_uint64_count, - *operand2, result_uint64_count, result); - return; - } - - // Clear out result. - set_zero_uint(result_uint64_count, result); - - // Multiply operand1 and operand2. - size_t operand1_index_max = min(operand1_uint64_count, - result_uint64_count); - for (size_t operand1_index = 0; - operand1_index < operand1_index_max; operand1_index++) - { - const uint64_t *inner_operand2 = operand2; - uint64_t *inner_result = result++; - uint64_t carry = 0; - size_t operand2_index = 0; - size_t operand2_index_max = min(operand2_uint64_count, - result_uint64_count - operand1_index); - for (; operand2_index < operand2_index_max; operand2_index++) - { - // Perform 64-bit multiplication of operand1 and operand2 - unsigned long long temp_result[2]; - multiply_uint64(*operand1, *inner_operand2++, temp_result); - carry = temp_result[1] + add_uint64(temp_result[0], carry, 0, temp_result); - unsigned long long temp; - carry += add_uint64(*inner_result, temp_result[0], 0, &temp); - *inner_result++ = temp; - } - - // Write carry if there is room in result - if (operand1_index + operand2_index_max < result_uint64_count) - { - *inner_result = carry; - } - - operand1++; - } - } - - void multiply_uint_uint64(const uint64_t *operand1, - size_t operand1_uint64_count, uint64_t operand2, - size_t result_uint64_count, uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand1 && operand1_uint64_count > 0) - { - throw invalid_argument("operand1"); - } - if (!result_uint64_count) - { - throw invalid_argument("result_uint64_count"); - } - if (!result) - { - throw invalid_argument("result"); - } - if (result != nullptr && operand1 == result) - { - throw invalid_argument("result cannot point to the same value as operand1"); - } -#endif - // Handle fast cases. - if (!operand1_uint64_count || !operand2) - { - // If either operand is 0, then result is 0. - set_zero_uint(result_uint64_count, result); - return; - } - if (result_uint64_count == 1) - { - *result = *operand1 * operand2; - return; - } - - // More fast cases - //if (result_uint64_count == 2 && operand1_uint64_count > 1) - //{ - // unsigned long long temp_result; - // multiply_uint64(*operand1, operand2, &temp_result); - // *result = temp_result; - // *(result + 1) += *(operand1 + 1) * operand2; - // return; - //} - - // Clear out result. - set_zero_uint(result_uint64_count, result); - - // Multiply operand1 and operand2. - unsigned long long carry = 0; - size_t operand1_index_max = min(operand1_uint64_count, - result_uint64_count); - for (size_t operand1_index = 0; - operand1_index < operand1_index_max; operand1_index++) - { - unsigned long long temp_result[2]; - multiply_uint64(*operand1++, operand2, temp_result); - unsigned long long temp; - carry = temp_result[1] + add_uint64(temp_result[0], carry, 0, &temp); - *result++ = temp; - } - - // Write carry if there is room in result - if (operand1_index_max < result_uint64_count) - { - *result = carry; - } - } - - void divide_uint_uint_inplace(uint64_t *numerator, - const uint64_t *denominator, size_t uint64_count, - uint64_t *quotient, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (!numerator && uint64_count > 0) - { - throw invalid_argument("numerator"); - } - if (!denominator && uint64_count > 0) - { - throw invalid_argument("denominator"); - } - if (!quotient && uint64_count > 0) - { - throw invalid_argument("quotient"); - } - if (is_zero_uint(denominator, uint64_count) && uint64_count > 0) - { - throw invalid_argument("denominator"); - } - if (quotient && (numerator == quotient || denominator == quotient)) - { - throw invalid_argument("quotient cannot point to same value as numerator or denominator"); - } -#endif - if (!uint64_count) - { - return; - } - - // Clear quotient. Set it to zero. - set_zero_uint(uint64_count, quotient); - - // Determine significant bits in numerator and denominator. - int numerator_bits = - get_significant_bit_count_uint(numerator, uint64_count); - int denominator_bits = - get_significant_bit_count_uint(denominator, uint64_count); - - // If numerator has fewer bits than denominator, then done. - if (numerator_bits < denominator_bits) - { - return; - } - - // Only perform computation up to last non-zero uint64s. - uint64_count = safe_cast( - divide_round_up(numerator_bits, bits_per_uint64)); - - // Handle fast case. - if (uint64_count == 1) - { - *quotient = *numerator / *denominator; - *numerator -= *quotient * *denominator; - return; - } - - auto alloc_anchor(allocate_uint(uint64_count << 1, pool)); - - // Create temporary space to store mutable copy of denominator. - uint64_t *shifted_denominator = alloc_anchor.get(); - - // Create temporary space to store difference calculation. - uint64_t *difference = shifted_denominator + uint64_count; - - // Shift denominator to bring MSB in alignment with MSB of numerator. - int denominator_shift = numerator_bits - denominator_bits; - left_shift_uint(denominator, denominator_shift, uint64_count, - shifted_denominator); - denominator_bits += denominator_shift; - - // Perform bit-wise division algorithm. - int remaining_shifts = denominator_shift; - while (numerator_bits == denominator_bits) - { - // NOTE: MSBs of numerator and denominator are aligned. - - // Even though MSB of numerator and denominator are aligned, - // still possible numerator < shifted_denominator. - if (sub_uint_uint(numerator, shifted_denominator, - uint64_count, difference)) - { - // numerator < shifted_denominator and MSBs are aligned, - // so current quotient bit is zero and next one is definitely one. - if (remaining_shifts == 0) - { - // No shifts remain and numerator < denominator so done. - break; - } - - // Effectively shift numerator left by 1 by instead adding - // numerator to difference (to prevent overflow in numerator). - add_uint_uint(difference, numerator, uint64_count, difference); - - // Adjust quotient and remaining shifts as a result of - // shifting numerator. - left_shift_uint(quotient, 1, uint64_count, quotient); - remaining_shifts--; - } - // Difference is the new numerator with denominator subtracted. - - // Update quotient to reflect subtraction. - quotient[0] |= 1; - - // Determine amount to shift numerator to bring MSB in alignment - // with denominator. - numerator_bits = get_significant_bit_count_uint(difference, uint64_count); - int numerator_shift = denominator_bits - numerator_bits; - if (numerator_shift > remaining_shifts) - { - // Clip the maximum shift to determine only the integer - // (as opposed to fractional) bits. - numerator_shift = remaining_shifts; - } - - // Shift and update numerator. - if (numerator_bits > 0) - { - left_shift_uint(difference, numerator_shift, uint64_count, numerator); - numerator_bits += numerator_shift; - } - else - { - // Difference is zero so no need to shift, just set to zero. - set_zero_uint(uint64_count, numerator); - } - - // Adjust quotient and remaining shifts as a result of shifting numerator. - left_shift_uint(quotient, numerator_shift, uint64_count, quotient); - remaining_shifts -= numerator_shift; - } - - // Correct numerator (which is also the remainder) for shifting of - // denominator, unless it is just zero. - if (numerator_bits > 0) - { - right_shift_uint(numerator, denominator_shift, uint64_count, numerator); - } - } - - void divide_uint128_uint64_inplace_generic(uint64_t *numerator, - uint64_t denominator, uint64_t *quotient) - { -#ifdef SEAL_DEBUG - if (!numerator) - { - throw invalid_argument("numerator"); - } - if (denominator == 0) - { - throw invalid_argument("denominator"); - } - if (!quotient) - { - throw invalid_argument("quotient"); - } - if (numerator == quotient) - { - throw invalid_argument("quotient cannot point to same value as numerator"); - } -#endif - // We expect 129-bit input - constexpr size_t uint64_count = 2; - - // Clear quotient. Set it to zero. - quotient[0] = 0; - quotient[1] = 0; - - // Determine significant bits in numerator and denominator. - int numerator_bits = get_significant_bit_count_uint(numerator, uint64_count); - int denominator_bits = get_significant_bit_count(denominator); - - // If numerator has fewer bits than denominator, then done. - if (numerator_bits < denominator_bits) - { - return; - } - - // Create temporary space to store mutable copy of denominator. - uint64_t shifted_denominator[uint64_count]{ denominator, 0 }; - - // Create temporary space to store difference calculation. - uint64_t difference[uint64_count]{ 0, 0 }; - - // Shift denominator to bring MSB in alignment with MSB of numerator. - int denominator_shift = numerator_bits - denominator_bits; - - left_shift_uint128(shifted_denominator, denominator_shift, shifted_denominator); - denominator_bits += denominator_shift; - - // Perform bit-wise division algorithm. - int remaining_shifts = denominator_shift; - while (numerator_bits == denominator_bits) - { - // NOTE: MSBs of numerator and denominator are aligned. - - // Even though MSB of numerator and denominator are aligned, - // still possible numerator < shifted_denominator. - if (sub_uint_uint(numerator, shifted_denominator, uint64_count, difference)) - { - // numerator < shifted_denominator and MSBs are aligned, - // so current quotient bit is zero and next one is definitely one. - if (remaining_shifts == 0) - { - // No shifts remain and numerator < denominator so done. - break; - } - - // Effectively shift numerator left by 1 by instead adding - // numerator to difference (to prevent overflow in numerator). - add_uint_uint(difference, numerator, uint64_count, difference); - - // Adjust quotient and remaining shifts as a result of shifting numerator. - quotient[1] = (quotient[1] << 1) | (quotient[0] >> (bits_per_uint64 - 1)); - quotient[0] <<= 1; - remaining_shifts--; - } - // Difference is the new numerator with denominator subtracted. - - // Determine amount to shift numerator to bring MSB in alignment - // with denominator. - numerator_bits = get_significant_bit_count_uint(difference, uint64_count); - - // Clip the maximum shift to determine only the integer - // (as opposed to fractional) bits. - int numerator_shift = min(denominator_bits - numerator_bits, remaining_shifts); - - // Shift and update numerator. - // This may be faster; first set to zero and then update if needed - - // Difference is zero so no need to shift, just set to zero. - numerator[0] = 0; - numerator[1] = 0; - - if (numerator_bits > 0) - { - left_shift_uint128(difference, numerator_shift, numerator); - numerator_bits += numerator_shift; - } - - // Update quotient to reflect subtraction. - quotient[0] |= 1; - - // Adjust quotient and remaining shifts as a result of shifting numerator. - left_shift_uint128(quotient, numerator_shift, quotient); - remaining_shifts -= numerator_shift; - } - - // Correct numerator (which is also the remainder) for shifting of - // denominator, unless it is just zero. - if (numerator_bits > 0) - { - right_shift_uint128(numerator, denominator_shift, numerator); - } - } - - void divide_uint192_uint64_inplace(uint64_t *numerator, - uint64_t denominator, uint64_t *quotient) - { -#ifdef SEAL_DEBUG - if (!numerator) - { - throw invalid_argument("numerator"); - } - if (denominator == 0) - { - throw invalid_argument("denominator"); - } - if (!quotient) - { - throw invalid_argument("quotient"); - } - if (numerator == quotient) - { - throw invalid_argument("quotient cannot point to same value as numerator"); - } -#endif - // We expect 192-bit input - size_t uint64_count = 3; - - // Clear quotient. Set it to zero. - quotient[0] = 0; - quotient[1] = 0; - quotient[2] = 0; - - // Determine significant bits in numerator and denominator. - int numerator_bits = get_significant_bit_count_uint(numerator, uint64_count); - int denominator_bits = get_significant_bit_count(denominator); - - // If numerator has fewer bits than denominator, then done. - if (numerator_bits < denominator_bits) - { - return; - } - - // Only perform computation up to last non-zero uint64s. - uint64_count = safe_cast( - divide_round_up(numerator_bits, bits_per_uint64)); - - // Handle fast case. - if (uint64_count == 1) - { - *quotient = *numerator / denominator; - *numerator -= *quotient * denominator; - return; - } - - // Create temporary space to store mutable copy of denominator. - vector shifted_denominator(uint64_count, 0); - shifted_denominator[0] = denominator; - - // Create temporary space to store difference calculation. - vector difference(uint64_count); - - // Shift denominator to bring MSB in alignment with MSB of numerator. - int denominator_shift = numerator_bits - denominator_bits; - - left_shift_uint192(shifted_denominator.data(), denominator_shift, - shifted_denominator.data()); - denominator_bits += denominator_shift; - - // Perform bit-wise division algorithm. - int remaining_shifts = denominator_shift; - while (numerator_bits == denominator_bits) - { - // NOTE: MSBs of numerator and denominator are aligned. - - // Even though MSB of numerator and denominator are aligned, - // still possible numerator < shifted_denominator. - if (sub_uint_uint(numerator, shifted_denominator.data(), - uint64_count, difference.data())) - { - // numerator < shifted_denominator and MSBs are aligned, - // so current quotient bit is zero and next one is definitely one. - if (remaining_shifts == 0) - { - // No shifts remain and numerator < denominator so done. - break; - } - - // Effectively shift numerator left by 1 by instead adding - // numerator to difference (to prevent overflow in numerator). - add_uint_uint(difference.data(), numerator, uint64_count, difference.data()); - - // Adjust quotient and remaining shifts as a result of shifting numerator. - left_shift_uint192(quotient, 1, quotient); - remaining_shifts--; - } - // Difference is the new numerator with denominator subtracted. - - // Update quotient to reflect subtraction. - quotient[0] |= 1; - - // Determine amount to shift numerator to bring MSB in alignment with denominator. - numerator_bits = get_significant_bit_count_uint(difference.data(), uint64_count); - int numerator_shift = denominator_bits - numerator_bits; - if (numerator_shift > remaining_shifts) - { - // Clip the maximum shift to determine only the integer - // (as opposed to fractional) bits. - numerator_shift = remaining_shifts; - } - - // Shift and update numerator. - if (numerator_bits > 0) - { - left_shift_uint192(difference.data(), numerator_shift, numerator); - numerator_bits += numerator_shift; - } - else - { - // Difference is zero so no need to shift, just set to zero. - set_zero_uint(uint64_count, numerator); - } - - // Adjust quotient and remaining shifts as a result of shifting numerator. - left_shift_uint192(quotient, numerator_shift, quotient); - remaining_shifts -= numerator_shift; - } - - // Correct numerator (which is also the remainder) for shifting of - // denominator, unless it is just zero. - if (numerator_bits > 0) - { - right_shift_uint192(numerator, denominator_shift, numerator); - } - } - - void exponentiate_uint(const uint64_t *operand, - size_t operand_uint64_count, const uint64_t *exponent, - size_t exponent_uint64_count, size_t result_uint64_count, - uint64_t *result, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (!operand) - { - throw invalid_argument("operand"); - } - if (!operand_uint64_count) - { - throw invalid_argument("operand_uint64_count"); - } - if (!exponent) - { - throw invalid_argument("exponent"); - } - if (!exponent_uint64_count) - { - throw invalid_argument("exponent_uint64_count"); - } - if (!result) - { - throw invalid_argument("result"); - } - if (!result_uint64_count) - { - throw invalid_argument("result_uint64_count"); - } -#endif - // Fast cases - if (is_zero_uint(exponent, exponent_uint64_count)) - { - set_uint(1, result_uint64_count, result); - return; - } - if (is_equal_uint(exponent, exponent_uint64_count, 1)) - { - set_uint_uint(operand, operand_uint64_count, result_uint64_count, result); - return; - } - - // Need to make a copy of exponent - auto exponent_copy(allocate_uint(exponent_uint64_count, pool)); - set_uint_uint(exponent, exponent_uint64_count, exponent_copy.get()); - - // Perform binary exponentiation. - auto big_alloc(allocate_uint( - result_uint64_count + result_uint64_count + result_uint64_count, pool)); - - uint64_t *powerptr = big_alloc.get(); - uint64_t *productptr = powerptr + result_uint64_count; - uint64_t *intermediateptr = productptr + result_uint64_count; - - set_uint_uint(operand, operand_uint64_count, result_uint64_count, powerptr); - set_uint(1, result_uint64_count, intermediateptr); - - // Initially: power = operand and intermediate = 1, product is not initialized. - while (true) - { - if ((*exponent_copy.get() % 2) == 1) - { - multiply_truncate_uint_uint(powerptr, intermediateptr, - result_uint64_count, productptr); - swap(productptr, intermediateptr); - } - right_shift_uint(exponent_copy.get(), 1, exponent_uint64_count, - exponent_copy.get()); - if (is_zero_uint(exponent_copy.get(), exponent_uint64_count)) - { - break; - } - multiply_truncate_uint_uint(powerptr, powerptr, result_uint64_count, - productptr); - swap(productptr, powerptr); - } - set_uint_uint(intermediateptr, result_uint64_count, result); - } - - uint64_t exponentiate_uint64_safe(uint64_t operand, uint64_t exponent) - { - // Fast cases - if (exponent == 0) - { - return 1; - } - if (exponent == 1) - { - return operand; - } - - // Perform binary exponentiation. - uint64_t power = operand; - uint64_t product = 0; - uint64_t intermediate = 1; - - // Initially: power = operand and intermediate = 1, product irrelevant. - while (true) - { - if (exponent & 1) - { - product = mul_safe(power, intermediate); - swap(product, intermediate); - } - exponent >>= 1; - if (exponent == 0) - { - break; - } - product = mul_safe(power, power); - swap(product, power); - } - - return intermediate; - } - - uint64_t exponentiate_uint64(uint64_t operand, uint64_t exponent) - { - // Fast cases - if (exponent == 0) - { - return 1; - } - if (exponent == 1) - { - return operand; - } - - // Perform binary exponentiation. - uint64_t power = operand; - uint64_t product = 0; - uint64_t intermediate = 1; - - // Initially: power = operand and intermediate = 1, product irrelevant. - while (true) - { - if (exponent & 1) - { - product = power * intermediate; - swap(product, intermediate); - } - exponent >>= 1; - if (exponent == 0) - { - break; - } - product = power * power; - swap(product, power); - } - - return intermediate; - } - } -} diff --git a/SEAL/native/src/seal/util/uintarith.h b/SEAL/native/src/seal/util/uintarith.h deleted file mode 100644 index ee93cdb..0000000 --- a/SEAL/native/src/seal/util/uintarith.h +++ /dev/null @@ -1,963 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/pointer.h" -#include "seal/util/defines.h" - -namespace seal -{ - namespace util - { - template>> - SEAL_NODISCARD inline unsigned char add_uint64_generic( - T operand1, S operand2, unsigned char carry, - unsigned long long *result) - { -#ifdef SEAL_DEBUG - if (!result) - { - throw std::invalid_argument("result cannot be null"); - } -#endif - operand1 += operand2; - *result = operand1 + carry; - return (operand1 < operand2) || (~operand1 < carry); - } - - template>> - SEAL_NODISCARD inline unsigned char add_uint64( - T operand1, S operand2, unsigned char carry, - unsigned long long *result) - { - return SEAL_ADD_CARRY_UINT64(operand1, operand2, carry, result); - } - - template>> - SEAL_NODISCARD inline unsigned char add_uint64( - T operand1, S operand2, R *result) - { - *result = operand1 + operand2; - return static_cast(*result < operand1); - } - - inline unsigned char add_uint_uint( - const std::uint64_t *operand1, std::size_t operand1_uint64_count, - const std::uint64_t *operand2, std::size_t operand2_uint64_count, - unsigned char carry, - std::size_t result_uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand1_uint64_count) - { - throw std::invalid_argument("operand1_uint64_count"); - } - if (!operand2_uint64_count) - { - throw std::invalid_argument("operand2_uint64_count"); - } - if (!result_uint64_count) - { - throw std::invalid_argument("result_uint64_count"); - } - if (!operand1) - { - throw std::invalid_argument("operand1"); - } - if (!operand2) - { - throw std::invalid_argument("operand2"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - for (std::size_t i = 0; i < result_uint64_count; i++) - { - unsigned long long temp_result; - carry = add_uint64( - (i < operand1_uint64_count) ? *operand1++ : 0, - (i < operand2_uint64_count) ? *operand2++ : 0, - carry, &temp_result); - *result++ = temp_result; - } - return carry; - } - - inline unsigned char add_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!operand1) - { - throw std::invalid_argument("operand1"); - } - if (!operand2) - { - throw std::invalid_argument("operand2"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - // Unroll first iteration of loop. We assume uint64_count > 0. - unsigned char carry = add_uint64(*operand1++, *operand2++, result++); - - // Do the rest - for(; --uint64_count; operand1++, operand2++, result++) - { - unsigned long long temp_result; - carry = add_uint64(*operand1, *operand2, carry, &temp_result); - *result = temp_result; - } - return carry; - } - - inline unsigned char add_uint_uint64( - const std::uint64_t *operand1, std::uint64_t operand2, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!operand1) - { - throw std::invalid_argument("operand1"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - // Unroll first iteration of loop. We assume uint64_count > 0. - unsigned char carry = add_uint64(*operand1++, operand2, result++); - - // Do the rest - for(; --uint64_count; operand1++, result++) - { - unsigned long long temp_result; - carry = add_uint64(*operand1, std::uint64_t(0), carry, &temp_result); - *result = temp_result; - } - return carry; - } - - template>> - SEAL_NODISCARD inline unsigned char sub_uint64_generic( - T operand1, S operand2, - unsigned char borrow, unsigned long long *result) - { -#ifdef SEAL_DEBUG - if (!result) - { - throw std::invalid_argument("result cannot be null"); - } -#endif - auto diff = operand1 - operand2; - *result = diff - (borrow != 0); - return (diff > operand1) || (diff < borrow); - } - - template>> - SEAL_NODISCARD inline unsigned char sub_uint64( - T operand1, S operand2, - unsigned char borrow, unsigned long long *result) - { - return SEAL_SUB_BORROW_UINT64(operand1, operand2, borrow, result); - } - - template>> - SEAL_NODISCARD inline unsigned char sub_uint64( - T operand1, S operand2, R *result) - { - *result = operand1 - operand2; - return static_cast(operand2 > operand1); - } - - inline unsigned char sub_uint_uint( - const std::uint64_t *operand1, std::size_t operand1_uint64_count, - const std::uint64_t *operand2, std::size_t operand2_uint64_count, - unsigned char borrow, - std::size_t result_uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!result_uint64_count) - { - throw std::invalid_argument("result_uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - for (std::size_t i = 0; i < result_uint64_count; - i++, operand1++, operand2++, result++) - { - unsigned long long temp_result; - borrow = sub_uint64((i < operand1_uint64_count) ? *operand1 : 0, - (i < operand2_uint64_count) ? *operand2 : 0, borrow, &temp_result); - *result = temp_result; - } - return borrow; - } - - inline unsigned char sub_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!operand1) - { - throw std::invalid_argument("operand1"); - } - if (!operand2) - { - throw std::invalid_argument("operand2"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - // Unroll first iteration of loop. We assume uint64_count > 0. - unsigned char borrow = sub_uint64(*operand1++, *operand2++, result++); - - // Do the rest - for(; --uint64_count; operand1++, operand2++, result++) - { - unsigned long long temp_result; - borrow = sub_uint64(*operand1, *operand2, borrow, &temp_result); - *result = temp_result; - } - return borrow; - } - - inline unsigned char sub_uint_uint64( - const std::uint64_t *operand1, std::uint64_t operand2, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!operand1) - { - throw std::invalid_argument("operand1"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - // Unroll first iteration of loop. We assume uint64_count > 0. - unsigned char borrow = sub_uint64(*operand1++, operand2, result++); - - // Do the rest - for(; --uint64_count; operand1++, operand2++, result++) - { - unsigned long long temp_result; - borrow = sub_uint64(*operand1, std::uint64_t(0), borrow, &temp_result); - *result = temp_result; - } - return borrow; - } - - inline unsigned char increment_uint( - const std::uint64_t *operand, std::size_t uint64_count, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - return add_uint_uint64(operand, 1, uint64_count, result); - } - - inline unsigned char decrement_uint( - const std::uint64_t *operand, std::size_t uint64_count, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand && uint64_count > 0) - { - throw std::invalid_argument("operand"); - } - if (!result && uint64_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - return sub_uint_uint64(operand, 1, uint64_count, result); - } - - inline void negate_uint( - const std::uint64_t *operand, std::size_t uint64_count, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - // Negation is equivalent to inverting bits and adding 1. - unsigned char carry = add_uint64(~*operand++, std::uint64_t(1), result++); - for(; --uint64_count; operand++, result++) - { - unsigned long long temp_result; - carry = add_uint64( - ~*operand, std::uint64_t(0), carry, &temp_result); - *result = temp_result; - } - } - - inline void left_shift_uint(const std::uint64_t *operand, - int shift_amount, std::size_t uint64_count, std::uint64_t *result) - { - const std::size_t bits_per_uint64_sz = - static_cast(bits_per_uint64); -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (shift_amount < 0 || - unsigned_geq(shift_amount, - mul_safe(uint64_count, bits_per_uint64_sz))) - { - throw std::invalid_argument("shift_amount"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - // How many words to shift - std::size_t uint64_shift_amount = - static_cast(shift_amount) / bits_per_uint64_sz; - - // Shift words - for (std::size_t i = 0; i < uint64_count - uint64_shift_amount; i++) - { - result[uint64_count - i - 1] = operand[uint64_count - i - 1 - uint64_shift_amount]; - } - for (std::size_t i = uint64_count - uint64_shift_amount; i < uint64_count; i++) - { - result[uint64_count - i - 1] = 0; - } - - // How many bits to shift in addition - std::size_t bit_shift_amount = static_cast(shift_amount) - - (uint64_shift_amount * bits_per_uint64_sz); - - if (bit_shift_amount) - { - std::size_t neg_bit_shift_amount = bits_per_uint64_sz - bit_shift_amount; - - for (std::size_t i = uint64_count - 1; i > 0; i--) - { - result[i] = (result[i] << bit_shift_amount) | - (result[i - 1] >> neg_bit_shift_amount); - } - result[0] = result[0] << bit_shift_amount; - } - } - - inline void right_shift_uint(const std::uint64_t *operand, - int shift_amount, std::size_t uint64_count, std::uint64_t *result) - { - const std::size_t bits_per_uint64_sz = - static_cast(bits_per_uint64); -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (shift_amount < 0 || - unsigned_geq(shift_amount, - mul_safe(uint64_count, bits_per_uint64_sz))) - { - throw std::invalid_argument("shift_amount"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - // How many words to shift - std::size_t uint64_shift_amount = - static_cast(shift_amount) / bits_per_uint64_sz; - - // Shift words - for (std::size_t i = 0; i < uint64_count - uint64_shift_amount; i++) - { - result[i] = operand[i + uint64_shift_amount]; - } - for (std::size_t i = uint64_count - uint64_shift_amount; i < uint64_count; i++) - { - result[i] = 0; - } - - // How many bits to shift in addition - std::size_t bit_shift_amount = static_cast(shift_amount) - - (uint64_shift_amount * bits_per_uint64_sz); - - if (bit_shift_amount) - { - std::size_t neg_bit_shift_amount = bits_per_uint64_sz - bit_shift_amount; - - for (std::size_t i = 0; i < uint64_count - 1; i++) - { - result[i] = (result[i] >> bit_shift_amount) | - (result[i + 1] << neg_bit_shift_amount); - } - result[uint64_count - 1] = result[uint64_count - 1] >> bit_shift_amount; - } - } - - inline void left_shift_uint128( - const std::uint64_t *operand, int shift_amount, std::uint64_t *result) - { - const std::size_t bits_per_uint64_sz = - static_cast(bits_per_uint64); -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (shift_amount < 0 || - unsigned_geq(shift_amount, 2 * bits_per_uint64_sz)) - { - throw std::invalid_argument("shift_amount"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - const std::size_t shift_amount_sz = - static_cast(shift_amount); - - // Early return - if (shift_amount_sz & bits_per_uint64_sz) - { - result[1] = operand[0]; - result[0] = 0; - } - else - { - result[1] = operand[1]; - result[0] = operand[0]; - } - - // How many bits to shift in addition to word shift - std::size_t bit_shift_amount = shift_amount_sz & (bits_per_uint64_sz - 1); - - // Do we have a word shift - if (bit_shift_amount) - { - std::size_t neg_bit_shift_amount = bits_per_uint64_sz - bit_shift_amount; - - // Warning: if bit_shift_amount == 0 this is incorrect - result[1] = (result[1] << bit_shift_amount) | - (result[0] >> neg_bit_shift_amount); - result[0] = result[0] << bit_shift_amount; - } - } - - inline void right_shift_uint128( - const std::uint64_t *operand, int shift_amount, std::uint64_t *result) - { - const std::size_t bits_per_uint64_sz = - static_cast(bits_per_uint64); -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (shift_amount < 0 || - unsigned_geq(shift_amount, 2 * bits_per_uint64_sz)) - { - throw std::invalid_argument("shift_amount"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - const std::size_t shift_amount_sz = - static_cast(shift_amount); - - if (shift_amount_sz & bits_per_uint64_sz) - { - result[0] = operand[1]; - result[1] = 0; - } - else - { - result[1] = operand[1]; - result[0] = operand[0]; - } - - // How many bits to shift in addition to word shift - std::size_t bit_shift_amount = shift_amount_sz & (bits_per_uint64_sz - 1); - - if (bit_shift_amount) - { - std::size_t neg_bit_shift_amount = bits_per_uint64_sz - bit_shift_amount; - - // Warning: if bit_shift_amount == 0 this is incorrect - result[0] = (result[0] >> bit_shift_amount) | - (result[1] << neg_bit_shift_amount); - result[1] = result[1] >> bit_shift_amount; - } - } - - inline void left_shift_uint192( - const std::uint64_t *operand, int shift_amount, std::uint64_t *result) - { - const std::size_t bits_per_uint64_sz = - static_cast(bits_per_uint64); -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (shift_amount < 0 || - unsigned_geq(shift_amount, 3 * bits_per_uint64_sz)) - { - throw std::invalid_argument("shift_amount"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - const std::size_t shift_amount_sz = - static_cast(shift_amount); - - if (shift_amount_sz & (bits_per_uint64_sz << 1)) - { - result[2] = operand[0]; - result[1] = 0; - result[0] = 0; - } - else if (shift_amount_sz & bits_per_uint64_sz) - { - result[2] = operand[1]; - result[1] = operand[0]; - result[0] = 0; - } - else - { - result[2] = operand[2]; - result[1] = operand[1]; - result[0] = operand[0]; - } - - // How many bits to shift in addition to word shift - std::size_t bit_shift_amount = shift_amount_sz & (bits_per_uint64_sz - 1); - - if (bit_shift_amount) - { - std::size_t neg_bit_shift_amount = bits_per_uint64_sz - bit_shift_amount; - - // Warning: if bit_shift_amount == 0 this is incorrect - result[2] = (result[2] << bit_shift_amount) | - (result[1] >> neg_bit_shift_amount); - result[1] = (result[1] << bit_shift_amount) | - (result[0] >> neg_bit_shift_amount); - result[0] = result[0] << bit_shift_amount; - } - } - - inline void right_shift_uint192( - const std::uint64_t *operand, int shift_amount, std::uint64_t *result) - { - const std::size_t bits_per_uint64_sz = - static_cast(bits_per_uint64); -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (shift_amount < 0 || - unsigned_geq(shift_amount, 3 * bits_per_uint64_sz)) - { - throw std::invalid_argument("shift_amount"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - const std::size_t shift_amount_sz = - static_cast(shift_amount); - - if (shift_amount_sz & (bits_per_uint64_sz << 1)) - { - result[0] = operand[2]; - result[1] = 0; - result[2] = 0; - } - else if (shift_amount_sz & bits_per_uint64_sz) - { - result[0] = operand[1]; - result[1] = operand[2]; - result[2] = 0; - } - else - { - result[2] = operand[2]; - result[1] = operand[1]; - result[0] = operand[0]; - } - - // How many bits to shift in addition to word shift - std::size_t bit_shift_amount = shift_amount_sz & (bits_per_uint64_sz - 1); - - if (bit_shift_amount) - { - std::size_t neg_bit_shift_amount = bits_per_uint64_sz - bit_shift_amount; - - // Warning: if bit_shift_amount == 0 this is incorrect - result[0] = (result[0] >> bit_shift_amount) | - (result[1] << neg_bit_shift_amount); - result[1] = (result[1] >> bit_shift_amount) | - (result[2] << neg_bit_shift_amount); - result[2] = result[2] >> bit_shift_amount; - } - } - - inline void half_round_up_uint( - const std::uint64_t *operand, std::size_t uint64_count, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand && uint64_count > 0) - { - throw std::invalid_argument("operand"); - } - if (!result && uint64_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - if (!uint64_count) - { - return; - } - // Set result to (operand + 1) / 2. To prevent overflowing operand, right shift - // and then increment result if low-bit of operand was set. - bool low_bit_set = operand[0] & 1; - - for (std::size_t i = 0; i < uint64_count - 1; i++) - { - result[i] = (operand[i] >> 1) | (operand[i + 1] << (bits_per_uint64 - 1)); - } - result[uint64_count - 1] = operand[uint64_count - 1] >> 1; - - if (low_bit_set) - { - increment_uint(result, uint64_count, result); - } - } - - inline void not_uint( - const std::uint64_t *operand, std::size_t uint64_count, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand && uint64_count > 0) - { - throw std::invalid_argument("operand"); - } - if (!result && uint64_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - for (; uint64_count--; result++, operand++) - { - *result = ~*operand; - } - } - - inline void and_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand1 && uint64_count > 0) - { - throw std::invalid_argument("operand1"); - } - if (!operand2 && uint64_count > 0) - { - throw std::invalid_argument("operand2"); - } - if (!result && uint64_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - for (; uint64_count--; result++, operand1++, operand2++) - { - *result = *operand1 & *operand2; - } - } - - inline void or_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand1 && uint64_count > 0) - { - throw std::invalid_argument("operand1"); - } - if (!operand2 && uint64_count > 0) - { - throw std::invalid_argument("operand2"); - } - if (!result && uint64_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - for (; uint64_count--; result++, operand1++, operand2++) - { - *result = *operand1 | *operand2; - } - } - - inline void xor_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand1 && uint64_count > 0) - { - throw std::invalid_argument("operand1"); - } - if (!operand2 && uint64_count > 0) - { - throw std::invalid_argument("operand2"); - } - if (!result && uint64_count > 0) - { - throw std::invalid_argument("result"); - } -#endif - for (; uint64_count--; result++, operand1++, operand2++) - { - *result = *operand1 ^ *operand2; - } - } - - template>> - inline void multiply_uint64_generic( - T operand1, S operand2, unsigned long long *result128) - { -#ifdef SEAL_DEBUG - if (!result128) - { - throw std::invalid_argument("result128 cannot be null"); - } -#endif - auto operand1_coeff_right = operand1 & 0x00000000FFFFFFFF; - auto operand2_coeff_right = operand2 & 0x00000000FFFFFFFF; - operand1 >>= 32; - operand2 >>= 32; - - auto middle1 = operand1 * operand2_coeff_right; - T middle; - auto left = operand1 * operand2 + (static_cast(add_uint64( - middle1, operand2 * operand1_coeff_right, &middle)) << 32); - auto right = operand1_coeff_right * operand2_coeff_right; - auto temp_sum = (right >> 32) + (middle & 0x00000000FFFFFFFF); - - result128[1] = static_cast( - left + (middle >> 32) + (temp_sum >> 32)); - result128[0] = static_cast( - (temp_sum << 32) | (right & 0x00000000FFFFFFFF)); - } - - template>> - inline void multiply_uint64( - T operand1, S operand2, unsigned long long *result128) - { - SEAL_MULTIPLY_UINT64(operand1, operand2, result128); - } - - template>> - inline void multiply_uint64_hw64_generic( - T operand1, S operand2, unsigned long long *hw64) - { -#ifdef SEAL_DEBUG - if (!hw64) - { - throw std::invalid_argument("hw64 cannot be null"); - } -#endif - auto operand1_coeff_right = operand1 & 0x00000000FFFFFFFF; - auto operand2_coeff_right = operand2 & 0x00000000FFFFFFFF; - operand1 >>= 32; - operand2 >>= 32; - - auto middle1 = operand1 * operand2_coeff_right; - T middle; - auto left = operand1 * operand2 + (static_cast(add_uint64( - middle1, operand2 * operand1_coeff_right, &middle)) << 32); - auto right = operand1_coeff_right * operand2_coeff_right; - auto temp_sum = (right >> 32) + (middle & 0x00000000FFFFFFFF); - - *hw64 = static_cast( - left + (middle >> 32) + (temp_sum >> 32)); - } - - template>> - inline void multiply_uint64_hw64( - T operand1, S operand2, unsigned long long *hw64) - { - SEAL_MULTIPLY_UINT64_HW64(operand1, operand2, hw64); - } - - void multiply_uint_uint( - const std::uint64_t *operand1, std::size_t operand1_uint64_count, - const std::uint64_t *operand2, std::size_t operand2_uint64_count, - std::size_t result_uint64_count, std::uint64_t *result); - - inline void multiply_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count, std::uint64_t *result) - { - multiply_uint_uint(operand1, uint64_count, operand2, uint64_count, - uint64_count * 2, result); - } - - void multiply_uint_uint64( - const std::uint64_t *operand1, std::size_t operand1_uint64_count, - std::uint64_t operand2, std::size_t result_uint64_count, - std::uint64_t *result); - - inline void multiply_truncate_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count, std::uint64_t *result) - { - multiply_uint_uint(operand1, uint64_count, operand2, uint64_count, - uint64_count, result); - } - - void divide_uint_uint_inplace( - std::uint64_t *numerator, const std::uint64_t *denominator, - std::size_t uint64_count, std::uint64_t *quotient, MemoryPool &pool); - - inline void divide_uint_uint( - const std::uint64_t *numerator, const std::uint64_t *denominator, - std::size_t uint64_count, std::uint64_t *quotient, - std::uint64_t *remainder, MemoryPool &pool) - { - set_uint_uint(numerator, uint64_count, remainder); - divide_uint_uint_inplace(remainder, denominator, uint64_count, quotient, pool); - } - - void divide_uint128_uint64_inplace_generic( - std::uint64_t *numerator, std::uint64_t denominator, - std::uint64_t *quotient); - - inline void divide_uint128_uint64_inplace( - std::uint64_t *numerator, std::uint64_t denominator, - std::uint64_t *quotient) - { -#ifdef SEAL_DEBUG - if (!numerator) - { - throw std::invalid_argument("numerator"); - } - if (denominator == 0) - { - throw std::invalid_argument("denominator"); - } - if (!quotient) - { - throw std::invalid_argument("quotient"); - } - if (numerator == quotient) - { - throw std::invalid_argument("quotient cannot point to same value as numerator"); - } -#endif - SEAL_DIVIDE_UINT128_UINT64(numerator, denominator, quotient); - } - - void divide_uint128_uint64_inplace( - std::uint64_t *numerator, std::uint64_t denominator, - std::uint64_t *quotient); - - void divide_uint192_uint64_inplace( - std::uint64_t *numerator, std::uint64_t denominator, - std::uint64_t *quotient); - - void exponentiate_uint( - const std::uint64_t *operand, std::size_t operand_uint64_count, - const std::uint64_t *exponent, std::size_t exponent_uint64_count, - std::size_t result_uint64_count, std::uint64_t *result, - MemoryPool &pool); - - SEAL_NODISCARD std::uint64_t exponentiate_uint64_safe( - std::uint64_t operand, std::uint64_t exponent); - - SEAL_NODISCARD std::uint64_t exponentiate_uint64( - std::uint64_t operand, std::uint64_t exponent); - } -} diff --git a/SEAL/native/src/seal/util/uintarithmod.cpp b/SEAL/native/src/seal/util/uintarithmod.cpp deleted file mode 100644 index cb45852..0000000 --- a/SEAL/native/src/seal/util/uintarithmod.cpp +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithmod.h" -#include "seal/util/common.h" - -using namespace std; - -namespace seal -{ - namespace util - { - bool try_invert_uint_mod(const uint64_t *operand, - const uint64_t *modulus, size_t uint64_count, - uint64_t *result, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (!operand) - { - throw invalid_argument("operand"); - } - if (!modulus) - { - throw invalid_argument("modulus"); - } - if (!uint64_count) - { - throw invalid_argument("uint64_count"); - } - if (!result) - { - throw invalid_argument("result"); - } - if (is_greater_than_or_equal_uint_uint(operand, modulus, uint64_count)) - { - throw out_of_range("operand"); - } -#endif - // Cannot invert 0. - int bit_count = get_significant_bit_count_uint(operand, uint64_count); - if (bit_count == 0) - { - return false; - } - - // If it is 1, then its invert is itself. - if (bit_count == 1) - { - set_uint(1, uint64_count, result); - return true; - } - - auto alloc_anchor(allocate_uint(7 * uint64_count, pool)); - - // Construct a mutable copy of operand and modulus, with numerator being modulus - // and operand being denominator. Notice that numerator > denominator. - uint64_t *numerator = alloc_anchor.get(); - set_uint_uint(modulus, uint64_count, numerator); - - uint64_t *denominator = numerator + uint64_count; - set_uint_uint(operand, uint64_count, denominator); - - // Create space to store difference. - uint64_t *difference = denominator + uint64_count; - - // Determine highest bit index of each. - int numerator_bits = get_significant_bit_count_uint(numerator, uint64_count); - int denominator_bits = get_significant_bit_count_uint(denominator, uint64_count); - - // Create space to store quotient. - uint64_t *quotient = difference + uint64_count; - - // Create three sign/magnitude values to store coefficients. - // Initialize invert_prior to +0 and invert_curr to +1. - uint64_t *invert_prior = quotient + uint64_count; - set_zero_uint(uint64_count, invert_prior); - bool invert_prior_positive = true; - - uint64_t *invert_curr = invert_prior + uint64_count; - set_uint(1, uint64_count, invert_curr); - bool invert_curr_positive = true; - - uint64_t *invert_next = invert_curr + uint64_count; - bool invert_next_positive = true; - - // Perform extended Euclidean algorithm. - while (true) - { - // NOTE: Numerator is > denominator. - - // Only perform computation up to last non-zero uint64s. - size_t division_uint64_count = static_cast( - divide_round_up(numerator_bits, bits_per_uint64)); - - // Shift denominator to bring MSB in alignment with MSB of numerator. - int denominator_shift = numerator_bits - denominator_bits; - left_shift_uint(denominator, denominator_shift, - division_uint64_count, denominator); - denominator_bits += denominator_shift; - - // Clear quotient. - set_zero_uint(uint64_count, quotient); - - // Perform bit-wise division algorithm. - int remaining_shifts = denominator_shift; - while (numerator_bits == denominator_bits) - { - // NOTE: MSBs of numerator and denominator are aligned. - - // Even though MSB of numerator and denominator are aligned, - // still possible numerator < denominator. - if (sub_uint_uint(numerator, denominator, - division_uint64_count, difference)) - { - // numerator < denominator and MSBs are aligned, so current - // quotient bit is zero and next one is definitely one. - if (remaining_shifts == 0) - { - // No shifts remain and numerator < denominator so done. - break; - } - - // Effectively shift numerator left by 1 by instead adding - // numerator to difference (to prevent overflow in numerator). - add_uint_uint(difference, numerator, division_uint64_count, difference); - - // Adjust quotient and remaining shifts as a result of shifting numerator. - left_shift_uint(quotient, 1, division_uint64_count, quotient); - remaining_shifts--; - } - // Difference is the new numerator with denominator subtracted. - - // Update quotient to reflect subtraction. - *quotient |= 1; - - // Determine amount to shift numerator to bring MSB in alignment - // with denominator. - numerator_bits = - get_significant_bit_count_uint(difference, division_uint64_count); - int numerator_shift = denominator_bits - numerator_bits; - if (numerator_shift > remaining_shifts) - { - // Clip the maximum shift to determine only the integer - // (as opposed to fractional) bits. - numerator_shift = remaining_shifts; - } - - // Shift and update numerator. - if (numerator_bits > 0) - { - left_shift_uint(difference, numerator_shift, - division_uint64_count, numerator); - numerator_bits += numerator_shift; - } - else - { - // Difference is zero so no need to shift, just set to zero. - set_zero_uint(division_uint64_count, numerator); - } - - // Adjust quotient and remaining shifts as a result of - // shifting numerator. - left_shift_uint(quotient, numerator_shift, - division_uint64_count, quotient); - remaining_shifts -= numerator_shift; - } - - // Correct for shifting of denominator. - right_shift_uint(denominator, denominator_shift, - division_uint64_count, denominator); - denominator_bits -= denominator_shift; - - // We are done if remainder (which is stored in numerator) is zero. - if (numerator_bits == 0) - { - break; - } - - // Correct for shifting of denominator. - right_shift_uint(numerator, denominator_shift, - division_uint64_count, numerator); - numerator_bits -= denominator_shift; - - // Integrate quotient with invert coefficients. - // Calculate: invert_prior + -quotient * invert_curr - multiply_truncate_uint_uint(quotient, invert_curr, - uint64_count, invert_next); - invert_next_positive = !invert_curr_positive; - if (invert_prior_positive == invert_next_positive) - { - // If both sides of add have same sign, then simple add and - // do not need to worry about overflow due to known limits - // on the coefficients proved in the euclidean algorithm. - add_uint_uint(invert_prior, invert_next, uint64_count, invert_next); - } - else - { - // If both sides of add have opposite sign, then subtract - // and check for overflow. - uint64_t borrow = sub_uint_uint(invert_prior, invert_next, - uint64_count, invert_next); - if (borrow == 0) - { - // No borrow means |invert_prior| >= |invert_next|, - // so sign is same as invert_prior. - invert_next_positive = invert_prior_positive; - } - else - { - // Borrow means |invert prior| < |invert_next|, - // so sign is opposite of invert_prior. - invert_next_positive = !invert_prior_positive; - negate_uint(invert_next, uint64_count, invert_next); - } - } - - // Swap prior and curr, and then curr and next. - swap(invert_prior, invert_curr); - swap(invert_prior_positive, invert_curr_positive); - swap(invert_curr, invert_next); - swap(invert_curr_positive, invert_next_positive); - - // Swap numerator and denominator using pointer swings. - swap(numerator, denominator); - swap(numerator_bits, denominator_bits); - } - - if (!is_equal_uint(denominator, uint64_count, 1)) - { - // GCD is not one, so unable to find inverse. - return false; - } - - // Correct coefficient if negative by modulo. - if (!invert_curr_positive && !is_zero_uint(invert_curr, uint64_count)) - { - sub_uint_uint(modulus, invert_curr, uint64_count, invert_curr); - invert_curr_positive = true; - } - - // Set result. - set_uint_uint(invert_curr, uint64_count, result); - return true; - } - } -} diff --git a/SEAL/native/src/seal/util/uintarithmod.h b/SEAL/native/src/seal/util/uintarithmod.h deleted file mode 100644 index 8598ede..0000000 --- a/SEAL/native/src/seal/util/uintarithmod.h +++ /dev/null @@ -1,269 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include "seal/util/pointer.h" - -namespace seal -{ - namespace util - { - inline void increment_uint_mod( - const std::uint64_t *operand, const std::uint64_t *modulus, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (!modulus) - { - throw std::invalid_argument("modulus"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } - if (is_greater_than_or_equal_uint_uint(operand, modulus, uint64_count)) - { - throw std::invalid_argument("operand"); - } - if (modulus == result) - { - throw std::invalid_argument("result cannot point to the same value as modulus"); - } -#endif - unsigned char carry = increment_uint(operand, uint64_count, result); - if (carry || - is_greater_than_or_equal_uint_uint(result, modulus, uint64_count)) - { - sub_uint_uint(result, modulus, uint64_count, result); - } - } - - inline void decrement_uint_mod( - const std::uint64_t *operand, const std::uint64_t *modulus, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (!modulus) - { - throw std::invalid_argument("modulus"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } - if (is_greater_than_or_equal_uint_uint(operand, modulus, uint64_count)) - { - throw std::invalid_argument("operand"); - } - if (modulus == result) - { - throw std::invalid_argument("result cannot point to the same value as modulus"); - } -#endif - if (decrement_uint(operand, uint64_count, result)) - { - add_uint_uint(result, modulus, uint64_count, result); - } - } - - inline void negate_uint_mod( - const std::uint64_t *operand, const std::uint64_t *modulus, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (!modulus) - { - throw std::invalid_argument("modulus"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } - if (is_greater_than_or_equal_uint_uint(operand, modulus, uint64_count)) - { - throw std::invalid_argument("operand"); - } -#endif - if (is_zero_uint(operand, uint64_count)) - { - // Negation of zero is zero. - set_zero_uint(uint64_count, result); - } - else - { - // Otherwise, we know operand > 0 and < modulus so subtract modulus - operand. - sub_uint_uint(modulus, operand, uint64_count, result); - } - } - - inline void div2_uint_mod( - const std::uint64_t *operand, const std::uint64_t *modulus, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand) - { - throw std::invalid_argument("operand"); - } - if (!modulus) - { - throw std::invalid_argument("modulus"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } - if (!is_bit_set_uint(modulus, uint64_count, 0)) - { - throw std::invalid_argument("modulus"); - } - if (is_greater_than_or_equal_uint_uint(operand, modulus, uint64_count)) - { - throw std::invalid_argument("operand"); - } -#endif - if (*operand & 1) - { - unsigned char carry = add_uint_uint(operand, modulus, uint64_count, result); - right_shift_uint(result, 1, uint64_count, result); - if (carry) - { - set_bit_uint(result, uint64_count, - static_cast(uint64_count) * bits_per_uint64 - 1); - } - } - else - { - right_shift_uint(operand, 1, uint64_count, result); - } - } - - inline void add_uint_uint_mod( - const std::uint64_t *operand1, const std::uint64_t *operand2, - const std::uint64_t *modulus, std::size_t uint64_count, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand1) - { - throw std::invalid_argument("operand1"); - } - if (!operand2) - { - throw std::invalid_argument("operand2"); - } - if (!modulus) - { - throw std::invalid_argument("modulus"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } - if (is_greater_than_or_equal_uint_uint(operand1, modulus, uint64_count)) - { - throw std::invalid_argument("operand1"); - } - if (is_greater_than_or_equal_uint_uint(operand2, modulus, uint64_count)) - { - throw std::invalid_argument("operand2"); - } - if (modulus == result) - { - throw std::invalid_argument("result cannot point to the same value as modulus"); - } -#endif - unsigned char carry = add_uint_uint(operand1, operand2, uint64_count, result); - if (carry || - is_greater_than_or_equal_uint_uint(result, modulus, uint64_count)) - { - sub_uint_uint(result, modulus, uint64_count, result); - } - } - - inline void sub_uint_uint_mod( - const std::uint64_t *operand1, const std::uint64_t *operand2, - const std::uint64_t *modulus, std::size_t uint64_count, - std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!operand1) - { - throw std::invalid_argument("operand1"); - } - if (!operand2) - { - throw std::invalid_argument("operand2"); - } - if (!modulus) - { - throw std::invalid_argument("modulus"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } - if (is_greater_than_or_equal_uint_uint(operand1, modulus, uint64_count)) - { - throw std::invalid_argument("operand1"); - } - if (is_greater_than_or_equal_uint_uint(operand2, modulus, uint64_count)) - { - throw std::invalid_argument("operand2"); - } - if (modulus == result) - { - throw std::invalid_argument("result cannot point to the same value as modulus"); - } -#endif - if (sub_uint_uint(operand1, operand2, uint64_count, result)) - { - add_uint_uint(result, modulus, uint64_count, result); - } - } - - bool try_invert_uint_mod( - const std::uint64_t *operand, const std::uint64_t *modulus, - std::size_t uint64_count, std::uint64_t *result, MemoryPool &pool); - } -} diff --git a/SEAL/native/src/seal/util/uintarithsmallmod.cpp b/SEAL/native/src/seal/util/uintarithsmallmod.cpp deleted file mode 100644 index 0f420a2..0000000 --- a/SEAL/native/src/seal/util/uintarithsmallmod.cpp +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include "seal/util/uintarithmod.h" -#include "seal/util/uintarithsmallmod.h" -#include - -using namespace std; - -namespace seal -{ - namespace util - { - bool is_primitive_root(uint64_t root, uint64_t degree, - const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (modulus.bit_count() < 2) - { - throw invalid_argument("modulus"); - } - if (root >= modulus.value()) - { - throw out_of_range("operand"); - } - if (get_power_of_two(degree) < 1) - { - throw invalid_argument("degree must be a power of two and at least two"); - } -#endif - if (root == 0) - { - return false; - } - - // We check if root is a degree-th root of unity in integers modulo - // modulus, where degree is a power of two. - // It suffices to check that root^(degree/2) is -1 modulo modulus. - return exponentiate_uint_mod( - root, degree >> 1, modulus) == (modulus.value() - 1); - } - - bool try_primitive_root(uint64_t degree, const SmallModulus &modulus, - uint64_t &destination) - { -#ifdef SEAL_DEBUG - if (modulus.bit_count() < 2) - { - throw invalid_argument("modulus"); - } - if (get_power_of_two(degree) < 1) - { - throw invalid_argument("degree must be a power of two and at least two"); - } -#endif - // We need to divide modulus-1 by degree to get the size of the - // quotient group - uint64_t size_entire_group = modulus.value() - 1; - - // Compute size of quotient group - uint64_t size_quotient_group = size_entire_group / degree; - - // size_entire_group must be divisible by degree, or otherwise the - // primitive root does not exist in integers modulo modulus - if (size_entire_group - size_quotient_group * degree != 0) - { - return false; - } - - // For randomness - random_device rd; - - int attempt_counter = 0; - int attempt_counter_max = 100; - do - { - attempt_counter++; - - // Set destination to be a random number modulo modulus - destination = (static_cast(rd()) << 32) | - static_cast(rd()); - destination %= modulus.value(); - - // Raise the random number to power the size of the quotient - // to get rid of irrelevant part - destination = exponentiate_uint_mod( - destination, size_quotient_group, modulus); - } while (!is_primitive_root(destination, degree, modulus) && - (attempt_counter < attempt_counter_max)); - - return is_primitive_root(destination, degree, modulus); - } - - bool try_minimal_primitive_root(uint64_t degree, - const SmallModulus &modulus, uint64_t &destination) - { - uint64_t root; - if (!try_primitive_root(degree, modulus, root)) - { - return false; - } - uint64_t generator_sq = multiply_uint_uint_mod(root, root, modulus); - uint64_t current_generator = root; - - // destination is going to always contain the smallest generator found - for (size_t i = 0; i < degree; i++) - { - // If our current generator is strictly smaller than destination, - // update - if (current_generator < root) - { - root = current_generator; - } - - // Then move on to the next generator - current_generator = multiply_uint_uint_mod( - current_generator, generator_sq, modulus); - } - - destination = root; - return true; - } - - uint64_t exponentiate_uint_mod(uint64_t operand, uint64_t exponent, - const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (modulus.is_zero()) - { - throw invalid_argument("modulus"); - } - if (operand >= modulus.value()) - { - throw invalid_argument("operand"); - } -#endif - // Fast cases - if (exponent == 0) - { - // Result is supposed to be only one digit - return 1; - } - - if (exponent == 1) - { - return operand; - } - - // Perform binary exponentiation. - uint64_t power = operand; - uint64_t product = 0; - uint64_t intermediate = 1; - - // Initially: power = operand and intermediate = 1, product is irrelevant. - while (true) - { - if (exponent & 1) - { - product = multiply_uint_uint_mod(power, intermediate, modulus); - swap(product, intermediate); - } - exponent >>= 1; - if (exponent == 0) - { - break; - } - product = multiply_uint_uint_mod(power, power, modulus); - swap(product, power); - } - return intermediate; - } - - void divide_uint_uint_mod_inplace(uint64_t *numerator, - const SmallModulus &modulus, size_t uint64_count, - uint64_t *quotient, MemoryPool &pool) - { - // Handle base cases - if (uint64_count == 2) - { - divide_uint128_uint64_inplace(numerator, modulus.value(), quotient); - return; - } - else if(uint64_count == 1) - { - *numerator = *numerator % modulus.value(); - *quotient = *numerator / modulus.value(); - return; - } - else - { - // If uint64_count > 2. - // x = numerator = x1 * 2^128 + x2. - // 2^128 = A*value + B. - - auto x1_alloc(allocate_uint(uint64_count - 2 , pool)); - uint64_t *x1 = x1_alloc.get(); - uint64_t x2[2]; - auto quot_alloc(allocate_uint(uint64_count, pool)); - uint64_t *quot = quot_alloc.get(); - auto rem_alloc(allocate_uint(uint64_count, pool)); - uint64_t *rem = rem_alloc.get(); - set_uint_uint(numerator + 2, uint64_count - 2, x1); - set_uint_uint(numerator, 2, x2); // x2 = (num) % 2^128. - - multiply_uint_uint(x1, uint64_count - 2, &modulus.const_ratio()[0], 2, - uint64_count, quot); // x1*A. - multiply_uint_uint64(x1, uint64_count - 2, modulus.const_ratio()[2], - uint64_count - 1, rem); // x1*B - add_uint_uint(rem, uint64_count - 1, x2, 2, 0, uint64_count, rem); // x1*B + x2; - - size_t remainder_uint64_count = get_significant_uint64_count_uint(rem, uint64_count); - divide_uint_uint_mod_inplace(rem, modulus, remainder_uint64_count, quotient, pool); - add_uint_uint(quotient, quot, uint64_count, quotient); - *numerator = rem[0]; - - return; - } - } - - uint64_t steps_to_galois_elt(int steps, size_t coeff_count) - { - uint32_t n = safe_cast(coeff_count); - uint32_t m32 = mul_safe(n, uint32_t(2)); - uint64_t m = static_cast(m32); - - if (steps == 0) - { - return m - 1; - } - else - { - // Extract sign of steps. When steps is positive, the rotation - // is to the left; when steps is negative, it is to the right. - bool sign = steps < 0; - uint32_t pos_steps = safe_cast(abs(steps)); - - if (pos_steps >= (n >> 1)) - { - throw invalid_argument("step count too large"); - } - - pos_steps &= m32 - 1; - if (sign) - { - steps = safe_cast(n >> 1) - safe_cast(pos_steps); - } - else - { - steps = safe_cast(pos_steps); - } - - // Construct Galois element for row rotation - uint64_t gen = 3; - uint64_t galois_elt = 1; - while(steps--) - { - galois_elt *= gen; - galois_elt &= m - 1; - } - return galois_elt; - } - } - } -} diff --git a/SEAL/native/src/seal/util/uintarithsmallmod.h b/SEAL/native/src/seal/util/uintarithsmallmod.h deleted file mode 100644 index e43578e..0000000 --- a/SEAL/native/src/seal/util/uintarithsmallmod.h +++ /dev/null @@ -1,327 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include "seal/smallmodulus.h" -#include "seal/util/defines.h" -#include "seal/util/pointer.h" -#include "seal/util/numth.h" -#include "seal/util/uintarith.h" - -namespace seal -{ - namespace util - { - SEAL_NODISCARD inline std::uint64_t increment_uint_mod( - std::uint64_t operand, const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } - if (operand >= modulus.value()) - { - throw std::out_of_range("operand"); - } -#endif - operand++; - return operand - (modulus.value() & static_cast( - -static_cast(operand >= modulus.value()))); - } - - SEAL_NODISCARD inline std::uint64_t decrement_uint_mod( - std::uint64_t operand, const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } - if (operand >= modulus.value()) - { - throw std::out_of_range("operand"); - } -#endif - std::int64_t carry = (operand == 0); - return operand - 1 + (modulus.value() & - static_cast(-carry)); - } - - SEAL_NODISCARD inline std::uint64_t negate_uint_mod( - std::uint64_t operand, const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } - if (operand >= modulus.value()) - { - throw std::out_of_range("operand"); - } -#endif - std::int64_t non_zero = (operand != 0); - return (modulus.value() - operand) - & static_cast(-non_zero); - } - - SEAL_NODISCARD inline std::uint64_t div2_uint_mod( - std::uint64_t operand, const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } - if (operand >= modulus.value()) - { - throw std::out_of_range("operand"); - } -#endif - if (operand & 1) - { - unsigned long long temp; - int64_t carry = add_uint64(operand, modulus.value(), 0, &temp); - operand = temp >> 1; - if (carry) - { - return operand | (std::uint64_t(1) << (bits_per_uint64 - 1)); - } - return operand; - } - return operand >> 1; - } - - SEAL_NODISCARD inline std::uint64_t add_uint_uint_mod( - std::uint64_t operand1, std::uint64_t operand2, - const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } - if (operand1 >= modulus.value()) - { - throw std::out_of_range("operand1"); - } - if (operand2 >= modulus.value()) - { - throw std::out_of_range("operand2"); - } -#endif - // Sum of operands modulo SmallModulus can never wrap around 2^64 - operand1 += operand2; - return operand1 - (modulus.value() & static_cast( - -static_cast(operand1 >= modulus.value()))); - } - - SEAL_NODISCARD inline std::uint64_t sub_uint_uint_mod( - std::uint64_t operand1, std::uint64_t operand2, - const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } - - if (operand1 >= modulus.value()) - { - throw std::out_of_range("operand1"); - } - if (operand2 >= modulus.value()) - { - throw std::out_of_range("operand2"); - } -#endif - unsigned long long temp; - std::int64_t borrow = SEAL_SUB_BORROW_UINT64(operand1, operand2, 0, &temp); - return static_cast(temp) + - (modulus.value() & static_cast(-borrow)); - } - - template>> - SEAL_NODISCARD inline std::uint64_t barrett_reduce_128( - const T *input, const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (!input) - { - throw std::invalid_argument("input"); - } - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } -#endif - // Reduces input using base 2^64 Barrett reduction - // input allocation size must be 128 bits - - unsigned long long tmp1, tmp2[2], tmp3, carry; - const std::uint64_t *const_ratio = modulus.const_ratio().data(); - - // Multiply input and const_ratio - // Round 1 - multiply_uint64_hw64(input[0], const_ratio[0], &carry); - - multiply_uint64(input[0], const_ratio[1], tmp2); - tmp3 = tmp2[1] + add_uint64(tmp2[0], carry, 0, &tmp1); - - // Round 2 - multiply_uint64(input[1], const_ratio[0], tmp2); - carry = tmp2[1] + add_uint64(tmp1, tmp2[0], 0, &tmp1); - - // This is all we care about - tmp1 = input[1] * const_ratio[1] + tmp3 + carry; - - // Barrett subtraction - tmp3 = input[0] - tmp1 * modulus.value(); - - // One more subtraction is enough - return static_cast(tmp3) - - (modulus.value() & static_cast( - -static_cast(tmp3 >= modulus.value()))); - } - - template>> - SEAL_NODISCARD inline std::uint64_t barrett_reduce_63( - T input, const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } - if (input >> 63) - { - throw std::invalid_argument("input"); - } -#endif - // Reduces input using base 2^64 Barrett reduction - // input must be at most 63 bits - - unsigned long long tmp[2]; - const std::uint64_t *const_ratio = modulus.const_ratio().data(); - multiply_uint64(input, const_ratio[1], tmp); - - // Barrett subtraction - tmp[0] = input - tmp[1] * modulus.value(); - - // One more subtraction is enough - return static_cast(tmp[0]) - - (modulus.value() & static_cast( - -static_cast(tmp[0] >= modulus.value()))); - } - - SEAL_NODISCARD inline std::uint64_t multiply_uint_uint_mod( - std::uint64_t operand1, std::uint64_t operand2, - const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (modulus.is_zero()) - { - throw std::invalid_argument("modulus"); - } -#endif - unsigned long long z[2]; - multiply_uint64(operand1, operand2, z); - return barrett_reduce_128(z, modulus); - } - - inline void modulo_uint_inplace( - std::uint64_t *value, std::size_t value_uint64_count, - const SmallModulus &modulus) - { -#ifdef SEAL_DEBUG - if (!value && value_uint64_count > 0) - { - throw std::invalid_argument("value"); - } -#endif - if (value_uint64_count == 1) - { - value[0] %= modulus.value(); - return; - } - - // Starting from the top, reduce always 128-bit blocks - for (std::size_t i = value_uint64_count - 1; i--; ) - { - value[i] = barrett_reduce_128(value + i, modulus); - value[i + 1] = 0; - } - } - - SEAL_NODISCARD inline std::uint64_t modulo_uint( - const std::uint64_t *value, std::size_t value_uint64_count, - const SmallModulus &modulus, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (!value && value_uint64_count) - { - throw std::invalid_argument("value"); - } - if (!value_uint64_count) - { - throw std::invalid_argument("value_uint64_count"); - } -#endif - if (value_uint64_count == 1) - { - // If value < modulus no operation is needed - return *value % modulus.value(); - } - - auto value_copy(allocate_uint(value_uint64_count, pool)); - set_uint_uint(value, value_uint64_count, value_copy.get()); - - // Starting from the top, reduce always 128-bit blocks - for (std::size_t i = value_uint64_count - 1; i--; ) - { - value_copy[i] = barrett_reduce_128(value_copy.get() + i, modulus); - } - - return value_copy[0]; - } - - inline bool try_invert_uint_mod( - std::uint64_t operand, const SmallModulus &modulus, - std::uint64_t &result) - { - return try_mod_inverse(operand, modulus.value(), result); - } - - bool is_primitive_root( - std::uint64_t root, std::uint64_t degree, - const SmallModulus &prime_modulus); - - // Try to find a primitive degree-th root of unity modulo small prime - // modulus, where degree must be a power of two. - bool try_primitive_root( - std::uint64_t degree, const SmallModulus &prime_modulus, - std::uint64_t &destination); - - // Try to find the smallest (as integer) primitive degree-th root of - // unity modulo small prime modulus, where degree must be a power of two. - bool try_minimal_primitive_root( - std::uint64_t degree, const SmallModulus &prime_modulus, - std::uint64_t &destination); - - SEAL_NODISCARD std::uint64_t exponentiate_uint_mod( - std::uint64_t operand, std::uint64_t exponent, - const SmallModulus &modulus); - - void divide_uint_uint_mod_inplace( - std::uint64_t *numerator, const SmallModulus &modulus, - std::size_t uint64_count, std::uint64_t *quotient, - MemoryPool &pool); - - SEAL_NODISCARD std::uint64_t steps_to_galois_elt( - int steps, std::size_t coeff_count); - } -} diff --git a/SEAL/native/src/seal/util/uintcore.cpp b/SEAL/native/src/seal/util/uintcore.cpp deleted file mode 100644 index 62cbbab..0000000 --- a/SEAL/native/src/seal/util/uintcore.cpp +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarith.h" -#include -#include - -using namespace std; - -namespace seal -{ - namespace util - { - string uint_to_hex_string(const uint64_t *value, size_t uint64_count) - { -#ifdef SEAL_DEBUG - if (uint64_count && !value) - { - throw invalid_argument("value"); - } -#endif - // Start with a string with a zero for each nibble in the array. - size_t num_nibbles = - mul_safe(uint64_count, static_cast(nibbles_per_uint64)); - string output(num_nibbles, '0'); - - // Iterate through each uint64 in array and set string with correct nibbles in hex. - size_t nibble_index = num_nibbles; - size_t leftmost_non_zero_pos = num_nibbles; - for (size_t i = 0; i < uint64_count; i++) - { - uint64_t part = *value++; - - // Iterate through each nibble in the current uint64. - for (size_t j = 0; j < nibbles_per_uint64; j++) - { - size_t nibble = safe_cast(part & uint64_t(0x0F)); - size_t pos = --nibble_index; - if (nibble != 0) - { - // If nibble is not zero, then update string and save this pos to determine - // number of leading zeros. - output[pos] = nibble_to_upper_hex(static_cast(nibble)); - leftmost_non_zero_pos = pos; - } - part >>= 4; - } - } - - // Trim string to remove leading zeros. - output.erase(0, leftmost_non_zero_pos); - - // Return 0 if nothing remains. - if (output.empty()) - { - return string("0"); - } - - return output; - } - - string uint_to_dec_string(const uint64_t *value, - size_t uint64_count, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (uint64_count && !value) - { - throw invalid_argument("value"); - } -#endif - if (!uint64_count) - { - return string("0"); - } - auto remainder(allocate_uint(uint64_count, pool)); - auto quotient(allocate_uint(uint64_count, pool)); - auto base(allocate_uint(uint64_count, pool)); - uint64_t *remainderptr = remainder.get(); - uint64_t *quotientptr = quotient.get(); - uint64_t *baseptr = base.get(); - set_uint(10, uint64_count, baseptr); - set_uint_uint(value, uint64_count, remainderptr); - string output; - while (!is_zero_uint(remainderptr, uint64_count)) - { - divide_uint_uint_inplace(remainderptr, baseptr, - uint64_count, quotientptr, pool); - char digit = static_cast( - remainderptr[0] + static_cast('0')); - output += digit; - swap(remainderptr, quotientptr); - } - reverse(output.begin(), output.end()); - - // Return 0 if nothing remains. - if (output.empty()) - { - return string("0"); - } - - return output; - } - } -} diff --git a/SEAL/native/src/seal/util/uintcore.h b/SEAL/native/src/seal/util/uintcore.h deleted file mode 100644 index 6bc0335..0000000 --- a/SEAL/native/src/seal/util/uintcore.h +++ /dev/null @@ -1,698 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include -#include -#include -#include -#include -#include "seal/util/common.h" -#include "seal/util/pointer.h" -#include "seal/util/defines.h" - -namespace seal -{ - namespace util - { - SEAL_NODISCARD std::string uint_to_hex_string( - const std::uint64_t *value, std::size_t uint64_count); - - SEAL_NODISCARD std::string uint_to_dec_string( - const std::uint64_t *value, std::size_t uint64_count, - MemoryPool &pool); - - inline void hex_string_to_uint(const char *hex_string, - int char_count, std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!hex_string && char_count > 0) - { - throw std::invalid_argument("hex_string"); - } - if (uint64_count && !result) - { - throw std::invalid_argument("result"); - } - if (unsigned_gt(get_hex_string_bit_count(hex_string, char_count), - mul_safe(uint64_count, static_cast(bits_per_uint64)))) - { - throw std::invalid_argument("hex_string"); - } -#endif - const char *hex_string_ptr = hex_string + char_count; - for (std::size_t uint64_index = 0; - uint64_index < uint64_count; uint64_index++) - { - std::uint64_t value = 0; - for (int bit_index = 0; bit_index < bits_per_uint64; - bit_index += bits_per_nibble) - { - if (hex_string_ptr == hex_string) - { - break; - } - char hex = *--hex_string_ptr; - int nibble = hex_to_nibble(hex); - if (nibble == -1) - { - throw std::invalid_argument("hex_value"); - } - value |= static_cast(nibble) << bit_index; - } - result[uint64_index] = value; - } - } - - SEAL_NODISCARD inline auto allocate_uint( - std::size_t uint64_count, MemoryPool &pool) - { - return allocate(uint64_count, pool); - } - - inline void set_zero_uint(std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!result && uint64_count) - { - throw std::invalid_argument("result"); - } -#endif - std::fill_n(result, uint64_count, std::uint64_t(0)); - } - - SEAL_NODISCARD inline auto allocate_zero_uint( - std::size_t uint64_count, MemoryPool &pool) - { - auto result(allocate_uint(uint64_count, pool)); - set_zero_uint(uint64_count, result.get()); - return result; - - // The following looks better but seems to yield worse results. - // return allocate(uint64_count, pool, std::uint64_t(0)); - } - - inline void set_uint( - std::uint64_t value, std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (!result) - { - throw std::invalid_argument("result"); - } -#endif - *result++ = value; - for (; --uint64_count; result++) - { - *result = 0; - } - } - - inline void set_uint_uint(const std::uint64_t *value, - std::size_t uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!value && uint64_count) - { - throw std::invalid_argument("value"); - } - if (!result && uint64_count) - { - throw std::invalid_argument("result"); - } -#endif - if ((value == result) || !uint64_count) - { - return; - } - std::copy_n(value, uint64_count, result); - } - - SEAL_NODISCARD inline bool is_zero_uint( - const std::uint64_t *value, std::size_t uint64_count) - { -#ifdef SEAL_DEBUG - if (!value && uint64_count) - { - throw std::invalid_argument("value"); - } -#endif - return std::all_of(value, value + uint64_count, - [](auto coeff) -> bool { return !coeff; }); - } - - SEAL_NODISCARD inline bool is_equal_uint( - const std::uint64_t *value, std::size_t uint64_count, - std::uint64_t scalar) - { -#ifdef SEAL_DEBUG - if (!value) - { - throw std::invalid_argument("value"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } -#endif - if (*value++ != scalar) - { - return false; - } - return std::all_of(value, value + uint64_count - 1, - [](auto coeff) -> bool { return !coeff; }); - } - - SEAL_NODISCARD inline bool is_high_bit_set_uint( - const std::uint64_t *value, std::size_t uint64_count) - { -#ifdef SEAL_DEBUG - if (!value) - { - throw std::invalid_argument("value"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } -#endif - return (value[uint64_count - 1] >> (bits_per_uint64 - 1)) != 0; - } -#ifndef SEAL_USE_MAYBE_UNUSED -#if (SEAL_COMPILER == SEAL_COMPILER_GCC) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#elif (SEAL_COMPILER == SEAL_COMPILER_CLANG) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wunused-parameter" -#endif -#endif - SEAL_NODISCARD inline bool is_bit_set_uint( - const std::uint64_t *value, - std::size_t uint64_count SEAL_MAYBE_UNUSED, int bit_index) - { -#ifdef SEAL_DEBUG - if (!value) - { - throw std::invalid_argument("value"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (bit_index < 0 || - static_cast(bit_index) >= - static_cast(uint64_count) * bits_per_uint64) - { - throw std::invalid_argument("bit_index"); - } -#endif - int uint64_index = bit_index / bits_per_uint64; - int sub_bit_index = bit_index - uint64_index * bits_per_uint64; - return ((value[static_cast(uint64_index)] - >> sub_bit_index) & 1) != 0; - } - - inline void set_bit_uint(std::uint64_t *value, - std::size_t uint64_count SEAL_MAYBE_UNUSED, int bit_index) - { -#ifdef SEAL_DEBUG - if (!value) - { - throw std::invalid_argument("value"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } - if (bit_index < 0 || - static_cast(bit_index) >= - static_cast(uint64_count) * bits_per_uint64) - { - throw std::invalid_argument("bit_index"); - } -#endif - int uint64_index = bit_index / bits_per_uint64; - int sub_bit_index = bit_index % bits_per_uint64; - value[static_cast(uint64_index)] |= - std::uint64_t(1) << sub_bit_index; - } -#ifndef SEAL_USE_MAYBE_UNUSED -#if (SEAL_COMPILER == SEAL_COMPILER_GCC) -#pragma GCC diagnostic pop -#elif (SEAL_COMPILER == SEAL_COMPILER_CLANG) -#pragma clang diagnostic pop -#endif -#endif - SEAL_NODISCARD inline int get_significant_bit_count_uint( - const std::uint64_t *value, std::size_t uint64_count) - { -#ifdef SEAL_DEBUG - if (!value && uint64_count) - { - throw std::invalid_argument("value"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } -#endif - if (!uint64_count) - { - return 0; - } - - value += uint64_count - 1; - for (; *value == 0 && uint64_count > 1; uint64_count--) - { - value--; - } - - return static_cast(uint64_count - 1) * bits_per_uint64 + - get_significant_bit_count(*value); - } - - SEAL_NODISCARD inline std::size_t get_significant_uint64_count_uint( - const std::uint64_t *value, std::size_t uint64_count) - { -#ifdef SEAL_DEBUG - if (!value && uint64_count) - { - throw std::invalid_argument("value"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } -#endif - value += uint64_count - 1; - for (; uint64_count && !*value; uint64_count--) - { - value--; - } - - return uint64_count; - } - - SEAL_NODISCARD inline std::size_t get_nonzero_uint64_count_uint( - const std::uint64_t *value, std::size_t uint64_count) - { -#ifdef SEAL_DEBUG - if (!value && uint64_count) - { - throw std::invalid_argument("value"); - } - if (!uint64_count) - { - throw std::invalid_argument("uint64_count"); - } -#endif - std::size_t nonzero_count = uint64_count; - - value += uint64_count - 1; - for (; uint64_count; uint64_count--) - { - if (*value-- == 0) - { - nonzero_count--; - } - } - - return nonzero_count; - } - - inline void set_uint_uint(const std::uint64_t *value, - std::size_t value_uint64_count, - std::size_t result_uint64_count, std::uint64_t *result) - { -#ifdef SEAL_DEBUG - if (!value && value_uint64_count) - { - throw std::invalid_argument("value"); - } - if (!result && result_uint64_count) - { - throw std::invalid_argument("result"); - } -#endif - if (value == result || !value_uint64_count) - { - // Fast path to handle self assignment. - std::fill(result + value_uint64_count, - result + result_uint64_count, std::uint64_t(0)); - } - else - { - std::size_t min_uint64_count = - std::min(value_uint64_count, result_uint64_count); - std::copy_n(value, min_uint64_count, result); - std::fill(result + min_uint64_count, - result + result_uint64_count, std::uint64_t(0)); - } - } - - SEAL_NODISCARD inline int get_power_of_two(std::uint64_t value) - { - if (value == 0 || (value & (value - 1)) != 0) - { - return -1; - } - - unsigned long result = 0; - SEAL_MSB_INDEX_UINT64(&result, value); - return static_cast(result); - } - - SEAL_NODISCARD inline int get_power_of_two_minus_one( - std::uint64_t value) - { - if (value == 0xFFFFFFFFFFFFFFFF) - { - return bits_per_uint64; - } - return get_power_of_two(value + 1); - } - - SEAL_NODISCARD inline int get_power_of_two_uint( - const std::uint64_t *operand, std::size_t uint64_count) - { -#ifdef SEAL_DEBUG - if (!operand && uint64_count) - { - throw std::invalid_argument("operand"); - } -#endif - operand += uint64_count; - int long_index = safe_cast(uint64_count), local_result = -1; - for (; (long_index >= 1) && (local_result == -1); long_index--) - { - operand--; - local_result = get_power_of_two(*operand); - } - - // If local_result != -1, we've found a power-of-two highest order block, - // in which case need to check that rest are zero. - // If local_result == -1, operand is not power of two. - if (local_result == -1) - { - return -1; - } - - int zeros = 1; - for (int j = long_index; j >= 1; j--) - { - zeros &= (*--operand == 0); - } - - return add_safe(mul_safe(zeros, - add_safe(local_result, - mul_safe(long_index, bits_per_uint64))), zeros, -1); - } - - SEAL_NODISCARD inline int get_power_of_two_minus_one_uint( - const std::uint64_t *operand, std::size_t uint64_count) - { -#ifdef SEAL_DEBUG - if (!operand && uint64_count) - { - throw std::invalid_argument("operand"); - } - if (unsigned_geq(uint64_count, std::numeric_limits::max())) - { - throw std::invalid_argument("uint64_count"); - } -#endif - operand += uint64_count; - int long_index = safe_cast(uint64_count), local_result = 0; - for (; (long_index >= 1) && (local_result == 0); long_index--) - { - operand--; - local_result = get_power_of_two_minus_one(*operand); - } - - // If local_result != -1, we've found a power-of-two-minus-one highest - // order block, in which case need to check that rest are ~0. - // If local_result == -1, operand is not power of two minus one. - if (local_result == -1) - { - return -1; - } - - int ones = 1; - for (int j = long_index; j >= 1; j--) - { - ones &= (~*--operand == 0); - } - - return add_safe(mul_safe(ones, - add_safe(local_result, - mul_safe(long_index, bits_per_uint64))), ones, -1); - } - - inline void filter_highbits_uint(std::uint64_t *operand, - std::size_t uint64_count, int bit_count) - { - std::size_t bits_per_uint64_sz = - static_cast(bits_per_uint64); -#ifdef SEAL_DEBUG - if (!operand && uint64_count) - { - throw std::invalid_argument("operand"); - } - if (bit_count < 0 || unsigned_gt(bit_count, - mul_safe(uint64_count, bits_per_uint64_sz))) - { - throw std::invalid_argument("bit_count"); - } -#endif - if (unsigned_eq(bit_count, mul_safe(uint64_count, bits_per_uint64_sz))) - { - return; - } - int uint64_index = bit_count / bits_per_uint64; - int subbit_index = bit_count - uint64_index * bits_per_uint64; - operand += uint64_index; - *operand++ &= (std::uint64_t(1) << subbit_index) - 1; - for (int long_index = uint64_index + 1; - unsigned_lt(long_index, uint64_count); long_index++) - { - *operand++ = 0; - } - } - - SEAL_NODISCARD inline auto duplicate_uint_if_needed( - const std::uint64_t *input, std::size_t uint64_count, - std::size_t new_uint64_count, bool force, MemoryPool &pool) - { -#ifdef SEAL_DEBUG - if (!input && uint64_count) - { - throw std::invalid_argument("uint"); - } -#endif - if (!force && uint64_count >= new_uint64_count) - { - return ConstPointer::Aliasing(input); - } - - auto allocation(allocate(new_uint64_count, pool)); - set_uint_uint(input, uint64_count, new_uint64_count, allocation.get()); - return ConstPointer(std::move(allocation)); - } - - SEAL_NODISCARD inline int compare_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count) - { -#ifdef SEAL_DEBUG - if (!operand1 && uint64_count) - { - throw std::invalid_argument("operand1"); - } - if (!operand2 && uint64_count) - { - throw std::invalid_argument("operand2"); - } -#endif - int result = 0; - operand1 += uint64_count - 1; - operand2 += uint64_count - 1; - - for (; (result == 0) && uint64_count--; operand1--, operand2--) - { - result = (*operand1 > *operand2) - (*operand1 < *operand2); - } - return result; - } - - SEAL_NODISCARD inline int compare_uint_uint( - const std::uint64_t *operand1, std::size_t operand1_uint64_count, - const std::uint64_t *operand2, std::size_t operand2_uint64_count) - { -#ifdef SEAL_DEBUG - if (!operand1 && operand1_uint64_count) - { - throw std::invalid_argument("operand1"); - } - if (!operand2 && operand2_uint64_count) - { - throw std::invalid_argument("operand2"); - } -#endif - int result = 0; - operand1 += operand1_uint64_count - 1; - operand2 += operand2_uint64_count - 1; - - std::size_t min_uint64_count = - std::min(operand1_uint64_count, operand2_uint64_count); - - operand1_uint64_count -= min_uint64_count; - for (; (result == 0) && operand1_uint64_count--; operand1--) - { - result = (*operand1 > 0); - } - - operand2_uint64_count -= min_uint64_count; - for (; (result == 0) && operand2_uint64_count--; operand2--) - { - result = -(*operand2 > 0); - } - - for (; (result == 0) && min_uint64_count--; operand1--, operand2--) - { - result = (*operand1 > *operand2) - (*operand1 < *operand2); - } - return result; - } - - SEAL_NODISCARD inline bool is_greater_than_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count) - { - return compare_uint_uint(operand1, operand2, uint64_count) > 0; - } - - SEAL_NODISCARD inline bool is_greater_than_or_equal_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count) - { - return compare_uint_uint(operand1, operand2, uint64_count) >= 0; - } - - SEAL_NODISCARD inline bool is_less_than_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count) - { - return compare_uint_uint(operand1, operand2, uint64_count) < 0; - } - - SEAL_NODISCARD inline bool is_less_than_or_equal_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count) - { - return compare_uint_uint(operand1, operand2, uint64_count) <= 0; - } - - SEAL_NODISCARD inline bool is_equal_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count) - { - return compare_uint_uint(operand1, operand2, uint64_count) == 0; - } - - SEAL_NODISCARD inline bool is_not_equal_uint_uint( - const std::uint64_t *operand1, const std::uint64_t *operand2, - std::size_t uint64_count) - { - return compare_uint_uint(operand1, operand2, uint64_count) != 0; - } - - SEAL_NODISCARD inline bool is_greater_than_uint_uint( - const std::uint64_t *operand1, std::size_t operand1_uint64_count, - const std::uint64_t *operand2, std::size_t operand2_uint64_count) - { - return compare_uint_uint(operand1, operand1_uint64_count, operand2, - operand2_uint64_count) > 0; - } - - SEAL_NODISCARD inline bool is_greater_than_or_equal_uint_uint( - const std::uint64_t *operand1, std::size_t operand1_uint64_count, - const std::uint64_t *operand2, std::size_t operand2_uint64_count) - { - return compare_uint_uint(operand1, operand1_uint64_count, operand2, - operand2_uint64_count) >= 0; - } - - SEAL_NODISCARD inline bool is_less_than_uint_uint( - const std::uint64_t *operand1, std::size_t operand1_uint64_count, - const std::uint64_t *operand2, std::size_t operand2_uint64_count) - { - return compare_uint_uint(operand1, operand1_uint64_count, operand2, - operand2_uint64_count) < 0; - } - - SEAL_NODISCARD inline bool is_less_than_or_equal_uint_uint( - const std::uint64_t *operand1, std::size_t operand1_uint64_count, - const std::uint64_t *operand2, std::size_t operand2_uint64_count) - { - return compare_uint_uint(operand1, operand1_uint64_count, operand2, - operand2_uint64_count) <= 0; - } - - SEAL_NODISCARD inline bool is_equal_uint_uint( - const std::uint64_t *operand1, std::size_t operand1_uint64_count, - const std::uint64_t *operand2, std::size_t operand2_uint64_count) - { - return compare_uint_uint(operand1, operand1_uint64_count, operand2, - operand2_uint64_count) == 0; - } - - SEAL_NODISCARD inline bool is_not_equal_uint_uint( - const std::uint64_t *operand1, std::size_t operand1_uint64_count, - const std::uint64_t *operand2, std::size_t operand2_uint64_count) - { - return compare_uint_uint(operand1, operand1_uint64_count, operand2, - operand2_uint64_count) != 0; - } - - SEAL_NODISCARD inline std::uint64_t hamming_weight( - std::uint64_t value) - { - std::uint64_t res = 0; - while (value) - { - res++; - value &= value - 1; - } - return res; - } - - SEAL_NODISCARD inline std::uint64_t hamming_weight_split( - std::uint64_t value) - { - std::uint64_t hwx = hamming_weight(value); - std::uint64_t target = (hwx + 1) >> 1; - std::uint64_t now = 0; - std::uint64_t result = 0; - - for (int i = 0; i < bits_per_uint64; i++) - { - std::uint64_t xbit = value & 1; - value = value >> 1; - now += xbit; - result += (xbit << i); - - if (now >= target) - { - break; - } - } - return result; - } - } -} diff --git a/SEAL/native/src/seal/valcheck.cpp b/SEAL/native/src/seal/valcheck.cpp deleted file mode 100644 index 4e4d128..0000000 --- a/SEAL/native/src/seal/valcheck.cpp +++ /dev/null @@ -1,426 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "seal/valcheck.h" -#include "seal/util/defines.h" -#include "seal/util/common.h" -#include "seal/plaintext.h" -#include "seal/ciphertext.h" -#include "seal/secretkey.h" -#include "seal/publickey.h" -#include "seal/kswitchkeys.h" -#include "seal/relinkeys.h" -#include "seal/galoiskeys.h" - -using namespace std; -using namespace seal::util; - -namespace seal -{ - bool is_metadata_valid_for( - const Plaintext &in, - shared_ptr context) - { - // Verify parameters - if (!context || !context->parameters_set()) - { - return false; - } - - if (in.is_ntt_form()) - { - - // Are the parameters valid for given plaintext? This check is slightly - // non-trivial because we need to consider both the case where key_parms_id - // equals first_parms_id, and cases where they are different. - auto context_data_ptr = context->get_context_data(in.parms_id()); - if (!context_data_ptr || - context_data_ptr->chain_index() > context->first_context_data()->chain_index()) - { - return false; - } - - auto &parms = context_data_ptr->parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t poly_modulus_degree = parms.poly_modulus_degree(); - if (mul_safe(coeff_modulus.size(), poly_modulus_degree) != in.coeff_count()) - { - return false; - } - } - else - { - auto &parms = context->first_context_data()->parms(); - if (parms.scheme() != scheme_type::BFV) - { - return false; - } - - size_t poly_modulus_degree = parms.poly_modulus_degree(); - if (in.coeff_count() > poly_modulus_degree) - { - return false; - } - } - - return true; - } - - bool is_metadata_valid_for( - const Ciphertext &in, - shared_ptr context) - { - // Verify parameters - if (!context || !context->parameters_set()) - { - return false; - } - - // Are the parameters valid for given ciphertext? This check is slightly - // non-trivial because we need to consider both the case where key_parms_id - // equals first_parms_id, and cases where they are different. - auto context_data_ptr = context->get_context_data(in.parms_id()); - if (!context_data_ptr || - context_data_ptr->chain_index() > context->first_context_data()->chain_index()) - { - return false; - } - - // Check that the metadata matches - auto &coeff_modulus = context_data_ptr->parms().coeff_modulus(); - size_t poly_modulus_degree = context_data_ptr->parms().poly_modulus_degree(); - if ((coeff_modulus.size() != in.coeff_mod_count()) || - (poly_modulus_degree != in.poly_modulus_degree())) - { - return false; - } - - // Check that size is either 0 or within right bounds - auto size = in.size(); - if ((size < SEAL_CIPHERTEXT_SIZE_MIN && size != 0) || - size > SEAL_CIPHERTEXT_SIZE_MAX) - { - return false; - } - - return true; - } - - bool is_metadata_valid_for( - const SecretKey &in, - shared_ptr context) - { - // Verify parameters - if (!context || !context->parameters_set()) - { - return false; - } - - // Are the parameters valid for given secret key? - if (in.parms_id() != context->key_parms_id()) - { - return false; - } - - auto context_data_ptr = context->key_context_data(); - auto &parms = context_data_ptr->parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t poly_modulus_degree = parms.poly_modulus_degree(); - if (mul_safe(coeff_modulus.size(), poly_modulus_degree) != in.data().coeff_count()) - { - return false; - } - - return true; - } - - bool is_metadata_valid_for( - const PublicKey &in, - shared_ptr context) - { - // Verify parameters - if (!context || !context->parameters_set()) - { - return false; - } - - // Are the parameters valid for given public key? - if (in.parms_id() != context->key_parms_id() || !in.data().is_ntt_form()) - { - return false; - } - - // Check that the metadata matches - auto context_data_ptr = context->key_context_data(); - auto &coeff_modulus = context_data_ptr->parms().coeff_modulus(); - size_t poly_modulus_degree = context_data_ptr->parms().poly_modulus_degree(); - if ((coeff_modulus.size() != in.data().coeff_mod_count()) || - (poly_modulus_degree != in.data().poly_modulus_degree())) - { - return false; - } - - // Check that size is right; for public key it should be exactly 2 - if (in.data().size() != SEAL_CIPHERTEXT_SIZE_MIN) - { - return false; - } - - return true; - } - - bool is_metadata_valid_for( - const KSwitchKeys &in, - shared_ptr context) - { - // Verify parameters - if (!context || !context->parameters_set()) - { - return false; - } - - // Are the parameters valid for given relinearization keys? - if (in.parms_id() != context->key_parms_id()) - { - return false; - } - - for (auto &a : in.data()) - { - for (auto &b : a) - { - // Check that b is a valid public key (metadata only); this also - // checks that its parms_id matches key_parms_id. - if (!is_metadata_valid_for(b, context)) - { - return false; - } - } - } - - return true; - } - - bool is_metadata_valid_for( - const RelinKeys &in, - shared_ptr context) - { - // Check that the size is within bounds. - bool size_check = !in.size() || - (in.size() <= SEAL_CIPHERTEXT_SIZE_MAX - 2 && - in.size() >= SEAL_CIPHERTEXT_SIZE_MIN - 2); - return is_metadata_valid_for( - static_cast(in), move(context)) && size_check; - } - - bool is_metadata_valid_for( - const GaloisKeys &in, - shared_ptr context) - { - return is_metadata_valid_for( - static_cast(in), move(context)); - } - - bool is_valid_for( - const Plaintext &in, - shared_ptr context) - { - // Check metadata - if (!is_metadata_valid_for(in, context)) - { - return false; - } - - // Check the data - if (in.is_ntt_form()) - { - auto context_data_ptr = context->get_context_data(in.parms_id()); - auto &parms = context_data_ptr->parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - - const Plaintext::pt_coeff_type *ptr = in.data(); - for (size_t j = 0; j < coeff_mod_count; j++) - { - uint64_t modulus = coeff_modulus[j].value(); - size_t poly_modulus_degree = parms.poly_modulus_degree(); - for (; poly_modulus_degree--; ptr++) - { - if (*ptr >= modulus) - { - return false; - } - } - } - } - else - { - auto &parms = context->first_context_data()->parms(); - uint64_t modulus = parms.plain_modulus().value(); - const Plaintext::pt_coeff_type *ptr = in.data(); - auto size = in.coeff_count(); - for (size_t k = 0; k < size; k++, ptr++) - { - if (*ptr >= modulus) - { - return false; - } - } - } - - return true; - } - - bool is_valid_for( - const Ciphertext &in, - shared_ptr context) - { - // Check metadata - if (!is_metadata_valid_for(in, context)) - { - return false; - } - - // Check the data - auto context_data_ptr = context->get_context_data(in.parms_id()); - const auto &coeff_modulus = context_data_ptr->parms().coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - - const Ciphertext::ct_coeff_type *ptr = in.data(); - auto size = in.size(); - - for (size_t i = 0; i < size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - uint64_t modulus = coeff_modulus[j].value(); - auto poly_modulus_degree = in.poly_modulus_degree(); - for (; poly_modulus_degree--; ptr++) - { - if (*ptr >= modulus) - { - return false; - } - } - } - } - - return true; - } - - bool is_valid_for( - const SecretKey &in, - shared_ptr context) - { - // Check metadata - if (!is_metadata_valid_for(in, context)) - { - return false; - } - - // Check the data - auto context_data_ptr = context->key_context_data(); - auto &parms = context_data_ptr->parms(); - auto &coeff_modulus = parms.coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - - const Plaintext::pt_coeff_type *ptr = in.data().data(); - for (size_t j = 0; j < coeff_mod_count; j++) - { - uint64_t modulus = coeff_modulus[j].value(); - size_t poly_modulus_degree = parms.poly_modulus_degree(); - for (; poly_modulus_degree--; ptr++) - { - if (*ptr >= modulus) - { - return false; - } - } - } - - return true; - } - - bool is_valid_for( - const PublicKey &in, - shared_ptr context) - { - // Check metadata - if (!is_metadata_valid_for(in, context)) - { - return false; - } - - // Check the data - auto context_data_ptr = context->key_context_data(); - const auto &coeff_modulus = context_data_ptr->parms().coeff_modulus(); - size_t coeff_mod_count = coeff_modulus.size(); - - const Ciphertext::ct_coeff_type *ptr = in.data().data(); - auto size = in.data().size(); - - for (size_t i = 0; i < size; i++) - { - for (size_t j = 0; j < coeff_mod_count; j++) - { - uint64_t modulus = coeff_modulus[j].value(); - auto poly_modulus_degree = in.data().poly_modulus_degree(); - for (; poly_modulus_degree--; ptr++) - { - if (*ptr >= modulus) - { - return false; - } - } - } - } - - return true; - } - - bool is_valid_for( - const KSwitchKeys &in, - shared_ptr context) - { - // Verify parameters - if (!context || !context->parameters_set()) - { - return false; - } - - // Are the parameters valid for given relinearization keys? - if (in.parms_id() != context->key_parms_id()) - { - return false; - } - - for (auto &a : in.data()) - { - for (auto &b : a) - { - // Check that b is a valid public key; this also checks that its - // parms_id matches key_parms_id. - if (!is_valid_for(b, context)) - { - return false; - } - } - } - - return true; - } - - bool is_valid_for( - const RelinKeys &in, - shared_ptr context) - { - return is_valid_for(static_cast(in), move(context)); - } - - bool is_valid_for( - const GaloisKeys &in, - shared_ptr context) - { - return is_valid_for(static_cast(in), move(context)); - } -} \ No newline at end of file diff --git a/SEAL/native/src/seal/valcheck.h b/SEAL/native/src/seal/valcheck.h deleted file mode 100644 index 9249998..0000000 --- a/SEAL/native/src/seal/valcheck.h +++ /dev/null @@ -1,208 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#pragma once - -#include "seal/util/defines.h" -#include "seal/context.h" -#include - -namespace seal -{ - class Plaintext; - class Ciphertext; - class SecretKey; - class PublicKey; - class KSwitchKeys; - class RelinKeys; - class GaloisKeys; - - /** - Check whether the given plaintext is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - plaintext data does not match the SEALContext, this function returns false. - Otherwise, returns true. This function only checks the metadata and not the - plaintext data itself. - - @param[in] in The plaintext to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_metadata_valid_for( - const Plaintext &in, - std::shared_ptr context); - - /** - Check whether the given ciphertext is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - ciphertext data does not match the SEALContext, this function returns false. - Otherwise, returns true. This function only checks the metadata and not the - ciphertext data itself. - - @param[in] in The ciphertext to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_metadata_valid_for( - const Ciphertext &in, - std::shared_ptr context); - - /** - Check whether the given secret key is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - secret key data does not match the SEALContext, this function returns false. - Otherwise, returns true. This function only checks the metadata and not the - secret key data itself. - - @param[in] in The secret key to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_metadata_valid_for( - const SecretKey &in, - std::shared_ptr context); - - /** - Check whether the given public key is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - public key data does not match the SEALContext, this function returns false. - Otherwise, returns true. This function only checks the metadata and not the - public key data itself. - - @param[in] in The public key to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_metadata_valid_for( - const PublicKey &in, - std::shared_ptr context); - - /** - Check whether the given KSwitchKeys is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - KSwitchKeys data does not match the SEALContext, this function returns false. - Otherwise, returns true. This function only checks the metadata and not the - KSwitchKeys data itself. - - @param[in] in The KSwitchKeys to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_metadata_valid_for( - const KSwitchKeys &in, - std::shared_ptr context); - - /** - Check whether the given RelinKeys is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - RelinKeys data does not match the SEALContext, this function returns false. - Otherwise, returns true. This function only checks the metadata and not the - RelinKeys data itself. - - @param[in] in The RelinKeys to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_metadata_valid_for( - const RelinKeys &in, - std::shared_ptr context); - - /** - Check whether the given GaloisKeys is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - GaloisKeys data does not match the SEALContext, this function returns false. - Otherwise, returns true. This function only checks the metadata and not the - GaloisKeys data itself. - - @param[in] in The RelinKeys to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_metadata_valid_for( - const GaloisKeys &in, - std::shared_ptr context); - - /** - Check whether the given plaintext is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - plaintext data does not match the SEALContext, this function returns false. - Otherwise, returns true. - - @param[in] in The plaintext to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_valid_for( - const Plaintext &in, - std::shared_ptr context); - - /** - Check whether the given ciphertext is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - ciphertext data does not match the SEALContext, this function returns false. - Otherwise, returns true. - - @param[in] in The ciphertext to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_valid_for( - const Ciphertext &in, - std::shared_ptr context); - - /** - Check whether the given secret key is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - secret key data does not match the SEALContext, this function returns false. - Otherwise, returns true. - - @param[in] in The secret key to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_valid_for( - const SecretKey &in, - std::shared_ptr context); - - /** - Check whether the given public key is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - public key data does not match the SEALContext, this function returns false. - Otherwise, returns true. - - @param[in] in The public key to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_valid_for( - const PublicKey &in, - std::shared_ptr context); - - /** - Check whether the given KSwitchKeys is valid for a given SEALContext. If - the given SEALContext is not set, the encryption parameters are invalid, - or the KSwitchKeys data does not match the SEALContext, this function returns - false. Otherwise, returns true. - - @param[in] in The KSwitchKeys to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_valid_for( - const KSwitchKeys &in, - std::shared_ptr context); - - /** - Check whether the given RelinKeys is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - RelinKeys data does not match the SEALContext, this function returns false. - Otherwise, returns true. - - @param[in] in The RelinKeys to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_valid_for( - const RelinKeys &in, - std::shared_ptr context); - - /** - Check whether the given GaloisKeys is valid for a given SEALContext. If the - given SEALContext is not set, the encryption parameters are invalid, or the - GaloisKeys data does not match the SEALContext, this function returns false. - Otherwise, returns true. - - @param[in] in The GaloisKeys to check - @param[in] context The SEALContext - */ - SEAL_NODISCARD bool is_valid_for( - const GaloisKeys &in, - std::shared_ptr context); -} diff --git a/SEAL/native/tests/CMakeLists.txt b/SEAL/native/tests/CMakeLists.txt deleted file mode 100644 index 279115c..0000000 --- a/SEAL/native/tests/CMakeLists.txt +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -cmake_minimum_required(VERSION 3.10) - -project(SEALTest LANGUAGES CXX) - -# Executable will be in ../bin -set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/../bin) - -add_executable(sealtest seal/testrunner.cpp) - -# Import Microsoft SEAL -find_package(SEAL 3.3.2 EXACT REQUIRED) - -# Add source files -add_subdirectory(seal) - -# Only build GTest -option(BUILD_GMOCK OFF) -option(INSTALL_GTEST OFF) -mark_as_advanced(BUILD_GMOCK INSTALL_GTEST) - -# Add GTest -set(GTEST_DIR "thirdparty/googletest") -if(NOT EXISTS ${GTEST_DIR}/CMakeLists.txt) - message(FATAL_ERROR "Could not find `${GTEST_DIR}/CMakeLists.txt`. Run `git submodule update --init` and retry.") -endif() -add_subdirectory(${GTEST_DIR}) - -# Link Microsoft SEAL and GTest -target_link_libraries(sealtest SEAL::seal gtest) diff --git a/SEAL/native/tests/SEALTest.vcxproj b/SEAL/native/tests/SEALTest.vcxproj deleted file mode 100644 index 699cb45..0000000 --- a/SEAL/native/tests/SEALTest.vcxproj +++ /dev/null @@ -1,141 +0,0 @@ - - - - - Debug - x64 - - - Release - x64 - - - - {0345DC4D-EFE3-460E-AB7E-AA6E05BB8DFF} - Win32Proj - Application - v141 - Unicode - 10.0.16299.0 - - - - - - - - - $(ProjectDir)..\bin\$(Platform)\$(Configuration)\ - $(ProjectDir)obj\$(Platform)\$(Configuration)\ - sealtest - - - $(ProjectDir)..\bin\$(Platform)\$(Configuration)\ - $(ProjectDir)obj\$(Platform)\$(Configuration)\ - sealtest - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - NotUsing - Disabled - X64;_SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) - false - EnableFastChecks - MultiThreadedDebugDLL - Level3 - $(SolutionDir)/native/src;%(AdditionalIncludeDirectories) - Guard - stdcpp17 - /Zc:__cplusplus %(AdditionalOptions) - true - - - true - Console - seal.lib;%(AdditionalDependencies) - $(ProjectDir)..\lib\$(Platform)\$(Configuration);%(AdditionalLibraryDirectories) - - - - - NotUsing - X64;_SILENCE_TR1_NAMESPACE_DEPRECATION_WARNING;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) - MultiThreadedDLL - Level3 - ProgramDatabase - $(SolutionDir)/native/src;%(AdditionalIncludeDirectories) - Guard - stdcpp17 - /Zc:__cplusplus %(AdditionalOptions) - true - - - true - Console - true - true - seal.lib;%(AdditionalDependencies) - $(ProjectDir)..\lib\$(Platform)\$(Configuration);%(AdditionalLibraryDirectories) - - - - - This project references NuGet package(s) that are missing on this computer. Use NuGet Package Restore to download them. For more information, see http://go.microsoft.com/fwlink/?LinkID=322105. The missing file is {0}. - - - - \ No newline at end of file diff --git a/SEAL/native/tests/SEALTest.vcxproj.filters b/SEAL/native/tests/SEALTest.vcxproj.filters deleted file mode 100644 index 2ea8445..0000000 --- a/SEAL/native/tests/SEALTest.vcxproj.filters +++ /dev/null @@ -1,165 +0,0 @@ - - - - - {4FC737F1-C7A5-4376-A066-2A32D752A2FF} - cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx - - - {93995380-89BD-4b04-88EB-625FBE52EBFB} - h;hh;hpp;hxx;hm;inl;inc;xsd - - - {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} - rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms - - - {6c39d93e-a64a-44b3-95ca-ba22fd03ea17} - cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx - - - {c2fb4c08-c2b6-492e-872b-05c104d883e6} - - - {ba659d9a-0bce-409b-8429-a60c1b9506ee} - - - {86489695-28b4-44f1-8f0d-2637ecd43cb3} - - - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files\util - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - Source Files - - - - - Other - - - Other\seal - - - Other\seal\util - - - - - Other - - - \ No newline at end of file diff --git a/SEAL/native/tests/packages.config b/SEAL/native/tests/packages.config deleted file mode 100644 index 3c6fe17..0000000 --- a/SEAL/native/tests/packages.config +++ /dev/null @@ -1,4 +0,0 @@ - - - - \ No newline at end of file diff --git a/SEAL/native/tests/seal/CMakeLists.txt b/SEAL/native/tests/seal/CMakeLists.txt deleted file mode 100644 index f7c4a0e..0000000 --- a/SEAL/native/tests/seal/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -target_sources(sealtest - PRIVATE - ${CMAKE_CURRENT_LIST_DIR}/batchencoder.cpp - ${CMAKE_CURRENT_LIST_DIR}/biguint.cpp - ${CMAKE_CURRENT_LIST_DIR}/ciphertext.cpp - ${CMAKE_CURRENT_LIST_DIR}/ckks.cpp - ${CMAKE_CURRENT_LIST_DIR}/context.cpp - ${CMAKE_CURRENT_LIST_DIR}/intencoder.cpp - ${CMAKE_CURRENT_LIST_DIR}/encryptionparams.cpp - ${CMAKE_CURRENT_LIST_DIR}/encryptor.cpp - ${CMAKE_CURRENT_LIST_DIR}/evaluator.cpp - ${CMAKE_CURRENT_LIST_DIR}/galoiskeys.cpp - ${CMAKE_CURRENT_LIST_DIR}/intarray.cpp - ${CMAKE_CURRENT_LIST_DIR}/keygenerator.cpp - ${CMAKE_CURRENT_LIST_DIR}/memorymanager.cpp - ${CMAKE_CURRENT_LIST_DIR}/modulus.cpp - ${CMAKE_CURRENT_LIST_DIR}/plaintext.cpp - ${CMAKE_CURRENT_LIST_DIR}/publickey.cpp - ${CMAKE_CURRENT_LIST_DIR}/randomgen.cpp - ${CMAKE_CURRENT_LIST_DIR}/randomtostd.cpp - ${CMAKE_CURRENT_LIST_DIR}/relinkeys.cpp - ${CMAKE_CURRENT_LIST_DIR}/secretkey.cpp - ${CMAKE_CURRENT_LIST_DIR}/smallmodulus.cpp -) - -add_subdirectory(util) diff --git a/SEAL/native/tests/seal/baseconverter.cpp b/SEAL/native/tests/seal/baseconverter.cpp deleted file mode 100644 index d0649c5..0000000 --- a/SEAL/native/tests/seal/baseconverter.cpp +++ /dev/null @@ -1,552 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "CppUnitTest.h" -#include -#include "util/mempool.h" -#include "util/uintcore.h" -#include "memorypoolhandle.h" -#include "smallmodulus.h" -#include "util/BaseConverter.h" -#include "util/uintarith.h" -#include "util/uintarithsmallmod.h" -#include "util/uintarithmod.h" -#include "primes.h" - -using namespace Microsoft::VisualStudio::CppUnitTestFramework; -using namespace seal::util; -using namespace seal; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST_CLASS(BaseConverterClass) - { - public: - TEST_METHOD(BaseConverterConstructor) - { - MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); - vector coeff_base; - vector aux_base; - SmallModulus mtilda = small_mods[10]; - SmallModulus msk = small_mods[11]; - SmallModulus plain_t = small_mods[9]; - int coeff_base_count = 4; - int aux_base_count = 4; - - for (int i = 0; i < coeff_base_count; ++i) - { - coeff_base.push_back(small_mods[i]); - aux_base.push_back(small_mods[i + coeff_base_count]); - } - - BaseConverter BaseConverter(coeff_base, 4, plain_t); - Assert::IsTrue(BaseConverter.is_generated()); - } - - TEST_METHOD(FastBConverter) - { - { - MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); - vector coeff_base; - vector aux_base; - SmallModulus plain_t = small_mods[9]; - int coeff_base_count = 2; - int aux_base_count = 2; - - for (int i = 0; i < coeff_base_count; ++i) - { - coeff_base.push_back(small_mods[i]); - aux_base.push_back(small_mods[i + coeff_base_count + 2]); - } - - BaseConverter BaseConverter(coeff_base, 1, plain_t); - Pointer input(allocate_uint(2, pool)); - Pointer output(allocate_uint(3, pool)); - - // the composed input is 0xffffffffffffff00ffffffffffffff - - input[0] = 4395513236581707780; - input[1] = 4395513390924464132; - - - output[0] = 0xFFFFFFFFFFFFFFFF; - output[1] = 0xFFFFFFFFFFFFFFFF; - output[2] = 0; - - Assert::IsTrue(BaseConverter.fastbconv(input.get(), output.get())); - Assert::AreEqual(static_cast(3116074317392112723), output[0]); - Assert::AreEqual(static_cast(1254200639185090240), output[1]); - Assert::AreEqual(static_cast(3528328721557038672), output[2]); - } - - { - MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); - vector coeff_base; - vector aux_base; - SmallModulus mtilda = small_mods[10]; - SmallModulus msk = small_mods[11]; - SmallModulus plain_t = small_mods[9]; - int coeff_base_count = 2; - int aux_base_count = 2; - - for (int i = 0; i < coeff_base_count; ++i) - { - coeff_base.push_back(small_mods[i]); - aux_base.push_back(small_mods[i + coeff_base_count + 2]); - } - BaseConverter BaseConverter(coeff_base, 4, plain_t); - Pointer input(allocate_uint(8, pool)); - Pointer output(allocate_uint(12, pool)); - - // the composed input is 0xffffffffffffff00ffffffffffffff for all coeffs - // mod q1 - input[0] = 4395513236581707780; // cons - input[1] = 4395513236581707780; // x - input[2] = 4395513236581707780; // x^2 - input[3] = 4395513236581707780; // x^3 - - //mod q2 - input[4] = 4395513390924464132; - input[5] = 4395513390924464132; - input[6] = 4395513390924464132; - input[7] = 4395513390924464132; - - output[0] = 0xFFFFFFFFFFFFFFFF; - output[1] = 0xFFFFFFFFFFFFFFFF; - output[2] = 0; - - Assert::IsTrue(BaseConverter.fastbconv(input.get(), output.get())); - Assert::AreEqual(static_cast(3116074317392112723), output[0]); - Assert::AreEqual(static_cast(3116074317392112723), output[1]); - Assert::AreEqual(static_cast(3116074317392112723), output[2]); - Assert::AreEqual(static_cast(3116074317392112723), output[3]); - - Assert::AreEqual(static_cast(1254200639185090240), output[4]); - Assert::AreEqual(static_cast(1254200639185090240), output[5]); - Assert::AreEqual(static_cast(1254200639185090240), output[6]); - Assert::AreEqual(static_cast(1254200639185090240), output[7]); - - Assert::AreEqual(static_cast(3528328721557038672), output[8]); - Assert::AreEqual(static_cast(3528328721557038672), output[9]); - Assert::AreEqual(static_cast(3528328721557038672), output[10]); - Assert::AreEqual(static_cast(3528328721557038672), output[11]); - } - } - - TEST_METHOD(FastBConvSK) - { - { - MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); - vector coeff_base; - vector aux_base; - SmallModulus mtilda = small_mods[10]; - SmallModulus msk = small_mods[4]; - SmallModulus plain_t = small_mods[9]; - - int coeff_base_count = 2; - int aux_base_count = 2; - for (int i = 0; i < coeff_base_count; ++i) - { - coeff_base.push_back(small_mods[i]); - aux_base.push_back(small_mods[i + coeff_base_count]); - } - - BaseConverter BaseConverter(coeff_base, 1, plain_t); - Pointer input(allocate_uint(3, pool)); - Pointer output(allocate_uint(2, pool)); - - // The composed input is 0xffffffffffffff00ffffffffffffff - - input[0] = 4395583330278772740; - input[1] = 4396634741790752772; - input[2] = 4396375252835237892; // mod msk - - output[0] = 0xFFFFFFFFFFFFFFF; - output[1] = 0xFFFFFFFFFFFFFFF; - - Assert::IsTrue(BaseConverter.fastbconv_sk(input.get(), output.get())); - Assert::AreEqual(static_cast(2494482839790051254), output[0]); - Assert::AreEqual(static_cast(218180408843610743), output[1]); - } - - { - MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); - vector coeff_base; - vector aux_base; - SmallModulus mtilda = small_mods[10]; - SmallModulus msk = small_mods[4]; - SmallModulus plain_t = small_mods[9]; - - int coeff_base_count = 2; - int aux_base_count = 2; - for (int i = 0; i < coeff_base_count; ++i) - { - coeff_base.push_back(small_mods[i]); - aux_base.push_back(small_mods[i + coeff_base_count]); - } - - BaseConverter BaseConverter(coeff_base, 4, plain_t); - Pointer input(allocate_uint(12, pool)); - Pointer output(allocate_uint(8, pool)); - - // The composed input is 0xffffffffffffff00ffffffffffffff - - input[0] = 4395583330278772740; // cons - input[1] = 4395583330278772740; // x - input[2] = 4395583330278772740; // x^2 - input[3] = 4395583330278772740; // x^3 - - input[4] = 4396634741790752772; - input[5] = 4396634741790752772; - input[6] = 4396634741790752772; - input[7] = 4396634741790752772; - - input[8] = 4396375252835237892; // mod msk - input[9] = 4396375252835237892; - input[10] = 4396375252835237892; - input[11] = 4396375252835237892; - - output[0] = 0xFFFFFFFFFFFFFFF; - output[1] = 0xFFFFFFFFFFFFFFF; - - Assert::IsTrue(BaseConverter.fastbconv_sk(input.get(), output.get())); - Assert::AreEqual(static_cast(2494482839790051254), output[0]); //mod q1 - Assert::AreEqual(static_cast(2494482839790051254), output[1]); - Assert::AreEqual(static_cast(2494482839790051254), output[2]); - Assert::AreEqual(static_cast(2494482839790051254), output[3]); - - Assert::AreEqual(static_cast(218180408843610743), output[4]); //mod q2 - Assert::AreEqual(static_cast(218180408843610743), output[5]); - Assert::AreEqual(static_cast(218180408843610743), output[6]); - Assert::AreEqual(static_cast(218180408843610743), output[7]); - } - - } - - TEST_METHOD(MontRq) - { - { - MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); - vector coeff_base; - vector aux_base; - SmallModulus mtilda = small_mods[5]; - SmallModulus msk = small_mods[4]; - SmallModulus plain_t = small_mods[9]; - - int coeff_base_count = 2; - int aux_base_count = 2; - for (int i = 0; i < coeff_base_count; ++i) - { - coeff_base.push_back(small_mods[i]); - aux_base.push_back(small_mods[i + coeff_base_count]); - } - - BaseConverter BaseConverter(coeff_base, 1, plain_t); - Pointer input(allocate_uint(4, pool)); - Pointer output(allocate_uint(3, pool)); - - // The composed input is 0xffffffffffffff00ffffffffffffff - - input[0] = 4395583330278772740; // mod m1 - input[1] = 4396634741790752772; // mod m2 - input[2] = 4396375252835237892; // mod msk - input[3] = 4396146554501595140; // mod m_tilde - - output[0] = 0xfffffffff; - output[1] = 0x00fffffff; - output[2] = 0; - - Assert::IsTrue(BaseConverter.mont_rq(input.get(), output.get())); - Assert::AreEqual(static_cast(1412154008057360306), output[0]); - Assert::AreEqual(static_cast(3215947095329058299), output[1]); - Assert::AreEqual(static_cast(1636465626706639696), output[2]); - } - - { - MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); - vector coeff_base; - vector aux_base; - SmallModulus mtilda = small_mods[5]; - SmallModulus msk = small_mods[4]; - SmallModulus plain_t = small_mods[9]; - - int coeff_base_count = 2; - int aux_base_count = 2; - for (int i = 0; i < coeff_base_count; ++i) - { - coeff_base.push_back(small_mods[i]); - aux_base.push_back(small_mods[i + coeff_base_count]); - } - - BaseConverter BaseConverter(coeff_base, 3, plain_t); - Pointer input(allocate_uint(12, pool)); - Pointer output(allocate_uint(9, pool)); - - // The composed input is 0xffffffffffffff00ffffffffffffff for all coeffs - - input[0] = 4395583330278772740; // cons mod m1 - input[1] = 4395583330278772740; // x mod m1 - input[2] = 4395583330278772740; // x^2 mod m1 - - input[3] = 4396634741790752772; // cons mod m2 - input[4] = 4396634741790752772; // x mod m2 - input[5] = 4396634741790752772; // x^2 mod m2 - - input[6] = 4396375252835237892; // cons mod msk - input[7] = 4396375252835237892; // x mod msk - input[8] = 4396375252835237892; // x^2 mod msk - - input[9] = 4396146554501595140; // cons mod m_tilde - input[10] = 4396146554501595140; // x mod m_tilde - input[11] = 4396146554501595140; // x^2 mod m_tilde - - output[0] = 0xfffffffff; - output[1] = 0x00fffffff; - output[2] = 0; - - Assert::IsTrue(BaseConverter.mont_rq(input.get(), output.get())); - Assert::AreEqual(static_cast(1412154008057360306), output[0]); - Assert::AreEqual(static_cast(1412154008057360306), output[1]); - Assert::AreEqual(static_cast(1412154008057360306), output[2]); - - Assert::AreEqual(static_cast(3215947095329058299), output[3]); - Assert::AreEqual(static_cast(3215947095329058299), output[4]); - Assert::AreEqual(static_cast(3215947095329058299), output[5]); - - Assert::AreEqual(static_cast(1636465626706639696), output[6]); - Assert::AreEqual(static_cast(1636465626706639696), output[7]); - Assert::AreEqual(static_cast(1636465626706639696), output[8]); - } - } - - TEST_METHOD(FastFloor) - { - { - MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); - vector coeff_base; - vector aux_base; - SmallModulus mtilda = small_mods[5]; - SmallModulus msk = small_mods[4]; - SmallModulus plain_t = small_mods[9]; - - int coeff_base_count = 2; - int aux_base_count = 2; - for (int i = 0; i < coeff_base_count; ++i) - { - coeff_base.push_back(small_mods[i]); - aux_base.push_back(small_mods[i + coeff_base_count]); - } - - BaseConverter BaseConverter(coeff_base, 1, plain_t); - Pointer input(allocate_uint(5, pool)); - Pointer output(allocate_uint(3, pool)); - - // The composed input is 0xffffffffffffff00ffffffffffffff - - input[0] = 4395513236581707780; // mod q1 - input[1] = 4395513390924464132; // mod q2 - input[2] = 4395583330278772740; // mod m1 - input[3] = 4396634741790752772; // mod m2 - input[4] = 4396375252835237892; // mod msk - - output[0] = 0xfffffffff; - output[1] = 0x00fffffff; - output[2] = 0; - - Assert::IsTrue(BaseConverter.fast_floor(input.get(), output.get())); - - // The result for all moduli is equal to -1 since the composed input is small - // Assert::AreEqual(static_cast(4611686018393899008), output[0]); - // Assert::AreEqual(static_cast(4611686018293432320), output[1]); - // Assert::AreEqual(static_cast(4611686018309947392), output[2]); - - // The composed input is 0xffffffffffffff00ffffffffffffff00ff - - input[0] = 17574536613119; // mod q1 - input[1] = 10132675570633983; // mod q2 - input[2] = 3113399115422302529; // mod m1 - input[3] = 1298513899176416785; // mod m2 - input[4] = 3518991311999157564; // mod msk - - output[0] = 0xfffffffff; - output[1] = 0x00fffffff; - output[2] = 0; - - // Since input > q1*q2, the result should be floor(x/(q1*q2)) - alpha (alpha = {0 or 1}) - Assert::IsTrue(BaseConverter.fast_floor(input.get(), output.get())); - Assert::AreEqual(static_cast(0xfff), output[0]); - Assert::AreEqual(static_cast(0xfff), output[1]); - Assert::AreEqual(static_cast(0xfff), output[2]); - - // The composed input is 0xffffffffffffff00ffffffffffffff00ffff - - input[0] = 4499081372958719; // mod q1 - input[1] = 2593964946082299903; // mod q2 - input[2] = 4013821342825660755; // mod m1 - input[3] = 457963018288239031; // mod m2 - input[4] = 1691919900291185724; // mod msk - - output[0] = 0xfffffffff; - output[1] = 0x00fffffff; - output[2] = 0; - - // Since input > q1*q2, the result should be floor(x/(q1*q2)) - alpha (alpha = {0 or 1}) - Assert::IsTrue(BaseConverter.fast_floor(input.get(), output.get())); - Assert::AreEqual(static_cast(0xfffff), output[0]); - Assert::AreEqual(static_cast(0xfffff), output[1]); - Assert::AreEqual(static_cast(0xfffff), output[2]); - } - - { - MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); - vector coeff_base; - vector aux_base; - SmallModulus plain_t = small_mods[9]; - - int coeff_base_count = 2; - int aux_base_count = 2; - for (int i = 0; i < coeff_base_count; ++i) - { - coeff_base.push_back(small_mods[i]); - } - - BaseConverter BaseConverter(coeff_base, 2, plain_t); - Pointer input(allocate_uint(10, pool)); - Pointer output(allocate_uint(6, pool)); - - input[0] = 4499081372958719; // mod q1 - input[1] = 4499081372958719; // mod q1 - - input[2] = 2593964946082299903; // mod q2 - input[3] = 2593964946082299903; // mod q2 - - input[4] = 4013821342825660755; // mod m1 - input[5] = 4013821342825660755; // mod m1 - - input[6] = 457963018288239031; // mod m2 - input[7] = 457963018288239031; // mod m2 - - input[8] = 1691919900291185724; // mod msk - input[9] = 1691919900291185724; // mod msk - - output[0] = 0xfffffffff; - output[1] = 0x00fffffff; - output[2] = 0; - - // Since input > q1*q2, the result should be floor(x/(q1*q2)) - alpha (alpha = {0 or 1}) - Assert::IsTrue(BaseConverter.fast_floor(input.get(), output.get())); - Assert::AreEqual(static_cast(0xfffff), output[0]); - Assert::AreEqual(static_cast(0xfffff), output[1]); - - Assert::AreEqual(static_cast(0xfffff), output[2]); - Assert::AreEqual(static_cast(0xfffff), output[3]); - - Assert::AreEqual(static_cast(0xfffff), output[4]); - Assert::AreEqual(static_cast(0xfffff), output[5]); - } - - } - - TEST_METHOD(FastBConver_mtilde) - { - MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); - vector coeff_base; - vector aux_base; - SmallModulus mtilda = small_mods[5]; - SmallModulus msk = small_mods[4]; - SmallModulus plain_t = small_mods[9]; - - int coeff_base_count = 2; - int aux_base_count = 2; - for (int i = 0; i < coeff_base_count; ++i) - { - coeff_base.push_back(small_mods[i]); - aux_base.push_back(small_mods[i + coeff_base_count]); - } - - BaseConverter BaseConverter(coeff_base, 3, plain_t); - Pointer input(allocate_uint(6, pool)); - Pointer output(allocate_uint(12, pool)); - - // The composed input is 0xffffffffffffff00ffffffffffffff for all coeffs - - input[0] = 4395513236581707780; // cons mod q1 - input[1] = 4395513236581707780; // x mod q1 - input[2] = 4395513236581707780; // x^2 mod q1 - - input[3] = 4395513390924464132; // cons mod q2 - input[4] = 4395513390924464132; // x mod q2 - input[5] = 4395513390924464132; // x^2 mod q2 - - output[0] = 0xffffffff; - output[1] = 0; - output[2] = 0xffffff; - output[3] = 0xffffff; - - Assert::IsTrue(BaseConverter.fastbconv_mtilde(input.get(), output.get())); - Assert::AreEqual(static_cast(3116074317392112723), output[0]);//mod m1 - Assert::AreEqual(static_cast(3116074317392112723), output[1]); - Assert::AreEqual(static_cast(3116074317392112723), output[2]); - - Assert::AreEqual(static_cast(1254200639185090240), output[3]);//mod m2 - Assert::AreEqual(static_cast(1254200639185090240), output[4]); - Assert::AreEqual(static_cast(1254200639185090240), output[5]); - - Assert::AreEqual(static_cast(3528328721557038672), output[6]);//mod msk - Assert::AreEqual(static_cast(3528328721557038672), output[7]); - Assert::AreEqual(static_cast(3528328721557038672), output[8]); - - Assert::AreEqual(static_cast(849325434816160659), output[9]);//mod m_tilde - Assert::AreEqual(static_cast(849325434816160659), output[10]); - Assert::AreEqual(static_cast(849325434816160659), output[11]); - } - - TEST_METHOD(FastBConvert_plain_gamma) - { - MemoryPoolMT &pool = *MemoryPoolMT::default_pool(); - vector coeff_base; - vector aux_base; - SmallModulus plain_t = small_mods[9]; - - int coeff_base_count = 2; - int aux_base_count = 2; - for (int i = 0; i < coeff_base_count; ++i) - { - coeff_base.push_back(small_mods[i]); - aux_base.push_back(small_mods[i + coeff_base_count]); - } - - BaseConverter BaseConverter(coeff_base, 3, plain_t); - Pointer input(allocate_uint(6, pool)); - Pointer output(allocate_uint(6, pool)); - - // The composed input is 0xffffffffffffff00ffffffffffffff for all coeffs - - input[0] = 4395513236581707780; // cons mod q1 - input[1] = 4395513236581707780; // x mod q1 - input[2] = 4395513236581707780; // x^2 mod q1 - - input[3] = 4395513390924464132; // cons mod q2 - input[4] = 4395513390924464132; // x mod q2 - input[5] = 4395513390924464132; // x^2 mod q2 - - output[0] = 0xffffffff; - output[1] = 0; - output[2] = 0xffffff; - output[3] = 0xffffff; - - Assert::IsTrue(BaseConverter.fastbconv_plain_gamma(input.get(), output.get())); - Assert::AreEqual(static_cast(1950841694949736435), output[0]);//mod plain modulus - Assert::AreEqual(static_cast(1950841694949736435), output[1]); - Assert::AreEqual(static_cast(1950841694949736435), output[2]); - - Assert::AreEqual(static_cast(3744510248429639755), output[3]);//mod gamma - Assert::AreEqual(static_cast(3744510248429639755), output[4]); - Assert::AreEqual(static_cast(3744510248429639755), output[5]); - } - }; - } -} diff --git a/SEAL/native/tests/seal/batchencoder.cpp b/SEAL/native/tests/seal/batchencoder.cpp deleted file mode 100644 index 09b3550..0000000 --- a/SEAL/native/tests/seal/batchencoder.cpp +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/batchencoder.h" -#include "seal/context.h" -#include "seal/keygenerator.h" -#include "seal/modulus.h" -#include -#include - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - TEST(BatchEncoderTest, BatchUnbatchUIntVector) - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); - parms.set_plain_modulus(257); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_TRUE(context->first_context_data()->qualifiers().using_batching); - - BatchEncoder batch_encoder(context); - ASSERT_EQ(64ULL, batch_encoder.slot_count()); - vector plain_vec; - for (size_t i = 0; i < batch_encoder.slot_count(); i++) - { - plain_vec.push_back(i); - } - - Plaintext plain; - batch_encoder.encode(plain_vec, plain); - vector plain_vec2; - batch_encoder.decode(plain, plain_vec2); - ASSERT_TRUE(plain_vec == plain_vec2); - - for (size_t i = 0; i < batch_encoder.slot_count(); i++) - { - plain_vec[i] = 5; - } - batch_encoder.encode(plain_vec, plain); - ASSERT_TRUE(plain.to_string() == "5"); - batch_encoder.decode(plain, plain_vec2); - ASSERT_TRUE(plain_vec == plain_vec2); - - vector short_plain_vec; - for (size_t i = 0; i < 20; i++) - { - short_plain_vec.push_back(i); - } - batch_encoder.encode(short_plain_vec, plain); - vector short_plain_vec2; - batch_encoder.decode(plain, short_plain_vec2); - ASSERT_EQ(20ULL, short_plain_vec.size()); - ASSERT_EQ(64ULL, short_plain_vec2.size()); - for (size_t i = 0; i < 20; i++) - { - ASSERT_EQ(short_plain_vec[i], short_plain_vec2[i]); - } - for (size_t i = 20; i < batch_encoder.slot_count(); i++) - { - ASSERT_EQ(0ULL, short_plain_vec2[i]); - } - } - - TEST(BatchEncoderTest, BatchUnbatchIntVector) - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); - parms.set_plain_modulus(257); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_TRUE(context->first_context_data()->qualifiers().using_batching); - - BatchEncoder batch_encoder(context); - ASSERT_EQ(64ULL, batch_encoder.slot_count()); - vector plain_vec; - for (size_t i = 0; i < batch_encoder.slot_count(); i++) - { - plain_vec.push_back(static_cast(i * (1 - 2 * (i % 2)))); - } - - Plaintext plain; - batch_encoder.encode(plain_vec, plain); - vector plain_vec2; - batch_encoder.decode(plain, plain_vec2); - ASSERT_TRUE(plain_vec == plain_vec2); - - for (size_t i = 0; i < batch_encoder.slot_count(); i++) - { - plain_vec[i] = -5; - } - batch_encoder.encode(plain_vec, plain); - ASSERT_TRUE(plain.to_string() == "FC"); - batch_encoder.decode(plain, plain_vec2); - ASSERT_TRUE(plain_vec == plain_vec2); - - vector short_plain_vec; - for (int i = 0; i < 20; i++) - { - short_plain_vec.push_back(static_cast(i * (1 - 2 * (i % 2)))); - } - batch_encoder.encode(short_plain_vec, plain); - vector short_plain_vec2; - batch_encoder.decode(plain, short_plain_vec2); - ASSERT_EQ(20ULL, short_plain_vec.size()); - ASSERT_EQ(64ULL, short_plain_vec2.size()); - for (size_t i = 0; i < 20; i++) - { - ASSERT_EQ(short_plain_vec[i], short_plain_vec2[i]); - } - for (size_t i = 20; i < batch_encoder.slot_count(); i++) - { - ASSERT_EQ(0ULL, short_plain_vec2[i]); - } - } - - TEST(BatchEncoderTest, BatchUnbatchPlaintext) - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); - parms.set_plain_modulus(257); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_TRUE(context->first_context_data()->qualifiers().using_batching); - - BatchEncoder batch_encoder(context); - ASSERT_EQ(64ULL, batch_encoder.slot_count()); - Plaintext plain(batch_encoder.slot_count()); - for (size_t i = 0; i < batch_encoder.slot_count(); i++) - { - plain[i] = i; - } - - batch_encoder.encode(plain); - batch_encoder.decode(plain); - for (size_t i = 0; i < batch_encoder.slot_count(); i++) - { - ASSERT_TRUE(plain[i] == i); - } - - for (size_t i = 0; i < batch_encoder.slot_count(); i++) - { - plain[i] = 5; - } - batch_encoder.encode(plain); - ASSERT_TRUE(plain.to_string() == "5"); - batch_encoder.decode(plain); - for (size_t i = 0; i < batch_encoder.slot_count(); i++) - { - ASSERT_EQ(5ULL, plain[i]); - } - - Plaintext short_plain(20); - for (size_t i = 0; i < 20; i++) - { - short_plain[i] = i; - } - batch_encoder.encode(short_plain); - batch_encoder.decode(short_plain); - for (size_t i = 0; i < 20; i++) - { - ASSERT_TRUE(short_plain[i] == i); - } - for (size_t i = 20; i < batch_encoder.slot_count(); i++) - { - ASSERT_TRUE(short_plain[i] == 0); - } - } -} diff --git a/SEAL/native/tests/seal/biguint.cpp b/SEAL/native/tests/seal/biguint.cpp deleted file mode 100644 index c8903bd..0000000 --- a/SEAL/native/tests/seal/biguint.cpp +++ /dev/null @@ -1,378 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/biguint.h" -#include "seal/util/defines.h" - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - TEST(BigUnsignedInt, EmptyBigUInt) - { - BigUInt uint; - ASSERT_EQ(0, uint.bit_count()); - ASSERT_TRUE(nullptr == uint.data()); - ASSERT_EQ(0ULL, uint.byte_count()); - ASSERT_EQ(0ULL, uint.uint64_count()); - ASSERT_EQ(0, uint.significant_bit_count()); - ASSERT_TRUE("0" == uint.to_string()); - ASSERT_TRUE(uint.is_zero()); - ASSERT_FALSE(uint.is_alias()); - uint.set_zero(); - - BigUInt uint2; - ASSERT_TRUE(uint == uint2); - ASSERT_FALSE(uint != uint2); - - uint.resize(1); - ASSERT_EQ(1, uint.bit_count()); - ASSERT_TRUE(nullptr != uint.data()); - ASSERT_FALSE(uint.is_alias()); - - uint.resize(0); - ASSERT_EQ(0, uint.bit_count()); - ASSERT_TRUE(nullptr == uint.data()); - ASSERT_FALSE(uint.is_alias()); - } - - TEST(BigUnsignedInt, BigUInt64Bits) - { - BigUInt uint(64); - ASSERT_EQ(64, uint.bit_count()); - ASSERT_TRUE(nullptr != uint.data()); - ASSERT_EQ(8ULL, uint.byte_count()); - ASSERT_EQ(1ULL, uint.uint64_count()); - ASSERT_EQ(0, uint.significant_bit_count()); - ASSERT_TRUE("0" == uint.to_string()); - ASSERT_TRUE(uint.is_zero()); - ASSERT_EQ(static_cast(0), *uint.data()); - ASSERT_TRUE(SEAL_BYTE(0) == uint[0]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[1]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[2]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[3]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[4]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[5]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[6]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[7]); - - uint = "1"; - ASSERT_EQ(1, uint.significant_bit_count()); - ASSERT_TRUE("1" == uint.to_string()); - ASSERT_FALSE(uint.is_zero()); - ASSERT_EQ(1ULL, *uint.data()); - ASSERT_TRUE(SEAL_BYTE(1) == uint[0]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[1]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[2]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[3]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[4]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[5]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[6]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[7]); - uint.set_zero(); - ASSERT_TRUE(uint.is_zero()); - ASSERT_EQ(static_cast(0), *uint.data()); - - uint = "7FFFFFFFFFFFFFFF"; - ASSERT_EQ(63, uint.significant_bit_count()); - ASSERT_TRUE("7FFFFFFFFFFFFFFF" == uint.to_string()); - ASSERT_EQ(static_cast(0x7FFFFFFFFFFFFFFF), *uint.data()); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[0]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[1]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[2]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[3]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[4]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[5]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[6]); - ASSERT_TRUE(SEAL_BYTE(0x7F) == uint[7]); - ASSERT_FALSE(uint.is_zero()); - - uint = "FFFFFFFFFFFFFFFF"; - ASSERT_EQ(64, uint.significant_bit_count()); - ASSERT_TRUE("FFFFFFFFFFFFFFFF" == uint.to_string()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), *uint.data()); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[0]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[1]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[2]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[3]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[4]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[5]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[6]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[7]); - ASSERT_FALSE(uint.is_zero()); - - uint = 0x8001; - ASSERT_EQ(16, uint.significant_bit_count()); - ASSERT_TRUE("8001" == uint.to_string()); - ASSERT_EQ(static_cast(0x8001), *uint.data()); - ASSERT_TRUE(SEAL_BYTE(0x01) == uint[0]); - ASSERT_TRUE(SEAL_BYTE(0x80) == uint[1]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[2]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[3]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[4]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[5]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[6]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[7]); - } - - TEST(BigUnsignedInt, BigUInt99Bits) - { - BigUInt uint(99); - ASSERT_EQ(99, uint.bit_count()); - ASSERT_TRUE(nullptr != uint.data()); - ASSERT_EQ(13ULL, uint.byte_count()); - ASSERT_EQ(2ULL, uint.uint64_count()); - ASSERT_EQ(0, uint.significant_bit_count()); - ASSERT_TRUE("0" == uint.to_string()); - ASSERT_TRUE(uint.is_zero()); - ASSERT_EQ(static_cast(0), uint.data()[0]); - ASSERT_EQ(static_cast(0), uint.data()[1]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[0]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[1]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[2]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[3]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[4]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[5]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[6]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[7]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[8]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[9]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[10]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[11]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[12]); - - uint = "1"; - ASSERT_EQ(1, uint.significant_bit_count()); - ASSERT_TRUE("1" == uint.to_string()); - ASSERT_FALSE(uint.is_zero()); - ASSERT_EQ(1ULL, uint.data()[0]); - ASSERT_EQ(static_cast(0), uint.data()[1]); - ASSERT_TRUE(SEAL_BYTE(1) == uint[0]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[1]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[2]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[3]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[4]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[5]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[6]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[7]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[8]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[9]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[10]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[11]); - ASSERT_TRUE(SEAL_BYTE(0) == uint[12]); - uint.set_zero(); - ASSERT_TRUE(uint.is_zero()); - ASSERT_EQ(static_cast(0), uint.data()[0]); - ASSERT_EQ(static_cast(0), uint.data()[1]); - - uint = "7FFFFFFFFFFFFFFFFFFFFFFFF"; - ASSERT_EQ(99, uint.significant_bit_count()); - ASSERT_TRUE("7FFFFFFFFFFFFFFFFFFFFFFFF" == uint.to_string()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), uint.data()[0]); - ASSERT_EQ(static_cast(0x7FFFFFFFF), uint.data()[1]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[0]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[1]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[2]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[3]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[4]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[5]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[6]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[7]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[8]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[9]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[10]); - ASSERT_TRUE(SEAL_BYTE(0xFF) == uint[11]); - ASSERT_TRUE(SEAL_BYTE(0x07) == uint[12]); - ASSERT_FALSE(uint.is_zero()); - uint.set_zero(); - ASSERT_TRUE(uint.is_zero()); - ASSERT_EQ(static_cast(0), uint.data()[0]); - ASSERT_EQ(static_cast(0), uint.data()[1]); - - uint = "4000000000000000000000000"; - ASSERT_EQ(99, uint.significant_bit_count()); - ASSERT_TRUE("4000000000000000000000000" == uint.to_string()); - ASSERT_EQ(static_cast(0x0000000000000000), uint.data()[0]); - ASSERT_EQ(static_cast(0x400000000), uint.data()[1]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[0]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[1]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[2]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[3]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[4]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[5]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[6]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[7]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[8]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[9]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[10]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[11]); - ASSERT_TRUE(SEAL_BYTE(0x04) == uint[12]); - ASSERT_FALSE(uint.is_zero()); - - uint = 0x8001; - ASSERT_EQ(16, uint.significant_bit_count()); - ASSERT_TRUE("8001" == uint.to_string()); - ASSERT_EQ(static_cast(0x8001), uint.data()[0]); - ASSERT_EQ(static_cast(0), uint.data()[1]); - ASSERT_TRUE(SEAL_BYTE(0x01) == uint[0]); - ASSERT_TRUE(SEAL_BYTE(0x80) == uint[1]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[2]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[3]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[4]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[5]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[6]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[7]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[8]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[9]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[10]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[11]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[12]); - - BigUInt uint2("123"); - ASSERT_FALSE(uint == uint2); - ASSERT_FALSE(uint2 == uint); - ASSERT_TRUE(uint != uint2); - ASSERT_TRUE(uint2 != uint); - - uint = uint2; - ASSERT_TRUE(uint == uint2); - ASSERT_FALSE(uint != uint2); - ASSERT_EQ(9, uint.significant_bit_count()); - ASSERT_TRUE("123" == uint.to_string()); - ASSERT_EQ(static_cast(0x123), uint.data()[0]); - ASSERT_EQ(static_cast(0), uint.data()[1]); - ASSERT_TRUE(SEAL_BYTE(0x23) == uint[0]); - ASSERT_TRUE(SEAL_BYTE(0x01) == uint[1]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[2]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[3]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[4]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[5]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[6]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[7]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[8]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[9]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[10]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[11]); - ASSERT_TRUE(SEAL_BYTE(0x00) == uint[12]); - - uint.resize(8); - ASSERT_EQ(8, uint.bit_count()); - ASSERT_EQ(1ULL, uint.uint64_count()); - ASSERT_TRUE("23" == uint.to_string()); - - uint.resize(100); - ASSERT_EQ(100, uint.bit_count()); - ASSERT_EQ(2ULL, uint.uint64_count()); - ASSERT_TRUE("23" == uint.to_string()); - - uint.resize(0); - ASSERT_EQ(0, uint.bit_count()); - ASSERT_EQ(0ULL, uint.uint64_count()); - ASSERT_TRUE(nullptr == uint.data()); - } - - TEST(BigUnsignedInt, SaveLoadUInt) - { - stringstream stream; - - BigUInt value; - BigUInt value2("100"); - value.save(stream); - value2.load(stream); - ASSERT_TRUE(value == value2); - - value = "123"; - value.save(stream); - value2.load(stream); - ASSERT_TRUE(value == value2); - - value = "FFFFFFFFFFFFFFFFFFFFFFFFFF"; - value.save(stream); - value2.load(stream); - ASSERT_TRUE(value == value2); - - value = "0"; - value.save(stream); - value2.load(stream); - ASSERT_TRUE(value == value2); - } - - TEST(BigUnsignedInt, DuplicateTo) - { - BigUInt original(123); - original = 56789; - - BigUInt target; - - original.duplicate_to(target); - ASSERT_EQ(target.bit_count(), original.bit_count()); - ASSERT_TRUE(target == original); - } - - TEST(BigUnsignedInt, DuplicateFrom) - { - BigUInt original(123); - original = 56789; - - BigUInt target; - - target.duplicate_from(original); - ASSERT_EQ(target.bit_count(), original.bit_count()); - ASSERT_TRUE(target == original); - } - - TEST(BigUnsignedInt, BigUIntCopyMoveAssign) - { - { - BigUInt p1("123"); - BigUInt p2("456"); - BigUInt p3; - - p1.operator =(p2); - p3.operator =(p1); - ASSERT_TRUE(p1 == p2); - ASSERT_TRUE(p3 == p1); - } - { - BigUInt p1("123"); - BigUInt p2("456"); - BigUInt p3; - BigUInt p4(p2); - - p1.operator =(move(p2)); - p3.operator =(move(p1)); - ASSERT_TRUE(p3 == p4); - ASSERT_TRUE(p1 == p2); - ASSERT_TRUE(p3 == p1); - } - { - uint64_t p1_anchor = 123; - uint64_t p2_anchor = 456; - BigUInt p1(64, &p1_anchor); - BigUInt p2(64, &p2_anchor); - BigUInt p3; - - p1.operator =(p2); - p3.operator =(p1); - ASSERT_TRUE(p1 == p2); - ASSERT_TRUE(p3 == p1); - } - { - uint64_t p1_anchor = 123; - uint64_t p2_anchor = 456; - BigUInt p1(64, &p1_anchor); - BigUInt p2(64, &p2_anchor); - BigUInt p3; - BigUInt p4(p2); - - p1.operator =(move(p2)); - p3.operator =(move(p1)); - ASSERT_TRUE(p3 == p4); - ASSERT_TRUE(p2 == 456); - ASSERT_TRUE(p1 == 456); - ASSERT_TRUE(p3 == 456); - } - } -} diff --git a/SEAL/native/tests/seal/ciphertext.cpp b/SEAL/native/tests/seal/ciphertext.cpp deleted file mode 100644 index adf4a5b..0000000 --- a/SEAL/native/tests/seal/ciphertext.cpp +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/ciphertext.h" -#include "seal/context.h" -#include "seal/keygenerator.h" -#include "seal/encryptor.h" -#include "seal/memorymanager.h" -#include "seal/modulus.h" - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - TEST(CiphertextTest, CiphertextBasics) - { - EncryptionParameters parms(scheme_type::BFV); - - parms.set_poly_modulus_degree(2); - parms.set_coeff_modulus(CoeffModulus::Create(2, { 30 })); - parms.set_plain_modulus(2); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - - Ciphertext ctxt(context); - ctxt.reserve(10); - ASSERT_EQ(0ULL, ctxt.size()); - ASSERT_EQ(0ULL, ctxt.uint64_count()); - ASSERT_EQ(10ULL * 2 * 1, ctxt.uint64_count_capacity()); - ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); - ASSERT_TRUE(ctxt.parms_id() == context->first_parms_id()); - ASSERT_FALSE(ctxt.is_ntt_form()); - const uint64_t *ptr = ctxt.data(); - - ctxt.reserve(5); - ASSERT_EQ(0ULL, ctxt.size()); - ASSERT_EQ(0ULL, ctxt.uint64_count()); - ASSERT_EQ(5ULL * 2 * 1, ctxt.uint64_count_capacity()); - ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); - ASSERT_TRUE(ptr != ctxt.data()); - ASSERT_TRUE(ctxt.parms_id() == context->first_parms_id()); - ptr = ctxt.data(); - - ctxt.reserve(10); - ASSERT_EQ(0ULL, ctxt.size()); - ASSERT_EQ(0ULL, ctxt.uint64_count()); - ASSERT_EQ(10ULL * 2 * 1, ctxt.uint64_count_capacity()); - ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); - ASSERT_TRUE(ptr != ctxt.data()); - ASSERT_TRUE(ctxt.parms_id() == context->first_parms_id()); - ASSERT_FALSE(ctxt.is_ntt_form()); - ptr = ctxt.data(); - - ctxt.reserve(2); - ASSERT_EQ(0ULL, ctxt.size()); - ASSERT_EQ(2ULL * 2 * 1, ctxt.uint64_count_capacity()); - ASSERT_EQ(0ULL, ctxt.uint64_count()); - ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); - ASSERT_TRUE(ptr != ctxt.data()); - ASSERT_TRUE(ctxt.parms_id() == context->first_parms_id()); - ASSERT_FALSE(ctxt.is_ntt_form()); - ptr = ctxt.data(); - - ctxt.reserve(5); - ASSERT_EQ(0ULL, ctxt.size()); - ASSERT_EQ(5ULL * 2 * 1, ctxt.uint64_count_capacity()); - ASSERT_EQ(0ULL, ctxt.uint64_count()); - ASSERT_EQ(2ULL, ctxt.poly_modulus_degree()); - ASSERT_TRUE(ptr != ctxt.data()); - ASSERT_TRUE(ctxt.parms_id() == context->first_parms_id()); - ASSERT_FALSE(ctxt.is_ntt_form()); - - Ciphertext ctxt2{ ctxt }; - ASSERT_EQ(ctxt.coeff_mod_count(), ctxt2.coeff_mod_count()); - ASSERT_EQ(ctxt.is_ntt_form(), ctxt2.is_ntt_form()); - ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt2.poly_modulus_degree()); - ASSERT_TRUE(ctxt.parms_id() == ctxt2.parms_id()); - ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt2.poly_modulus_degree()); - ASSERT_EQ(ctxt.size(), ctxt2.size()); - - Ciphertext ctxt3; - ctxt3 = ctxt; - ASSERT_EQ(ctxt.coeff_mod_count(), ctxt3.coeff_mod_count()); - ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt3.poly_modulus_degree()); - ASSERT_EQ(ctxt.is_ntt_form(), ctxt3.is_ntt_form()); - ASSERT_TRUE(ctxt.parms_id() == ctxt3.parms_id()); - ASSERT_EQ(ctxt.poly_modulus_degree(), ctxt3.poly_modulus_degree()); - ASSERT_EQ(ctxt.size(), ctxt3.size()); - } - - TEST(CiphertextTest, SaveLoadCiphertext) - { - stringstream stream; - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(2); - parms.set_coeff_modulus(CoeffModulus::Create(2, { 30 })); - parms.set_plain_modulus(2); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - - Ciphertext ctxt(context); - Ciphertext ctxt2; - ctxt.save(stream); - ctxt2.load(context, stream); - ASSERT_TRUE(ctxt.parms_id() == ctxt2.parms_id()); - ASSERT_FALSE(ctxt.is_ntt_form()); - ASSERT_FALSE(ctxt2.is_ntt_form()); - - parms.set_poly_modulus_degree(1024); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(1024)); - parms.set_plain_modulus(0xF0F0); - context = SEALContext::Create(parms, false); - KeyGenerator keygen(context); - Encryptor encryptor(context, keygen.public_key()); - encryptor.encrypt(Plaintext("Ax^10 + 9x^9 + 8x^8 + 7x^7 + 6x^6 + 5x^5 + 4x^4 + 3x^3 + 2x^2 + 1"), ctxt); - ctxt.save(stream); - ctxt2.load(context, stream); - ASSERT_TRUE(ctxt.parms_id() == ctxt2.parms_id()); - ASSERT_FALSE(ctxt.is_ntt_form()); - ASSERT_FALSE(ctxt2.is_ntt_form()); - ASSERT_TRUE(is_equal_uint_uint(ctxt.data(), ctxt2.data(), - parms.poly_modulus_degree() * parms.coeff_modulus().size() * 2)); - ASSERT_TRUE(ctxt.data() != ctxt2.data()); - } -} diff --git a/SEAL/native/tests/seal/ckks.cpp b/SEAL/native/tests/seal/ckks.cpp deleted file mode 100644 index d06b5d1..0000000 --- a/SEAL/native/tests/seal/ckks.cpp +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/ckks.h" -#include "seal/context.h" -#include "seal/keygenerator.h" -#include "seal/modulus.h" -#include -#include - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - TEST(CKKSEncoderTest, CKKSEncoderEncodeVectorDecodeTest) - { - EncryptionParameters parms(scheme_type::CKKS); - { - uint32_t slots = 32; - parms.set_poly_modulus_degree(2 * slots); - parms.set_coeff_modulus(CoeffModulus::Create(2 * slots, { 40, 40, 40, 40 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - - std::vector> values(slots); - - for (size_t i = 0; i < slots; i++) - { - std::complex value(0.0, 0.0); - values[i] = value; - } - - CKKSEncoder encoder(context); - double delta = (1ULL << 16); - Plaintext plain; - encoder.encode(values, context->first_parms_id(), delta, plain); - std::vector> result; - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(values[i].real() - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - uint32_t slots = 32; - parms.set_poly_modulus_degree(2 * slots); - parms.set_coeff_modulus(CoeffModulus::Create(2 * slots, { 60, 60, 60, 60 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - - std::vector> values(slots); - - srand(static_cast(time(NULL))); - int data_bound = (1 << 30); - - for (size_t i = 0; i < slots; i++) - { - std::complex value(static_cast(rand() % data_bound), 0); - values[i] = value; - } - - CKKSEncoder encoder(context); - double delta = (1ULL << 40); - Plaintext plain; - encoder.encode(values, context->first_parms_id(), delta, plain); - std::vector> result; - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(values[i].real() - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - uint32_t slots = 64; - parms.set_poly_modulus_degree(2 * slots); - parms.set_coeff_modulus(CoeffModulus::Create(2 * slots, { 60, 60, 60 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - - std::vector> values(slots); - - srand(static_cast(time(NULL))); - int data_bound = (1 << 30); - - for (size_t i = 0; i < slots; i++) - { - std::complex value(static_cast(rand() % data_bound), 0); - values[i] = value; - } - - CKKSEncoder encoder(context); - double delta = (1ULL << 40); - Plaintext plain; - encoder.encode(values, context->first_parms_id(), delta, plain); - std::vector> result; - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(values[i].real() - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - uint32_t slots = 64; - parms.set_poly_modulus_degree(2 * slots); - parms.set_coeff_modulus(CoeffModulus::Create(2 * slots, { 30, 30, 30, 30, 30 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - - std::vector> values(slots); - - srand(static_cast(time(NULL))); - int data_bound = (1 << 30); - - for (size_t i = 0; i < slots; i++) - { - std::complex value(static_cast(rand() % data_bound), 0); - values[i] = value; - } - - CKKSEncoder encoder(context); - double delta = (1ULL << 40); - Plaintext plain; - encoder.encode(values, context->first_parms_id(), delta, plain); - std::vector> result; - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(values[i].real() - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - uint32_t slots = 32; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30, 30 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - - std::vector> values(slots); - - srand(static_cast(time(NULL))); - int data_bound = (1 << 30); - - for (size_t i = 0; i < slots; i++) - { - std::complex value(static_cast(rand() % data_bound), 0); - values[i] = value; - } - - CKKSEncoder encoder(context); - double delta = (1ULL << 40); - Plaintext plain; - encoder.encode(values, context->first_parms_id(), delta, plain); - std::vector> result; - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(values[i].real() - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - // Many primes - uint32_t slots = 32; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { - 30, 30, 30, 30, 30, 30, - 30, 30, 30, 30, 30, 30, - 30, 30, 30, 30, 30, 30, 30 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - - std::vector> values(slots); - - srand(static_cast(time(NULL))); - int data_bound = (1 << 30); - - for (size_t i = 0; i < slots; i++) - { - std::complex value(static_cast(rand() % data_bound), 0); - values[i] = value; - } - - CKKSEncoder encoder(context); - double delta = (1ULL << 40); - Plaintext plain; - encoder.encode(values, context->first_parms_id(), delta, plain); - std::vector> result; - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(values[i].real() - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - uint32_t slots = 64; - parms.set_poly_modulus_degree(2 * slots); - parms.set_coeff_modulus(CoeffModulus::Create(2 * slots, { 40, 40, 40, 40, 40 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - - std::vector> values(slots); - - srand(static_cast(time(NULL))); - int data_bound = (1 << 20); - - for (size_t i = 0; i < slots; i++) - { - std::complex value(static_cast(rand() % data_bound), 0); - values[i] = value; - } - - CKKSEncoder encoder(context); - { - // Use a very large scale - double delta = pow(2.0, 110); - Plaintext plain; - encoder.encode(values, context->first_parms_id(), delta, plain); - std::vector> result; - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(values[i].real() - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - // Use a scale over 128 bits - double delta = pow(2.0, 130); - Plaintext plain; - encoder.encode(values, context->first_parms_id(), delta, plain); - std::vector> result; - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(values[i].real() - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - } - - TEST(CKKSEncoderTest, CKKSEncoderEncodeSingleDecodeTest) - { - EncryptionParameters parms(scheme_type::CKKS); - { - uint32_t slots = 16; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40, 40 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - CKKSEncoder encoder(context); - - srand(static_cast(time(NULL))); - int data_bound = (1 << 30); - double delta = (1ULL << 16); - Plaintext plain; - std::vector> result; - - for (int iRun = 0; iRun < 50; iRun++) - { - double value = static_cast(rand() % data_bound); - encoder.encode(value, context->first_parms_id(), delta, plain); - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(value - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - uint32_t slots = 32; - parms.set_poly_modulus_degree(slots * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slots * 2, { 40, 40, 40, 40 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - CKKSEncoder encoder(context); - - srand(static_cast(time(NULL))); - { - int data_bound = (1 << 30); - Plaintext plain; - std::vector> result; - - for (int iRun = 0; iRun < 50; iRun++) - { - int value = static_cast(rand() % data_bound); - encoder.encode(value, context->first_parms_id(), plain); - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(value - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - // Use a very large scale - int data_bound = (1 << 20); - Plaintext plain; - std::vector> result; - - for (int iRun = 0; iRun < 50; iRun++) - { - int value = static_cast(rand() % data_bound); - encoder.encode(value, context->first_parms_id(), plain); - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(value - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - // Use a scale over 128 bits - int data_bound = (1 << 20); - Plaintext plain; - std::vector> result; - - for (int iRun = 0; iRun < 50; iRun++) - { - int value = static_cast(rand() % data_bound); - encoder.encode(value, context->first_parms_id(), plain); - encoder.decode(plain, result); - - for (size_t i = 0; i < slots; ++i) - { - auto tmp = abs(value - result[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - } - } -} diff --git a/SEAL/native/tests/seal/context.cpp b/SEAL/native/tests/seal/context.cpp deleted file mode 100644 index 7bf3f60..0000000 --- a/SEAL/native/tests/seal/context.cpp +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/context.h" -#include "seal/modulus.h" - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - TEST(ContextTest, ContextConstructor) - { - // Nothing set - auto scheme = scheme_type::BFV; - EncryptionParameters parms(scheme); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_FALSE(qualifiers.parameters_set); - ASSERT_FALSE(qualifiers.using_fft); - ASSERT_FALSE(qualifiers.using_ntt); - ASSERT_FALSE(qualifiers.using_batching); - ASSERT_FALSE(qualifiers.using_fast_plain_lift); - ASSERT_FALSE(qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_FALSE(context->using_keyswitching()); - } - - // Not relatively prime coeff moduli - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 2, 30 }); - parms.set_plain_modulus(2); - parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_FALSE(qualifiers.parameters_set); - ASSERT_FALSE(qualifiers.using_fft); - ASSERT_FALSE(qualifiers.using_ntt); - ASSERT_FALSE(qualifiers.using_batching); - ASSERT_FALSE(qualifiers.using_fast_plain_lift); - ASSERT_FALSE(qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_FALSE(context->using_keyswitching()); - } - - // Plain modulus not relatively prime to coeff moduli - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 17, 41 }); - parms.set_plain_modulus(34); - parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_FALSE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_TRUE(qualifiers.using_ntt); - ASSERT_FALSE(qualifiers.using_batching); - ASSERT_FALSE(qualifiers.using_fast_plain_lift); - ASSERT_FALSE(qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_FALSE(context->using_keyswitching()); - } - - // Plain modulus not smaller than product of coeff moduli - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 2 }); - parms.set_plain_modulus(3); - parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_EQ(2ULL, *context->first_context_data()->total_coeff_modulus()); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_FALSE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_FALSE(qualifiers.using_ntt); - ASSERT_FALSE(qualifiers.using_batching); - ASSERT_FALSE(qualifiers.using_fast_plain_lift); - ASSERT_FALSE(qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_FALSE(context->using_keyswitching()); - } - - // FFT poly but not NTT modulus - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 3 }); - parms.set_plain_modulus(2); - parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_EQ(3ULL, *context->first_context_data()->total_coeff_modulus()); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_FALSE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_FALSE(qualifiers.using_ntt); - ASSERT_FALSE(qualifiers.using_batching); - ASSERT_FALSE(qualifiers.using_fast_plain_lift); - ASSERT_FALSE(qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_FALSE(context->using_keyswitching()); - } - - // Parameters OK; no fast plain lift - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 17, 41 }); - parms.set_plain_modulus(18); - parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_EQ(697ULL, *context->first_context_data()->total_coeff_modulus()); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_TRUE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_TRUE(qualifiers.using_ntt); - ASSERT_FALSE(qualifiers.using_batching); - ASSERT_FALSE(qualifiers.using_fast_plain_lift); - ASSERT_FALSE(qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_FALSE(context->using_keyswitching()); - } - - // Parameters OK; fast plain lift - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 17, 41 }); - parms.set_plain_modulus(16); - parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_EQ(17ULL, *context->first_context_data()->total_coeff_modulus()); - ASSERT_EQ(697ULL, *context->key_context_data()->total_coeff_modulus()); - auto qualifiers = context->first_context_data()->qualifiers(); - auto key_qualifiers = context->key_context_data()->qualifiers(); - ASSERT_TRUE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_TRUE(qualifiers.using_ntt); - ASSERT_FALSE(qualifiers.using_batching); - ASSERT_TRUE(qualifiers.using_fast_plain_lift); - ASSERT_FALSE(key_qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_TRUE(context->using_keyswitching()); - } - - // Parameters OK; no batching due to non-prime plain modulus - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 17, 41 }); - parms.set_plain_modulus(49); - parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_EQ(697ULL, *context->first_context_data()->total_coeff_modulus()); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_TRUE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_TRUE(qualifiers.using_ntt); - ASSERT_FALSE(qualifiers.using_batching); - ASSERT_FALSE(qualifiers.using_fast_plain_lift); - ASSERT_FALSE(qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_FALSE(context->using_keyswitching()); - } - - // Parameters OK; batching enabled - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 17, 41 }); - parms.set_plain_modulus(73); - parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_EQ(697ULL, *context->first_context_data()->total_coeff_modulus()); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_TRUE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_TRUE(qualifiers.using_ntt); - ASSERT_TRUE(qualifiers.using_batching); - ASSERT_FALSE(qualifiers.using_fast_plain_lift); - ASSERT_FALSE(qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_FALSE(context->using_keyswitching()); - } - - // Parameters OK; batching and fast plain lift enabled - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 137, 193 }); - parms.set_plain_modulus(73); - parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_EQ(137ULL, *context->first_context_data()->total_coeff_modulus()); - ASSERT_EQ(26441ULL, *context->key_context_data()->total_coeff_modulus()); - auto qualifiers = context->first_context_data()->qualifiers(); - auto key_qualifiers = context->key_context_data()->qualifiers(); - ASSERT_TRUE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_TRUE(qualifiers.using_ntt); - ASSERT_TRUE(qualifiers.using_batching); - ASSERT_TRUE(qualifiers.using_fast_plain_lift); - ASSERT_FALSE(key_qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_TRUE(context->using_keyswitching()); - } - - // Parameters OK; batching and fast plain lift enabled; nullptr RNG - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 137, 193 }); - parms.set_plain_modulus(73); - parms.set_random_generator(nullptr); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_EQ(137ULL, *context->first_context_data()->total_coeff_modulus()); - ASSERT_EQ(26441ULL, *context->key_context_data()->total_coeff_modulus()); - auto qualifiers = context->first_context_data()->qualifiers(); - auto key_qualifiers = context->key_context_data()->qualifiers(); - ASSERT_TRUE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_TRUE(qualifiers.using_ntt); - ASSERT_TRUE(qualifiers.using_batching); - ASSERT_TRUE(qualifiers.using_fast_plain_lift); - ASSERT_FALSE(key_qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_TRUE(context->using_keyswitching()); - } - - // Parameters not OK due to too small poly_modulus_degree and enforce_hes - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 137, 193 }); - parms.set_plain_modulus(73); - parms.set_random_generator(nullptr); - { - auto context = SEALContext::Create(parms, false, sec_level_type::tc128); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_FALSE(qualifiers.parameters_set); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_FALSE(context->using_keyswitching()); - } - - // Parameters not OK due to too large coeff_modulus and enforce_hes - parms.set_poly_modulus_degree(2048); - parms.set_coeff_modulus(CoeffModulus::BFVDefault(4096, sec_level_type::tc128)); - parms.set_plain_modulus(73); - parms.set_random_generator(nullptr); - { - auto context = SEALContext::Create(parms, false, sec_level_type::tc128); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_FALSE(qualifiers.parameters_set); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_FALSE(context->using_keyswitching()); - } - - // Parameters OK; descending modulus chain - parms.set_poly_modulus_degree(4096); - parms.set_coeff_modulus({ 0xffffee001, 0xffffc4001 }); - parms.set_plain_modulus(73); - { - auto context = SEALContext::Create(parms, false, sec_level_type::tc128); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_TRUE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_TRUE(qualifiers.using_ntt); - ASSERT_FALSE(qualifiers.using_batching); - ASSERT_TRUE(qualifiers.using_fast_plain_lift); - ASSERT_TRUE(qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::tc128, qualifiers.sec_level); - ASSERT_TRUE(context->using_keyswitching()); - } - - // Parameters OK; no standard security - parms.set_poly_modulus_degree(2048); - parms.set_coeff_modulus({ 0x1ffffe0001, 0xffffee001, 0xffffc4001 }); - parms.set_plain_modulus(73); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - auto qualifiers = context->first_context_data()->qualifiers(); - auto key_qualifiers = context->key_context_data()->qualifiers(); - ASSERT_TRUE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_TRUE(qualifiers.using_ntt); - ASSERT_FALSE(qualifiers.using_batching); - ASSERT_TRUE(qualifiers.using_fast_plain_lift); - ASSERT_TRUE(key_qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_TRUE(context->using_keyswitching()); - } - - // Parameters OK; using batching; no keyswitching - parms.set_poly_modulus_degree(2048); - parms.set_coeff_modulus(CoeffModulus::Create(2048, { 40 })); - parms.set_plain_modulus(65537); - { - auto context = SEALContext::Create(parms, false, sec_level_type::none); - auto qualifiers = context->first_context_data()->qualifiers(); - ASSERT_TRUE(qualifiers.parameters_set); - ASSERT_TRUE(qualifiers.using_fft); - ASSERT_TRUE(qualifiers.using_ntt); - ASSERT_TRUE(qualifiers.using_batching); - ASSERT_TRUE(qualifiers.using_fast_plain_lift); - ASSERT_TRUE(qualifiers.using_descending_modulus_chain); - ASSERT_EQ(sec_level_type::none, qualifiers.sec_level); - ASSERT_FALSE(context->using_keyswitching()); - } - } - - TEST(ContextTest, ModulusChainExpansion) - { - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 41, 137, 193, 65537 }); - parms.set_plain_modulus(73); - auto context = SEALContext::Create(parms, true, sec_level_type::none); - auto context_data = context->key_context_data(); - ASSERT_EQ(size_t(2), context_data->chain_index()); - ASSERT_EQ(71047416497ULL, *context_data->total_coeff_modulus()); - ASSERT_FALSE(!!context_data->prev_context_data()); - ASSERT_EQ(context_data->parms_id(), context->key_parms_id()); - auto prev_context_data = context_data; - context_data = context_data->next_context_data(); - ASSERT_EQ(size_t(1), context_data->chain_index()); - ASSERT_EQ(1084081ULL, *context_data->total_coeff_modulus()); - ASSERT_EQ(context_data->prev_context_data()->parms_id(), - prev_context_data->parms_id()); - prev_context_data = context_data; - context_data = context_data->next_context_data(); - ASSERT_EQ(size_t(0), context_data->chain_index()); - ASSERT_EQ(5617ULL, *context_data->total_coeff_modulus()); - ASSERT_EQ(context_data->prev_context_data()->parms_id(), - prev_context_data->parms_id()); - ASSERT_FALSE(!!context_data->next_context_data()); - ASSERT_EQ(context_data->parms_id(), context->last_parms_id()); - - context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_EQ(size_t(1), context->key_context_data()->chain_index()); - ASSERT_EQ(size_t(0), context->first_context_data()->chain_index()); - ASSERT_EQ(71047416497ULL, *context->key_context_data()->total_coeff_modulus()); - ASSERT_EQ(1084081ULL, *context->first_context_data()->total_coeff_modulus()); - ASSERT_FALSE(!!context->first_context_data()->next_context_data()); - ASSERT_TRUE(!!context->first_context_data()->prev_context_data()); - } - { - EncryptionParameters parms(scheme_type::CKKS); - parms.set_poly_modulus_degree(4); - parms.set_coeff_modulus({ 41, 137, 193, 65537 }); - auto context = SEALContext::Create(parms, true, sec_level_type::none); - auto context_data = context->key_context_data(); - ASSERT_EQ(size_t(3), context_data->chain_index()); - ASSERT_EQ(71047416497ULL, *context_data->total_coeff_modulus()); - ASSERT_FALSE(!!context_data->prev_context_data()); - ASSERT_EQ(context_data->parms_id(), context->key_parms_id()); - auto prev_context_data = context_data; - context_data = context_data->next_context_data(); - ASSERT_EQ(size_t(2), context_data->chain_index()); - ASSERT_EQ(1084081ULL, *context_data->total_coeff_modulus()); - ASSERT_EQ(context_data->prev_context_data()->parms_id(), - prev_context_data->parms_id()); - prev_context_data = context_data; - context_data = context_data->next_context_data(); - ASSERT_EQ(size_t(1), context_data->chain_index()); - ASSERT_EQ(5617ULL, *context_data->total_coeff_modulus()); - ASSERT_EQ(context_data->prev_context_data()->parms_id(), - prev_context_data->parms_id()); - prev_context_data = context_data; - context_data = context_data->next_context_data(); - ASSERT_EQ(size_t(0), context_data->chain_index()); - ASSERT_EQ(41ULL, *context_data->total_coeff_modulus()); - ASSERT_EQ(context_data->prev_context_data()->parms_id(), - prev_context_data->parms_id()); - ASSERT_FALSE(!!context_data->next_context_data()); - ASSERT_EQ(context_data->parms_id(), context->last_parms_id()); - - context = SEALContext::Create(parms, false, sec_level_type::none); - ASSERT_EQ(size_t(1), context->key_context_data()->chain_index()); - ASSERT_EQ(size_t(0), context->first_context_data()->chain_index()); - ASSERT_EQ(71047416497ULL, *context->key_context_data()->total_coeff_modulus()); - ASSERT_EQ(1084081ULL, *context->first_context_data()->total_coeff_modulus()); - ASSERT_FALSE(!!context->first_context_data()->next_context_data()); - ASSERT_TRUE(!!context->first_context_data()->prev_context_data()); - } - } -} diff --git a/SEAL/native/tests/seal/encryptionparams.cpp b/SEAL/native/tests/seal/encryptionparams.cpp deleted file mode 100644 index 887a671..0000000 --- a/SEAL/native/tests/seal/encryptionparams.cpp +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/encryptionparams.h" -#include "seal/modulus.h" -#include "seal/util/numth.h" - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - TEST(EncryptionParametersTest, EncryptionParametersSet) - { - auto encryption_parameters_test = [](scheme_type scheme) - { - EncryptionParameters parms(scheme); - parms.set_coeff_modulus({ 2, 3 }); - if (scheme == scheme_type::BFV) - parms.set_plain_modulus(2); - parms.set_poly_modulus_degree(2); - parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - - ASSERT_TRUE(scheme == parms.scheme()); - ASSERT_TRUE(parms.coeff_modulus()[0] == 2); - ASSERT_TRUE(parms.coeff_modulus()[1] == 3); - if (scheme == scheme_type::BFV) - { - ASSERT_TRUE(parms.plain_modulus().value() == 2); - } - else if (scheme == scheme_type::CKKS) - { - ASSERT_TRUE(parms.plain_modulus().value() == 0); - } - ASSERT_TRUE(parms.poly_modulus_degree() == 2); - ASSERT_TRUE(parms.random_generator() == UniformRandomGeneratorFactory::default_factory()); - - parms.set_coeff_modulus(CoeffModulus::Create(2, { 30, 40, 50 })); - if (scheme == scheme_type::BFV) - parms.set_plain_modulus(2); - parms.set_poly_modulus_degree(128); - parms.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - - ASSERT_TRUE(util::is_prime(parms.coeff_modulus()[0])); - ASSERT_TRUE(util::is_prime(parms.coeff_modulus()[1])); - ASSERT_TRUE(util::is_prime(parms.coeff_modulus()[2])); - - if (scheme == scheme_type::BFV) - { - ASSERT_TRUE(parms.plain_modulus().value() == 2); - } - else if (scheme == scheme_type::CKKS) - { - ASSERT_TRUE(parms.plain_modulus().value() == 0); - } - ASSERT_TRUE(parms.poly_modulus_degree() == 128); - ASSERT_TRUE(parms.random_generator() == UniformRandomGeneratorFactory::default_factory()); - }; - encryption_parameters_test(scheme_type::BFV); - encryption_parameters_test(scheme_type::CKKS); - } - - TEST(EncryptionParametersTest, EncryptionParametersCompare) - { - auto scheme = scheme_type::BFV; - EncryptionParameters parms1(scheme); - parms1.set_coeff_modulus(CoeffModulus::Create(64, { 30 })); - if (scheme == scheme_type::BFV) - parms1.set_plain_modulus(1 << 6); - parms1.set_poly_modulus_degree(64); - parms1.set_random_generator(UniformRandomGeneratorFactory::default_factory()); - - EncryptionParameters parms2(parms1); - ASSERT_TRUE(parms1 == parms2); - - EncryptionParameters parms3(scheme); - parms3 = parms2; - ASSERT_TRUE(parms3 == parms2); - parms3.set_coeff_modulus(CoeffModulus::Create(64, { 32 })); - ASSERT_FALSE(parms3 == parms2); - - parms3 = parms2; - ASSERT_TRUE(parms3 == parms2); - parms3.set_coeff_modulus(CoeffModulus::Create(64, { 30, 30 })); - ASSERT_FALSE(parms3 == parms2); - - parms3 = parms2; - parms3.set_poly_modulus_degree(128); - ASSERT_FALSE(parms3 == parms1); - - parms3 = parms2; - if (scheme == scheme_type::BFV) - parms3.set_plain_modulus((1 << 6) + 1); - ASSERT_FALSE(parms3 == parms2); - - parms3 = parms2; - ASSERT_TRUE(parms3 == parms2); - - parms3 = parms2; - parms3.set_random_generator(nullptr); - ASSERT_TRUE(parms3 == parms2); - - parms3 = parms2; - parms3.set_poly_modulus_degree(128); - parms3.set_poly_modulus_degree(64); - ASSERT_TRUE(parms3 == parms1); - - parms3 = parms2; - parms3.set_coeff_modulus({ 2 }); - parms3.set_coeff_modulus(CoeffModulus::Create(64, { 50 })); - parms3.set_coeff_modulus(parms2.coeff_modulus()); - ASSERT_TRUE(parms3 == parms2); - } - - TEST(EncryptionParametersTest, EncryptionParametersSaveLoad) - { - stringstream stream; - - auto scheme = scheme_type::BFV; - EncryptionParameters parms(scheme); - EncryptionParameters parms2(scheme); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 30 })); - if (scheme == scheme_type::BFV) - parms.set_plain_modulus(1 << 6); - parms.set_poly_modulus_degree(64); - EncryptionParameters::Save(parms, stream); - parms2 = EncryptionParameters::Load(stream); - ASSERT_TRUE(parms.scheme() == parms2.scheme()); - ASSERT_TRUE(parms.coeff_modulus() == parms2.coeff_modulus()); - ASSERT_TRUE(parms.plain_modulus() == parms2.plain_modulus()); - ASSERT_TRUE(parms.poly_modulus_degree() == parms2.poly_modulus_degree()); - ASSERT_TRUE(parms == parms2); - - parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 60, 60 })); - - if (scheme == scheme_type::BFV) - parms.set_plain_modulus(1 << 30); - parms.set_poly_modulus_degree(256); - - EncryptionParameters::Save(parms, stream); - parms2 = EncryptionParameters::Load(stream); - ASSERT_TRUE(parms.scheme() == parms2.scheme()); - ASSERT_TRUE(parms.coeff_modulus() == parms2.coeff_modulus()); - ASSERT_TRUE(parms.plain_modulus() == parms2.plain_modulus()); - ASSERT_TRUE(parms.poly_modulus_degree() == parms2.poly_modulus_degree()); - ASSERT_TRUE(parms == parms2); - } -} diff --git a/SEAL/native/tests/seal/encryptor.cpp b/SEAL/native/tests/seal/encryptor.cpp deleted file mode 100644 index 9e5f0d9..0000000 --- a/SEAL/native/tests/seal/encryptor.cpp +++ /dev/null @@ -1,482 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/context.h" -#include "seal/encryptor.h" -#include "seal/decryptor.h" -#include "seal/keygenerator.h" -#include "seal/batchencoder.h" -#include "seal/ckks.h" -#include "seal/intencoder.h" -#include "seal/modulus.h" -#include -#include -#include - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - TEST(EncryptorTest, BFVEncryptDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_plain_modulus(plain_modulus); - { - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - encryptor.encrypt(encoder.encode(0x12345678), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x12345678ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(1), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(1ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(2), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(2ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFD)), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x7FFFFFFFFFFFFFFDULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFE)), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x7FFFFFFFFFFFFFFEULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFF)), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x7FFFFFFFFFFFFFFFULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(314159265), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(314159265ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - } - { - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - encryptor.encrypt(encoder.encode(0x12345678), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x12345678ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(1), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(1ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(2), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(2ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFD)), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x7FFFFFFFFFFFFFFDULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFE)), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x7FFFFFFFFFFFFFFEULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFF)), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x7FFFFFFFFFFFFFFFULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(314159265), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(314159265ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - } - - { - parms.set_poly_modulus_degree(256); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 40, 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - encryptor.encrypt(encoder.encode(0x12345678), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x12345678ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(1), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(1ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(2), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(2ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFD)), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x7FFFFFFFFFFFFFFDULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFE)), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x7FFFFFFFFFFFFFFEULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(static_cast(0x7FFFFFFFFFFFFFFF)), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x7FFFFFFFFFFFFFFFULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(314159265), encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(314159265ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - } - } - - TEST(EncryptorTest, BFVEncryptZeroDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_plain_modulus(plain_modulus); - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40 })); - auto context = SEALContext::Create(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext ct; - encryptor.encrypt_zero(ct); - ASSERT_FALSE(ct.is_ntt_form()); - ASSERT_FALSE(ct.is_transparent()); - ASSERT_DOUBLE_EQ(ct.scale(), 1.0); - Plaintext pt; - decryptor.decrypt(ct, pt); - ASSERT_TRUE(pt.is_zero()); - - parms_id_type next_parms = context->first_context_data()->next_context_data()->parms_id(); - encryptor.encrypt_zero(next_parms, ct); - ASSERT_FALSE(ct.is_ntt_form()); - ASSERT_FALSE(ct.is_transparent()); - ASSERT_DOUBLE_EQ(ct.scale(), 1.0); - ASSERT_EQ(ct.parms_id(), next_parms); - decryptor.decrypt(ct, pt); - ASSERT_TRUE(pt.is_zero()); - } - - TEST(EncryptorTest, CKKSEncryptZeroDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - CKKSEncoder encoder(context); - - Ciphertext ct; - encryptor.encrypt_zero(ct); - ASSERT_FALSE(ct.is_transparent()); - ASSERT_TRUE(ct.is_ntt_form()); - ASSERT_DOUBLE_EQ(ct.scale(), 1.0); - ct.scale() = std::pow(2.0, 20); - Plaintext pt; - decryptor.decrypt(ct, pt); - - std::vector> res; - encoder.decode(pt, res); - for (auto val : res) - { - ASSERT_NEAR(val.real(), 0.0, 0.01); - ASSERT_NEAR(val.imag(), 0.0, 0.01); - } - - parms_id_type next_parms = context->first_context_data()->next_context_data()->parms_id(); - encryptor.encrypt_zero(next_parms, ct); - ASSERT_FALSE(ct.is_transparent()); - ASSERT_TRUE(ct.is_ntt_form()); - ASSERT_DOUBLE_EQ(ct.scale(), 1.0); - ct.scale() = std::pow(2.0, 20); - ASSERT_EQ(ct.parms_id(), next_parms); - decryptor.decrypt(ct, pt); - ASSERT_EQ(pt.parms_id(), next_parms); - - encoder.decode(pt, res); - for (auto val : res) - { - ASSERT_NEAR(val.real(), 0.0, 0.01); - ASSERT_NEAR(val.imag(), 0.0, 0.01); - } - } - - TEST(EncryptorTest, CKKSEncryptDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //input consists of ones - size_t slot_size = 32; - parms.set_poly_modulus_degree(2 * slot_size); - parms.set_coeff_modulus(CoeffModulus::Create(2 * slot_size, { 40, 40, 40, 40 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 1.0); - std::vector> output(slot_size); - const double delta = static_cast(1 << 16); - - encoder.encode(input, context->first_parms_id(), delta, plain); - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - //input consists of zeros - size_t slot_size = 32; - parms.set_poly_modulus_degree(2 * slot_size); - parms.set_coeff_modulus(CoeffModulus::Create(2 * slot_size, { 40, 40, 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 0.0); - std::vector> output(slot_size); - const double delta = static_cast(1 << 16); - - encoder.encode(input, context->first_parms_id(), delta, plain); - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - // Input is a random mix of positive and negative integers - size_t slot_size = 64; - parms.set_poly_modulus_degree(2 * slot_size); - parms.set_coeff_modulus(CoeffModulus::Create(2 * slot_size, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size); - std::vector> output(slot_size); - - srand(static_cast(time(NULL))); - int input_bound = 1 << 30; - const double delta = static_cast(1ULL << 50); - - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input[i] = pow(-1.0, rand() % 2) * static_cast(rand() % input_bound); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - // Input is a random mix of positive and negative integers - size_t slot_size = 32; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size); - std::vector> output(slot_size); - - srand(static_cast(time(NULL))); - int input_bound = 1 << 30; - const double delta = static_cast(1ULL << 60); - - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input[i] = pow(-1.0, rand() % 2) * static_cast(rand() % input_bound); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plain, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - // Encrypt at lower level - size_t slot_size = 32; - parms.set_poly_modulus_degree(2 * slot_size); - parms.set_coeff_modulus(CoeffModulus::Create(2 * slot_size, { 40, 40, 40, 40 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 1.0); - std::vector> output(slot_size); - const double delta = static_cast(1 << 16); - - auto first_context_data = context->first_context_data(); - ASSERT_NE(nullptr, first_context_data.get()); - auto second_context_data = first_context_data->next_context_data(); - ASSERT_NE(nullptr, second_context_data.get()); - auto second_parms_id = second_context_data->parms_id(); - - encoder.encode(input, second_parms_id, delta, plain); - encryptor.encrypt(plain, encrypted); - - // Check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == second_parms_id); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } -} diff --git a/SEAL/native/tests/seal/evaluator.cpp b/SEAL/native/tests/seal/evaluator.cpp deleted file mode 100644 index 517338b..0000000 --- a/SEAL/native/tests/seal/evaluator.cpp +++ /dev/null @@ -1,3967 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/context.h" -#include "seal/encryptor.h" -#include "seal/decryptor.h" -#include "seal/evaluator.h" -#include "seal/keygenerator.h" -#include "seal/batchencoder.h" -#include "seal/ckks.h" -#include "seal/intencoder.h" -#include "seal/modulus.h" -#include -#include -#include -#include - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - TEST(EvaluatorTest, BFVEncryptNegateDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - encryptor.encrypt(encoder.encode(0x12345678), encrypted); - evaluator.negate_inplace(encrypted); - Plaintext plain; - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(-0x12345678), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted); - evaluator.negate_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(0), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(1), encrypted); - evaluator.negate_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(-1), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-1), encrypted); - evaluator.negate_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(1), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(2), encrypted); - evaluator.negate_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(-2), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-5), encrypted); - evaluator.negate_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(5), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - } - - TEST(EvaluatorTest, BFVEncryptAddDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted1; - encryptor.encrypt(encoder.encode(0x12345678), encrypted1); - Ciphertext encrypted2; - encryptor.encrypt(encoder.encode(0x54321), encrypted2); - evaluator.add_inplace(encrypted1, encrypted2); - Plaintext plain; - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0x12399999), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - encryptor.encrypt(encoder.encode(0), encrypted2); - evaluator.add_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - encryptor.encrypt(encoder.encode(5), encrypted2); - evaluator.add_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(5), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(5), encrypted1); - encryptor.encrypt(encoder.encode(-3), encrypted2); - evaluator.add_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(2), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-7), encrypted1); - encryptor.encrypt(encoder.encode(2), encrypted2); - evaluator.add_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(-5), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - Plaintext plain1("2x^2 + 1x^1 + 3"); - Plaintext plain2("3x^3 + 4x^2 + 5x^1 + 6"); - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - evaluator.add_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_TRUE(plain.to_string() == "3x^3 + 6x^2 + 6x^1 + 9"); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - plain1 = "3x^5 + 1x^4 + 4x^3 + 1"; - plain2 = "5x^2 + 9x^1 + 2"; - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - evaluator.add_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_TRUE(plain.to_string() == "3x^5 + 1x^4 + 4x^3 + 5x^2 + 9x^1 + 3"); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - } - - TEST(EvaluatorTest, CKKSEncryptAddDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //adding two zero vectors - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30, 30 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 0.0); - std::vector> output(slot_size); - const double delta = static_cast(1 << 16); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - evaluator.add_inplace(encrypted, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - //adding two random vectors 100 times - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - - int data_bound = (1 << 30); - const double delta = static_cast(1 << 16); - - srand(static_cast(time(NULL))); - - for (int expCount = 0; expCount < 100; expCount++) - { - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] + input2[i]; - } - - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - evaluator.add_inplace(encrypted1, encrypted2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - - } - { - //adding two random vectors 100 times - size_t slot_size = 8; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - - int data_bound = (1 << 30); - const double delta = static_cast(1 << 16); - - srand(static_cast(time(NULL))); - - for (int expCount = 0; expCount < 100; expCount++) - { - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] + input2[i]; - } - - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - evaluator.add_inplace(encrypted1, encrypted2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - - } - } - TEST(EvaluatorTest, CKKSEncryptAddPlainDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //adding two zero vectors - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30, 30 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 0.0); - std::vector> output(slot_size); - const double delta = static_cast(1 << 16); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - evaluator.add_plain_inplace(encrypted, plain); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - //adding two random vectors 50 times - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - - int data_bound = (1 << 8); - const double delta = static_cast(1ULL << 16); - - srand(static_cast(time(NULL))); - - for (int expCount = 0; expCount < 50; expCount++) - { - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] + input2[i]; - } - - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - evaluator.add_plain_inplace(encrypted1, plain2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - - } - { - //adding two random vectors 50 times - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - double input2; - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - - int data_bound = (1 << 8); - const double delta = static_cast(1ULL << 16); - - srand(static_cast(time(NULL))); - - for (int expCount = 0; expCount < 50; expCount++) - { - input2 = static_cast(rand() % (data_bound*data_bound))/data_bound; - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] + input2; - } - - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - evaluator.add_plain_inplace(encrypted1, plain2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - - } - { - //adding two random vectors 50 times - size_t slot_size = 8; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - double input2; - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - - int data_bound = (1 << 8); - const double delta = static_cast(1ULL << 16); - - srand(static_cast(time(NULL))); - - for (int expCount = 0; expCount < 50; expCount++) - { - input2 = static_cast(rand() % (data_bound*data_bound)) / data_bound; - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] + input2; - } - - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - evaluator.add_plain_inplace(encrypted1, plain2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - - } - } - - TEST(EvaluatorTest, CKKSEncryptSubPlainDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //adding two zero vectors - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30, 30 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 0.0); - std::vector> output(slot_size); - const double delta = static_cast(1 << 16); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - evaluator.add_plain_inplace(encrypted, plain); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - //adding two random vectors 100 times - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - - int data_bound = (1 << 8); - const double delta = static_cast(1ULL << 16); - - srand(static_cast(time(NULL))); - - for (int expCount = 0; expCount < 100; expCount++) - { - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] - input2[i]; - } - - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - evaluator.sub_plain_inplace(encrypted1, plain2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - - } - { - //adding two random vectors 100 times - size_t slot_size = 8; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - - int data_bound = (1 << 8); - const double delta = static_cast(1ULL << 16); - - srand(static_cast(time(NULL))); - - for (int expCount = 0; expCount < 100; expCount++) - { - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] - input2[i]; - } - - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - evaluator.sub_plain_inplace(encrypted1, plain2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - - } - } - - TEST(EvaluatorTest, BFVEncryptSubDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted1; - encryptor.encrypt(encoder.encode(0x12345678), encrypted1); - Ciphertext encrypted2; - encryptor.encrypt(encoder.encode(0x54321), encrypted2); - evaluator.sub_inplace(encrypted1, encrypted2); - Plaintext plain; - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0x122F1357), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - encryptor.encrypt(encoder.encode(0), encrypted2); - evaluator.sub_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - encryptor.encrypt(encoder.encode(5), encrypted2); - evaluator.sub_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(-5), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(5), encrypted1); - encryptor.encrypt(encoder.encode(-3), encrypted2); - evaluator.sub_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(8), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-7), encrypted1); - encryptor.encrypt(encoder.encode(2), encrypted2); - evaluator.sub_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(-9), encoder.decode_int32(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - } - - TEST(EvaluatorTest, BFVEncryptAddPlainDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Plaintext plain; - encryptor.encrypt(encoder.encode(0x12345678), encrypted1); - plain = encoder.encode(0x54321); - evaluator.add_plain_inplace(encrypted1, plain); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0x12399999), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - plain = encoder.encode(0); - evaluator.add_plain_inplace(encrypted1, plain); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - plain = encoder.encode(5); - evaluator.add_plain_inplace(encrypted1, plain); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(5), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(5), encrypted1); - plain = encoder.encode(-3); - evaluator.add_plain_inplace(encrypted1, plain); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(2), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-7), encrypted1); - plain = encoder.encode(7); - evaluator.add_plain_inplace(encrypted1, plain); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - } - - TEST(EvaluatorTest, BFVEncryptSubPlainDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted1; - Plaintext plain; - encryptor.encrypt(encoder.encode(0x12345678), encrypted1); - plain = encoder.encode(0x54321); - evaluator.sub_plain_inplace(encrypted1, plain); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0x122F1357), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - plain = encoder.encode(0); - evaluator.sub_plain_inplace(encrypted1, plain); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - plain = encoder.encode(5); - evaluator.sub_plain_inplace(encrypted1, plain); - decryptor.decrypt(encrypted1, plain); - ASSERT_TRUE(static_cast(-5) == encoder.decode_int64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(5), encrypted1); - plain = encoder.encode(-3); - evaluator.sub_plain_inplace(encrypted1, plain); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(8), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-7), encrypted1); - plain = encoder.encode(2); - evaluator.sub_plain_inplace(encrypted1, plain); - decryptor.decrypt(encrypted1, plain); - ASSERT_TRUE(static_cast(-9) == encoder.decode_int64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - } - - TEST(EvaluatorTest, BFVEncryptMultiplyPlainDecrypt) - { - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - encryptor.encrypt(encoder.encode(0x12345678), encrypted); - plain = encoder.encode(0x54321); - evaluator.multiply_plain_inplace(encrypted, plain); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(0x5FCBBBB88D78), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted); - plain = encoder.encode(5); - evaluator.multiply_plain_inplace(encrypted, plain); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(7), encrypted); - plain = encoder.encode(4); - evaluator.multiply_plain_inplace(encrypted, plain); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(28), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(7), encrypted); - plain = encoder.encode(2); - evaluator.multiply_plain_inplace(encrypted, plain); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(14), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(7), encrypted); - plain = encoder.encode(1); - evaluator.multiply_plain_inplace(encrypted, plain); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(7), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(5), encrypted); - plain = encoder.encode(-3); - evaluator.multiply_plain_inplace(encrypted, plain); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(static_cast(-15) == encoder.decode_int64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-7), encrypted); - plain = encoder.encode(2); - evaluator.multiply_plain_inplace(encrypted, plain); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(static_cast(-14) == encoder.decode_int64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - } - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus((1ULL << 20) - 1); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - encryptor.encrypt(encoder.encode(0x12345678), encrypted); - plain = "1"; - evaluator.multiply_plain_inplace(encrypted, plain); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(0x12345678), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - plain = "5"; - evaluator.multiply_plain_inplace(encrypted, plain); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(0x5B05B058), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - } - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus((1ULL << 40) - 1); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - encryptor.encrypt(encoder.encode(0x12345678), encrypted); - plain = "1"; - evaluator.multiply_plain_inplace(encrypted, plain); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(0x12345678), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - plain = "5"; - evaluator.multiply_plain_inplace(encrypted, plain); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(0x5B05B058), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - } - } - - TEST(EvaluatorTest, BFVEncryptMultiplyDecrypt) - { - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Plaintext plain; - encryptor.encrypt(encoder.encode(0x12345678), encrypted1); - encryptor.encrypt(encoder.encode(0x54321), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0x5FCBBBB88D78), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - encryptor.encrypt(encoder.encode(0), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - encryptor.encrypt(encoder.encode(5), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(7), encrypted1); - encryptor.encrypt(encoder.encode(1), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(7), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(5), encrypted1); - encryptor.encrypt(encoder.encode(-3), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_TRUE(static_cast(-15) == encoder.decode_int64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0x10000), encrypted1); - encryptor.encrypt(encoder.encode(0x100), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0x1000000), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - } - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus((1ULL << 60) - 1); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Plaintext plain; - encryptor.encrypt(encoder.encode(0x12345678), encrypted1); - encryptor.encrypt(encoder.encode(0x54321), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0x5FCBBBB88D78), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - encryptor.encrypt(encoder.encode(0), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - encryptor.encrypt(encoder.encode(5), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(7), encrypted1); - encryptor.encrypt(encoder.encode(1), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(7), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(5), encrypted1); - encryptor.encrypt(encoder.encode(-3), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_TRUE(static_cast(-15) == encoder.decode_int64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0x10000), encrypted1); - encryptor.encrypt(encoder.encode(0x100), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0x1000000), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - } - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Plaintext plain; - encryptor.encrypt(encoder.encode(0x12345678), encrypted1); - encryptor.encrypt(encoder.encode(0x54321), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0x5FCBBBB88D78), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - encryptor.encrypt(encoder.encode(0), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted1); - encryptor.encrypt(encoder.encode(5), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(7), encrypted1); - encryptor.encrypt(encoder.encode(1), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(7), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(5), encrypted1); - encryptor.encrypt(encoder.encode(-3), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_TRUE(static_cast(-15) == encoder.decode_int64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0x10000), encrypted1); - encryptor.encrypt(encoder.encode(0x100), encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(0x1000000), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted2.parms_id() == encrypted1.parms_id()); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - } - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 8); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted1; - Plaintext plain; - encryptor.encrypt(encoder.encode(123), encrypted1); - evaluator.multiply(encrypted1, encrypted1, encrypted1); - evaluator.multiply(encrypted1, encrypted1, encrypted1); - decryptor.decrypt(encrypted1, plain); - ASSERT_EQ(static_cast(228886641), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - } - } - -#include "seal/randomgen.h" - TEST(EvaluatorTest, BFVRelinearize) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40, 40 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - RelinKeys rlk = keygen.relin_keys(); - - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted(context); - Ciphertext encrypted2(context); - - Plaintext plain; - Plaintext plain2; - - plain = 0; - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain == plain2); - - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain == plain2); - - plain = "1x^10 + 2"; - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4"); - - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10"); - - // Relinearization with modulus switching - plain = "1x^10 + 2"; - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.mod_switch_to_next_inplace(encrypted); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain2.to_string() == "1x^20 + 4x^10 + 4"); - - encryptor.encrypt(plain, encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.mod_switch_to_next_inplace(encrypted); - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.mod_switch_to_next_inplace(encrypted); - decryptor.decrypt(encrypted, plain2); - ASSERT_TRUE(plain2.to_string() == "1x^40 + 8x^30 + 18x^20 + 20x^10 + 10"); - } - - TEST(EvaluatorTest, CKKSEncryptNaiveMultiplyDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //multiplying two zero vectors - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 30, 30, 30, 30 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 0.0); - std::vector> output(slot_size); - const double delta = static_cast(1 << 30); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - evaluator.multiply_inplace(encrypted, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - { - //multiplying two random vectors - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - const double delta = static_cast(1ULL << 40); - - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); - - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i]; - } - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //multiplying two random vectors - size_t slot_size = 16; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - const double delta = static_cast(1ULL << 40); - - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); - - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i]; - } - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - evaluator.multiply_inplace(encrypted1, encrypted2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - } - - TEST(EvaluatorTest, CKKSEncryptMultiplyByNumberDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //multiplying two random vectors by an integer - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - int64_t input2; - std::vector> expected(slot_size, 0.0); - - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); - - for (int iExp = 0; iExp < 50; iExp++) - { - input2 = max(rand() % data_bound, 1); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * static_cast(input2); - } - - std::vector> output(slot_size); - const double delta = static_cast(1ULL << 40); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), plain2); - - encryptor.encrypt(plain1, encrypted1); - evaluator.multiply_plain_inplace(encrypted1, plain2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //multiplying two random vectors by an integer - size_t slot_size = 8; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - int64_t input2; - std::vector> expected(slot_size, 0.0); - - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); - - for (int iExp = 0; iExp < 50; iExp++) - { - input2 = max(rand() % data_bound, 1); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * static_cast(input2); - } - - std::vector> output(slot_size); - const double delta = static_cast(1ULL << 40); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), plain2); - - encryptor.encrypt(plain1, encrypted1); - evaluator.multiply_plain_inplace(encrypted1, plain2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //multiplying two random vectors by a double - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - double input2; - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); - - for (int iExp = 0; iExp < 50; iExp++) - { - input2 = static_cast(rand() % (data_bound*data_bound)) - /static_cast(data_bound); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2; - } - - const double delta = static_cast(1ULL << 40); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - evaluator.multiply_plain_inplace(encrypted1, plain2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //multiplying two random vectors by a double - size_t slot_size = 16; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - Ciphertext encrypted1; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 2.1); - double input2; - std::vector> expected(slot_size, 2.1); - std::vector> output(slot_size); - - int data_bound = (1 << 10); - srand(static_cast(time(NULL))); - - for (int iExp = 0; iExp < 50; iExp++) - { - input2 = static_cast(rand() % (data_bound*data_bound)) - / static_cast(data_bound); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2; - } - - const double delta = static_cast(1ULL << 40); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - evaluator.multiply_plain_inplace(encrypted1, plain2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - - decryptor.decrypt(encrypted1, plainRes); - - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - } - - TEST(EvaluatorTest, CKKSEncryptMultiplyRelinDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //multiplying two random vectors 50 times - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Ciphertext encryptedRes; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - int data_bound = 1 << 10; - - for (int round = 0; round < 50; round++) - { - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i]; - } - - std::vector> output(slot_size); - const double delta = static_cast(1ULL << 40); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - //check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context->first_parms_id()); - - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //multiplying two random vectors 50 times - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 30, 30, 30 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Ciphertext encryptedRes; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - int data_bound = 1 << 10; - - for (int round = 0; round < 50; round++) - { - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i]; - } - - std::vector> output(slot_size); - const double delta = static_cast(1ULL << 40); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - //check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context->first_parms_id()); - - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //multiplying two random vectors 50 times - size_t slot_size = 2; - parms.set_poly_modulus_degree(8); - parms.set_coeff_modulus(CoeffModulus::Create(8, { 60, 30, 30, 30 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Ciphertext encryptedRes; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - int data_bound = 1 << 10; - const double delta = static_cast(1ULL << 40); - - for (int round = 0; round < 50; round++) - { - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i]; - } - - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - //check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context->first_parms_id()); - - evaluator.multiply_inplace(encrypted1, encrypted2); - //evaluator.relinearize_inplace(encrypted1, rlk); - - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - } - - TEST(EvaluatorTest, CKKSEncryptSquareRelinDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //squaring two random vectors 100 times - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 60, 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - - int data_bound = 1 << 7; - srand(static_cast(time(NULL))); - - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input[i] = static_cast(rand() % data_bound); - expected[i] = input[i] * input[i]; - } - - std::vector> output(slot_size); - const double delta = static_cast(1ULL << 40); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - //evaluator.square_inplace(encrypted); - evaluator.multiply_inplace(encrypted, encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //squaring two random vectors 100 times - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 60, 30, 30, 30 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - - int data_bound = 1 << 7; - srand(static_cast(time(NULL))); - - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input[i] = static_cast(rand() % data_bound); - expected[i] = input[i] * input[i]; - } - - std::vector> output(slot_size); - const double delta = static_cast(1ULL << 40); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - //evaluator.square_inplace(encrypted); - evaluator.multiply_inplace(encrypted, encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //squaring two random vectors 100 times - size_t slot_size = 16; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60, 30, 30, 30 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - - int data_bound = 1 << 7; - srand(static_cast(time(NULL))); - - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input[i] = static_cast(rand() % data_bound); - expected[i] = input[i] * input[i]; - } - - std::vector> output(slot_size); - const double delta = static_cast(1ULL << 40); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - //evaluator.square_inplace(encrypted); - evaluator.multiply_inplace(encrypted, encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - } - - TEST(EvaluatorTest, CKKSEncryptMultiplyRelinRescaleDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //multiplying two random vectors 100 times - size_t slot_size = 64; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, - { 30, 30, 30, 30, 30, 30 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - auto next_parms_id = context->first_context_data()-> - next_context_data()->parms_id(); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Ciphertext encryptedRes; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - - for (int round = 0; round < 100; round++) - { - int data_bound = 1 << 7; - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i]; - } - - std::vector> output(slot_size); - double delta = static_cast(1ULL << 40); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - //check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context->first_parms_id()); - - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.rescale_to_next_inplace(encrypted1); - - //check correctness of modulo switching - ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); - - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //multiplying two random vectors 100 times - size_t slot_size = 16; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30, 30 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - auto next_parms_id = context->first_context_data()-> - next_context_data()->parms_id(); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Ciphertext encryptedRes; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - - for (int round = 0; round < 100; round++) - { - int data_bound = 1 << 7; - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i]; - } - - std::vector> output(slot_size); - double delta = static_cast(1ULL << 40); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - //check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context->first_parms_id()); - - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.rescale_to_next_inplace(encrypted1); - - //check correctness of modulo switching - ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); - - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //multiplying two random vectors 100 times - size_t slot_size = 16; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 60, 60, 60, 60, 60 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Ciphertext encryptedRes; - Plaintext plain1; - Plaintext plain2; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - - for (int round = 0; round < 100; round++) - { - int data_bound = 1 << 7; - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i] * input2[i]; - } - - std::vector> output(slot_size); - double delta = static_cast(1ULL << 60); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - //check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context->first_parms_id()); - - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - - // Scale down by two levels - auto target_parms = context->first_context_data() - ->next_context_data()->next_context_data()->parms_id(); - evaluator.rescale_to_inplace(encrypted1, target_parms); - - //check correctness of modulo switching - ASSERT_TRUE(encrypted1.parms_id() == target_parms); - - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - - // Test with inverted order: rescale then relin - for (int round = 0; round < 100; round++) - { - int data_bound = 1 << 7; - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i] * input2[i]; - } - - std::vector> output(slot_size); - double delta = static_cast(1ULL << 50); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - //check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context->first_parms_id()); - - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.multiply_inplace(encrypted1, encrypted2); - - // Scale down by two levels - auto target_parms = context->first_context_data() - ->next_context_data()->next_context_data()->parms_id(); - evaluator.rescale_to_inplace(encrypted1, target_parms); - - // Relinearize now - evaluator.relinearize_inplace(encrypted1, rlk); - - //check correctness of modulo switching - ASSERT_TRUE(encrypted1.parms_id() == target_parms); - - decryptor.decrypt(encrypted1, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - } - - TEST(EvaluatorTest, CKKSEncryptSquareRelinRescaleDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //squaring two random vectors 100 times - size_t slot_size = 64; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 50, 50, 50 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - auto next_parms_id = context->first_context_data()-> - next_context_data()->parms_id(); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 0.0); - std::vector> output(slot_size); - std::vector> expected(slot_size, 0.0); - int data_bound = 1 << 8; - - for (int round = 0; round < 100; round++) - { - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input[i] = static_cast(rand() % data_bound); - expected[i] = input[i] * input[i]; - } - - double delta = static_cast(1ULL << 40); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.rescale_to_next_inplace(encrypted); - - //check correctness of modulo switching - ASSERT_TRUE(encrypted.parms_id() == next_parms_id); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //squaring two random vectors 100 times - size_t slot_size = 16; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 50, 50, 50 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - auto next_parms_id = context->first_context_data()-> - next_context_data()->parms_id(); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - std::vector> input(slot_size, 0.0); - std::vector> output(slot_size); - std::vector> expected(slot_size, 0.0); - int data_bound = 1 << 8; - - for (int round = 0; round < 100; round++) - { - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input[i] = static_cast(rand() % data_bound); - expected[i] = input[i] * input[i]; - } - - double delta = static_cast(1ULL << 40); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - evaluator.square_inplace(encrypted); - evaluator.relinearize_inplace(encrypted, rlk); - evaluator.rescale_to_next_inplace(encrypted); - - //check correctness of modulo switching - ASSERT_TRUE(encrypted.parms_id() == next_parms_id); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - } - TEST(EvaluatorTest, CKKSEncryptModSwitchDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //modulo switching without rescaling for random vectors - size_t slot_size = 64; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create( - slot_size * 2, { 60, 60, 60, 60, 60 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - auto next_parms_id = context->first_context_data()-> - next_context_data()->parms_id(); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - int data_bound = 1 << 30; - srand(static_cast(time(NULL))); - - std::vector> input(slot_size, 0.0); - std::vector> output(slot_size); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input[i] = static_cast(rand() % data_bound); - } - - double delta = static_cast(1ULL << 40); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - evaluator.mod_switch_to_next_inplace(encrypted); - - //check correctness of modulo switching - ASSERT_TRUE(encrypted.parms_id() == next_parms_id); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //modulo switching without rescaling for random vectors - size_t slot_size = 32; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create( - slot_size * 2, { 40, 40, 40, 40, 40 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - auto next_parms_id = context->first_context_data()-> - next_context_data()->parms_id(); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - int data_bound = 1 << 30; - srand(static_cast(time(NULL))); - - std::vector> input(slot_size, 0.0); - std::vector> output(slot_size); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input[i] = static_cast(rand() % data_bound); - } - - double delta = static_cast(1ULL << 40); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - evaluator.mod_switch_to_next_inplace(encrypted); - - //check correctness of modulo switching - ASSERT_TRUE(encrypted.parms_id() == next_parms_id); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //modulo switching without rescaling for random vectors - size_t slot_size = 32; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40, 40, 40 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - auto next_parms_id = context->first_context_data()-> - next_context_data()->parms_id(); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - - int data_bound = 1 << 30; - srand(static_cast(time(NULL))); - - std::vector> input(slot_size, 0.0); - std::vector> output(slot_size); - - Ciphertext encrypted; - Plaintext plain; - Plaintext plainRes; - - for (int round = 0; round < 100; round++) - { - for (size_t i = 0; i < slot_size; i++) - { - input[i] = static_cast(rand() % data_bound); - } - - double delta = static_cast(1ULL << 40); - encoder.encode(input, context->first_parms_id(), delta, plain); - - encryptor.encrypt(plain, encrypted); - - //check correctness of encryption - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - evaluator.mod_switch_to_next_inplace(encrypted); - - //check correctness of modulo switching - ASSERT_TRUE(encrypted.parms_id() == next_parms_id); - - decryptor.decrypt(encrypted, plainRes); - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(input[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - } - TEST(EvaluatorTest, CKKSEncryptMultiplyRelinRescaleModSwitchAddDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - //multiplication and addition without rescaling for random vectors - size_t slot_size = 64; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 50, 50, 50 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - auto next_parms_id = context->first_context_data()-> - next_context_data()->parms_id(); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Ciphertext encrypted3; - Plaintext plain1; - Plaintext plain2; - Plaintext plain3; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> input3(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - - for (int round = 0; round < 100; round++) - { - int data_bound = 1 << 8; - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i] + input3[i]; - } - - std::vector> output(slot_size); - double delta = static_cast(1ULL << 40); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - encoder.encode(input3, context->first_parms_id(), delta * delta, plain3); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - //check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context->first_parms_id()); - //check correctness of encryption - ASSERT_TRUE(encrypted3.parms_id() == context->first_parms_id()); - - //enc1*enc2 - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.rescale_to_next_inplace(encrypted1); - - //check correctness of modulo switching with rescaling - ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); - - //move enc3 to the level of enc1 * enc2 - evaluator.rescale_to_inplace(encrypted3, next_parms_id); - - //enc1*enc2 + enc3 - evaluator.add_inplace(encrypted1, encrypted3); - - //decryption - decryptor.decrypt(encrypted1, plainRes); - //decoding - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - { - //multiplication and addition without rescaling for random vectors - size_t slot_size = 16; - parms.set_poly_modulus_degree(128); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 50, 50, 50 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - auto next_parms_id = context->first_context_data()-> - next_context_data()->parms_id(); - KeyGenerator keygen(context); - - CKKSEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Decryptor decryptor(context, keygen.secret_key()); - Evaluator evaluator(context); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted1; - Ciphertext encrypted2; - Ciphertext encrypted3; - Plaintext plain1; - Plaintext plain2; - Plaintext plain3; - Plaintext plainRes; - - std::vector> input1(slot_size, 0.0); - std::vector> input2(slot_size, 0.0); - std::vector> input3(slot_size, 0.0); - std::vector> expected(slot_size, 0.0); - std::vector> output(slot_size); - - for (int round = 0; round < 100; round++) - { - int data_bound = 1 << 8; - srand(static_cast(time(NULL))); - for (size_t i = 0; i < slot_size; i++) - { - input1[i] = static_cast(rand() % data_bound); - input2[i] = static_cast(rand() % data_bound); - expected[i] = input1[i] * input2[i] + input3[i]; - } - - double delta = static_cast(1ULL << 40); - encoder.encode(input1, context->first_parms_id(), delta, plain1); - encoder.encode(input2, context->first_parms_id(), delta, plain2); - encoder.encode(input3, context->first_parms_id(), delta * delta, plain3); - - encryptor.encrypt(plain1, encrypted1); - encryptor.encrypt(plain2, encrypted2); - encryptor.encrypt(plain3, encrypted3); - - //check correctness of encryption - ASSERT_TRUE(encrypted1.parms_id() == context->first_parms_id()); - //check correctness of encryption - ASSERT_TRUE(encrypted2.parms_id() == context->first_parms_id()); - //check correctness of encryption - ASSERT_TRUE(encrypted3.parms_id() == context->first_parms_id()); - - //enc1*enc2 - evaluator.multiply_inplace(encrypted1, encrypted2); - evaluator.relinearize_inplace(encrypted1, rlk); - evaluator.rescale_to_next_inplace(encrypted1); - - //check correctness of modulo switching with rescaling - ASSERT_TRUE(encrypted1.parms_id() == next_parms_id); - - //move enc3 to the level of enc1 * enc2 - evaluator.rescale_to_inplace(encrypted3, next_parms_id); - - //enc1*enc2 + enc3 - evaluator.add_inplace(encrypted1, encrypted3); - - //decryption - decryptor.decrypt(encrypted1, plainRes); - //decoding - encoder.decode(plainRes, output); - - for (size_t i = 0; i < slot_size; i++) - { - auto tmp = abs(expected[i].real() - output[i].real()); - ASSERT_TRUE(tmp < 0.5); - } - } - } - } - TEST(EvaluatorTest, CKKSEncryptRotateDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - // maximal number of slots - size_t slot_size = 4; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 40, 40, 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - GaloisKeys glk = keygen.galois_keys(); - - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - CKKSEncoder encoder(context); - const double delta = static_cast(1ULL << 30); - - Ciphertext encrypted; - Plaintext plain; - - vector> input{ - std::complex(1, 1), - std::complex(2, 2), - std::complex(3, 3), - std::complex(4, 4) - }; - input.resize(slot_size); - - vector> output(slot_size, 0); - - encoder.encode(input, context->first_parms_id(), delta, plain); - int shift = 1; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - shift = 2; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - shift = 3; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - encryptor.encrypt(plain, encrypted); - evaluator.complex_conjugate_inplace(encrypted, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[i].real(), round(output[i].real())); - ASSERT_EQ(-input[i].imag(), round(output[i].imag())); - } - } - { - size_t slot_size = 32; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - GaloisKeys glk = keygen.galois_keys(); - - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - CKKSEncoder encoder(context); - const double delta = static_cast(1ULL << 30); - - Ciphertext encrypted; - Plaintext plain; - - vector> input{ - std::complex(1, 1), - std::complex(2, 2), - std::complex(3, 3), - std::complex(4, 4) - }; - input.resize(slot_size); - - vector> output(slot_size, 0); - - encoder.encode(input, context->first_parms_id(), delta, plain); - int shift = 1; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < input.size(); i++) - { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - shift = 2; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - shift = 3; - encryptor.encrypt(plain, encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - encryptor.encrypt(plain, encrypted); - evaluator.complex_conjugate_inplace(encrypted, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[i].real()), round(output[i].real())); - ASSERT_EQ(round(-input[i].imag()), round(output[i].imag())); - } - } - } - - TEST(EvaluatorTest, CKKSEncryptRescaleRotateDecrypt) - { - EncryptionParameters parms(scheme_type::CKKS); - { - // maximal number of slots - size_t slot_size = 4; - parms.set_poly_modulus_degree(slot_size * 2); - parms.set_coeff_modulus(CoeffModulus::Create(slot_size * 2, { 40, 40, 40, 40 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - GaloisKeys glk = keygen.galois_keys(); - - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - CKKSEncoder encoder(context); - const double delta = std::pow(2.0, 70); - - Ciphertext encrypted; - Plaintext plain; - - vector> input{ - std::complex(1, 1), - std::complex(2, 2), - std::complex(3, 3), - std::complex(4, 4) - }; - input.resize(slot_size); - - vector> output(slot_size, 0); - - encoder.encode(input, context->first_parms_id(), delta, plain); - int shift = 1; - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - shift = 2; - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - shift = 3; - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].real(), round(output[i].real())); - ASSERT_EQ(input[(i + static_cast(shift)) % slot_size].imag(), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.complex_conjugate_inplace(encrypted, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(input[i].real(), round(output[i].real())); - ASSERT_EQ(-input[i].imag(), round(output[i].imag())); - } - } - { - size_t slot_size = 32; - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 40, 40, 40, 40 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - GaloisKeys glk = keygen.galois_keys(); - - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - CKKSEncoder encoder(context); - const double delta = std::pow(2, 70); - - Ciphertext encrypted; - Plaintext plain; - - vector> input{ - std::complex(1, 1), - std::complex(2, 2), - std::complex(3, 3), - std::complex(4, 4) - }; - input.resize(slot_size); - - vector> output(slot_size, 0); - - encoder.encode(input, context->first_parms_id(), delta, plain); - int shift = 1; - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - shift = 2; - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - shift = 3; - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.rotate_vector_inplace(encrypted, shift, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].real()), round(output[i].real())); - ASSERT_EQ(round(input[(i + static_cast(shift)) % slot_size].imag()), round(output[i].imag())); - } - - encoder.encode(input, context->first_parms_id(), delta, plain); - encryptor.encrypt(plain, encrypted); - evaluator.rescale_to_next_inplace(encrypted); - evaluator.complex_conjugate_inplace(encrypted, glk); - decryptor.decrypt(encrypted, plain); - encoder.decode(plain, output); - for (size_t i = 0; i < slot_size; i++) - { - ASSERT_EQ(round(input[i].real()), round(output[i].real())); - ASSERT_EQ(round(-input[i].imag()), round(output[i].imag())); - } - } - } - - TEST(EvaluatorTest, BFVEncryptSquareDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 8); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted; - Plaintext plain; - encryptor.encrypt(encoder.encode(1), encrypted); - evaluator.square_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(1ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0), encrypted); - evaluator.square_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-5), encrypted); - evaluator.square_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(25ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-1), encrypted); - evaluator.square_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(1ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(123), encrypted); - evaluator.square_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(15129ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0x10000), encrypted); - evaluator.square_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(0x100000000ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(123), encrypted); - evaluator.square_inplace(encrypted); - evaluator.square_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(228886641ULL, encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - } - - TEST(EvaluatorTest, BFVEncryptMultiplyManyDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, product; - Plaintext plain; - encryptor.encrypt(encoder.encode(5), encrypted1); - encryptor.encrypt(encoder.encode(6), encrypted2); - encryptor.encrypt(encoder.encode(7), encrypted3); - vector encrypteds{ encrypted1, encrypted2, encrypted3 }; - evaluator.multiply_many(encrypteds, rlk, product); - ASSERT_EQ(3, encrypteds.size()); - decryptor.decrypt(product, plain); - ASSERT_EQ(static_cast(210), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); - ASSERT_TRUE(product.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-9), encrypted1); - encryptor.encrypt(encoder.encode(-17), encrypted2); - encrypteds = { encrypted1, encrypted2 }; - evaluator.multiply_many(encrypteds, rlk, product); - ASSERT_EQ(2, encrypteds.size()); - decryptor.decrypt(product, plain); - ASSERT_EQ(static_cast(153), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); - ASSERT_TRUE(product.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(2), encrypted1); - encryptor.encrypt(encoder.encode(-31), encrypted2); - encryptor.encrypt(encoder.encode(7), encrypted3); - encrypteds = { encrypted1, encrypted2, encrypted3 }; - evaluator.multiply_many(encrypteds, rlk, product); - ASSERT_EQ(3, encrypteds.size()); - decryptor.decrypt(product, plain); - ASSERT_TRUE(static_cast(-434) == encoder.decode_int64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); - ASSERT_TRUE(product.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(1), encrypted1); - encryptor.encrypt(encoder.encode(-1), encrypted2); - encryptor.encrypt(encoder.encode(1), encrypted3); - encryptor.encrypt(encoder.encode(-1), encrypted4); - encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; - evaluator.multiply_many(encrypteds, rlk, product); - ASSERT_EQ(4, encrypteds.size()); - decryptor.decrypt(product, plain); - ASSERT_EQ(static_cast(1), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted4.parms_id() == product.parms_id()); - ASSERT_TRUE(product.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(98765), encrypted1); - encryptor.encrypt(encoder.encode(0), encrypted2); - encryptor.encrypt(encoder.encode(12345), encrypted3); - encryptor.encrypt(encoder.encode(34567), encrypted4); - encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; - evaluator.multiply_many(encrypteds, rlk, product); - ASSERT_EQ(4, encrypteds.size()); - decryptor.decrypt(product, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == product.parms_id()); - ASSERT_TRUE(encrypted4.parms_id() == product.parms_id()); - ASSERT_TRUE(product.parms_id() == context->first_parms_id()); - } - - TEST(EvaluatorTest, BFVEncryptExponentiateDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - RelinKeys rlk = keygen.relin_keys(); - - Ciphertext encrypted; - Plaintext plain; - encryptor.encrypt(encoder.encode(5), encrypted); - evaluator.exponentiate_inplace(encrypted, 1, rlk); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(5), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(7), encrypted); - evaluator.exponentiate_inplace(encrypted, 2, rlk); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(49), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-7), encrypted); - evaluator.exponentiate_inplace(encrypted, 3, rlk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(static_cast(-343) == encoder.decode_int64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(0x100), encrypted); - evaluator.exponentiate_inplace(encrypted, 4, rlk); - decryptor.decrypt(encrypted, plain); - ASSERT_EQ(static_cast(0x100000000), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - } - - TEST(EvaluatorTest, BFVEncryptAddManyDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - IntegerEncoder encoder(context); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Ciphertext encrypted1, encrypted2, encrypted3, encrypted4, sum; - Plaintext plain; - encryptor.encrypt(encoder.encode(5), encrypted1); - encryptor.encrypt(encoder.encode(6), encrypted2); - encryptor.encrypt(encoder.encode(7), encrypted3); - vector encrypteds = { encrypted1, encrypted2, encrypted3 }; - evaluator.add_many(encrypteds, sum); - decryptor.decrypt(sum, plain); - ASSERT_EQ(static_cast(18), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); - ASSERT_TRUE(sum.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(-9), encrypted1); - encryptor.encrypt(encoder.encode(-17), encrypted2); - encrypteds = { encrypted1, encrypted2, }; - evaluator.add_many(encrypteds, sum); - decryptor.decrypt(sum, plain); - ASSERT_TRUE(static_cast(-26) == encoder.decode_int64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); - ASSERT_TRUE(sum.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(2), encrypted1); - encryptor.encrypt(encoder.encode(-31), encrypted2); - encryptor.encrypt(encoder.encode(7), encrypted3); - encrypteds = { encrypted1, encrypted2, encrypted3 }; - evaluator.add_many(encrypteds, sum); - decryptor.decrypt(sum, plain); - ASSERT_TRUE(static_cast(-22) == encoder.decode_int64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); - ASSERT_TRUE(sum.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(1), encrypted1); - encryptor.encrypt(encoder.encode(-1), encrypted2); - encryptor.encrypt(encoder.encode(1), encrypted3); - encryptor.encrypt(encoder.encode(-1), encrypted4); - encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; - evaluator.add_many(encrypteds, sum); - decryptor.decrypt(sum, plain); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id()); - ASSERT_TRUE(sum.parms_id() == context->first_parms_id()); - - encryptor.encrypt(encoder.encode(98765), encrypted1); - encryptor.encrypt(encoder.encode(0), encrypted2); - encryptor.encrypt(encoder.encode(12345), encrypted3); - encryptor.encrypt(encoder.encode(34567), encrypted4); - encrypteds = { encrypted1, encrypted2, encrypted3, encrypted4 }; - evaluator.add_many(encrypteds, sum); - decryptor.decrypt(sum, plain); - ASSERT_EQ(static_cast(145677), encoder.decode_uint64(plain)); - ASSERT_TRUE(encrypted1.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted2.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted3.parms_id() == sum.parms_id()); - ASSERT_TRUE(encrypted4.parms_id() == sum.parms_id()); - ASSERT_TRUE(sum.parms_id() == context->first_parms_id()); - } - - TEST(EvaluatorTest, TransformPlainToNTT) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40, 40 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - - Evaluator evaluator(context); - Plaintext plain("0"); - ASSERT_FALSE(plain.is_ntt_form()); - evaluator.transform_to_ntt_inplace(plain, context->first_parms_id()); - ASSERT_TRUE(plain.is_zero()); - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == context->first_parms_id()); - - plain.release(); - plain = "0"; - ASSERT_FALSE(plain.is_ntt_form()); - auto next_parms_id = context->first_context_data()-> - next_context_data()->parms_id(); - evaluator.transform_to_ntt_inplace(plain, next_parms_id); - ASSERT_TRUE(plain.is_zero()); - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == next_parms_id); - - plain.release(); - plain = "1"; - ASSERT_FALSE(plain.is_ntt_form()); - evaluator.transform_to_ntt_inplace(plain, context->first_parms_id()); - for (size_t i = 0; i < 256; i++) - { - ASSERT_TRUE(plain[i] == uint64_t(1)); - } - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == context->first_parms_id()); - - plain.release(); - plain = "1"; - ASSERT_FALSE(plain.is_ntt_form()); - evaluator.transform_to_ntt_inplace(plain, next_parms_id); - for (size_t i = 0; i < 128; i++) - { - ASSERT_TRUE(plain[i] == uint64_t(1)); - } - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == next_parms_id); - - plain.release(); - plain = "2"; - ASSERT_FALSE(plain.is_ntt_form()); - evaluator.transform_to_ntt_inplace(plain, context->first_parms_id()); - for (size_t i = 0; i < 256; i++) - { - ASSERT_TRUE(plain[i] == uint64_t(2)); - } - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == context->first_parms_id()); - - plain.release(); - plain = "2"; - evaluator.transform_to_ntt_inplace(plain, next_parms_id); - for (size_t i = 0; i < 128; i++) - { - ASSERT_TRUE(plain[i] == uint64_t(2)); - } - ASSERT_TRUE(plain.is_ntt_form()); - ASSERT_TRUE(plain.parms_id() == next_parms_id); - } - - TEST(EvaluatorTest, TransformEncryptedToFromNTT) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Plaintext plain; - Ciphertext encrypted; - plain = "0"; - encryptor.encrypt(plain, encrypted); - evaluator.transform_to_ntt_inplace(encrypted); - evaluator.transform_from_ntt_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(plain.to_string() == "0"); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - plain = "1"; - encryptor.encrypt(plain, encrypted); - evaluator.transform_to_ntt_inplace(encrypted); - evaluator.transform_from_ntt_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(plain.to_string() == "1"); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - plain = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"; - encryptor.encrypt(plain, encrypted); - evaluator.transform_to_ntt_inplace(encrypted); - evaluator.transform_from_ntt_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(plain.to_string() == "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - } - - TEST(EvaluatorTest, BFVEncryptMultiplyPlainNTTDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(1 << 6); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Plaintext plain; - Plaintext plain_multiplier; - Ciphertext encrypted; - - plain = 0; - encryptor.encrypt(plain, encrypted); - evaluator.transform_to_ntt_inplace(encrypted); - plain_multiplier = 1; - evaluator.transform_to_ntt_inplace(plain_multiplier, context->first_parms_id()); - evaluator.multiply_plain_inplace(encrypted, plain_multiplier); - evaluator.transform_from_ntt_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(plain.to_string() == "0"); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - plain = 2; - encryptor.encrypt(plain, encrypted); - evaluator.transform_to_ntt_inplace(encrypted); - plain_multiplier.release(); - plain_multiplier = 3; - evaluator.transform_to_ntt_inplace(plain_multiplier, context->first_parms_id()); - evaluator.multiply_plain_inplace(encrypted, plain_multiplier); - evaluator.transform_from_ntt_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(plain.to_string() == "6"); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - plain = 1; - encryptor.encrypt(plain, encrypted); - evaluator.transform_to_ntt_inplace(encrypted); - plain_multiplier.release(); - plain_multiplier = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"; - evaluator.transform_to_ntt_inplace(plain_multiplier, context->first_parms_id()); - evaluator.multiply_plain_inplace(encrypted, plain_multiplier); - evaluator.transform_from_ntt_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(plain.to_string() == "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - - plain = "1x^20"; - encryptor.encrypt(plain, encrypted); - evaluator.transform_to_ntt_inplace(encrypted); - plain_multiplier.release(); - plain_multiplier = "Fx^10 + Ex^9 + Dx^8 + Cx^7 + Bx^6 + Ax^5 + 1x^4 + 2x^3 + 3x^2 + 4x^1 + 5"; - evaluator.transform_to_ntt_inplace(plain_multiplier, context->first_parms_id()); - evaluator.multiply_plain_inplace(encrypted, plain_multiplier); - evaluator.transform_from_ntt_inplace(encrypted); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(plain.to_string() == "Fx^30 + Ex^29 + Dx^28 + Cx^27 + Bx^26 + Ax^25 + 1x^24 + 2x^23 + 3x^22 + 4x^21 + 5x^20"); - ASSERT_TRUE(encrypted.parms_id() == context->first_parms_id()); - } - - TEST(EvaluatorTest, BFVEncryptApplyGaloisDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(257); - parms.set_poly_modulus_degree(8); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(8, { 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - GaloisKeys glk = keygen.galois_keys(vector{ 1, 3, 5, 15 }); - - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - - Plaintext plain("1"); - Ciphertext encrypted; - encryptor.encrypt(plain, encrypted); - evaluator.apply_galois_inplace(encrypted, 1, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 3, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 5, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 15, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1" == plain.to_string()); - - plain = "1x^1"; - encryptor.encrypt(plain, encrypted); - evaluator.apply_galois_inplace(encrypted, 1, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1x^1" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 3, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1x^3" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 5, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("100x^7" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 15, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1x^1" == plain.to_string()); - - plain = "1x^2"; - encryptor.encrypt(plain, encrypted); - evaluator.apply_galois_inplace(encrypted, 1, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1x^2" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 3, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1x^6" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 5, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("100x^6" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 15, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1x^2" == plain.to_string()); - - plain = "1x^3 + 2x^2 + 1x^1 + 1"; - encryptor.encrypt(plain, encrypted); - evaluator.apply_galois_inplace(encrypted, 1, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1x^3 + 2x^2 + 1x^1 + 1" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 3, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("2x^6 + 1x^3 + 100x^1 + 1" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 5, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("100x^7 + FFx^6 + 100x^5 + 1" == plain.to_string()); - evaluator.apply_galois_inplace(encrypted, 15, glk); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE("1x^3 + 2x^2 + 1x^1 + 1" == plain.to_string()); - } - - TEST(EvaluatorTest, BFVEncryptRotateMatrixDecrypt) - { - EncryptionParameters parms(scheme_type::BFV); - SmallModulus plain_modulus(257); - parms.set_poly_modulus_degree(8); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(8, { 40, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - GaloisKeys glk = keygen.galois_keys(); - - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - BatchEncoder batch_encoder(context); - - Plaintext plain; - vector plain_vec{ - 1, 2, 3, 4, - 5, 6, 7, 8 - }; - batch_encoder.encode(plain_vec, plain); - Ciphertext encrypted; - encryptor.encrypt(plain, encrypted); - - evaluator.rotate_columns_inplace(encrypted, glk); - decryptor.decrypt(encrypted, plain); - batch_encoder.decode(plain, plain_vec); - ASSERT_TRUE((plain_vec == vector{ - 5, 6, 7, 8, - 1, 2, 3, 4 - })); - - evaluator.rotate_rows_inplace(encrypted, -1, glk); - decryptor.decrypt(encrypted, plain); - batch_encoder.decode(plain, plain_vec); - ASSERT_TRUE((plain_vec == vector{ - 8, 5, 6, 7, - 4, 1, 2, 3 - })); - - evaluator.rotate_rows_inplace(encrypted, 2, glk); - decryptor.decrypt(encrypted, plain); - batch_encoder.decode(plain, plain_vec); - ASSERT_TRUE((plain_vec == vector{ - 6, 7, 8, 5, - 2, 3, 4, 1 - })); - - evaluator.rotate_columns_inplace(encrypted, glk); - decryptor.decrypt(encrypted, plain); - batch_encoder.decode(plain, plain_vec); - ASSERT_TRUE((plain_vec == vector{ - 2, 3, 4, 1, - 6, 7, 8, 5 - })); - - evaluator.rotate_rows_inplace(encrypted, 0, glk); - decryptor.decrypt(encrypted, plain); - batch_encoder.decode(plain, plain_vec); - ASSERT_TRUE((plain_vec == vector{ - 2, 3, 4, 1, - 6, 7, 8, 5 - })); - } - TEST(EvaluatorTest, BFVEncryptModSwitchToNextDecrypt) - { - // the common parameters: the plaintext and the polynomial moduli - SmallModulus plain_modulus(1 << 6); - - // the parameters and the context of the higher level - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - SecretKey secret_key = keygen.secret_key(); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - auto parms_id = context->first_parms_id(); - - Ciphertext encrypted(context); - Ciphertext encryptedRes; - Plaintext plain; - - plain = 0; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_next(encrypted, encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "0"); - - evaluator.mod_switch_to_next_inplace(encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "0"); - - parms_id = context->first_parms_id(); - plain = 1; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_next(encrypted, encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1"); - - evaluator.mod_switch_to_next_inplace(encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1"); - - parms_id = context->first_parms_id(); - plain = "1x^127"; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_next(encrypted, encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1x^127"); - - evaluator.mod_switch_to_next_inplace(encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1x^127"); - - parms_id = context->first_parms_id(); - plain = "5x^64 + Ax^5"; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_next(encrypted, encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); - - evaluator.mod_switch_to_next_inplace(encryptedRes); - decryptor.decrypt(encryptedRes, plain); - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - ASSERT_TRUE(encryptedRes.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); - } - - TEST(EvaluatorTest, BFVEncryptModSwitchToDecrypt) - { - // the common parameters: the plaintext and the polynomial moduli - SmallModulus plain_modulus(1 << 6); - - // the parameters and the context of the higher level - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 30, 30, 30, 30 })); - - auto context = SEALContext::Create(parms, true, sec_level_type::none); - KeyGenerator keygen(context); - SecretKey secret_key = keygen.secret_key(); - Encryptor encryptor(context, keygen.public_key()); - Evaluator evaluator(context); - Decryptor decryptor(context, keygen.secret_key()); - auto parms_id = context->first_parms_id(); - - Ciphertext encrypted(context); - Plaintext plain; - - plain = 0; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "0"); - - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "0"); - - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "0"); - - parms_id = context->first_parms_id(); - encryptor.encrypt(plain, encrypted); - parms_id = context->get_context_data(parms_id)-> - next_context_data()-> - next_context_data()->parms_id(); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "0"); - - parms_id = context->first_parms_id(); - plain = 1; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1"); - - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1"); - - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1"); - - parms_id = context->first_parms_id(); - encryptor.encrypt(plain, encrypted); - parms_id = context->get_context_data(parms_id)-> - next_context_data()-> - next_context_data()->parms_id(); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1"); - - parms_id = context->first_parms_id(); - plain = "1x^127"; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1x^127"); - - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1x^127"); - - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1x^127"); - - parms_id = context->first_parms_id(); - encryptor.encrypt(plain, encrypted); - parms_id = context->get_context_data(parms_id)-> - next_context_data()-> - next_context_data()->parms_id(); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "1x^127"); - - parms_id = context->first_parms_id(); - plain = "5x^64 + Ax^5"; - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); - - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); - - parms_id = context->get_context_data(parms_id)-> - next_context_data()->parms_id(); - encryptor.encrypt(plain, encrypted); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); - - parms_id = context->first_parms_id(); - encryptor.encrypt(plain, encrypted); - parms_id = context->get_context_data(parms_id)-> - next_context_data()-> - next_context_data()->parms_id(); - evaluator.mod_switch_to_inplace(encrypted, parms_id); - decryptor.decrypt(encrypted, plain); - ASSERT_TRUE(encrypted.parms_id() == parms_id); - ASSERT_TRUE(plain.to_string() == "5x^64 + Ax^5"); - } -} diff --git a/SEAL/native/tests/seal/galoiskeys.cpp b/SEAL/native/tests/seal/galoiskeys.cpp deleted file mode 100644 index 2a38048..0000000 --- a/SEAL/native/tests/seal/galoiskeys.cpp +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/galoiskeys.h" -#include "seal/context.h" -#include "seal/keygenerator.h" -#include "seal/util/uintcore.h" -#include "seal/modulus.h" -#include - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - TEST(GaloisKeysTest, GaloisKeysSaveLoad) - { - stringstream stream; - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(65537); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - GaloisKeys keys; - GaloisKeys test_keys; - keys.save(stream); - test_keys.unsafe_load(stream); - ASSERT_EQ(keys.data().size(), test_keys.data().size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - ASSERT_EQ(0ULL, keys.data().size()); - - keys = keygen.galois_keys(); - keys.save(stream); - test_keys.load(context, stream); - ASSERT_EQ(keys.data().size(), test_keys.data().size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - for (size_t j = 0; j < test_keys.data().size(); j++) - { - for (size_t i = 0; i < test_keys.data()[j].size(); i++) - { - ASSERT_EQ(keys.data()[j][i].data().size(), test_keys.data()[j][i].data().size()); - ASSERT_EQ(keys.data()[j][i].data().uint64_count(), test_keys.data()[j][i].data().uint64_count()); - ASSERT_TRUE(is_equal_uint_uint(keys.data()[j][i].data().data(), test_keys.data()[j][i].data().data(), keys.data()[j][i].data().uint64_count())); - } - } - ASSERT_EQ(64ULL, keys.data().size()); - } - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(256); - parms.set_plain_modulus(65537); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - GaloisKeys keys; - GaloisKeys test_keys; - keys.save(stream); - test_keys.unsafe_load(stream); - ASSERT_EQ(keys.data().size(), test_keys.data().size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - ASSERT_EQ(0ULL, keys.data().size()); - - keys = keygen.galois_keys(); - keys.save(stream); - test_keys.load(context, stream); - ASSERT_EQ(keys.data().size(), test_keys.data().size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - for (size_t j = 0; j < test_keys.data().size(); j++) - { - for (size_t i = 0; i < test_keys.data()[j].size(); i++) - { - ASSERT_EQ(keys.data()[j][i].data().size(), test_keys.data()[j][i].data().size()); - ASSERT_EQ(keys.data()[j][i].data().uint64_count(), test_keys.data()[j][i].data().uint64_count()); - ASSERT_TRUE(is_equal_uint_uint(keys.data()[j][i].data().data(), test_keys.data()[j][i].data().data(), keys.data()[j][i].data().uint64_count())); - } - } - ASSERT_EQ(256ULL, keys.data().size()); - } - } -} \ No newline at end of file diff --git a/SEAL/native/tests/seal/intarray.cpp b/SEAL/native/tests/seal/intarray.cpp deleted file mode 100644 index 24b3b4b..0000000 --- a/SEAL/native/tests/seal/intarray.cpp +++ /dev/null @@ -1,180 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/intarray.h" -#include "seal/memorymanager.h" -#include - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - TEST(IntArrayTest, IntArrayBasics) - { - { - auto pool = MemoryPoolHandle::New(); - MemoryManager::SwitchProfile(new MMProfFixed(pool)); - IntArray arr; - ASSERT_TRUE(arr.begin() == nullptr); - ASSERT_TRUE(arr.end() == nullptr); - ASSERT_EQ(0ULL, arr.size()); - ASSERT_EQ(0ULL, arr.capacity()); - ASSERT_TRUE(arr.empty()); - - arr.resize(1); - ASSERT_FALSE(arr.begin() == nullptr); - ASSERT_FALSE(arr.end() == nullptr); - ASSERT_FALSE(arr.begin() == arr.end()); - ASSERT_EQ(1, static_cast(arr.end() - arr.begin())); - ASSERT_EQ(1ULL, arr.size()); - ASSERT_EQ(1ULL, arr.capacity()); - ASSERT_FALSE(arr.empty()); - ASSERT_EQ(0, arr[0]); - arr.at(0) = 1; - ASSERT_EQ(1, arr[0]); - ASSERT_EQ(4, static_cast(pool.alloc_byte_count())); - - arr.reserve(6); - ASSERT_FALSE(arr.begin() == nullptr); - ASSERT_FALSE(arr.end() == nullptr); - ASSERT_FALSE(arr.begin() == arr.end()); - ASSERT_EQ(1, static_cast(arr.end() - arr.begin())); - ASSERT_EQ(1ULL, arr.size()); - ASSERT_EQ(6ULL, arr.capacity()); - ASSERT_FALSE(arr.empty()); - ASSERT_EQ(1, arr[0]); - ASSERT_EQ(28, static_cast(pool.alloc_byte_count())); - - arr.resize(4); - ASSERT_FALSE(arr.begin() == nullptr); - ASSERT_FALSE(arr.end() == nullptr); - ASSERT_FALSE(arr.begin() == arr.end()); - ASSERT_EQ(4, static_cast(arr.end() - arr.begin())); - ASSERT_EQ(4ULL, arr.size()); - ASSERT_EQ(6ULL, arr.capacity()); - ASSERT_FALSE(arr.empty()); - arr.at(0) = 0; - arr.at(1) = 1; - arr.at(2) = 2; - arr.at(3) = 3; - ASSERT_EQ(0, arr[0]); - ASSERT_EQ(1, arr[1]); - ASSERT_EQ(2, arr[2]); - ASSERT_EQ(3, arr[3]); - ASSERT_EQ(28, static_cast(pool.alloc_byte_count())); - - arr.shrink_to_fit(); - ASSERT_FALSE(arr.begin() == nullptr); - ASSERT_FALSE(arr.end() == nullptr); - ASSERT_FALSE(arr.begin() == arr.end()); - ASSERT_EQ(4, static_cast(arr.end() - arr.begin())); - ASSERT_EQ(4ULL, arr.size()); - ASSERT_EQ(4ULL, arr.capacity()); - ASSERT_FALSE(arr.empty()); - ASSERT_EQ(0, arr[0]); - ASSERT_EQ(1, arr[1]); - ASSERT_EQ(2, arr[2]); - ASSERT_EQ(3, arr[3]); - ASSERT_EQ(44, static_cast(pool.alloc_byte_count())); - } - { - auto pool = MemoryPoolHandle::New(); - MemoryManager::SwitchProfile(new MMProfFixed(pool)); - IntArray arr; - ASSERT_TRUE(arr.begin() == nullptr); - ASSERT_TRUE(arr.end() == nullptr); - ASSERT_EQ(0ULL, arr.size()); - ASSERT_EQ(0ULL, arr.capacity()); - ASSERT_TRUE(arr.empty()); - - arr.resize(1); - ASSERT_FALSE(arr.begin() == nullptr); - ASSERT_FALSE(arr.end() == nullptr); - ASSERT_FALSE(arr.begin() == arr.end()); - ASSERT_EQ(1, static_cast(arr.end() - arr.begin())); - ASSERT_EQ(1ULL, arr.size()); - ASSERT_EQ(1ULL, arr.capacity()); - ASSERT_FALSE(arr.empty()); - ASSERT_EQ(0ULL, arr[0]); - arr.at(0) = 1; - ASSERT_EQ(1ULL, arr[0]); - ASSERT_EQ(8, static_cast(pool.alloc_byte_count())); - - arr.reserve(6); - ASSERT_FALSE(arr.begin() == nullptr); - ASSERT_FALSE(arr.end() == nullptr); - ASSERT_FALSE(arr.begin() == arr.end()); - ASSERT_EQ(1, static_cast(arr.end() - arr.begin())); - ASSERT_EQ(1ULL, arr.size()); - ASSERT_EQ(6ULL, arr.capacity()); - ASSERT_FALSE(arr.empty()); - ASSERT_EQ(1ULL, arr[0]); - ASSERT_EQ(56, static_cast(pool.alloc_byte_count())); - - arr.resize(4); - ASSERT_FALSE(arr.begin() == nullptr); - ASSERT_FALSE(arr.end() == nullptr); - ASSERT_FALSE(arr.begin() == arr.end()); - ASSERT_EQ(4, static_cast(arr.end() - arr.begin())); - ASSERT_EQ(4ULL, arr.size()); - ASSERT_EQ(6ULL, arr.capacity()); - ASSERT_FALSE(arr.empty()); - arr.at(0) = 0; - arr.at(1) = 1; - arr.at(2) = 2; - arr.at(3) = 3; - ASSERT_EQ(0ULL, arr[0]); - ASSERT_EQ(1ULL, arr[1]); - ASSERT_EQ(2ULL, arr[2]); - ASSERT_EQ(3ULL, arr[3]); - ASSERT_EQ(56, static_cast(pool.alloc_byte_count())); - - arr.shrink_to_fit(); - ASSERT_FALSE(arr.begin() == nullptr); - ASSERT_FALSE(arr.end() == nullptr); - ASSERT_FALSE(arr.begin() == arr.end()); - ASSERT_EQ(4, static_cast(arr.end() - arr.begin())); - ASSERT_EQ(4ULL, arr.size()); - ASSERT_EQ(4ULL, arr.capacity()); - ASSERT_FALSE(arr.empty()); - ASSERT_EQ(0ULL, arr[0]); - ASSERT_EQ(1ULL, arr[1]); - ASSERT_EQ(2ULL, arr[2]); - ASSERT_EQ(3ULL, arr[3]); - ASSERT_EQ(88, static_cast(pool.alloc_byte_count())); - } - } - - TEST(IntArrayTest, SaveLoadIntArray) - { - IntArray arr(6, 4); - arr.at(0) = 0; - arr.at(1) = 1; - arr.at(2) = 2; - arr.at(3) = 3; - stringstream ss; - arr.save(ss); - IntArray arr2; - arr2.load(ss); - - ASSERT_EQ(arr.size(), arr2.size()); - ASSERT_EQ(arr.size(), arr2.capacity()); - ASSERT_EQ(arr[0], arr2[0]); - ASSERT_EQ(arr[1], arr2[1]); - ASSERT_EQ(arr[2], arr2[2]); - ASSERT_EQ(arr[3], arr2[3]); - - arr.resize(2); - arr[0] = 5; - arr[1] = 6; - arr.save(ss); - arr2.load(ss); - - ASSERT_EQ(arr.size(), arr2.size()); - ASSERT_EQ(4ULL, arr2.capacity()); - ASSERT_EQ(arr[0], arr2[0]); - ASSERT_EQ(arr[1], arr2[1]); - } -} diff --git a/SEAL/native/tests/seal/intencoder.cpp b/SEAL/native/tests/seal/intencoder.cpp deleted file mode 100644 index 89427d6..0000000 --- a/SEAL/native/tests/seal/intencoder.cpp +++ /dev/null @@ -1,438 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/intencoder.h" -#include "seal/context.h" -#include -#include - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - TEST(Encoder, IntEncodeDecodeBigUInt) - { - SmallModulus modulus(0xFFFFFFFFFFFFFFF); - EncryptionParameters parms(scheme_type::BFV); - parms.set_plain_modulus(modulus); - auto context = SEALContext::Create(parms); - IntegerEncoder encoder(context); - - BigUInt value(64); - value = "0"; - Plaintext poly = encoder.encode(value); - ASSERT_EQ(0ULL, poly.significant_coeff_count()); - ASSERT_TRUE(poly.is_zero()); - ASSERT_TRUE(value == encoder.decode_biguint(poly)); - - value = "1"; - Plaintext poly1 = encoder.encode(value); - ASSERT_EQ(1ULL, poly1.coeff_count()); - ASSERT_TRUE("1" == poly1.to_string()); - ASSERT_TRUE(value == encoder.decode_biguint(poly1)); - - value = "2"; - Plaintext poly2 = encoder.encode(value); - ASSERT_EQ(2ULL, poly2.coeff_count()); - ASSERT_TRUE("1x^1" == poly2.to_string()); - ASSERT_TRUE(value == encoder.decode_biguint(poly2)); - - value = "3"; - Plaintext poly3 = encoder.encode(value); - ASSERT_EQ(2ULL, poly3.coeff_count()); - ASSERT_TRUE("1x^1 + 1" == poly3.to_string()); - ASSERT_TRUE(value == encoder.decode_biguint(poly3)); - - value = "FFFFFFFFFFFFFFFF"; - Plaintext poly4 = encoder.encode(value); - ASSERT_EQ(64ULL, poly4.coeff_count()); - for (size_t i = 0; i < 64; ++i) - { - ASSERT_TRUE(poly4[i] == 1); - } - ASSERT_TRUE(value == encoder.decode_biguint(poly4)); - - value = "80F02"; - Plaintext poly5 = encoder.encode(value); - ASSERT_EQ(20ULL, poly5.coeff_count()); - for (size_t i = 0; i < 20; ++i) - { - if (i == 19 || (i >= 8 && i <= 11) || i == 1) - { - ASSERT_TRUE(poly5[i] == 1); - } - else - { - ASSERT_TRUE(poly5[i] == 0); - } - } - ASSERT_TRUE(value == encoder.decode_biguint(poly5)); - - Plaintext poly6(3); - poly6[0] = 1; - poly6[1] = 500; - poly6[2] = 1023; - value = 1 + 500 * 2 + 1023 * 4; - ASSERT_TRUE(value == encoder.decode_biguint(poly6)); - - modulus = 1024; - parms.set_plain_modulus(modulus); - auto context2 = SEALContext::Create(parms); - IntegerEncoder encoder2(context2); - Plaintext poly7(4); - poly7[0] = 1023; // -1 (*1) - poly7[1] = 512; // -512 (*2) - poly7[2] = 511; // 511 (*4) - poly7[3] = 1; // 1 (*8) - value = -1 + -512 * 2 + 511 * 4 + 1 * 8; - ASSERT_TRUE(value == encoder2.decode_biguint(poly7)); - } - - TEST(Encoder, IntEncodeDecodeUInt64) - { - SmallModulus modulus(0xFFFFFFFFFFFFFFF); - EncryptionParameters parms(scheme_type::BFV); - parms.set_plain_modulus(modulus); - auto context = SEALContext::Create(parms); - IntegerEncoder encoder(context); - - Plaintext poly = encoder.encode(static_cast(0)); - ASSERT_EQ(0ULL, poly.significant_coeff_count()); - ASSERT_TRUE(poly.is_zero()); - ASSERT_EQ(static_cast(0), encoder.decode_uint64(poly)); - - Plaintext poly1 = encoder.encode(1u); - ASSERT_EQ(1ULL, poly1.coeff_count()); - ASSERT_TRUE("1" == poly1.to_string()); - ASSERT_EQ(1ULL, encoder.decode_uint64(poly1)); - - Plaintext poly2 = encoder.encode(static_cast(2)); - ASSERT_EQ(2ULL, poly2.coeff_count()); - ASSERT_TRUE("1x^1" == poly2.to_string()); - ASSERT_EQ(static_cast(2), encoder.decode_uint64(poly2)); - - Plaintext poly3 = encoder.encode(static_cast(3)); - ASSERT_EQ(2ULL, poly3.coeff_count()); - ASSERT_TRUE("1x^1 + 1" == poly3.to_string()); - ASSERT_EQ(static_cast(3), encoder.decode_uint64(poly3)); - - Plaintext poly4 = encoder.encode(static_cast(0xFFFFFFFFFFFFFFFF)); - ASSERT_EQ(64ULL, poly4.coeff_count()); - for (size_t i = 0; i < 64; ++i) - { - ASSERT_TRUE(poly4[i] == 1); - } - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), encoder.decode_uint64(poly4)); - - Plaintext poly5 = encoder.encode(static_cast(0x80F02)); - ASSERT_EQ(20ULL, poly5.coeff_count()); - for (size_t i = 0; i < 20; ++i) - { - if (i == 19 || (i >= 8 && i <= 11) || i == 1) - { - ASSERT_TRUE(poly5[i] == 1); - } - else - { - ASSERT_TRUE(poly5[i] == 0); - } - } - ASSERT_EQ(static_cast(0x80F02), encoder.decode_uint64(poly5)); - - Plaintext poly6(3); - poly6[0] = 1; - poly6[1] = 500; - poly6[2] = 1023; - ASSERT_EQ(static_cast(1 + 500 * 2 + 1023 * 4), encoder.decode_uint64(poly6)); - - modulus = 1024; - parms.set_plain_modulus(modulus); - auto context2 = SEALContext::Create(parms); - IntegerEncoder encoder2(context2); - Plaintext poly7(4); - poly7[0] = 1023; // -1 (*1) - poly7[1] = 512; // -512 (*2) - poly7[2] = 511; // 511 (*4) - poly7[3] = 1; // 1 (*8) - ASSERT_EQ(static_cast(-1 + -512 * 2 + 511 * 4 + 1 * 8), encoder2.decode_uint64(poly7)); - } - - TEST(Encoder, IntEncodeDecodeUInt32) - { - SmallModulus modulus(0xFFFFFFFFFFFFFFF); - EncryptionParameters parms(scheme_type::BFV); - parms.set_plain_modulus(modulus); - auto context = SEALContext::Create(parms); - IntegerEncoder encoder(context); - - Plaintext poly = encoder.encode(static_cast(0)); - ASSERT_EQ(0ULL, poly.significant_coeff_count()); - ASSERT_TRUE(poly.is_zero()); - ASSERT_EQ(static_cast(0), encoder.decode_uint32(poly)); - - Plaintext poly1 = encoder.encode(static_cast(1)); - ASSERT_EQ(1ULL, poly1.significant_coeff_count()); - ASSERT_TRUE("1" == poly1.to_string()); - ASSERT_EQ(static_cast(1), encoder.decode_uint32(poly1)); - - Plaintext poly2 = encoder.encode(static_cast(2)); - ASSERT_EQ(2ULL, poly2.significant_coeff_count()); - ASSERT_TRUE("1x^1" == poly2.to_string()); - ASSERT_EQ(static_cast(2), encoder.decode_uint32(poly2)); - - Plaintext poly3 = encoder.encode(static_cast(3)); - ASSERT_EQ(2ULL, poly3.significant_coeff_count()); - ASSERT_TRUE("1x^1 + 1" == poly3.to_string()); - ASSERT_EQ(static_cast(3), encoder.decode_uint32(poly3)); - - Plaintext poly4 = encoder.encode(static_cast(0xFFFFFFFF)); - ASSERT_EQ(32ULL, poly4.significant_coeff_count()); - for (size_t i = 0; i < 32; ++i) - { - ASSERT_TRUE(1 == poly4[i]); - } - ASSERT_EQ(static_cast(0xFFFFFFFF), encoder.decode_uint32(poly4)); - - Plaintext poly5 = encoder.encode(static_cast(0x80F02)); - ASSERT_EQ(20ULL, poly5.significant_coeff_count()); - for (size_t i = 0; i < 20; ++i) - { - if (i == 19 || (i >= 8 && i <= 11) || i == 1) - { - ASSERT_TRUE(1 == poly5[i]); - } - else - { - ASSERT_TRUE(poly5[i] == 0); - } - } - ASSERT_EQ(static_cast(0x80F02), encoder.decode_uint32(poly5)); - - Plaintext poly6(3); - poly6[0] = 1; - poly6[1] = 500; - poly6[2] = 1023; - ASSERT_EQ(static_cast(1 + 500 * 2 + 1023 * 4), encoder.decode_uint32(poly6)); - - modulus = 1024; - parms.set_plain_modulus(modulus); - auto context2 = SEALContext::Create(parms); - IntegerEncoder encoder2(context2); - Plaintext poly7(4); - poly7[0] = 1023; // -1 (*1) - poly7[1] = 512; // -512 (*2) - poly7[2] = 511; // 511 (*4) - poly7[3] = 1; // 1 (*8) - ASSERT_EQ(static_cast(-1 + -512 * 2 + 511 * 4 + 1 * 8), encoder2.decode_uint32(poly7)); - } - - TEST(Encoder, IntEncodeDecodeInt64) - { - SmallModulus modulus(0x7FFFFFFFFFFFF); - EncryptionParameters parms(scheme_type::BFV); - parms.set_plain_modulus(modulus); - auto context = SEALContext::Create(parms); - IntegerEncoder encoder(context); - - Plaintext poly = encoder.encode(static_cast(0)); - ASSERT_EQ(0ULL, poly.significant_coeff_count()); - ASSERT_TRUE(poly.is_zero()); - ASSERT_EQ(static_cast(0), static_cast(encoder.decode_int64(poly))); - - Plaintext poly1 = encoder.encode(static_cast(1)); - ASSERT_EQ(1ULL, poly1.significant_coeff_count()); - ASSERT_TRUE("1" == poly1.to_string()); - ASSERT_EQ(1ULL, static_cast(encoder.decode_int64(poly1))); - - Plaintext poly2 = encoder.encode(static_cast(2)); - ASSERT_EQ(2ULL, poly2.significant_coeff_count()); - ASSERT_TRUE("1x^1" == poly2.to_string()); - ASSERT_EQ(static_cast(2), static_cast(encoder.decode_int64(poly2))); - - Plaintext poly3 = encoder.encode(static_cast(3)); - ASSERT_EQ(2ULL, poly3.significant_coeff_count()); - ASSERT_TRUE("1x^1 + 1" == poly3.to_string()); - ASSERT_EQ(static_cast(3), static_cast(encoder.decode_int64(poly3))); - - Plaintext poly4 = encoder.encode(static_cast(-1)); - ASSERT_EQ(1ULL, poly4.significant_coeff_count()); - ASSERT_TRUE("7FFFFFFFFFFFE" == poly4.to_string()); - ASSERT_EQ(static_cast(-1), static_cast(encoder.decode_int64(poly4))); - - Plaintext poly5 = encoder.encode(static_cast(-2)); - ASSERT_EQ(2ULL, poly5.significant_coeff_count()); - ASSERT_TRUE("7FFFFFFFFFFFEx^1" == poly5.to_string()); - ASSERT_EQ(static_cast(-2), static_cast(encoder.decode_int64(poly5))); - - Plaintext poly6 = encoder.encode(static_cast(-3)); - ASSERT_EQ(2ULL, poly6.significant_coeff_count()); - ASSERT_TRUE("7FFFFFFFFFFFEx^1 + 7FFFFFFFFFFFE" == poly6.to_string()); - ASSERT_EQ(static_cast(-3), static_cast(encoder.decode_int64(poly6))); - - Plaintext poly7 = encoder.encode(static_cast(0x7FFFFFFFFFFFF)); - ASSERT_EQ(51ULL, poly7.significant_coeff_count()); - for (size_t i = 0; i < 51; ++i) - { - ASSERT_TRUE(1 == poly7[i]); - } - ASSERT_EQ(static_cast(0x7FFFFFFFFFFFF), static_cast(encoder.decode_int64(poly7))); - - Plaintext poly8 = encoder.encode(static_cast(0x8000000000000)); - ASSERT_EQ(52ULL, poly8.significant_coeff_count()); - ASSERT_TRUE(poly8[51] == 1); - for (size_t i = 0; i < 51; ++i) - { - ASSERT_TRUE(poly8[i] == 0); - } - ASSERT_EQ(static_cast(0x8000000000000), static_cast(encoder.decode_int64(poly8))); - - Plaintext poly9 = encoder.encode(static_cast(0x80F02)); - ASSERT_EQ(20ULL, poly9.significant_coeff_count()); - for (size_t i = 0; i < 20; ++i) - { - if (i == 19 || (i >= 8 && i <= 11) || i == 1) - { - ASSERT_TRUE(1 == poly9[i]); - } - else - { - ASSERT_TRUE(poly9[i] == 0); - } - } - ASSERT_EQ(static_cast(0x80F02), static_cast(encoder.decode_int64(poly9))); - - Plaintext poly10 = encoder.encode(static_cast(-1073)); - ASSERT_EQ(11ULL, poly10.significant_coeff_count()); - ASSERT_TRUE(0x7FFFFFFFFFFFE == poly10[10]); - ASSERT_TRUE(poly10[9] == 0); - ASSERT_TRUE(poly10[8] == 0); - ASSERT_TRUE(poly10[7] == 0); - ASSERT_TRUE(poly10[6] == 0); - ASSERT_TRUE(0x7FFFFFFFFFFFE == poly10[5]); - ASSERT_TRUE(0x7FFFFFFFFFFFE == poly10[4]); - ASSERT_TRUE(poly10[3] == 0); - ASSERT_TRUE(poly10[2] == 0); - ASSERT_TRUE(poly10[1] == 0); - ASSERT_TRUE(0x7FFFFFFFFFFFE == poly10[0]); - ASSERT_EQ(static_cast(-1073), static_cast(encoder.decode_int64(poly10))); - - modulus = 0xFFFF; - parms.set_plain_modulus(modulus); - auto context2 = SEALContext::Create(parms); - IntegerEncoder encoder2(context2); - Plaintext poly11(6); - poly11[0] = 1; - poly11[1] = 0xFFFE; // -1 - poly11[2] = 0xFFFD; // -2 - poly11[3] = 0x8000; // -32767 - poly11[4] = 0x7FFF; // 32767 - poly11[5] = 0x7FFE; // 32766 - ASSERT_EQ(static_cast(1 + -1 * 2 + -2 * 4 + -32767 * 8 + 32767 * 16 + 32766 * 32), static_cast(encoder2.decode_int64(poly11))); - } - - TEST(Encoder, IntEncodeDecodeInt32) - { - SmallModulus modulus(0x7FFFFFFFFFFFFF); - EncryptionParameters parms(scheme_type::BFV); - parms.set_plain_modulus(modulus); - auto context = SEALContext::Create(parms); - IntegerEncoder encoder(context); - - Plaintext poly = encoder.encode(static_cast(0)); - ASSERT_EQ(0ULL, poly.significant_coeff_count()); - ASSERT_TRUE(poly.is_zero()); - ASSERT_EQ(static_cast(0), encoder.decode_int32(poly)); - - Plaintext poly1 = encoder.encode(static_cast(1)); - ASSERT_EQ(1ULL, poly1.significant_coeff_count()); - ASSERT_TRUE("1" == poly1.to_string()); - ASSERT_EQ(static_cast(1), encoder.decode_int32(poly1)); - - Plaintext poly2 = encoder.encode(static_cast(2)); - ASSERT_EQ(2ULL, poly2.significant_coeff_count()); - ASSERT_TRUE("1x^1" == poly2.to_string()); - ASSERT_EQ(static_cast(2), encoder.decode_int32(poly2)); - - Plaintext poly3 = encoder.encode(static_cast(3)); - ASSERT_EQ(2ULL, poly3.significant_coeff_count()); - ASSERT_TRUE("1x^1 + 1" == poly3.to_string()); - ASSERT_EQ(static_cast(3), encoder.decode_int32(poly3)); - - Plaintext poly4 = encoder.encode(static_cast(-1)); - ASSERT_EQ(1ULL, poly4.significant_coeff_count()); - ASSERT_TRUE("7FFFFFFFFFFFFE" == poly4.to_string()); - ASSERT_EQ(static_cast(-1), encoder.decode_int32(poly4)); - - Plaintext poly5 = encoder.encode(static_cast(-2)); - ASSERT_EQ(2ULL, poly5.significant_coeff_count()); - ASSERT_TRUE("7FFFFFFFFFFFFEx^1" == poly5.to_string()); - ASSERT_EQ(static_cast(-2), encoder.decode_int32(poly5)); - - Plaintext poly6 = encoder.encode(static_cast(-3)); - ASSERT_EQ(2ULL, poly6.significant_coeff_count()); - ASSERT_TRUE("7FFFFFFFFFFFFEx^1 + 7FFFFFFFFFFFFE" == poly6.to_string()); - ASSERT_EQ(static_cast(-3), encoder.decode_int32(poly6)); - - Plaintext poly7 = encoder.encode(static_cast(0x7FFFFFFF)); - ASSERT_EQ(31ULL, poly7.significant_coeff_count()); - for (size_t i = 0; i < 31; ++i) - { - ASSERT_TRUE(1 == poly7[i]); - } - ASSERT_EQ(static_cast(0x7FFFFFFF), encoder.decode_int32(poly7)); - - Plaintext poly8 = encoder.encode(static_cast(0x80000000)); - ASSERT_EQ(32ULL, poly8.significant_coeff_count()); - ASSERT_TRUE(0x7FFFFFFFFFFFFE == poly8[31]); - for (size_t i = 0; i < 31; ++i) - { - ASSERT_TRUE(poly8[i] == 0); - } - ASSERT_EQ(static_cast(0x80000000), encoder.decode_int32(poly8)); - - Plaintext poly9 = encoder.encode(static_cast(0x80F02)); - ASSERT_EQ(20ULL, poly9.significant_coeff_count()); - for (size_t i = 0; i < 20; ++i) - { - if (i == 19 || (i >= 8 && i <= 11) || i == 1) - { - ASSERT_TRUE(1 == poly9[i]); - } - else - { - ASSERT_TRUE(poly9[i] == 0); - } - } - ASSERT_EQ(static_cast(0x80F02), encoder.decode_int32(poly9)); - - Plaintext poly10 = encoder.encode(static_cast(-1073)); - ASSERT_EQ(11ULL, poly10.significant_coeff_count()); - ASSERT_TRUE(0x7FFFFFFFFFFFFE == poly10[10]); - ASSERT_TRUE(poly10[9] == 0); - ASSERT_TRUE(poly10[8] == 0); - ASSERT_TRUE(poly10[7] == 0); - ASSERT_TRUE(poly10[6] == 0); - ASSERT_TRUE(0x7FFFFFFFFFFFFE == poly10[5]); - ASSERT_TRUE(0x7FFFFFFFFFFFFE == poly10[4]); - ASSERT_TRUE(poly10[3] == 0); - ASSERT_TRUE(poly10[2] == 0); - ASSERT_TRUE(poly10[1] == 0); - ASSERT_TRUE(0x7FFFFFFFFFFFFE == poly10[0]); - ASSERT_EQ(static_cast(-1073), encoder.decode_int32(poly10)); - - modulus = 0xFFFF; - parms.set_plain_modulus(modulus); - auto context2 = SEALContext::Create(parms); - IntegerEncoder encoder2(context2); - Plaintext poly11(6); - poly11[0] = 1; - poly11[1] = 0xFFFE; // -1 - poly11[2] = 0xFFFD; // -2 - poly11[3] = 0x8000; // -32767 - poly11[4] = 0x7FFF; // 32767 - poly11[5] = 0x7FFE; // 32766 - ASSERT_EQ(static_cast(1 + -1 * 2 + -2 * 4 + -32767 * 8 + 32767 * 16 + 32766 * 32), encoder2.decode_int32(poly11)); - } -} diff --git a/SEAL/native/tests/seal/keygenerator.cpp b/SEAL/native/tests/seal/keygenerator.cpp deleted file mode 100644 index 6af58c7..0000000 --- a/SEAL/native/tests/seal/keygenerator.cpp +++ /dev/null @@ -1,323 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/context.h" -#include "seal/keygenerator.h" -#include "seal/util/polycore.h" -#include "seal/encryptor.h" -#include "seal/decryptor.h" -#include "seal/evaluator.h" - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - TEST(KeyGeneratorTest, BFVKeyGeneration) - { - EncryptionParameters parms(scheme_type::BFV); - { - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(1 << 6); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - RelinKeys evk = keygen.relin_keys(); - ASSERT_TRUE(evk.parms_id() == context->key_parms_id()); - ASSERT_EQ(1ULL, evk.key(2).size()); - for (size_t j = 0; j < evk.size(); j++) - { - for (size_t i = 0; i < evk.key(j + 2).size(); i++) - { - for (size_t k = 0; k < evk.key(j + 2)[i].data().size(); k++) - { - ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data().data(k), evk.key(j + 2)[i].data().poly_modulus_degree(), evk.key(j + 2)[i].data().coeff_mod_count())); - } - } - } - - GaloisKeys galks = keygen.galois_keys(); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_EQ(1ULL, galks.key(3).size()); - ASSERT_EQ(10ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 1, 3, 5, 7 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_TRUE(galks.has_key(1)); - ASSERT_TRUE(galks.has_key(3)); - ASSERT_TRUE(galks.has_key(5)); - ASSERT_TRUE(galks.has_key(7)); - ASSERT_FALSE(galks.has_key(9)); - ASSERT_FALSE(galks.has_key(127)); - ASSERT_EQ(1ULL, galks.key(1).size()); - ASSERT_EQ(1ULL, galks.key(3).size()); - ASSERT_EQ(1ULL, galks.key(5).size()); - ASSERT_EQ(1ULL, galks.key(7).size()); - ASSERT_EQ(4ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 1 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_TRUE(galks.has_key(1)); - ASSERT_FALSE(galks.has_key(3)); - ASSERT_FALSE(galks.has_key(127)); - ASSERT_EQ(1ULL, galks.key(1).size()); - ASSERT_EQ(1ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 127 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_FALSE(galks.has_key(1)); - ASSERT_TRUE(galks.has_key(127)); - ASSERT_EQ(1ULL, galks.key(127).size()); - ASSERT_EQ(1ULL, galks.size()); - } - { - parms.set_poly_modulus_degree(256); - parms.set_plain_modulus(1 << 6); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 30, 30 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - RelinKeys evk = keygen.relin_keys(); - ASSERT_TRUE(evk.parms_id() == context->key_parms_id()); - ASSERT_EQ(2ULL, evk.key(2).size()); - for (size_t j = 0; j < evk.size(); j++) - { - for (size_t i = 0; i < evk.key(j + 2).size(); i++) - { - for (size_t k = 0; k < evk.key(j + 2)[i].data().size(); k++) - { - ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data().data(k), evk.key(j + 2)[i].data().poly_modulus_degree(), evk.key(j + 2)[i].data().coeff_mod_count())); - } - } - } - - GaloisKeys galks = keygen.galois_keys(); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_EQ(2ULL, galks.key(3).size()); - ASSERT_EQ(14ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 1, 3, 5, 7 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_TRUE(galks.has_key(1)); - ASSERT_TRUE(galks.has_key(3)); - ASSERT_TRUE(galks.has_key(5)); - ASSERT_TRUE(galks.has_key(7)); - ASSERT_FALSE(galks.has_key(9)); - ASSERT_FALSE(galks.has_key(511)); - ASSERT_EQ(2ULL, galks.key(1).size()); - ASSERT_EQ(2ULL, galks.key(3).size()); - ASSERT_EQ(2ULL, galks.key(5).size()); - ASSERT_EQ(2ULL, galks.key(7).size()); - ASSERT_EQ(4ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 1 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_TRUE(galks.has_key(1)); - ASSERT_FALSE(galks.has_key(3)); - ASSERT_FALSE(galks.has_key(511)); - ASSERT_EQ(2ULL, galks.key(1).size()); - ASSERT_EQ(1ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 511 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_FALSE(galks.has_key(1)); - ASSERT_TRUE(galks.has_key(511)); - ASSERT_EQ(2ULL, galks.key(511).size()); - ASSERT_EQ(1ULL, galks.size()); - } - } - - TEST(KeyGeneratorTest, CKKSKeyGeneration) - { - EncryptionParameters parms(scheme_type::CKKS); - { - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - RelinKeys evk = keygen.relin_keys(); - ASSERT_TRUE(evk.parms_id() == context->key_parms_id()); - ASSERT_EQ(1ULL, evk.key(2).size()); - for (size_t j = 0; j < evk.size(); j++) - { - for (size_t i = 0; i < evk.key(j + 2).size(); i++) - { - for (size_t k = 0; k < evk.key(j + 2)[i].data().size(); k++) - { - ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data().data(k), evk.key(j + 2)[i].data().poly_modulus_degree(), evk.key(j + 2)[i].data().coeff_mod_count())); - } - } - } - - GaloisKeys galks = keygen.galois_keys(); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_EQ(1ULL, galks.key(3).size()); - ASSERT_EQ(10ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 1, 3, 5, 7 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_TRUE(galks.has_key(1)); - ASSERT_TRUE(galks.has_key(3)); - ASSERT_TRUE(galks.has_key(5)); - ASSERT_TRUE(galks.has_key(7)); - ASSERT_FALSE(galks.has_key(9)); - ASSERT_FALSE(galks.has_key(127)); - ASSERT_EQ(1ULL, galks.key(1).size()); - ASSERT_EQ(1ULL, galks.key(3).size()); - ASSERT_EQ(1ULL, galks.key(5).size()); - ASSERT_EQ(1ULL, galks.key(7).size()); - ASSERT_EQ(4ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 1 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_TRUE(galks.has_key(1)); - ASSERT_FALSE(galks.has_key(3)); - ASSERT_FALSE(galks.has_key(127)); - ASSERT_EQ(1ULL, galks.key(1).size()); - ASSERT_EQ(1ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 127 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_FALSE(galks.has_key(1)); - ASSERT_TRUE(galks.has_key(127)); - ASSERT_EQ(1ULL, galks.key(127).size()); - ASSERT_EQ(1ULL, galks.size()); - } - { - parms.set_poly_modulus_degree(256); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 30, 30 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - RelinKeys evk = keygen.relin_keys(); - ASSERT_TRUE(evk.parms_id() == context->key_parms_id()); - ASSERT_EQ(2ULL, evk.key(2).size()); - for (size_t j = 0; j < evk.size(); j++) - { - for (size_t i = 0; i < evk.key(j + 2).size(); i++) - { - for (size_t k = 0; k < evk.key(j + 2)[i].data().size(); k++) - { - ASSERT_FALSE(is_zero_poly(evk.key(j + 2)[i].data().data(k), evk.key(j + 2)[i].data().poly_modulus_degree(), evk.key(j + 2)[i].data().coeff_mod_count())); - } - } - } - - GaloisKeys galks = keygen.galois_keys(); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_EQ(2ULL, galks.key(3).size()); - ASSERT_EQ(14ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 1, 3, 5, 7 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_TRUE(galks.has_key(1)); - ASSERT_TRUE(galks.has_key(3)); - ASSERT_TRUE(galks.has_key(5)); - ASSERT_TRUE(galks.has_key(7)); - ASSERT_FALSE(galks.has_key(9)); - ASSERT_FALSE(galks.has_key(511)); - ASSERT_EQ(2ULL, galks.key(1).size()); - ASSERT_EQ(2ULL, galks.key(3).size()); - ASSERT_EQ(2ULL, galks.key(5).size()); - ASSERT_EQ(2ULL, galks.key(7).size()); - ASSERT_EQ(4ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 1 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_TRUE(galks.has_key(1)); - ASSERT_FALSE(galks.has_key(3)); - ASSERT_FALSE(galks.has_key(511)); - ASSERT_EQ(2ULL, galks.key(1).size()); - ASSERT_EQ(1ULL, galks.size()); - - galks = keygen.galois_keys(vector{ 511 }); - ASSERT_TRUE(galks.parms_id() == context->key_parms_id()); - ASSERT_FALSE(galks.has_key(1)); - ASSERT_TRUE(galks.has_key(511)); - ASSERT_EQ(2ULL, galks.key(511).size()); - ASSERT_EQ(1ULL, galks.size()); - } - } - - TEST(KeyGeneratorTest, Constructors) - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(128); - parms.set_plain_modulus(1 << 6); - parms.set_coeff_modulus(CoeffModulus::Create(128, { 60, 50, 40 })); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - Evaluator evaluator(context); - - KeyGenerator keygen(context); - auto pk = keygen.public_key(); - auto sk = keygen.secret_key(); - RelinKeys rlk = keygen.relin_keys(); - GaloisKeys galk = keygen.galois_keys(); - - ASSERT_TRUE(is_valid_for(rlk, context)); - ASSERT_TRUE(is_valid_for(galk, context)); - - Encryptor encryptor(context, pk); - Decryptor decryptor(context, sk); - Plaintext pt("1x^2 + 2"), ptres; - Ciphertext ct; - encryptor.encrypt(pt, ct); - evaluator.square_inplace(ct); - evaluator.relinearize_inplace(ct, rlk); - decryptor.decrypt(ct, ptres); - ASSERT_EQ("1x^4 + 4x^2 + 4", ptres.to_string()); - - KeyGenerator keygen2(context, sk); - auto sk2 = keygen.secret_key(); - auto pk2 = keygen2.public_key(); - ASSERT_EQ(sk2.data(), sk.data()); - - RelinKeys rlk2 = keygen2.relin_keys(); - GaloisKeys galk2 = keygen2.galois_keys(); - - ASSERT_TRUE(is_valid_for(rlk2, context)); - ASSERT_TRUE(is_valid_for(galk2, context)); - - Encryptor encryptor2(context, pk2); - Decryptor decryptor2(context, sk2); - pt = "1x^2 + 2"; - ptres.set_zero(); - encryptor.encrypt(pt, ct); - evaluator.square_inplace(ct); - evaluator.relinearize_inplace(ct, rlk2); - decryptor.decrypt(ct, ptres); - ASSERT_EQ("1x^4 + 4x^2 + 4", ptres.to_string()); - - KeyGenerator keygen3(context, sk2, pk2); - auto sk3 = keygen3.secret_key(); - auto pk3 = keygen3.public_key(); - ASSERT_EQ(sk3.data(), sk2.data()); - for (size_t i = 0; i < pk3.data().uint64_count(); i++) - { - ASSERT_EQ(pk3.data().data()[i], pk2.data().data()[i]); - } - - RelinKeys rlk3 = keygen3.relin_keys(); - GaloisKeys galk3 = keygen3.galois_keys(); - - ASSERT_TRUE(is_valid_for(rlk3, context)); - ASSERT_TRUE(is_valid_for(galk3, context)); - - Encryptor encryptor3(context, pk3); - Decryptor decryptor3(context, sk3); - pt = "1x^2 + 2"; - ptres.set_zero(); - encryptor.encrypt(pt, ct); - evaluator.square_inplace(ct); - evaluator.relinearize_inplace(ct, rlk3); - decryptor.decrypt(ct, ptres); - ASSERT_EQ("1x^4 + 4x^2 + 4", ptres.to_string()); - } -} \ No newline at end of file diff --git a/SEAL/native/tests/seal/memorymanager.cpp b/SEAL/native/tests/seal/memorymanager.cpp deleted file mode 100644 index 08fb4c1..0000000 --- a/SEAL/native/tests/seal/memorymanager.cpp +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/pointer.h" -#include "seal/memorymanager.h" -#include "seal/intarray.h" -#include "seal/util/uintcore.h" - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - TEST(MemoryPoolHandleTest, MemoryPoolHandleConstructAssign) - { - MemoryPoolHandle pool; - ASSERT_FALSE(pool); - pool = MemoryPoolHandle::Global(); - ASSERT_TRUE(&static_cast(pool) == global_variables::global_memory_pool.get()); - pool = MemoryPoolHandle::New(); - ASSERT_FALSE(&pool.operator seal::util::MemoryPool &() == global_variables::global_memory_pool.get()); - MemoryPoolHandle pool2 = MemoryPoolHandle::New(); - ASSERT_FALSE(pool == pool2); - - pool = pool2; - ASSERT_TRUE(pool == pool2); - pool = MemoryPoolHandle::Global(); - ASSERT_FALSE(pool == pool2); - pool2 = MemoryPoolHandle::Global(); - ASSERT_TRUE(pool == pool2); - } - - TEST(MemoryPoolHandleTest, MemoryPoolHandleAllocate) - { - MemoryPoolHandle pool = MemoryPoolHandle::New(); - ASSERT_TRUE(0LL == pool.alloc_byte_count()); - { - auto ptr(allocate_uint(5, pool)); - ASSERT_TRUE(5LL * bytes_per_uint64 == pool.alloc_byte_count()); - } - - pool = MemoryPoolHandle::New(); - ASSERT_TRUE(0LL * bytes_per_uint64 == pool.alloc_byte_count()); - { - auto ptr(allocate_uint(5, pool)); - ASSERT_TRUE(5LL * bytes_per_uint64 == pool.alloc_byte_count()); - - ptr = allocate_uint(8, pool); - ASSERT_TRUE(13LL * bytes_per_uint64 == pool.alloc_byte_count()); - - auto ptr2(allocate_uint(2, pool)); - ASSERT_TRUE(15LL * bytes_per_uint64 == pool.alloc_byte_count()); - } - } - - TEST(MemoryPoolHandleTest, UseCount) - { - MemoryPoolHandle pool = MemoryPoolHandle::New(); - ASSERT_EQ(1L, pool.use_count()); - { - IntArray arr(pool); - ASSERT_EQ(2L, pool.use_count()); - IntArray arr2(pool); - ASSERT_EQ(3L, pool.use_count()); - } - ASSERT_EQ(1L, pool.use_count()); - } -} \ No newline at end of file diff --git a/SEAL/native/tests/seal/modulus.cpp b/SEAL/native/tests/seal/modulus.cpp deleted file mode 100644 index 1c42de9..0000000 --- a/SEAL/native/tests/seal/modulus.cpp +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include "gtest/gtest.h" -#include "seal/modulus.h" -#include "seal/util/uintcore.h" - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - TEST(CoeffModTest, CustomExceptionTest) - { - // Too small poly_modulus_degree - ASSERT_THROW(auto modulus = CoeffModulus::Create(1, { 2 }), invalid_argument); - - // Too large poly_modulus_degree - ASSERT_THROW(auto modulus = CoeffModulus::Create(65536, { 30 }), invalid_argument); - - // Invalid poly_modulus_degree - ASSERT_THROW(auto modulus = CoeffModulus::Create(1023, { 20 }), invalid_argument); - - // Invalid bit-size - ASSERT_THROW(auto modulus = CoeffModulus::Create(2048, { 0 }), invalid_argument); - ASSERT_THROW(auto modulus = CoeffModulus::Create(2048, { -30 }), invalid_argument); - ASSERT_THROW(auto modulus = CoeffModulus::Create(2048, { 30, -30 }), invalid_argument); - - // Too small primes requested - ASSERT_THROW(auto modulus = CoeffModulus::Create(2, { 2 }), logic_error); - ASSERT_THROW(auto modulus = CoeffModulus::Create(2, { 3, 3, 3 }), logic_error); - ASSERT_THROW(auto modulus = CoeffModulus::Create(1024, { 8 }), logic_error); - } - - TEST(CoeffModTest, CustomTest) - { - auto cm = CoeffModulus::Create(2, { }); - ASSERT_EQ(0, cm.size()); - - cm = CoeffModulus::Create(2, { 3 }); - ASSERT_EQ(1, cm.size()); - ASSERT_EQ(uint64_t(5), cm[0].value()); - - cm = CoeffModulus::Create(2, { 3, 4 }); - ASSERT_EQ(2, cm.size()); - ASSERT_EQ(uint64_t(5), cm[0].value()); - ASSERT_EQ(uint64_t(13), cm[1].value()); - - cm = CoeffModulus::Create(2, { 3, 5, 4, 5 }); - ASSERT_EQ(4, cm.size()); - ASSERT_EQ(uint64_t(5), cm[0].value()); - ASSERT_EQ(uint64_t(17), cm[1].value()); - ASSERT_EQ(uint64_t(13), cm[2].value()); - ASSERT_EQ(uint64_t(29), cm[3].value()); - - cm = CoeffModulus::Create(32, { 30, 40, 30, 30, 40 }); - ASSERT_EQ(5, cm.size()); - ASSERT_EQ(30, get_significant_bit_count(cm[0].value())); - ASSERT_EQ(40, get_significant_bit_count(cm[1].value())); - ASSERT_EQ(30, get_significant_bit_count(cm[2].value())); - ASSERT_EQ(30, get_significant_bit_count(cm[3].value())); - ASSERT_EQ(40, get_significant_bit_count(cm[4].value())); - ASSERT_EQ(1ULL, cm[0].value() % 64); - ASSERT_EQ(1ULL, cm[1].value() % 64); - ASSERT_EQ(1ULL, cm[2].value() % 64); - ASSERT_EQ(1ULL, cm[3].value() % 64); - ASSERT_EQ(1ULL, cm[4].value() % 64); - } -} diff --git a/SEAL/native/tests/seal/plaintext.cpp b/SEAL/native/tests/seal/plaintext.cpp deleted file mode 100644 index 9e4f821..0000000 --- a/SEAL/native/tests/seal/plaintext.cpp +++ /dev/null @@ -1,158 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include -#include "gtest/gtest.h" -#include "seal/plaintext.h" -#include "seal/evaluator.h" -#include "seal/context.h" -#include "seal/memorymanager.h" -#include "seal/modulus.h" -#include "seal/ckks.h" - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - TEST(PlaintextTest, PlaintextBasics) - { - Plaintext plain(2); - ASSERT_EQ(2ULL, plain.capacity()); - ASSERT_EQ(2ULL, plain.coeff_count()); - ASSERT_EQ(0ULL, plain.significant_coeff_count()); - ASSERT_EQ(0ULL, plain.nonzero_coeff_count()); - ASSERT_FALSE(plain.is_ntt_form()); - - plain[0] = 1; - plain[1] = 2; - - plain.reserve(10); - ASSERT_EQ(10ULL, plain.capacity()); - ASSERT_EQ(2ULL, plain.coeff_count()); - ASSERT_EQ(2ULL, plain.significant_coeff_count()); - ASSERT_EQ(2ULL, plain.nonzero_coeff_count()); - ASSERT_EQ(1ULL, plain[0]); - ASSERT_EQ(2ULL, plain[1]); - ASSERT_FALSE(plain.is_ntt_form()); - - plain.resize(5); - ASSERT_EQ(10ULL, plain.capacity()); - ASSERT_EQ(5ULL, plain.coeff_count()); - ASSERT_EQ(2ULL, plain.significant_coeff_count()); - ASSERT_EQ(2ULL, plain.nonzero_coeff_count()); - ASSERT_EQ(1ULL, plain[0]); - ASSERT_EQ(2ULL, plain[1]); - ASSERT_EQ(0ULL, plain[2]); - ASSERT_EQ(0ULL, plain[3]); - ASSERT_EQ(0ULL, plain[4]); - ASSERT_FALSE(plain.is_ntt_form()); - - Plaintext plain2; - plain2.resize(15); - ASSERT_EQ(15ULL, plain2.capacity()); - ASSERT_EQ(15ULL, plain2.coeff_count()); - ASSERT_EQ(0ULL, plain2.significant_coeff_count()); - ASSERT_EQ(0ULL, plain2.nonzero_coeff_count()); - ASSERT_FALSE(plain.is_ntt_form()); - - plain2 = plain; - ASSERT_EQ(15ULL, plain2.capacity()); - ASSERT_EQ(5ULL, plain2.coeff_count()); - ASSERT_EQ(2ULL, plain2.significant_coeff_count()); - ASSERT_EQ(2ULL, plain2.nonzero_coeff_count()); - ASSERT_EQ(1ULL, plain2[0]); - ASSERT_EQ(2ULL, plain2[1]); - ASSERT_EQ(0ULL, plain2[2]); - ASSERT_EQ(0ULL, plain2[3]); - ASSERT_EQ(0ULL, plain2[4]); - ASSERT_FALSE(plain.is_ntt_form()); - - plain.parms_id() = { 1ULL, 2ULL, 3ULL, 4ULL }; - ASSERT_TRUE(plain.is_ntt_form()); - plain2 = plain; - ASSERT_TRUE(plain == plain2); - plain2.parms_id() = parms_id_zero; - ASSERT_FALSE(plain2.is_ntt_form()); - ASSERT_FALSE(plain == plain2); - plain2.parms_id() = { 1ULL, 2ULL, 3ULL, 5ULL }; - ASSERT_FALSE(plain == plain2); - } - - TEST(PlaintextTest, SaveLoadPlaintext) - { - stringstream stream; - - Plaintext plain; - Plaintext plain2; - plain.save(stream); - plain2.unsafe_load(stream); - ASSERT_TRUE(plain.data() == plain2.data()); - ASSERT_TRUE(plain2.data() == nullptr); - ASSERT_EQ(0ULL, plain2.capacity()); - ASSERT_EQ(0ULL, plain2.coeff_count()); - ASSERT_FALSE(plain2.is_ntt_form()); - - plain.reserve(20); - plain.resize(5); - plain[0] = 1; - plain[1] = 2; - plain[2] = 3; - plain.save(stream); - plain2.unsafe_load(stream); - ASSERT_TRUE(plain.data() != plain2.data()); - ASSERT_EQ(5ULL, plain2.capacity()); - ASSERT_EQ(5ULL, plain2.coeff_count()); - ASSERT_EQ(1ULL, plain2[0]); - ASSERT_EQ(2ULL, plain2[1]); - ASSERT_EQ(3ULL, plain2[2]); - ASSERT_EQ(0ULL, plain2[3]); - ASSERT_EQ(0ULL, plain2[4]); - ASSERT_FALSE(plain2.is_ntt_form()); - - plain.parms_id() = { 1, 2, 3, 4 }; - plain.save(stream); - plain2.unsafe_load(stream); - ASSERT_TRUE(plain2.is_ntt_form()); - ASSERT_TRUE(plain2.parms_id() == plain.parms_id()); - - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 30 })); - - parms.set_plain_modulus(65537); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - - plain.parms_id() = parms_id_zero; - plain = "1x^63 + 2x^62 + Fx^32 + Ax^9 + 1x^1 + 1"; - plain.save(stream); - plain2.load(context, stream); - ASSERT_TRUE(plain.data() != plain2.data()); - ASSERT_FALSE(plain2.is_ntt_form()); - - Evaluator evaluator(context); - evaluator.transform_to_ntt_inplace(plain, context->first_parms_id()); - plain.save(stream); - plain2.load(context, stream); - ASSERT_TRUE(plain.data() != plain2.data()); - ASSERT_TRUE(plain2.is_ntt_form()); - } - { - EncryptionParameters parms(scheme_type::CKKS); - parms.set_poly_modulus_degree(64); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 30, 30 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - CKKSEncoder encoder(context); - - encoder.encode(vector{ 0.1, 2.3, 34.4 }, pow(2.0, 20), plain); - ASSERT_TRUE(plain.is_ntt_form()); - plain.save(stream); - plain2.load(context, stream); - ASSERT_TRUE(plain.data() != plain2.data()); - ASSERT_TRUE(plain2.is_ntt_form()); - } - } -} diff --git a/SEAL/native/tests/seal/publickey.cpp b/SEAL/native/tests/seal/publickey.cpp deleted file mode 100644 index bec3a89..0000000 --- a/SEAL/native/tests/seal/publickey.cpp +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/publickey.h" -#include "seal/context.h" -#include "seal/modulus.h" -#include "seal/keygenerator.h" - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - TEST(PublicKeyTest, SaveLoadPublicKey) - { - stringstream stream; - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(1 << 6); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - PublicKey pk = keygen.public_key(); - ASSERT_TRUE(pk.parms_id() == context->key_parms_id()); - pk.save(stream); - - PublicKey pk2; - pk2.load(context, stream); - - ASSERT_EQ(pk.data().uint64_count(), pk2.data().uint64_count()); - for (size_t i = 0; i < pk.data().uint64_count(); i++) - { - ASSERT_EQ(pk.data().data()[i], pk2.data().data()[i]); - } - ASSERT_TRUE(pk.parms_id() == pk2.parms_id()); - } - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(256); - parms.set_plain_modulus(1 << 20); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 30, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - PublicKey pk = keygen.public_key(); - ASSERT_TRUE(pk.parms_id() == context->key_parms_id()); - pk.save(stream); - - PublicKey pk2; - pk2.load(context, stream); - - ASSERT_EQ(pk.data().uint64_count(), pk2.data().uint64_count()); - for (size_t i = 0; i < pk.data().uint64_count(); i++) - { - ASSERT_EQ(pk.data().data()[i], pk2.data().data()[i]); - } - ASSERT_TRUE(pk.parms_id() == pk2.parms_id()); - } - } -} \ No newline at end of file diff --git a/SEAL/native/tests/seal/randomgen.cpp b/SEAL/native/tests/seal/randomgen.cpp deleted file mode 100644 index 38624b6..0000000 --- a/SEAL/native/tests/seal/randomgen.cpp +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/randomgen.h" -#include "seal/keygenerator.h" -#include -#include -#include - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - namespace - { - class CustomRandomEngine : public UniformRandomGenerator - { - public: - CustomRandomEngine() - { - } - - uint32_t generate() override - { - count_++; - return static_cast(engine_()); - } - - static int count() - { - return count_; - } - - private: - default_random_engine engine_; - - static int count_; - }; - - class CustomRandomEngineFactory : public UniformRandomGeneratorFactory - { - public: - shared_ptr create() override - { - return shared_ptr(new CustomRandomEngine()); - } - }; - - int CustomRandomEngine::count_ = 0; - } - - TEST(RandomGenerator, UniformRandomCreateDefault) - { - shared_ptr generator(UniformRandomGeneratorFactory::default_factory()->create()); - bool lower_half = false; - bool upper_half = false; - bool even = false; - bool odd = false; - for (int i = 0; i < 10; ++i) - { - uint32_t value = generator->generate(); - if (value < UINT32_MAX / 2) - { - lower_half = true; - } - else - { - upper_half = true; - } - if ((value % 2) == 0) - { - even = true; - } - else - { - odd = true; - } - } - ASSERT_TRUE(lower_half); - ASSERT_TRUE(upper_half); - ASSERT_TRUE(even); - ASSERT_TRUE(odd); - } - - TEST(RandomGenerator, StandardRandomAdapterGenerate) - { - StandardRandomAdapter generator; - bool lower_half = false; - bool upper_half = false; - bool even = false; - bool odd = false; - for (int i = 0; i < 10; ++i) - { - uint32_t value = generator.generate(); - if (value < UINT32_MAX / 2) - { - lower_half = true; - } - else - { - upper_half = true; - } - if ((value % 2) == 0) - { - even = true; - } - else - { - odd = true; - } - } - ASSERT_TRUE(lower_half); - ASSERT_TRUE(upper_half); - ASSERT_TRUE(even); - ASSERT_TRUE(odd); - } - - TEST(RandomGenerator, CustomRandomGenerator) - { - shared_ptr factory(new CustomRandomEngineFactory); - - EncryptionParameters parms(scheme_type::BFV); - uint64_t coeff_modulus; - SmallModulus plain_modulus; - coeff_modulus = 0xFFFFFFFFC001; - plain_modulus = 1 << 6; - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(plain_modulus); - parms.set_coeff_modulus({ coeff_modulus }); - parms.set_random_generator(factory); - auto context = SEALContext::Create(parms, false, sec_level_type::none); - - ASSERT_EQ(0, CustomRandomEngine::count()); - - KeyGenerator keygen(context); - - ASSERT_NE(0, CustomRandomEngine::count()); - } -} diff --git a/SEAL/native/tests/seal/randomtostd.cpp b/SEAL/native/tests/seal/randomtostd.cpp deleted file mode 100644 index 09c6185..0000000 --- a/SEAL/native/tests/seal/randomtostd.cpp +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/randomgen.h" -#include "seal/randomtostd.h" -#include -#include - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - TEST(RandomToStandard, RandomToStandardGenerate) - { - shared_ptr generator(UniformRandomGeneratorFactory::default_factory()->create()); - RandomToStandardAdapter rand(generator); - ASSERT_TRUE(rand.generator() == generator); - ASSERT_EQ(static_cast(0), rand.min()); - ASSERT_EQ(static_cast(UINT32_MAX), rand.max()); - bool lower_half = false; - bool upper_half = false; - bool even = false; - bool odd = false; - for (int i = 0; i < 50; i++) - { - uint32_t value = rand(); - if (value < UINT32_MAX / 2) - { - lower_half = true; - } - else - { - upper_half = true; - } - if ((value % 2) == 0) - { - even = true; - } - else - { - odd = true; - } - } - ASSERT_TRUE(lower_half); - ASSERT_TRUE(upper_half); - ASSERT_TRUE(even); - ASSERT_TRUE(odd); - } -} diff --git a/SEAL/native/tests/seal/relinkeys.cpp b/SEAL/native/tests/seal/relinkeys.cpp deleted file mode 100644 index 7764f8e..0000000 --- a/SEAL/native/tests/seal/relinkeys.cpp +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/relinkeys.h" -#include "seal/context.h" -#include "seal/keygenerator.h" -#include "seal/util/uintcore.h" -#include "seal/modulus.h" - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - TEST(RelinKeysTest, RelinKeysSaveLoad) - { - stringstream stream; - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(1 << 6); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - RelinKeys keys; - RelinKeys test_keys; - keys = keygen.relin_keys(); - keys.save(stream); - test_keys.load(context, stream); - ASSERT_EQ(keys.size(), test_keys.size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - for (size_t j = 0; j < test_keys.size(); j++) - { - for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) - { - ASSERT_EQ(keys.key(j + 2)[i].data().size(), test_keys.key(j + 2)[i].data().size()); - ASSERT_EQ(keys.key(j + 2)[i].data().uint64_count(), test_keys.key(j + 2)[i].data().uint64_count()); - ASSERT_TRUE(is_equal_uint_uint(keys.key(j + 2)[i].data().data(), test_keys.key(j + 2)[i].data().data(), keys.key(j + 2)[i].data().uint64_count())); - } - } - } - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(256); - parms.set_plain_modulus(1 << 6); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 60, 50 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - RelinKeys keys; - RelinKeys test_keys; - keys = keygen.relin_keys(); - keys.save(stream); - test_keys.load(context, stream); - ASSERT_EQ(keys.size(), test_keys.size()); - ASSERT_TRUE(keys.parms_id() == test_keys.parms_id()); - for (size_t j = 0; j < test_keys.size(); j++) - { - for (size_t i = 0; i < test_keys.key(j + 2).size(); i++) - { - ASSERT_EQ(keys.key(j + 2)[i].data().size(), test_keys.key(j + 2)[i].data().size()); - ASSERT_EQ(keys.key(j + 2)[i].data().uint64_count(), test_keys.key(j + 2)[i].data().uint64_count()); - ASSERT_TRUE(is_equal_uint_uint(keys.key(j + 2)[i].data().data(), test_keys.key(j + 2)[i].data().data(), keys.key(j + 2)[i].data().uint64_count())); - } - } - } - } -} \ No newline at end of file diff --git a/SEAL/native/tests/seal/secretkey.cpp b/SEAL/native/tests/seal/secretkey.cpp deleted file mode 100644 index 88a4e3a..0000000 --- a/SEAL/native/tests/seal/secretkey.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/secretkey.h" -#include "seal/context.h" -#include "seal/keygenerator.h" -#include "seal/modulus.h" - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - TEST(SecretKeyTest, SaveLoadSecretKey) - { - stringstream stream; - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(64); - parms.set_plain_modulus(1 << 6); - parms.set_coeff_modulus(CoeffModulus::Create(64, { 60 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - SecretKey sk = keygen.secret_key(); - ASSERT_TRUE(sk.parms_id() == context->key_parms_id()); - sk.save(stream); - - SecretKey sk2; - sk2.load(context, stream); - - ASSERT_TRUE(sk.data() == sk2.data()); - ASSERT_TRUE(sk.parms_id() == sk2.parms_id()); - } - { - EncryptionParameters parms(scheme_type::BFV); - parms.set_poly_modulus_degree(256); - parms.set_plain_modulus(1 << 20); - parms.set_coeff_modulus(CoeffModulus::Create(256, { 30, 40 })); - - auto context = SEALContext::Create(parms, false, sec_level_type::none); - KeyGenerator keygen(context); - - SecretKey sk = keygen.secret_key(); - ASSERT_TRUE(sk.parms_id() == context->key_parms_id()); - sk.save(stream); - - SecretKey sk2; - sk2.load(context, stream); - - ASSERT_TRUE(sk.data() == sk2.data()); - ASSERT_TRUE(sk.parms_id() == sk2.parms_id()); - } - } -} \ No newline at end of file diff --git a/SEAL/native/tests/seal/smallmodulus.cpp b/SEAL/native/tests/seal/smallmodulus.cpp deleted file mode 100644 index 50ed116..0000000 --- a/SEAL/native/tests/seal/smallmodulus.cpp +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/smallmodulus.h" - -using namespace seal; -using namespace std; - -namespace SEALTest -{ - TEST(SmallModulusTest, CreateSmallModulus) - { - SmallModulus mod; - ASSERT_TRUE(mod.is_zero()); - ASSERT_EQ(0ULL, mod.value()); - ASSERT_EQ(0, mod.bit_count()); - ASSERT_EQ(1ULL, mod.uint64_count()); - ASSERT_EQ(0ULL, mod.const_ratio()[0]); - ASSERT_EQ(0ULL, mod.const_ratio()[1]); - ASSERT_EQ(0ULL, mod.const_ratio()[2]); - ASSERT_FALSE(mod.is_prime()); - - mod = 3; - ASSERT_FALSE(mod.is_zero()); - ASSERT_EQ(3ULL, mod.value()); - ASSERT_EQ(2, mod.bit_count()); - ASSERT_EQ(1ULL, mod.uint64_count()); - ASSERT_EQ(6148914691236517205ULL, mod.const_ratio()[0]); - ASSERT_EQ(6148914691236517205ULL, mod.const_ratio()[1]); - ASSERT_EQ(1ULL, mod.const_ratio()[2]); - ASSERT_TRUE(mod.is_prime()); - - - SmallModulus mod2(2); - SmallModulus mod3(3); - ASSERT_TRUE(mod != mod2); - ASSERT_TRUE(mod == mod3); - - mod = 0; - ASSERT_TRUE(mod.is_zero()); - ASSERT_EQ(0ULL, mod.value()); - ASSERT_EQ(0, mod.bit_count()); - ASSERT_EQ(1ULL, mod.uint64_count()); - ASSERT_EQ(0ULL, mod.const_ratio()[0]); - ASSERT_EQ(0ULL, mod.const_ratio()[1]); - ASSERT_EQ(0ULL, mod.const_ratio()[2]); - - mod = 0xF00000F00000F; - ASSERT_FALSE(mod.is_zero()); - ASSERT_EQ(0xF00000F00000FULL, mod.value()); - ASSERT_EQ(52, mod.bit_count()); - ASSERT_EQ(1ULL, mod.uint64_count()); - ASSERT_EQ(1224979098644774929ULL, mod.const_ratio()[0]); - ASSERT_EQ(4369ULL, mod.const_ratio()[1]); - ASSERT_EQ(281470698520321ULL, mod.const_ratio()[2]); - ASSERT_FALSE(mod.is_prime()); - - mod = 0xF00000F000079; - ASSERT_FALSE(mod.is_zero()); - ASSERT_EQ(0xF00000F000079ULL, mod.value()); - ASSERT_EQ(52, mod.bit_count()); - ASSERT_EQ(1ULL, mod.uint64_count()); - ASSERT_EQ(1224979096621368355ULL, mod.const_ratio()[0]); - ASSERT_EQ(4369ULL, mod.const_ratio()[1]); - ASSERT_EQ(1144844808538997ULL, mod.const_ratio()[2]); - ASSERT_TRUE(mod.is_prime()); - } - - TEST(SmallModulusTest, CompareSmallModulus) - { - SmallModulus sm0; - SmallModulus sm2(2); - SmallModulus sm5(5); - ASSERT_FALSE(sm0 < sm0); - ASSERT_TRUE(sm0 == sm0); - ASSERT_TRUE(sm0 <= sm0); - ASSERT_TRUE(sm0 >= sm0); - ASSERT_FALSE(sm0 > sm0); - - ASSERT_FALSE(sm5 < sm5); - ASSERT_TRUE(sm5 == sm5); - ASSERT_TRUE(sm5 <= sm5); - ASSERT_TRUE(sm5 >= sm5); - ASSERT_FALSE(sm5 > sm5); - - ASSERT_FALSE(sm5 < sm2); - ASSERT_FALSE(sm5 == sm2); - ASSERT_FALSE(sm5 <= sm2); - ASSERT_TRUE(sm5 >= sm2); - ASSERT_TRUE(sm5 > sm2); - - ASSERT_TRUE(sm5 < 6); - ASSERT_FALSE(sm5 == 6); - ASSERT_TRUE(sm5 <= 6); - ASSERT_FALSE(sm5 >= 6); - ASSERT_FALSE(sm5 > 6); - } - - TEST(SmallModulusTest, SaveLoadSmallModulus) - { - stringstream stream; - - SmallModulus mod; - mod.save(stream); - - SmallModulus mod2; - mod2.load(stream); - ASSERT_EQ(mod2.value(), mod.value()); - ASSERT_EQ(mod2.bit_count(), mod.bit_count()); - ASSERT_EQ(mod2.uint64_count(), mod.uint64_count()); - ASSERT_EQ(mod2.const_ratio()[0], mod.const_ratio()[0]); - ASSERT_EQ(mod2.const_ratio()[1], mod.const_ratio()[1]); - ASSERT_EQ(mod2.const_ratio()[2], mod.const_ratio()[2]); - ASSERT_EQ(mod2.is_prime(), mod.is_prime()); - - mod = 3; - mod.save(stream); - mod2.load(stream); - ASSERT_EQ(mod2.value(), mod.value()); - ASSERT_EQ(mod2.bit_count(), mod.bit_count()); - ASSERT_EQ(mod2.uint64_count(), mod.uint64_count()); - ASSERT_EQ(mod2.const_ratio()[0], mod.const_ratio()[0]); - ASSERT_EQ(mod2.const_ratio()[1], mod.const_ratio()[1]); - ASSERT_EQ(mod2.const_ratio()[2], mod.const_ratio()[2]); - ASSERT_EQ(mod2.is_prime(), mod.is_prime()); - - mod = 0xF00000F00000F; - mod.save(stream); - mod2.load(stream); - ASSERT_EQ(mod2.value(), mod.value()); - ASSERT_EQ(mod2.bit_count(), mod.bit_count()); - ASSERT_EQ(mod2.uint64_count(), mod.uint64_count()); - ASSERT_EQ(mod2.const_ratio()[0], mod.const_ratio()[0]); - ASSERT_EQ(mod2.const_ratio()[1], mod.const_ratio()[1]); - ASSERT_EQ(mod2.const_ratio()[2], mod.const_ratio()[2]); - ASSERT_EQ(mod2.is_prime(), mod.is_prime()); - - mod = 0xF00000F000079; - mod.save(stream); - mod2.load(stream); - ASSERT_EQ(mod2.value(), mod.value()); - ASSERT_EQ(mod2.bit_count(), mod.bit_count()); - ASSERT_EQ(mod2.uint64_count(), mod.uint64_count()); - ASSERT_EQ(mod2.const_ratio()[0], mod.const_ratio()[0]); - ASSERT_EQ(mod2.const_ratio()[1], mod.const_ratio()[1]); - ASSERT_EQ(mod2.const_ratio()[2], mod.const_ratio()[2]); - ASSERT_EQ(mod2.is_prime(), mod.is_prime()); - } -} \ No newline at end of file diff --git a/SEAL/native/tests/seal/testrunner.cpp b/SEAL/native/tests/seal/testrunner.cpp deleted file mode 100644 index 3868498..0000000 --- a/SEAL/native/tests/seal/testrunner.cpp +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" - -/** -Main entry point for Google Test unit tests. -*/ -int main(int argc, char** argv) -{ - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} \ No newline at end of file diff --git a/SEAL/native/tests/seal/util/CMakeLists.txt b/SEAL/native/tests/seal/util/CMakeLists.txt deleted file mode 100644 index bc37ac6..0000000 --- a/SEAL/native/tests/seal/util/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT license. - -target_sources(sealtest - PRIVATE - ${CMAKE_CURRENT_LIST_DIR}/clipnormal.cpp - ${CMAKE_CURRENT_LIST_DIR}/common.cpp - ${CMAKE_CURRENT_LIST_DIR}/hash.cpp - ${CMAKE_CURRENT_LIST_DIR}/locks.cpp - ${CMAKE_CURRENT_LIST_DIR}/mempool.cpp - ${CMAKE_CURRENT_LIST_DIR}/numth.cpp - ${CMAKE_CURRENT_LIST_DIR}/polyarith.cpp - ${CMAKE_CURRENT_LIST_DIR}/polyarithmod.cpp - ${CMAKE_CURRENT_LIST_DIR}/polyarithsmallmod.cpp - ${CMAKE_CURRENT_LIST_DIR}/polycore.cpp - ${CMAKE_CURRENT_LIST_DIR}/smallntt.cpp - ${CMAKE_CURRENT_LIST_DIR}/stringtouint64.cpp - ${CMAKE_CURRENT_LIST_DIR}/uint64tostring.cpp - ${CMAKE_CURRENT_LIST_DIR}/uintarith.cpp - ${CMAKE_CURRENT_LIST_DIR}/uintarithmod.cpp - ${CMAKE_CURRENT_LIST_DIR}/uintarithsmallmod.cpp - ${CMAKE_CURRENT_LIST_DIR}/uintcore.cpp -) diff --git a/SEAL/native/tests/seal/util/clipnormal.cpp b/SEAL/native/tests/seal/util/clipnormal.cpp deleted file mode 100644 index 7716745..0000000 --- a/SEAL/native/tests/seal/util/clipnormal.cpp +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/randomgen.h" -#include "seal/randomtostd.h" -#include "seal/util/clipnormal.h" -#include -#include - -using namespace seal::util; -using namespace seal; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(ClipNormal, ClipNormalGenerate) - { - shared_ptr generator(UniformRandomGeneratorFactory::default_factory()->create()); - RandomToStandardAdapter rand(generator); - ClippedNormalDistribution dist(50.0, 10.0, 20.0); - - ASSERT_EQ(50.0, dist.mean()); - ASSERT_EQ(10.0, dist.standard_deviation()); - ASSERT_EQ(20.0, dist.max_deviation()); - ASSERT_EQ(30.0, dist.min()); - ASSERT_EQ(70.0, dist.max()); - double average = 0; - double stddev = 0; - for (int i = 0; i < 100; ++i) - { - double value = dist(rand); - average += value; - stddev += (value - 50.0) * (value - 50.0); - ASSERT_TRUE(value >= 30.0 && value <= 70.0); - } - average /= 100; - stddev /= 100; - stddev = sqrt(stddev); - ASSERT_TRUE(average >= 40.0 && average <= 60.0); - ASSERT_TRUE(stddev >= 5.0 && stddev <= 15.0); - } - } -} diff --git a/SEAL/native/tests/seal/util/common.cpp b/SEAL/native/tests/seal/util/common.cpp deleted file mode 100644 index 17d6281..0000000 --- a/SEAL/native/tests/seal/util/common.cpp +++ /dev/null @@ -1,326 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/common.h" -#include - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(Common, Constants) - { - ASSERT_EQ(4, bits_per_nibble); - ASSERT_EQ(8, bits_per_byte); - ASSERT_EQ(4, bytes_per_uint32); - ASSERT_EQ(8, bytes_per_uint64); - ASSERT_EQ(32, bits_per_uint32); - ASSERT_EQ(64, bits_per_uint64); - ASSERT_EQ(2, nibbles_per_byte); - ASSERT_EQ(2, uint32_per_uint64); - ASSERT_EQ(16, nibbles_per_uint64); - ASSERT_EQ(static_cast(INT64_MAX) + 1, uint64_high_bit); - } - - TEST(Common, UnsignedComparisons) - { - int pos_i = 5; - int neg_i = -5; - unsigned pos_u = 6; - signed pos_s = 6; - unsigned char pos_uc = 1; - char neg_c = -1; - char pos_c = 1; - unsigned char pos_uc_max = 0xFF; - unsigned long long pos_ull = 1; - unsigned long long pos_ull_max = 0xFFFFFFFFFFFFFFFF; - long long neg_ull = -1; - - ASSERT_TRUE(unsigned_eq(pos_i, pos_i)); - ASSERT_FALSE(unsigned_eq(pos_i, neg_i)); - ASSERT_TRUE(unsigned_gt(pos_u, pos_i)); - ASSERT_TRUE(unsigned_lt(pos_i, neg_i)); - ASSERT_TRUE(unsigned_geq(pos_u, pos_s)); - ASSERT_TRUE(unsigned_gt(neg_c, pos_c)); - ASSERT_TRUE(unsigned_geq(neg_c, pos_c)); - ASSERT_FALSE(unsigned_eq(neg_c, pos_c)); - ASSERT_FALSE(unsigned_gt(pos_u, neg_c)); - ASSERT_TRUE(unsigned_eq(pos_uc, pos_c)); - ASSERT_TRUE(unsigned_geq(pos_uc, pos_c)); - ASSERT_TRUE(unsigned_leq(pos_uc, pos_c)); - ASSERT_TRUE(unsigned_lt(pos_uc_max, neg_c)); - ASSERT_TRUE(unsigned_eq(neg_c, pos_ull_max)); - ASSERT_TRUE(unsigned_eq(neg_ull, pos_ull_max)); - ASSERT_FALSE(unsigned_lt(neg_ull, pos_ull_max)); - ASSERT_TRUE(unsigned_lt(pos_ull, pos_ull_max)); - } - - TEST(Common, SafeArithmetic) - { - int pos_i = 5; - int neg_i = -5; - unsigned pos_u = 6; - unsigned char pos_uc_max = 0xFF; - unsigned long long pos_ull_max = 0xFFFFFFFFFFFFFFFF; - long long neg_ull = -1; - unsigned long long res_ul; - long long res_l; - - ASSERT_EQ(25, mul_safe(pos_i, pos_i)); - ASSERT_EQ(25, mul_safe(neg_i, neg_i)); - ASSERT_EQ(10, add_safe(pos_i, pos_i)); - ASSERT_EQ(-10, add_safe(neg_i, neg_i)); - ASSERT_EQ(0, add_safe(pos_i, neg_i)); - ASSERT_EQ(0, add_safe(neg_i, pos_i)); - ASSERT_EQ(10, sub_safe(pos_i, neg_i)); - ASSERT_EQ(-10, sub_safe(neg_i, pos_i)); - ASSERT_EQ(unsigned(0), sub_safe(pos_u, pos_u)); - ASSERT_THROW(res_ul = sub_safe(unsigned(0), pos_u), out_of_range); - ASSERT_THROW(res_ul = sub_safe(unsigned(4), pos_u), out_of_range); - ASSERT_THROW(res_ul = add_safe(pos_uc_max, (unsigned char)1), out_of_range); - ASSERT_TRUE(pos_uc_max == add_safe(pos_uc_max, (unsigned char)0)); - ASSERT_THROW(res_ul = mul_safe(pos_ull_max, pos_ull_max), out_of_range); - ASSERT_EQ(0ULL, mul_safe(0ULL, pos_ull_max)); - ASSERT_TRUE((long long)1 == mul_safe(neg_ull, neg_ull)); - ASSERT_THROW(res_ul = mul_safe(pos_uc_max, pos_uc_max), out_of_range); - ASSERT_EQ(15, add_safe(pos_i, -pos_i, pos_i, pos_i, pos_i)); - ASSERT_EQ(6, add_safe(0, -pos_i, pos_i, 1, pos_i)); - ASSERT_EQ(0, mul_safe(pos_i, pos_i, pos_i, 0, pos_i)); - ASSERT_EQ(625, mul_safe(pos_i, pos_i, pos_i, pos_i)); - ASSERT_THROW(res_l = mul_safe( - pos_i, pos_i, pos_i, pos_i, pos_i, pos_i, pos_i, - pos_i, pos_i, pos_i, pos_i, pos_i, pos_i, pos_i), out_of_range); - } - - TEST(Common, FitsIn) - { - int neg_i = -5; - signed pos_s = 6; - unsigned char pos_uc = 1; - unsigned char pos_uc_max = 0xFF; - float f = 1.234f; - double d = -1234; - - ASSERT_TRUE(fits_in(pos_s)); - ASSERT_TRUE(fits_in(pos_uc)); - ASSERT_FALSE(fits_in(neg_i)); - ASSERT_FALSE(fits_in(pos_uc_max)); - ASSERT_TRUE(fits_in(d)); - ASSERT_TRUE(fits_in(f)); - ASSERT_TRUE(fits_in(d)); - ASSERT_TRUE(fits_in(f)); - ASSERT_FALSE(fits_in(d)); - } - - TEST(Common, DivideRoundUp) - { - ASSERT_EQ(0, divide_round_up(0, 4)); - ASSERT_EQ(1, divide_round_up(1, 4)); - ASSERT_EQ(1, divide_round_up(2, 4)); - ASSERT_EQ(1, divide_round_up(3, 4)); - ASSERT_EQ(1, divide_round_up(4, 4)); - ASSERT_EQ(2, divide_round_up(5, 4)); - ASSERT_EQ(2, divide_round_up(6, 4)); - ASSERT_EQ(2, divide_round_up(7, 4)); - ASSERT_EQ(2, divide_round_up(8, 4)); - ASSERT_EQ(3, divide_round_up(9, 4)); - ASSERT_EQ(3, divide_round_up(12, 4)); - ASSERT_EQ(4, divide_round_up(13, 4)); - } - - TEST(Common, GetUInt64Byte) - { - uint64_t number[2]; - number[0] = 0x3456789ABCDEF121; - number[1] = 0x23456789ABCDEF12; - ASSERT_TRUE(SEAL_BYTE(0x21) == *get_uint64_byte(number, 0)); - ASSERT_TRUE(SEAL_BYTE(0xF1) == *get_uint64_byte(number, 1)); - ASSERT_TRUE(SEAL_BYTE(0xDE) == *get_uint64_byte(number, 2)); - ASSERT_TRUE(SEAL_BYTE(0xBC) == *get_uint64_byte(number, 3)); - ASSERT_TRUE(SEAL_BYTE(0x9A) == *get_uint64_byte(number, 4)); - ASSERT_TRUE(SEAL_BYTE(0x78) == *get_uint64_byte(number, 5)); - ASSERT_TRUE(SEAL_BYTE(0x56) == *get_uint64_byte(number, 6)); - ASSERT_TRUE(SEAL_BYTE(0x34) == *get_uint64_byte(number, 7)); - ASSERT_TRUE(SEAL_BYTE(0x12) == *get_uint64_byte(number, 8)); - ASSERT_TRUE(SEAL_BYTE(0xEF) == *get_uint64_byte(number, 9)); - ASSERT_TRUE(SEAL_BYTE(0xCD) == *get_uint64_byte(number, 10)); - ASSERT_TRUE(SEAL_BYTE(0xAB) == *get_uint64_byte(number, 11)); - ASSERT_TRUE(SEAL_BYTE(0x89) == *get_uint64_byte(number, 12)); - ASSERT_TRUE(SEAL_BYTE(0x67) == *get_uint64_byte(number, 13)); - ASSERT_TRUE(SEAL_BYTE(0x45) == *get_uint64_byte(number, 14)); - ASSERT_TRUE(SEAL_BYTE(0x23) == *get_uint64_byte(number, 15)); - } - - template - void ReverseBits32Helper() - { - ASSERT_EQ(static_cast(0), reverse_bits(static_cast(0))); - ASSERT_EQ(static_cast(0x80000000), reverse_bits(static_cast(1))); - ASSERT_EQ(static_cast(0x40000000), reverse_bits(static_cast(2))); - ASSERT_EQ(static_cast(0xC0000000), reverse_bits(static_cast(3))); - ASSERT_EQ(static_cast(0x00010000), reverse_bits(static_cast(0x00008000))); - ASSERT_EQ(static_cast(0xFFFF0000), reverse_bits(static_cast(0x0000FFFF))); - ASSERT_EQ(static_cast(0x0000FFFF), reverse_bits(static_cast(0xFFFF0000))); - ASSERT_EQ(static_cast(0x00008000), reverse_bits(static_cast(0x00010000))); - ASSERT_EQ(static_cast(3), reverse_bits(static_cast(0xC0000000))); - ASSERT_EQ(static_cast(2), reverse_bits(static_cast(0x40000000))); - ASSERT_EQ(static_cast(1), reverse_bits(static_cast(0x80000000))); - ASSERT_EQ(static_cast(0xFFFFFFFF), reverse_bits(static_cast(0xFFFFFFFF))); - - // Reversing a 0-bit item should return 0 - ASSERT_EQ(static_cast(0), reverse_bits(static_cast(0xFFFFFFFF), 0)); - - // Reversing a 32-bit item returns is same as normal reverse - ASSERT_EQ(static_cast(0), reverse_bits(static_cast(0), 32)); - ASSERT_EQ(static_cast(0x80000000), reverse_bits(static_cast(1), 32)); - ASSERT_EQ(static_cast(0x40000000), reverse_bits(static_cast(2), 32)); - ASSERT_EQ(static_cast(0xC0000000), reverse_bits(static_cast(3), 32)); - ASSERT_EQ(static_cast(0x00010000), reverse_bits(static_cast(0x00008000), 32)); - ASSERT_EQ(static_cast(0xFFFF0000), reverse_bits(static_cast(0x0000FFFF), 32)); - ASSERT_EQ(static_cast(0x0000FFFF), reverse_bits(static_cast(0xFFFF0000), 32)); - ASSERT_EQ(static_cast(0x00008000), reverse_bits(static_cast(0x00010000), 32)); - ASSERT_EQ(static_cast(3), reverse_bits(static_cast(0xC0000000), 32)); - ASSERT_EQ(static_cast(2), reverse_bits(static_cast(0x40000000), 32)); - ASSERT_EQ(static_cast(1), reverse_bits(static_cast(0x80000000), 32)); - ASSERT_EQ(static_cast(0xFFFFFFFF), reverse_bits(static_cast(0xFFFFFFFF), 32)); - - // 16-bit reversal - ASSERT_EQ(static_cast(0), reverse_bits(static_cast(0), 16)); - ASSERT_EQ(static_cast(0x00008000), reverse_bits(static_cast(1), 16)); - ASSERT_EQ(static_cast(0x00004000), reverse_bits(static_cast(2), 16)); - ASSERT_EQ(static_cast(0x0000C000), reverse_bits(static_cast(3), 16)); - ASSERT_EQ(static_cast(0x00000001), reverse_bits(static_cast(0x00008000), 16)); - ASSERT_EQ(static_cast(0x0000FFFF), reverse_bits(static_cast(0x0000FFFF), 16)); - ASSERT_EQ(static_cast(0x00000000), reverse_bits(static_cast(0xFFFF0000), 16)); - ASSERT_EQ(static_cast(0x00000000), reverse_bits(static_cast(0x00010000), 16)); - ASSERT_EQ(static_cast(3), reverse_bits(static_cast(0x0000C000), 16)); - ASSERT_EQ(static_cast(2), reverse_bits(static_cast(0x00004000), 16)); - ASSERT_EQ(static_cast(1), reverse_bits(static_cast(0x00008000), 16)); - ASSERT_EQ(static_cast(0x0000FFFF), reverse_bits(static_cast(0xFFFFFFFF), 16)); - } - - TEST(Common, ReverseBits32) - { - ReverseBits32Helper(); - - // Other types -#ifdef SEAL_USE_IF_CONSTEXPR - SEAL_IF_CONSTEXPR (sizeof(unsigned) == 4) - ReverseBits32Helper(); - - SEAL_IF_CONSTEXPR (sizeof(unsigned long) == 4) - ReverseBits32Helper(); - - SEAL_IF_CONSTEXPR (sizeof(unsigned long long) == 4) - ReverseBits32Helper(); - - SEAL_IF_CONSTEXPR (sizeof(size_t) == 4) - ReverseBits32Helper(); -#endif - } - - template - void ReverseBits64Helper() - { - ASSERT_EQ(0ULL, reverse_bits(0ULL)); - ASSERT_EQ(1ULL << 63, reverse_bits(1ULL)); - ASSERT_EQ(1ULL << 32, reverse_bits(1ULL << 31)); - ASSERT_EQ(0xFFFFULL << 32, reverse_bits(0xFFFFULL << 16)); - ASSERT_EQ(0x0000FFFFFFFF0000ULL, reverse_bits(0x0000FFFFFFFF0000ULL)); - ASSERT_EQ(0x0000FFFF0000FFFFULL, reverse_bits(0xFFFF0000FFFF0000ULL)); - - ASSERT_EQ(0ULL, reverse_bits(0ULL, 0)); - ASSERT_EQ(0ULL, reverse_bits(0ULL, 1)); - ASSERT_EQ(0ULL, reverse_bits(0ULL, 32)); - ASSERT_EQ(0ULL, reverse_bits(0ULL, 64)); - - ASSERT_EQ(0ULL, reverse_bits(1ULL, 0)); - ASSERT_EQ(1ULL, reverse_bits(1ULL, 1)); - ASSERT_EQ(1ULL << 31, reverse_bits(1ULL, 32)); - ASSERT_EQ(1ULL << 63, reverse_bits(1ULL, 64)); - - ASSERT_EQ(0ULL, reverse_bits(1ULL << 31, 0)); - ASSERT_EQ(0ULL, reverse_bits(1ULL << 31, 1)); - ASSERT_EQ(1ULL, reverse_bits(1ULL << 31, 32)); - ASSERT_EQ(1ULL << 32, reverse_bits(1ULL << 31, 64)); - - ASSERT_EQ(0ULL, reverse_bits(0xFFFFULL << 16, 0)); - ASSERT_EQ(0ULL, reverse_bits(0xFFFFULL << 16, 1)); - ASSERT_EQ(0xFFFFULL, reverse_bits(0xFFFFULL << 16, 32)); - ASSERT_EQ(0xFFFFULL << 32, reverse_bits(0xFFFFULL << 16, 64)); - - ASSERT_EQ(0ULL, reverse_bits(0x0000FFFFFFFF0000ULL, 0)); - ASSERT_EQ(0ULL, reverse_bits(0x0000FFFFFFFF0000ULL, 1)); - ASSERT_EQ(0xFFFFULL, reverse_bits(0x0000FFFFFFFF0000ULL, 32)); - ASSERT_EQ(0x0000FFFFFFFF0000ULL, reverse_bits(0x0000FFFFFFFF0000ULL, 64)); - - ASSERT_EQ(0ULL, reverse_bits(0xFFFF0000FFFF0000ULL, 0)); - ASSERT_EQ(0ULL, reverse_bits(0xFFFF0000FFFF0000ULL, 1)); - ASSERT_EQ(0xFFFFULL, reverse_bits(0xFFFF0000FFFF0000ULL, 32)); - ASSERT_EQ(0x0000FFFF0000FFFFULL, reverse_bits(0xFFFF0000FFFF0000ULL, 64)); - } - - TEST(Common, ReverseBits64) - { - ReverseBits64Helper(); - - // Other types -#ifdef SEAL_USE_IF_CONSTEXPR - SEAL_IF_CONSTEXPR (sizeof(unsigned) == 8) - ReverseBits64Helper(); - - SEAL_IF_CONSTEXPR (sizeof(unsigned long) == 8) - ReverseBits64Helper(); - - SEAL_IF_CONSTEXPR (sizeof(unsigned long long) == 8) - ReverseBits64Helper(); - - SEAL_IF_CONSTEXPR (sizeof(size_t) == 8) - ReverseBits64Helper(); -#endif - } - - TEST(Common, GetSignificantBitCount) - { - ASSERT_EQ(0, get_significant_bit_count(0)); - ASSERT_EQ(1, get_significant_bit_count(1)); - ASSERT_EQ(2, get_significant_bit_count(2)); - ASSERT_EQ(2, get_significant_bit_count(3)); - ASSERT_EQ(3, get_significant_bit_count(4)); - ASSERT_EQ(3, get_significant_bit_count(5)); - ASSERT_EQ(3, get_significant_bit_count(6)); - ASSERT_EQ(3, get_significant_bit_count(7)); - ASSERT_EQ(4, get_significant_bit_count(8)); - ASSERT_EQ(63, get_significant_bit_count(0x7000000000000000)); - ASSERT_EQ(63, get_significant_bit_count(0x7FFFFFFFFFFFFFFF)); - ASSERT_EQ(64, get_significant_bit_count(0x8000000000000000)); - ASSERT_EQ(64, get_significant_bit_count(0xFFFFFFFFFFFFFFFF)); - } - - TEST(Common, GetMSBIndexGeneric) - { - unsigned long result; - get_msb_index_generic(&result, 1); - ASSERT_EQ(static_cast(0), result); - get_msb_index_generic(&result, 2); - ASSERT_EQ(static_cast(1), result); - get_msb_index_generic(&result, 3); - ASSERT_EQ(static_cast(1), result); - get_msb_index_generic(&result, 4); - ASSERT_EQ(static_cast(2), result); - get_msb_index_generic(&result, 16); - ASSERT_EQ(static_cast(4), result); - get_msb_index_generic(&result, 0xFFFFFFFF); - ASSERT_EQ(static_cast(31), result); - get_msb_index_generic(&result, 0x100000000); - ASSERT_EQ(static_cast(32), result); - get_msb_index_generic(&result, 0xFFFFFFFFFFFFFFFF); - ASSERT_EQ(static_cast(63), result); - } - } -} diff --git a/SEAL/native/tests/seal/util/hash.cpp b/SEAL/native/tests/seal/util/hash.cpp deleted file mode 100644 index f5cb3b1..0000000 --- a/SEAL/native/tests/seal/util/hash.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/hash.h" -#include - -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(HashTest, SHA3Hash) - { - uint64_t input[3]{ 0, 0, 0 }; - HashFunction::sha3_block_type hash1, hash2; - HashFunction::sha3_hash(0, hash1); - - HashFunction::sha3_hash(input, 0, hash2); - ASSERT_TRUE(hash1 != hash2); - - HashFunction::sha3_hash(input, 1, hash2); - ASSERT_TRUE(hash1 == hash2); - - HashFunction::sha3_hash(input, 2, hash2); - ASSERT_TRUE(hash1 != hash2); - - HashFunction::sha3_hash(0x123456, hash1); - HashFunction::sha3_hash(0x023456, hash2); - ASSERT_TRUE(hash1 != hash2); - - input[0] = 0x123456; - input[1] = 1; - HashFunction::sha3_hash(0x123456, hash1); - HashFunction::sha3_hash(input, 2, hash2); - ASSERT_TRUE(hash1 != hash2); - } - } -} diff --git a/SEAL/native/tests/seal/util/locks.cpp b/SEAL/native/tests/seal/util/locks.cpp deleted file mode 100644 index fac6127..0000000 --- a/SEAL/native/tests/seal/util/locks.cpp +++ /dev/null @@ -1,309 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/locks.h" -#include -#include - -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - class Reader - { - public: - Reader(ReaderWriterLocker &locker) : locker_(locker), locked_(false), trying_(false) - { - } - - bool is_locked() const - { - return locked_; - } - - bool is_trying_to_lock() const - { - return trying_; - } - - void acquire_read() - { - trying_ = true; - lock_ = locker_.acquire_read(); - locked_ = true; - trying_ = false; - } - - void release() - { - lock_.unlock(); - locked_ = false; - } - - void wait_until_trying() - { - while (!trying_); - } - - void wait_until_locked() - { - while (!locked_); - } - - private: - ReaderWriterLocker &locker_; - - ReaderLock lock_; - - volatile bool locked_; - - volatile bool trying_; - }; - - class Writer - { - public: - Writer(ReaderWriterLocker &locker) : locker_(locker), locked_(false), trying_(false) - { - } - - bool is_locked() const - { - return locked_; - } - - bool is_trying_to_lock() const - { - return trying_; - } - - void acquire_write() - { - trying_ = true; - lock_ = locker_.acquire_write(); - locked_ = true; - trying_ = false; - } - - void release() - { - lock_.unlock(); - locked_ = false; - } - - void wait_until_trying() - { - while (!trying_); - } - - void wait_until_locked() - { - while (!locked_); - } - - void wait_until_unlocked() - { - while(locked_); - } - - private: - ReaderWriterLocker &locker_; - - WriterLock lock_; - - volatile bool locked_; - - volatile bool trying_; - }; - - TEST(ReaderWriterLockerTests, ReaderWriterLockNonBlocking) - { - ReaderWriterLocker locker; - - WriterLock writeLock = locker.acquire_write(); - ASSERT_TRUE(writeLock.owns_lock()); - writeLock.unlock(); - ASSERT_FALSE(writeLock.owns_lock()); - - ReaderLock readLock = locker.acquire_read(); - ASSERT_TRUE(readLock.owns_lock()); - readLock.unlock(); - - ReaderLock readLock2 = locker.acquire_read(); - ASSERT_TRUE(readLock2.owns_lock()); - ASSERT_FALSE(readLock.owns_lock()); - readLock2.unlock(); - ASSERT_FALSE(readLock2.owns_lock()); - - readLock = locker.try_acquire_read(); - ASSERT_TRUE(readLock.owns_lock()); - writeLock = locker.try_acquire_write(); - ASSERT_FALSE(writeLock.owns_lock()); - - readLock2 = locker.try_acquire_read(); - ASSERT_TRUE(readLock2.owns_lock()); - writeLock = locker.try_acquire_write(); - ASSERT_FALSE(writeLock.owns_lock()); - - readLock.unlock(); - writeLock = locker.try_acquire_write(); - ASSERT_FALSE(writeLock.owns_lock()); - - readLock2.unlock(); - writeLock = locker.try_acquire_write(); - ASSERT_TRUE(writeLock.owns_lock()); - - WriterLock writeLock2 = locker.try_acquire_write(); - - ASSERT_FALSE(writeLock2.owns_lock()); - readLock2 = locker.try_acquire_read(); - ASSERT_FALSE(readLock2.owns_lock()); - - writeLock.unlock(); - - writeLock2 = locker.try_acquire_write(); - ASSERT_TRUE(writeLock2.owns_lock()); - readLock2 = locker.try_acquire_read(); - ASSERT_FALSE(readLock2.owns_lock()); - - writeLock2.unlock(); - } - - TEST(ReaderWriterLockerTests, ReaderWriterLockBlocking) - { - ReaderWriterLocker locker; - - Reader *reader1 = new Reader(locker); - Reader *reader2 = new Reader(locker); - Writer *writer1 = new Writer(locker); - Writer *writer2 = new Writer(locker); - - ASSERT_FALSE(reader1->is_locked()); - ASSERT_FALSE(reader2->is_locked()); - ASSERT_FALSE(writer1->is_locked()); - ASSERT_FALSE(writer2->is_locked()); - - reader1->acquire_read(); - ASSERT_TRUE(reader1->is_locked()); - ASSERT_FALSE(reader2->is_locked()); - reader2->acquire_read(); - ASSERT_TRUE(reader1->is_locked()); - ASSERT_TRUE(reader2->is_locked()); - - atomic should_unlock1{ false }; - atomic should_unlock2{ false }; - - thread writer1_thread([&] { - writer1->acquire_write(); - while (!should_unlock1) - { - this_thread::sleep_for(10ms); - } - writer1->release(); - }); - - writer1->wait_until_trying(); - ASSERT_TRUE(writer1->is_trying_to_lock()); - ASSERT_FALSE(writer1->is_locked()); - - reader2->release(); - ASSERT_TRUE(reader1->is_locked()); - ASSERT_FALSE(reader2->is_locked()); - ASSERT_TRUE(writer1->is_trying_to_lock()); - ASSERT_FALSE(writer1->is_locked()); - - thread writer2_thread([&] { - writer2->acquire_write(); - while (!should_unlock2) - { - this_thread::sleep_for(10ms); - } - writer2->release(); - }); - - writer2->wait_until_trying(); - ASSERT_TRUE(writer1->is_trying_to_lock()); - ASSERT_FALSE(writer1->is_locked()); - ASSERT_TRUE(writer2->is_trying_to_lock()); - ASSERT_FALSE(writer2->is_locked()); - - reader1->release(); - ASSERT_FALSE(reader1->is_locked()); - - while (writer1->is_trying_to_lock() && writer2->is_trying_to_lock()); - - Writer *winner; - Writer *waiting; - atomic* should_unlock_winner; - atomic* should_unlock_waiting; - - if (writer1->is_locked()) - { - winner = writer1; - waiting = writer2; - should_unlock_winner = &should_unlock1; - should_unlock_waiting = &should_unlock2; - } - else - { - winner = writer2; - waiting = writer1; - should_unlock_winner = &should_unlock2; - should_unlock_waiting = &should_unlock1; - } - - ASSERT_TRUE(winner->is_locked()); - ASSERT_FALSE(waiting->is_locked()); - - *should_unlock_winner = true; - winner->wait_until_unlocked(); - ASSERT_FALSE(winner->is_locked()); - - waiting->wait_until_locked(); - ASSERT_TRUE(waiting->is_locked()); - - thread reader1_thread(&Reader::acquire_read, reader1); - reader1->wait_until_trying(); - ASSERT_TRUE(reader1->is_trying_to_lock()); - ASSERT_FALSE(reader1->is_locked()); - - thread reader2_thread(&Reader::acquire_read, reader2); - reader2->wait_until_trying(); - ASSERT_TRUE(reader2->is_trying_to_lock()); - ASSERT_FALSE(reader2->is_locked()); - - *should_unlock_waiting = true; - - reader1->wait_until_locked(); - reader2->wait_until_locked(); - ASSERT_TRUE(reader1->is_locked()); - ASSERT_TRUE(reader2->is_locked()); - - reader1->release(); - reader2->release(); - - ASSERT_FALSE(reader1->is_locked()); - ASSERT_FALSE(reader2->is_locked()); - ASSERT_FALSE(writer1->is_locked()); - ASSERT_FALSE(reader2->is_locked()); - ASSERT_FALSE(reader1->is_trying_to_lock()); - ASSERT_FALSE(reader2->is_trying_to_lock()); - ASSERT_FALSE(writer1->is_trying_to_lock()); - ASSERT_FALSE(reader2->is_trying_to_lock()); - - writer1_thread.join(); - writer2_thread.join(); - reader1_thread.join(); - reader2_thread.join(); - - delete reader1; - delete reader2; - delete writer1; - delete writer2; - } - } -} diff --git a/SEAL/native/tests/seal/util/mempool.cpp b/SEAL/native/tests/seal/util/mempool.cpp deleted file mode 100644 index 617b4b5..0000000 --- a/SEAL/native/tests/seal/util/mempool.cpp +++ /dev/null @@ -1,674 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/mempool.h" -#include "seal/util/pointer.h" -#include "seal/util/common.h" -#include - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(MemoryPoolTests, TestMemoryPoolMT) - { - { - MemoryPoolMT pool; - ASSERT_TRUE(0LL == pool.pool_count()); - - Pointer pointer{ pool.get_for_byte_count(bytes_per_uint64 * 0) }; - ASSERT_FALSE(pointer.is_set()); - pointer.release(); - ASSERT_TRUE(0LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); - uint64_t *allocation1 = pointer.get(); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(bytes_per_uint64 * 1); - ASSERT_FALSE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(allocation1 == pointer.get()); - Pointer pointer2 = pool.get_for_byte_count(bytes_per_uint64 * 2); - uint64_t *allocation2 = pointer2.get(); - ASSERT_FALSE(allocation2 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - pointer2.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(allocation2 == pointer.get()); - pointer2 = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(allocation1 == pointer2.get()); - Pointer pointer3 = pool.get_for_byte_count(bytes_per_uint64 * 1); - pointer.release(); - pointer2.release(); - pointer3.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - - Pointer pointer4 = pool.get_for_byte_count(1); - Pointer pointer5 = pool.get_for_byte_count(2); - Pointer pointer6 = pool.get_for_byte_count(1); - pointer4.release(); - pointer5.release(); - pointer6.release(); - ASSERT_TRUE(4LL == pool.pool_count()); - } - { - MemoryPoolMT pool; - ASSERT_TRUE(0LL == pool.pool_count()); - - Pointer pointer{ pool.get_for_byte_count(4 * 0) }; - ASSERT_FALSE(pointer.is_set()); - pointer.release(); - ASSERT_TRUE(0LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(4 * 2); - int *allocation1 = pointer.get(); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(4 * 2); - ASSERT_TRUE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(4 * 1); - ASSERT_FALSE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(4 * 2); - ASSERT_TRUE(allocation1 == pointer.get()); - Pointer pointer2 = pool.get_for_byte_count(4 * 2); - int *allocation2 = pointer2.get(); - ASSERT_FALSE(allocation2 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - pointer2.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(4 * 2); - ASSERT_TRUE(allocation2 == pointer.get()); - pointer2 = pool.get_for_byte_count(4 * 2); - ASSERT_TRUE(allocation1 == pointer2.get()); - Pointer pointer3 = pool.get_for_byte_count(4 * 1); - pointer.release(); - pointer2.release(); - pointer3.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - - Pointer pointer4 = pool.get_for_byte_count(1); - Pointer pointer5 = pool.get_for_byte_count(2); - Pointer pointer6 = pool.get_for_byte_count(1); - pointer4.release(); - pointer5.release(); - pointer6.release(); - ASSERT_TRUE(4LL == pool.pool_count()); - } - { - MemoryPoolMT pool; - ASSERT_TRUE(0LL == pool.pool_count()); - - Pointer pointer = pool.get_for_byte_count(0); - ASSERT_FALSE(pointer.is_set()); - pointer.release(); - ASSERT_TRUE(0LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(2); - SEAL_BYTE *allocation1 = pointer.get(); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(2); - ASSERT_TRUE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(1); - ASSERT_FALSE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(2); - ASSERT_TRUE(allocation1 == pointer.get()); - Pointer pointer2 = pool.get_for_byte_count(2); - SEAL_BYTE *allocation2 = pointer2.get(); - ASSERT_FALSE(allocation2 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - pointer2.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(2); - ASSERT_TRUE(allocation2 == pointer.get()); - pointer2 = pool.get_for_byte_count(2); - ASSERT_TRUE(allocation1 == pointer2.get()); - Pointer pointer3 = pool.get_for_byte_count(1); - pointer.release(); - pointer2.release(); - pointer3.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - } - } - - TEST(MemoryPoolTests, PointerTestsMT) - { - MemoryPool &pool = *global_variables::global_memory_pool; - { - Pointer p1; - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p1.get() == nullptr); - - p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); - uint64_t *allocation1 = p1.get(); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() != nullptr); - - p1.release(); - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p1.get() == nullptr); - - p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() == allocation1); - - Pointer p2; - p2.acquire(p1); - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p2.is_set()); - ASSERT_TRUE(p2.get() == allocation1); - - ConstPointer cp2; - cp2.acquire(p2); - ASSERT_FALSE(p2.is_set()); - ASSERT_TRUE(cp2.is_set()); - ASSERT_TRUE(cp2.get() == allocation1); - cp2.release(); - - Pointer p3 = pool.get_for_byte_count(bytes_per_uint64 * 1); - ASSERT_TRUE(p3.is_set()); - ASSERT_TRUE(p3.get() == allocation1); - - Pointer p4 = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(p4.is_set()); - uint64_t *allocation2 = p4.get(); - swap(p3, p4); - ASSERT_TRUE(p3.is_set()); - ASSERT_TRUE(p3.get() == allocation2); - ASSERT_TRUE(p4.is_set()); - ASSERT_TRUE(p4.get() == allocation1); - p3.release(); - p4.release(); - } - { - Pointer p1; - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p1.get() == nullptr); - - p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); - SEAL_BYTE *allocation1 = p1.get(); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() != nullptr); - - p1.release(); - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p1.get() == nullptr); - - p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() == allocation1); - - Pointer p2; - p2.acquire(p1); - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p2.is_set()); - ASSERT_TRUE(p2.get() == allocation1); - - ConstPointer cp2; - cp2.acquire(p2); - ASSERT_FALSE(p2.is_set()); - ASSERT_TRUE(cp2.is_set()); - ASSERT_TRUE(cp2.get() == allocation1); - cp2.release(); - - Pointer p3 = pool.get_for_byte_count(bytes_per_uint64 * 1); - ASSERT_TRUE(p3.is_set()); - ASSERT_TRUE(p3.get() == allocation1); - - Pointer p4 = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(p4.is_set()); - SEAL_BYTE *allocation2 = p4.get(); - swap(p3, p4); - ASSERT_TRUE(p3.is_set()); - ASSERT_TRUE(p3.get() == allocation2); - ASSERT_TRUE(p4.is_set()); - ASSERT_TRUE(p4.get() == allocation1); - p3.release(); - p4.release(); - } - } - - TEST(MemoryPoolTests, DuplicateIfNeededMT) - { - { - unique_ptr allocation(new uint64_t[2]); - allocation[0] = 0x1234567812345678; - allocation[1] = 0x8765432187654321; - - MemoryPoolMT pool; - Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() == allocation.get()); - ASSERT_TRUE(0LL == pool.pool_count()); - - p1 = duplicate_if_needed(allocation.get(), 2, true, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_FALSE(p1.get() == allocation.get()); - ASSERT_TRUE(1LL == pool.pool_count()); - ASSERT_TRUE(p1.get()[0] == 0x1234567812345678); - ASSERT_TRUE(p1.get()[1] == 0x8765432187654321); - p1.release(); - } - { - unique_ptr allocation(new int64_t[2]); - allocation[0] = 0x234567812345678; - allocation[1] = 0x765432187654321; - - MemoryPoolMT pool; - Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() == allocation.get()); - ASSERT_TRUE(0LL == pool.pool_count()); - - p1 = duplicate_if_needed(allocation.get(), 2, true, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_FALSE(p1.get() == allocation.get()); - ASSERT_TRUE(1LL == pool.pool_count()); - ASSERT_TRUE(p1.get()[0] == 0x234567812345678); - ASSERT_TRUE(p1.get()[1] == 0x765432187654321); - p1.release(); - } - { - unique_ptr allocation(new int[2]); - allocation[0] = 0x123; - allocation[1] = 0x876; - - MemoryPoolMT pool; - Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() == allocation.get()); - ASSERT_TRUE(0LL == pool.pool_count()); - - p1 = duplicate_if_needed(allocation.get(), 2, true, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_FALSE(p1.get() == allocation.get()); - ASSERT_TRUE(1LL == pool.pool_count()); - ASSERT_TRUE(p1.get()[0] == 0x123); - ASSERT_TRUE(p1.get()[1] == 0x876); - p1.release(); - } - } - - TEST(MemoryPoolTests, TestMemoryPoolST) - { - { - MemoryPoolST pool; - ASSERT_TRUE(0LL == pool.pool_count()); - - Pointer pointer{ pool.get_for_byte_count(bytes_per_uint64 * 0) }; - ASSERT_FALSE(pointer.is_set()); - pointer.release(); - ASSERT_TRUE(0LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); - uint64_t *allocation1 = pointer.get(); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(bytes_per_uint64 * 1); - ASSERT_FALSE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(allocation1 == pointer.get()); - Pointer pointer2 = pool.get_for_byte_count(bytes_per_uint64 * 2); - uint64_t *allocation2 = pointer2.get(); - ASSERT_FALSE(allocation2 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - pointer2.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(allocation2 == pointer.get()); - pointer2 = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(allocation1 == pointer2.get()); - Pointer pointer3 = pool.get_for_byte_count(bytes_per_uint64 * 1); - pointer.release(); - pointer2.release(); - pointer3.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - - Pointer pointer4 = pool.get_for_byte_count(1); - Pointer pointer5 = pool.get_for_byte_count(2); - Pointer pointer6 = pool.get_for_byte_count(1); - pointer4.release(); - pointer5.release(); - pointer6.release(); - ASSERT_TRUE(4LL == pool.pool_count()); - } - { - MemoryPoolST pool; - ASSERT_TRUE(0LL == pool.pool_count()); - - Pointer pointer{ pool.get_for_byte_count(4 * 0) }; - ASSERT_FALSE(pointer.is_set()); - pointer.release(); - ASSERT_TRUE(0LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(4 * 2); - int *allocation1 = pointer.get(); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(4 * 2); - ASSERT_TRUE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(4 * 1); - ASSERT_FALSE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(4 * 2); - ASSERT_TRUE(allocation1 == pointer.get()); - Pointer pointer2 = pool.get_for_byte_count(4 * 2); - int *allocation2 = pointer2.get(); - ASSERT_FALSE(allocation2 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - pointer2.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(4 * 2); - ASSERT_TRUE(allocation2 == pointer.get()); - pointer2 = pool.get_for_byte_count(4 * 2); - ASSERT_TRUE(allocation1 == pointer2.get()); - Pointer pointer3 = pool.get_for_byte_count(4 * 1); - pointer.release(); - pointer2.release(); - pointer3.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - - Pointer pointer4 = pool.get_for_byte_count(1); - Pointer pointer5 = pool.get_for_byte_count(2); - Pointer pointer6 = pool.get_for_byte_count(1); - pointer4.release(); - pointer5.release(); - pointer6.release(); - ASSERT_TRUE(4LL == pool.pool_count()); - } - { - MemoryPoolST pool; - ASSERT_TRUE(0LL == pool.pool_count()); - - Pointer pointer = pool.get_for_byte_count(0); - ASSERT_FALSE(pointer.is_set()); - pointer.release(); - ASSERT_TRUE(0LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(2); - SEAL_BYTE *allocation1 = pointer.get(); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(2); - ASSERT_TRUE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(1LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(1); - ASSERT_FALSE(allocation1 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - ASSERT_FALSE(pointer.is_set()); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(2); - ASSERT_TRUE(allocation1 == pointer.get()); - Pointer pointer2 = pool.get_for_byte_count(2); - SEAL_BYTE *allocation2 = pointer2.get(); - ASSERT_FALSE(allocation2 == pointer.get()); - ASSERT_TRUE(pointer.is_set()); - pointer.release(); - pointer2.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - - pointer = pool.get_for_byte_count(2); - ASSERT_TRUE(allocation2 == pointer.get()); - pointer2 = pool.get_for_byte_count(2); - ASSERT_TRUE(allocation1 == pointer2.get()); - Pointer pointer3 = pool.get_for_byte_count(1); - pointer.release(); - pointer2.release(); - pointer3.release(); - ASSERT_TRUE(2LL == pool.pool_count()); - } - } - - TEST(MemoryPoolTests, PointerTestsST) - { - MemoryPoolST pool; - { - Pointer p1; - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p1.get() == nullptr); - - p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); - uint64_t *allocation1 = p1.get(); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() != nullptr); - - p1.release(); - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p1.get() == nullptr); - - p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() == allocation1); - - Pointer p2; - p2.acquire(p1); - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p2.is_set()); - ASSERT_TRUE(p2.get() == allocation1); - - ConstPointer cp2; - cp2.acquire(p2); - ASSERT_FALSE(p2.is_set()); - ASSERT_TRUE(cp2.is_set()); - ASSERT_TRUE(cp2.get() == allocation1); - cp2.release(); - - Pointer p3 = pool.get_for_byte_count(bytes_per_uint64 * 1); - ASSERT_TRUE(p3.is_set()); - ASSERT_TRUE(p3.get() == allocation1); - - Pointer p4 = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(p4.is_set()); - uint64_t *allocation2 = p4.get(); - swap(p3, p4); - ASSERT_TRUE(p3.is_set()); - ASSERT_TRUE(p3.get() == allocation2); - ASSERT_TRUE(p4.is_set()); - ASSERT_TRUE(p4.get() == allocation1); - p3.release(); - p4.release(); - } - { - Pointer p1; - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p1.get() == nullptr); - - p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); - SEAL_BYTE *allocation1 = p1.get(); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() != nullptr); - - p1.release(); - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p1.get() == nullptr); - - p1 = pool.get_for_byte_count(bytes_per_uint64 * 1); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() == allocation1); - - Pointer p2; - p2.acquire(p1); - ASSERT_FALSE(p1.is_set()); - ASSERT_TRUE(p2.is_set()); - ASSERT_TRUE(p2.get() == allocation1); - - ConstPointer cp2; - cp2.acquire(p2); - ASSERT_FALSE(p2.is_set()); - ASSERT_TRUE(cp2.is_set()); - ASSERT_TRUE(cp2.get() == allocation1); - cp2.release(); - - Pointer p3 = pool.get_for_byte_count(bytes_per_uint64 * 1); - ASSERT_TRUE(p3.is_set()); - ASSERT_TRUE(p3.get() == allocation1); - - Pointer p4 = pool.get_for_byte_count(bytes_per_uint64 * 2); - ASSERT_TRUE(p4.is_set()); - SEAL_BYTE *allocation2 = p4.get(); - swap(p3, p4); - ASSERT_TRUE(p3.is_set()); - ASSERT_TRUE(p3.get() == allocation2); - ASSERT_TRUE(p4.is_set()); - ASSERT_TRUE(p4.get() == allocation1); - p3.release(); - p4.release(); - } - } - - TEST(MemoryPoolTests, DuplicateIfNeededST) - { - { - unique_ptr allocation(new uint64_t[2]); - allocation[0] = 0x1234567812345678; - allocation[1] = 0x8765432187654321; - - MemoryPoolST pool; - Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() == allocation.get()); - ASSERT_TRUE(0LL == pool.pool_count()); - - p1 = duplicate_if_needed(allocation.get(), 2, true, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_FALSE(p1.get() == allocation.get()); - ASSERT_TRUE(1LL == pool.pool_count()); - ASSERT_TRUE(p1.get()[0] == 0x1234567812345678); - ASSERT_TRUE(p1.get()[1] == 0x8765432187654321); - p1.release(); - } - { - unique_ptr allocation(new int64_t[2]); - allocation[0] = 0x234567812345678; - allocation[1] = 0x765432187654321; - - MemoryPoolST pool; - Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() == allocation.get()); - ASSERT_TRUE(0LL == pool.pool_count()); - - p1 = duplicate_if_needed(allocation.get(), 2, true, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_FALSE(p1.get() == allocation.get()); - ASSERT_TRUE(1LL == pool.pool_count()); - ASSERT_TRUE(p1.get()[0] == 0x234567812345678); - ASSERT_TRUE(p1.get()[1] == 0x765432187654321); - p1.release(); - } - { - unique_ptr allocation(new int[2]); - allocation[0] = 0x123; - allocation[1] = 0x876; - - MemoryPoolST pool; - Pointer p1 = duplicate_if_needed(allocation.get(), 2, false, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_TRUE(p1.get() == allocation.get()); - ASSERT_TRUE(0LL == pool.pool_count()); - - p1 = duplicate_if_needed(allocation.get(), 2, true, pool); - ASSERT_TRUE(p1.is_set()); - ASSERT_FALSE(p1.get() == allocation.get()); - ASSERT_TRUE(1LL == pool.pool_count()); - ASSERT_TRUE(p1.get()[0] == 0x123); - ASSERT_TRUE(p1.get()[1] == 0x876); - p1.release(); - } - } - } -} \ No newline at end of file diff --git a/SEAL/native/tests/seal/util/numth.cpp b/SEAL/native/tests/seal/util/numth.cpp deleted file mode 100644 index 145f9ad..0000000 --- a/SEAL/native/tests/seal/util/numth.cpp +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/numth.h" -#include - -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(NumberTheoryTest, GCD) - { - ASSERT_EQ(1ULL, gcd(1, 1)); - ASSERT_EQ(1ULL, gcd(2, 1)); - ASSERT_EQ(1ULL, gcd(1, 2)); - ASSERT_EQ(2ULL, gcd(2, 2)); - ASSERT_EQ(3ULL, gcd(6, 15)); - ASSERT_EQ(3ULL, gcd(15, 6)); - ASSERT_EQ(1ULL, gcd(7, 15)); - ASSERT_EQ(1ULL, gcd(15, 7)); - ASSERT_EQ(1ULL, gcd(7, 15)); - ASSERT_EQ(3ULL, gcd(11112, 44445)); - } - - TEST(NumberTheoryTest, ExtendedGCD) - { - tuple result; - - // Corner case behavior - result = xgcd(7, 7); - ASSERT_TRUE(result == make_tuple<>(7, 0, 1)); - result = xgcd(2, 2); - ASSERT_TRUE(result == make_tuple<>(2, 0, 1)); - - result = xgcd(1, 1); - ASSERT_TRUE(result == make_tuple<>(1, 0, 1)); - result = xgcd(1, 2); - ASSERT_TRUE(result == make_tuple<>(1, 1, 0)); - result = xgcd(5, 6); - ASSERT_TRUE(result == make_tuple<>(1, -1, 1)); - result = xgcd(13, 19); - ASSERT_TRUE(result == make_tuple<>(1, 3, -2)); - result = xgcd(14, 21); - ASSERT_TRUE(result == make_tuple<>(7, -1, 1)); - - result = xgcd(2, 1); - ASSERT_TRUE(result == make_tuple<>(1, 0, 1)); - result = xgcd(6, 5); - ASSERT_TRUE(result == make_tuple<>(1, 1, -1)); - result = xgcd(19, 13); - ASSERT_TRUE(result == make_tuple<>(1, -2, 3)); - result = xgcd(21, 14); - ASSERT_TRUE(result == make_tuple<>(7, 1, -1)); - } - - TEST(NumberTheoryTest, TryModInverse) - { - uint64_t input, modulus, result; - - input = 1, modulus = 2; - ASSERT_TRUE(try_mod_inverse(input, modulus, result)); - ASSERT_EQ(result, 1ULL); - - input = 2, modulus = 2; - ASSERT_FALSE(try_mod_inverse(input, modulus, result)); - - input = 3, modulus = 2; - ASSERT_TRUE(try_mod_inverse(input, modulus, result)); - ASSERT_EQ(result, 1ULL); - - input = 0xFFFFFF, modulus = 2; - ASSERT_TRUE(try_mod_inverse(input, modulus, result)); - ASSERT_EQ(result, 1ULL); - - input = 0xFFFFFE, modulus = 2; - ASSERT_FALSE(try_mod_inverse(input, modulus, result)); - - input = 12345, modulus = 3; - ASSERT_FALSE(try_mod_inverse(input, modulus, result)); - - input = 5, modulus = 19; - ASSERT_TRUE(try_mod_inverse(input, modulus, result)); - ASSERT_EQ(result, 4ULL); - - input = 4, modulus = 19; - ASSERT_TRUE(try_mod_inverse(input, modulus, result)); - ASSERT_EQ(result, 5ULL); - } - - TEST(NumberTheoryTest, IsPrime) - { - ASSERT_FALSE(is_prime(0)); - ASSERT_TRUE(is_prime(2)); - ASSERT_TRUE(is_prime(3)); - ASSERT_FALSE(is_prime(4)); - ASSERT_TRUE(is_prime(5)); - ASSERT_FALSE(is_prime(221)); - ASSERT_TRUE(is_prime(65537)); - ASSERT_FALSE(is_prime(65536)); - ASSERT_TRUE(is_prime(59399)); - ASSERT_TRUE(is_prime(72307)); - ASSERT_FALSE(is_prime(72307ULL * 59399ULL)); - ASSERT_TRUE(is_prime(36893488147419103ULL)); - ASSERT_FALSE(is_prime(36893488147419107ULL)); - } - } -} \ No newline at end of file diff --git a/SEAL/native/tests/seal/util/polyarith.cpp b/SEAL/native/tests/seal/util/polyarith.cpp deleted file mode 100644 index 76e49bf..0000000 --- a/SEAL/native/tests/seal/util/polyarith.cpp +++ /dev/null @@ -1,505 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/uintcore.h" -#include "seal/util/polyarith.h" -#include - -using namespace seal::util; -using namespace std; -using namespace seal; - -namespace SEALTest -{ - namespace util - { - TEST(PolyArith, RightShiftPolyCoeffs) - { - right_shift_poly_coeffs(nullptr, 0, 0, 0, nullptr); - right_shift_poly_coeffs(nullptr, 0, 0, 1, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_zero_poly(3, 2, pool)); - ptr[0] = 2; - ptr[1] = 4; - ptr[2] = 8; - right_shift_poly_coeffs(ptr.get(), 3, 1, 0, ptr.get()); - ASSERT_EQ(2ULL, ptr[0]); - ASSERT_EQ(4ULL, ptr[1]); - ASSERT_EQ(8ULL, ptr[2]); - - right_shift_poly_coeffs(ptr.get(), 3, 1, 1, ptr.get()); - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(2ULL, ptr[1]); - ASSERT_EQ(4ULL, ptr[2]); - - right_shift_poly_coeffs(ptr.get(), 3, 1, 1, ptr.get()); - ASSERT_EQ(0ULL, ptr[0]); - ASSERT_EQ(1ULL, ptr[1]); - ASSERT_EQ(2ULL, ptr[2]); - - ptr[0] = 3; - ptr[1] = 5; - ptr[2] = 9; - right_shift_poly_coeffs(ptr.get(), 3, 1, 2, ptr.get()); - ASSERT_EQ(0ULL, ptr[0]); - ASSERT_EQ(1ULL, ptr[1]); - ASSERT_EQ(2ULL, ptr[2]); - - ptr[0] = 3; - ptr[1] = 5; - ptr[2] = 9; - right_shift_poly_coeffs(ptr.get(), 3, 1, 4, ptr.get()); - ASSERT_EQ(0ULL, ptr[0]); - ASSERT_EQ(0ULL, ptr[1]); - ASSERT_EQ(0ULL, ptr[2]); - - ptr[0] = 1; - ptr[1] = 1; - ptr[2] = 1; - right_shift_poly_coeffs(ptr.get(), 1, 2, 64, ptr.get()); - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(0ULL, ptr[1]); - ASSERT_EQ(1ULL, ptr[2]); - - ptr[0] = 3; - ptr[1] = 5; - ptr[2] = 9; - right_shift_poly_coeffs(ptr.get(), 1, 3, 128, ptr.get()); - ASSERT_EQ(9ULL, ptr[0]); - ASSERT_EQ(0ULL, ptr[1]); - ASSERT_EQ(0ULL, ptr[2]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr[2] = 0xFFFFFFFFFFFFFFFF; - right_shift_poly_coeffs(ptr.get(), 1, 3, 191, ptr.get()); - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(0ULL, ptr[1]); - ASSERT_EQ(0ULL, ptr[2]); - } - - TEST(PolyArith, NegatePoly) - { - negate_poly(nullptr, 0, 0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_zero_poly(3, 2, pool)); - ptr[0] = 2; - ptr[2] = 3; - ptr[4] = 4; - negate_poly(ptr.get(), 3, 2, ptr.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFD), ptr[2]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[3]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFC), ptr[4]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[5]); - } - - TEST(PolyArith, AddPolyPoly) - { - add_poly_poly(nullptr, nullptr, 0, 0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(3, 2, pool)); - auto poly2(allocate_zero_poly(3, 2, pool)); - - poly1[0] = 0; - poly1[1] = 0xFFFFFFFFFFFFFFFF; - poly1[2] = 1; - poly1[3] = 0; - poly1[4] = 0xFFFFFFFFFFFFFFFF; - poly1[5] = 1; - poly2[0] = 1; - poly2[1] = 1; - poly2[2] = 1; - poly2[3] = 1; - poly2[4] = 0xFFFFFFFFFFFFFFFF; - poly2[5] = 1; - add_poly_poly(poly1.get(), poly2.get(), 3, 2, poly1.get()); - ASSERT_EQ(static_cast(1), poly1[0]); - ASSERT_EQ(static_cast(0), poly1[1]); - ASSERT_EQ(static_cast(2), poly1[2]); - ASSERT_EQ(static_cast(1), poly1[3]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), poly1[4]); - ASSERT_EQ(static_cast(3), poly1[5]); - - poly1[0] = 2; - poly1[1] = 0; - poly1[2] = 3; - poly1[3] = 0; - poly1[4] = 0xFFFFFFFFFFFFFFFF; - poly1[5] = 0xFFFFFFFFFFFFFFFF; - poly2[0] = 5; - poly2[1] = 0; - poly2[2] = 6; - poly2[3] = 0; - poly2[4] = 0xFFFFFFFFFFFFFFFF; - poly2[5] = 0xFFFFFFFFFFFFFFFF; - add_poly_poly(poly1.get(), poly2.get(), 3, 2, poly1.get()); - ASSERT_EQ(static_cast(7), poly1[0]); - ASSERT_EQ(static_cast(0), poly1[1]); - ASSERT_EQ(static_cast(9), poly1[2]); - ASSERT_EQ(static_cast(0), poly1[3]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), poly1[4]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly1[5]); - } - - TEST(PolyArith, SubPolyPoly) - { - sub_poly_poly(nullptr, nullptr, 0, 0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(3, 2, pool)); - auto poly2(allocate_zero_poly(3, 2, pool)); - - poly1[0] = 0; - poly1[1] = 0xFFFFFFFFFFFFFFFF; - poly1[2] = 1; - poly1[3] = 0; - poly1[4] = 0xFFFFFFFFFFFFFFFF; - poly1[5] = 1; - poly2[0] = 1; - poly2[1] = 1; - poly2[2] = 1; - poly2[3] = 1; - poly2[4] = 0xFFFFFFFFFFFFFFFF; - poly2[5] = 1; - sub_poly_poly(poly1.get(), poly2.get(), 6, 1, poly1.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly1[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), poly1[1]); - ASSERT_EQ(static_cast(0), poly1[2]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly1[3]); - ASSERT_EQ(static_cast(0), poly1[4]); - ASSERT_EQ(static_cast(0), poly1[5]); - - poly1[0] = 5; - poly1[1] = 0; - poly1[2] = 6; - poly1[3] = 0; - poly1[4] = 0xFFFFFFFFFFFFFFFF; - poly1[5] = 0xFFFFFFFFFFFFFFFF; - poly2[0] = 2; - poly2[1] = 0; - poly2[2] = 8; - poly2[3] = 0; - poly2[4] = 0xFFFFFFFFFFFFFFFE; - poly2[5] = 0xFFFFFFFFFFFFFFFF; - sub_poly_poly(poly1.get(), poly2.get(), 3, 2, poly1.get()); - ASSERT_EQ(static_cast(3), poly1[0]); - ASSERT_EQ(static_cast(0), poly1[1]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), poly1[2]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly1[3]); - ASSERT_EQ(1ULL, poly1[4]); - ASSERT_EQ(static_cast(0), poly1[5]); - } - - TEST(PolyArith, MultiplyPolyPoly) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(3, 2, pool)); - auto poly2(allocate_zero_poly(3, 2, pool)); - auto result(allocate_zero_poly(5, 2, pool)); - poly1[0] = 1; - poly1[2] = 2; - poly1[4] = 3; - poly2[0] = 2; - poly2[2] = 3; - poly2[4] = 4; - multiply_poly_poly(poly1.get(), 3, 2, poly2.get(), 3, 2, 5, 2, result.get(), pool); - ASSERT_EQ(static_cast(2), result[0]); - ASSERT_EQ(static_cast(0), result[1]); - ASSERT_EQ(static_cast(7), result[2]); - ASSERT_EQ(static_cast(0), result[3]); - ASSERT_EQ(static_cast(16), result[4]); - ASSERT_EQ(static_cast(0), result[5]); - ASSERT_EQ(static_cast(17), result[6]); - ASSERT_EQ(static_cast(0), result[7]); - ASSERT_EQ(static_cast(12), result[8]); - ASSERT_EQ(static_cast(0), result[9]); - - poly2[0] = 2; - poly2[1] = 3; - multiply_poly_poly(poly1.get(), 3, 2, poly2.get(), 2, 1, 5, 2, result.get(), pool); - ASSERT_EQ(static_cast(2), result[0]); - ASSERT_EQ(static_cast(0), result[1]); - ASSERT_EQ(static_cast(7), result[2]); - ASSERT_EQ(static_cast(0), result[3]); - ASSERT_EQ(static_cast(12), result[4]); - ASSERT_EQ(static_cast(0), result[5]); - ASSERT_EQ(static_cast(9), result[6]); - ASSERT_EQ(static_cast(0), result[7]); - ASSERT_EQ(static_cast(0), result[8]); - ASSERT_EQ(static_cast(0), result[9]); - - multiply_poly_poly(poly1.get(), 3, 2, poly2.get(), 2, 1, 5, 1, result.get(), pool); - ASSERT_EQ(static_cast(2), result[0]); - ASSERT_EQ(static_cast(7), result[1]); - ASSERT_EQ(static_cast(12), result[2]); - ASSERT_EQ(static_cast(9), result[3]); - ASSERT_EQ(static_cast(0), result[4]); - } - - TEST(PolyArith, PolyInftyNorm) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly(allocate_zero_poly(10, 1, pool)); - uint64_t result[2]; - - poly[0] = 1, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; - poly[5] = 4, poly[6] = 0xB, poly[7] = 0xA, poly[8] = 5, poly[9] = 2; - poly_infty_norm(poly.get(), 10, 1, result); - ASSERT_EQ(result[0], 0xBULL); - - poly[0] = 2, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; - poly[5] = 0xF7, poly[6] = 0xFE, poly[7] = 0xCF, poly[8] = 0xCA, poly[9] = 0xAB; - poly_infty_norm(poly.get(), 10, 1, result); - ASSERT_EQ(result[0], 0xFEULL); - - poly[0] = 2, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; - poly[5] = 0xABCDEF, poly[6] = 0xABCDE, poly[7] = 0xABCD, poly[8] = 0xABC, poly[9] = 0xAB; - poly_infty_norm(poly.get(), 10, 1, result); - ASSERT_EQ(result[0], 0xABCDEFULL); - - poly[0] = 6, poly[1] = 5, poly[2] = 4, poly[3] = 3, poly[4] = 2; - poly[5] = 1, poly[6] = 0; - poly_infty_norm(poly.get(), 6, 1, result); - ASSERT_EQ(result[0], 6ULL); - - poly[0] = 1, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; - poly[5] = 4, poly[6] = 0xB, poly[7] = 0xA, poly[8] = 5, poly[9] = 2; - poly_infty_norm(poly.get(), 5, 2, result); - ASSERT_EQ(result[0], 0xBULL); - ASSERT_EQ(result[1], 0xAULL); - - poly[0] = 2, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; - poly[5] = 0xF7, poly[6] = 0xFE, poly[7] = 0xCF, poly[8] = 0xCA, poly[9] = 0xAB; - poly_infty_norm(poly.get(), 5, 2, result); - ASSERT_EQ(result[0], 0x0ULL); - ASSERT_EQ(result[1], 0xF7ULL); - - poly[0] = 2, poly[1] = 0, poly[2] = 1, poly[3] = 0, poly[4] = 0; - poly[5] = 0xABCDEF, poly[6] = 0xABCDE, poly[7] = 0xABCD, poly[8] = 0xABC, poly[9] = 0xAB; - poly_infty_norm(poly.get(), 5, 2, result); - ASSERT_EQ(result[0], 0ULL); - ASSERT_EQ(result[1], 0xABCDEFULL); - - poly[0] = 6, poly[1] = 5, poly[2] = 4, poly[3] = 3, poly[4] = 2; - poly[5] = 1, poly[6] = 0; - poly_infty_norm(poly.get(), 3, 2, result); - ASSERT_EQ(result[0], 6ULL); - ASSERT_EQ(result[1], 5ULL); - } - - TEST(PolyArith, PolyEvalPoly) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(4, 1, pool)); - auto poly2(allocate_zero_poly(4, 1, pool)); - auto poly3(allocate_zero_poly(8, 1, pool)); - - poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); - ASSERT_EQ(poly3[0], 0ULL); - ASSERT_EQ(poly3[1], 0ULL); - ASSERT_EQ(poly3[2], 0ULL); - ASSERT_EQ(poly3[3], 0ULL); - ASSERT_EQ(poly3[4], 0ULL); - ASSERT_EQ(poly3[5], 0ULL); - ASSERT_EQ(poly3[6], 0ULL); - ASSERT_EQ(poly3[7], 0ULL); - - poly1[0] = 1; - poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); - ASSERT_EQ(poly3[0], 1ULL); - ASSERT_EQ(poly3[1], 0ULL); - ASSERT_EQ(poly3[2], 0ULL); - ASSERT_EQ(poly3[3], 0ULL); - ASSERT_EQ(poly3[4], 0ULL); - ASSERT_EQ(poly3[5], 0ULL); - ASSERT_EQ(poly3[6], 0ULL); - ASSERT_EQ(poly3[7], 0ULL); - - poly1[0] = 2; - poly2[0] = 1; - poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); - ASSERT_EQ(poly3[0], 2ULL); - ASSERT_EQ(poly3[1], 0ULL); - ASSERT_EQ(poly3[2], 0ULL); - ASSERT_EQ(poly3[3], 0ULL); - ASSERT_EQ(poly3[4], 0ULL); - ASSERT_EQ(poly3[5], 0ULL); - ASSERT_EQ(poly3[6], 0ULL); - ASSERT_EQ(poly3[7], 0ULL); - - poly1[0] = 1; - poly1[1] = 1; - poly2[0] = 1; - poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); - ASSERT_EQ(poly3[0], 2ULL); - ASSERT_EQ(poly3[1], 0ULL); - ASSERT_EQ(poly3[2], 0ULL); - ASSERT_EQ(poly3[3], 0ULL); - ASSERT_EQ(poly3[4], 0ULL); - ASSERT_EQ(poly3[5], 0ULL); - ASSERT_EQ(poly3[6], 0ULL); - ASSERT_EQ(poly3[7], 0ULL); - - poly1[0] = 1; - poly1[1] = 1; - poly2[0] = 2; - poly2[1] = 0; - poly2[2] = 1; - poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); - ASSERT_EQ(poly3[0], 3ULL); - ASSERT_EQ(poly3[1], 0ULL); - ASSERT_EQ(poly3[2], 1ULL); - ASSERT_EQ(poly3[3], 0ULL); - ASSERT_EQ(poly3[4], 0ULL); - ASSERT_EQ(poly3[5], 0ULL); - ASSERT_EQ(poly3[6], 0ULL); - ASSERT_EQ(poly3[7], 0ULL); - - poly1[0] = 2; - poly1[1] = 0; - poly1[2] = 1; - poly2[0] = 1; - poly2[1] = 1; - poly2[2] = 0; - poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); - ASSERT_EQ(poly3[0], 3ULL); - ASSERT_EQ(poly3[1], 2ULL); - ASSERT_EQ(poly3[2], 1ULL); - ASSERT_EQ(poly3[3], 0ULL); - ASSERT_EQ(poly3[4], 0ULL); - ASSERT_EQ(poly3[5], 0ULL); - ASSERT_EQ(poly3[6], 0ULL); - ASSERT_EQ(poly3[7], 0ULL); - - poly1[0] = 0; - poly1[1] = 0; - poly1[2] = 0; - poly1[3] = 1; - poly2[0] = 2; - poly2[1] = 0; - poly2[2] = 0; - poly2[3] = 0; - poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); - ASSERT_EQ(poly3[0], 8ULL); - ASSERT_EQ(poly3[1], 0ULL); - ASSERT_EQ(poly3[2], 0ULL); - ASSERT_EQ(poly3[3], 0ULL); - ASSERT_EQ(poly3[4], 0ULL); - ASSERT_EQ(poly3[5], 0ULL); - ASSERT_EQ(poly3[6], 0ULL); - ASSERT_EQ(poly3[7], 0ULL); - - poly1[0] = 0; - poly1[1] = 0; - poly1[2] = 0; - poly1[3] = 1; - poly2[0] = 0; - poly2[1] = 0; - poly2[2] = 2; - poly2[3] = 0; - poly_eval_poly(poly1.get(), 4, 1, poly2.get(), 4, 1, 8, 1, poly3.get(), pool); - ASSERT_EQ(poly3[0], 0ULL); - ASSERT_EQ(poly3[1], 0ULL); - ASSERT_EQ(poly3[2], 0ULL); - ASSERT_EQ(poly3[3], 0ULL); - ASSERT_EQ(poly3[4], 0ULL); - ASSERT_EQ(poly3[5], 0ULL); - ASSERT_EQ(poly3[6], 8ULL); - ASSERT_EQ(poly3[7], 0ULL); - } - - TEST(PolyArith, ExponentiatePoly) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(4, 1, pool)); - auto poly2(allocate_zero_poly(12, 1, pool)); - - uint64_t exponent = 1; - exponentiate_poly(poly1.get(), 4, 1, &exponent, 1, 12, 1, poly2.get(), pool); - ASSERT_EQ(poly2[0], 0ULL); - ASSERT_EQ(poly2[1], 0ULL); - ASSERT_EQ(poly2[2], 0ULL); - ASSERT_EQ(poly2[3], 0ULL); - ASSERT_EQ(poly2[4], 0ULL); - ASSERT_EQ(poly2[5], 0ULL); - ASSERT_EQ(poly2[6], 0ULL); - ASSERT_EQ(poly2[7], 0ULL); - ASSERT_EQ(poly2[8], 0ULL); - ASSERT_EQ(poly2[9], 0ULL); - ASSERT_EQ(poly2[10], 0ULL); - ASSERT_EQ(poly2[11], 0ULL); - - exponent = 0; - exponentiate_poly(poly1.get(), 4, 1, &exponent, 1, 12, 1, poly2.get(), pool); - ASSERT_EQ(poly2[0], 1ULL); - ASSERT_EQ(poly2[1], 0ULL); - ASSERT_EQ(poly2[2], 0ULL); - ASSERT_EQ(poly2[3], 0ULL); - ASSERT_EQ(poly2[4], 0ULL); - ASSERT_EQ(poly2[5], 0ULL); - ASSERT_EQ(poly2[6], 0ULL); - ASSERT_EQ(poly2[7], 0ULL); - ASSERT_EQ(poly2[8], 0ULL); - ASSERT_EQ(poly2[9], 0ULL); - ASSERT_EQ(poly2[10], 0ULL); - ASSERT_EQ(poly2[11], 0ULL); - - exponent = 3; - poly1[1] = 2; - exponentiate_poly(poly1.get(), 4, 1, &exponent, 1, 12, 1, poly2.get(), pool); - ASSERT_EQ(poly2[0], 0ULL); - ASSERT_EQ(poly2[1], 0ULL); - ASSERT_EQ(poly2[2], 0ULL); - ASSERT_EQ(poly2[3], 8ULL); - ASSERT_EQ(poly2[4], 0ULL); - ASSERT_EQ(poly2[5], 0ULL); - ASSERT_EQ(poly2[6], 0ULL); - ASSERT_EQ(poly2[7], 0ULL); - ASSERT_EQ(poly2[8], 0ULL); - ASSERT_EQ(poly2[9], 0ULL); - ASSERT_EQ(poly2[10], 0ULL); - ASSERT_EQ(poly2[11], 0ULL); - - exponent = 3; - poly1[0] = 1; - poly1[1] = 1; - exponentiate_poly(poly1.get(), 4, 1, &exponent, 1, 12, 1, poly2.get(), pool); - ASSERT_EQ(poly2[0], 1ULL); - ASSERT_EQ(poly2[1], 3ULL); - ASSERT_EQ(poly2[2], 3ULL); - ASSERT_EQ(poly2[3], 1ULL); - ASSERT_EQ(poly2[4], 0ULL); - ASSERT_EQ(poly2[5], 0ULL); - ASSERT_EQ(poly2[6], 0ULL); - ASSERT_EQ(poly2[7], 0ULL); - ASSERT_EQ(poly2[8], 0ULL); - ASSERT_EQ(poly2[9], 0ULL); - ASSERT_EQ(poly2[10], 0ULL); - ASSERT_EQ(poly2[11], 0ULL); - - exponent = 5; - poly1[0] = 0; - poly1[1] = 0; - poly1[2] = 2; - exponentiate_poly(poly1.get(), 4, 1, &exponent, 1, 12, 1, poly2.get(), pool); - ASSERT_EQ(poly2[0], 0ULL); - ASSERT_EQ(poly2[1], 0ULL); - ASSERT_EQ(poly2[2], 0ULL); - ASSERT_EQ(poly2[3], 0ULL); - ASSERT_EQ(poly2[4], 0ULL); - ASSERT_EQ(poly2[5], 0ULL); - ASSERT_EQ(poly2[6], 0ULL); - ASSERT_EQ(poly2[7], 0ULL); - ASSERT_EQ(poly2[8], 0ULL); - ASSERT_EQ(poly2[9], 0ULL); - ASSERT_EQ(poly2[10], 32ULL); - ASSERT_EQ(poly2[11], 0ULL); - } - } -} diff --git a/SEAL/native/tests/seal/util/polyarithmod.cpp b/SEAL/native/tests/seal/util/polyarithmod.cpp deleted file mode 100644 index 3ad719b..0000000 --- a/SEAL/native/tests/seal/util/polyarithmod.cpp +++ /dev/null @@ -1,96 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/uintcore.h" -#include "seal/util/polycore.h" -#include "seal/util/polyarithmod.h" -#include - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(PolyArithMod, NegatePolyCoeffMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly(allocate_zero_poly(3, 2, pool)); - auto modulus(allocate_uint(2, pool)); - poly[0] = 2; - poly[2] = 3; - poly[4] = 4; - modulus[0] = 15; - modulus[1] = 0; - negate_poly_coeffmod(poly.get(), 3, modulus.get(), 2, poly.get()); - ASSERT_EQ(static_cast(13), poly[0]); - ASSERT_EQ(static_cast(0), poly[1]); - ASSERT_EQ(static_cast(12), poly[2]); - ASSERT_EQ(static_cast(0), poly[3]); - ASSERT_EQ(static_cast(11), poly[4]); - ASSERT_EQ(static_cast(0), poly[5]); - - poly[0] = 2; - poly[2] = 3; - poly[4] = 4; - modulus[0] = 0xFFFFFFFFFFFFFFFF; - modulus[1] = 0xFFFFFFFFFFFFFFFF; - negate_poly_coeffmod(poly.get(), 3, modulus.get(), 2, poly.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFD), poly[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly[1]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFC), poly[2]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly[3]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFB), poly[4]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), poly[5]); - } - - TEST(PolyArithMod, AddPolyPolyCoeffMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(3, 2, pool)); - auto poly2(allocate_zero_poly(3, 2, pool)); - auto modulus(allocate_uint(2, pool)); - poly1[0] = 1; - poly1[2] = 3; - poly1[4] = 4; - poly2[0] = 1; - poly2[2] = 2; - poly2[4] = 4; - modulus[0] = 5; - modulus[1] = 0; - add_poly_poly_coeffmod(poly1.get(), poly2.get(), 3, modulus.get(), 2, poly1.get()); - ASSERT_EQ(static_cast(2), poly1[0]); - ASSERT_EQ(static_cast(0), poly1[1]); - ASSERT_EQ(static_cast(0), poly1[2]); - ASSERT_EQ(static_cast(0), poly1[3]); - ASSERT_EQ(static_cast(3), poly1[4]); - ASSERT_EQ(static_cast(0), poly1[5]); - } - - TEST(PolyArithMod, SubPolyPolyCoeffMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(3, 2, pool)); - auto poly2(allocate_zero_poly(3, 2, pool)); - auto modulus(allocate_uint(2, pool)); - poly1[0] = 4; - poly1[2] = 3; - poly1[4] = 2; - poly2[0] = 2; - poly2[2] = 3; - poly2[4] = 4; - modulus[0] = 5; - modulus[1] = 0; - sub_poly_poly_coeffmod(poly1.get(), poly2.get(), 3, modulus.get(), 2, poly1.get()); - ASSERT_EQ(static_cast(2), poly1[0]); - ASSERT_EQ(static_cast(0), poly1[1]); - ASSERT_EQ(static_cast(0), poly1[2]); - ASSERT_EQ(static_cast(0), poly1[3]); - ASSERT_EQ(static_cast(3), poly1[4]); - ASSERT_EQ(static_cast(0), poly1[5]); - } - } -} diff --git a/SEAL/native/tests/seal/util/polyarithsmallmod.cpp b/SEAL/native/tests/seal/util/polyarithsmallmod.cpp deleted file mode 100644 index 9c1251d..0000000 --- a/SEAL/native/tests/seal/util/polyarithsmallmod.cpp +++ /dev/null @@ -1,413 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/uintcore.h" -#include "seal/util/polycore.h" -#include "seal/util/polyarithsmallmod.h" -#include -#include - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(PolyArithSmallMod, SmallModuloPolyCoeffs) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly(allocate_zero_poly(3, 1, pool)); - auto modulus(allocate_uint(2, pool)); - poly[0] = 2; - poly[1] = 15; - poly[2] = 77; - SmallModulus mod(15); - modulo_poly_coeffs(poly.get(), 3, mod, poly.get()); - ASSERT_EQ(2ULL, poly[0]); - ASSERT_EQ(0ULL, poly[1]); - ASSERT_EQ(2ULL, poly[2]); - } - - TEST(PolyArithSmallMod, NegatePolyCoeffSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly(allocate_zero_poly(3, 1, pool)); - poly[0] = 2; - poly[1] = 3; - poly[2] = 4; - SmallModulus mod(15); - negate_poly_coeffmod(poly.get(), 3, mod, poly.get()); - ASSERT_EQ(static_cast(13), poly[0]); - ASSERT_EQ(static_cast(12), poly[1]); - ASSERT_EQ(static_cast(11), poly[2]); - - poly[0] = 2; - poly[1] = 3; - poly[2] = 4; - mod = 0xFFFFFFFFFFFFFFULL; - negate_poly_coeffmod(poly.get(), 3, mod, poly.get()); - ASSERT_EQ(0xFFFFFFFFFFFFFDULL, poly[0]); - ASSERT_EQ(0xFFFFFFFFFFFFFCULL, poly[1]); - ASSERT_EQ(0xFFFFFFFFFFFFFBULL, poly[2]); - } - - TEST(PolyArithSmallMod, AddPolyPolyCoeffSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(3, 1, pool)); - auto poly2(allocate_zero_poly(3, 1, pool)); - poly1[0] = 1; - poly1[1] = 3; - poly1[2] = 4; - poly2[0] = 1; - poly2[1] = 2; - poly2[2] = 4; - SmallModulus mod(5); - add_poly_poly_coeffmod(poly1.get(), poly2.get(), 3, mod, poly1.get()); - ASSERT_EQ(2ULL, poly1[0]); - ASSERT_EQ(0ULL, poly1[1]); - ASSERT_EQ(3ULL, poly1[2]); - } - - TEST(PolyArithSmallMod, SubPolyPolyCoeffSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(3, 1, pool)); - auto poly2(allocate_zero_poly(3, 1, pool)); - poly1[0] = 4; - poly1[1] = 3; - poly1[2] = 2; - poly2[0] = 2; - poly2[1] = 3; - poly2[2] = 4; - SmallModulus mod(5); - sub_poly_poly_coeffmod(poly1.get(), poly2.get(), 3, mod, poly1.get()); - ASSERT_EQ(2ULL, poly1[0]); - ASSERT_EQ(0ULL, poly1[1]); - ASSERT_EQ(3ULL, poly1[2]); - } - - TEST(PolyArithSmallMod, MultiplyPolyScalarCoeffSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly(allocate_zero_poly(3, 1, pool)); - poly[0] = 1; - poly[1] = 3; - poly[2] = 4; - uint64_t scalar = 3; - SmallModulus mod(5); - multiply_poly_scalar_coeffmod(poly.get(), 3, scalar, mod, poly.get()); - ASSERT_EQ(3ULL, poly[0]); - ASSERT_EQ(4ULL, poly[1]); - ASSERT_EQ(2ULL, poly[2]); - } - - TEST(PolyArithSmallMod, MultiplyPolyMonoCoeffSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(4, 1, pool)); - poly1[0] = 1; - poly1[1] = 3; - poly1[2] = 4; - poly1[3] = 2; - uint64_t mono_coeff = 3; - auto result(allocate_zero_poly(4, 1, pool)); - SmallModulus mod(5); - - size_t mono_exponent = 0; - negacyclic_multiply_poly_mono_coeffmod(poly1.get(), 1, mono_coeff, mono_exponent, mod, result.get(), pool); - ASSERT_EQ(3ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - - negacyclic_multiply_poly_mono_coeffmod(poly1.get(), 2, mono_coeff, mono_exponent, mod, result.get(), pool); - ASSERT_EQ(3ULL, result[0]); - ASSERT_EQ(4ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - - mono_exponent = 1; - negacyclic_multiply_poly_mono_coeffmod(poly1.get(), 2, mono_coeff, mono_exponent, mod, result.get(), pool); - ASSERT_EQ(1ULL, result[0]); - ASSERT_EQ(3ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - - negacyclic_multiply_poly_mono_coeffmod(poly1.get(), 4, mono_coeff, mono_exponent, mod, result.get(), pool); - ASSERT_EQ(4ULL, result[0]); - ASSERT_EQ(3ULL, result[1]); - ASSERT_EQ(4ULL, result[2]); - ASSERT_EQ(2ULL, result[3]); - - mono_coeff = 1; - negacyclic_multiply_poly_mono_coeffmod(poly1.get(), 4, mono_coeff, mono_exponent, mod, result.get(), pool); - ASSERT_EQ(3ULL, result[0]); - ASSERT_EQ(1ULL, result[1]); - ASSERT_EQ(3ULL, result[2]); - ASSERT_EQ(4ULL, result[3]); - - mono_coeff = 4; - mono_exponent = 3; - negacyclic_multiply_poly_mono_coeffmod(poly1.get(), 4, mono_coeff, mono_exponent, mod, result.get(), pool); - ASSERT_EQ(3ULL, result[0]); - ASSERT_EQ(4ULL, result[1]); - ASSERT_EQ(2ULL, result[2]); - ASSERT_EQ(4ULL, result[3]); - - mono_coeff = 1; - mono_exponent = 0; - negacyclic_multiply_poly_mono_coeffmod(poly1.get(), 4, mono_coeff, mono_exponent, mod, result.get(), pool); - ASSERT_EQ(1ULL, result[0]); - ASSERT_EQ(3ULL, result[1]); - ASSERT_EQ(4ULL, result[2]); - ASSERT_EQ(2ULL, result[3]); - } - - TEST(PolyArithSmallMod, MultiplyPolyPolyCoeffSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(3, 1, pool)); - auto poly2(allocate_zero_poly(3, 1, pool)); - auto result(allocate_zero_poly(5, 1, pool)); - poly1[0] = 1; - poly1[1] = 2; - poly1[2] = 3; - poly2[0] = 2; - poly2[1] = 3; - poly2[2] = 4; - SmallModulus mod(5); - multiply_poly_poly_coeffmod(poly1.get(), 3, poly2.get(), 3, mod, 5, result.get()); - ASSERT_EQ(2ULL, result[0]); - ASSERT_EQ(2ULL, result[1]); - ASSERT_EQ(1ULL, result[2]); - ASSERT_EQ(2ULL, result[3]); - ASSERT_EQ(2ULL, result[4]); - - poly2[0] = 2; - poly2[1] = 3; - multiply_poly_poly_coeffmod(poly1.get(), 3, poly2.get(), 2, mod, 5, result.get()); - ASSERT_EQ(2ULL, result[0]); - ASSERT_EQ(2ULL, result[1]); - ASSERT_EQ(2ULL, result[2]); - ASSERT_EQ(4ULL, result[3]); - ASSERT_EQ(0ULL, result[4]); - } - - TEST(PolyArithSmallMod, DividePolyPolyCoeffSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(5, 1, pool)); - auto poly2(allocate_zero_poly(5, 1, pool)); - auto result(allocate_zero_poly(5, 1, pool)); - auto quotient(allocate_zero_poly(5, 1, pool)); - SmallModulus mod(5); - - poly1[0] = 2; - poly1[1] = 2; - poly2[0] = 2; - poly2[1] = 3; - poly2[2] = 4; - - divide_poly_poly_coeffmod_inplace(poly1.get(), poly2.get(), 5, mod, result.get()); - ASSERT_EQ(2ULL, poly1[0]); - ASSERT_EQ(2ULL, poly1[1]); - ASSERT_EQ(0ULL, poly1[2]); - ASSERT_EQ(0ULL, poly1[3]); - ASSERT_EQ(0ULL, poly1[4]); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - ASSERT_EQ(0ULL, result[4]); - - poly1[0] = 2; - poly1[1] = 2; - poly1[2] = 1; - poly1[3] = 2; - poly1[4] = 2; - poly2[0] = 4; - poly2[1] = 3; - poly2[2] = 2; - - divide_poly_poly_coeffmod(poly1.get(), poly2.get(), 5, mod, quotient.get(), result.get()); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - ASSERT_EQ(0ULL, result[4]); - ASSERT_EQ(3ULL, quotient[0]); - ASSERT_EQ(2ULL, quotient[1]); - ASSERT_EQ(1ULL, quotient[2]); - ASSERT_EQ(0ULL, quotient[3]); - ASSERT_EQ(0ULL, quotient[4]); - } - - TEST(PolyArithSmallMod, DyadicProductCoeffSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly1(allocate_zero_poly(3, 1, pool)); - auto poly2(allocate_zero_poly(3, 1, pool)); - auto result(allocate_zero_poly(3, 1, pool)); - SmallModulus mod(13); - - poly1[0] = 1; - poly1[1] = 1; - poly1[2] = 1; - poly2[0] = 2; - poly2[1] = 3; - poly2[2] = 4; - - dyadic_product_coeffmod(poly1.get(), poly2.get(), 3, mod, result.get()); - ASSERT_EQ(2ULL, result[0]); - ASSERT_EQ(3ULL, result[1]); - ASSERT_EQ(4ULL, result[2]); - - poly1[0] = 0; - poly1[1] = 0; - poly1[2] = 0; - poly2[0] = 2; - poly2[1] = 3; - poly2[2] = 4; - - dyadic_product_coeffmod(poly1.get(), poly2.get(), 3, mod, result.get()); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - - poly1[0] = 3; - poly1[1] = 5; - poly1[2] = 8; - poly2[0] = 2; - poly2[1] = 3; - poly2[2] = 4; - - dyadic_product_coeffmod(poly1.get(), poly2.get(), 3, mod, result.get()); - ASSERT_EQ(6ULL, result[0]); - ASSERT_EQ(2ULL, result[1]); - ASSERT_EQ(6ULL, result[2]); - } - - TEST(PolyArithSmallMod, TryInvertPolyCoeffSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly(allocate_zero_poly(4, 1, pool)); - auto polymod(allocate_zero_poly(4, 1, pool)); - auto result(allocate_zero_poly(4, 1, pool)); - SmallModulus mod(5); - - polymod[0] = 4; - polymod[1] = 3; - polymod[2] = 0; - polymod[3] = 2; - - ASSERT_FALSE(try_invert_poly_coeffmod(poly.get(), polymod.get(), 4, mod, result.get(), pool)); - - poly[0] = 1; - ASSERT_TRUE(try_invert_poly_coeffmod(poly.get(), polymod.get(), 4, mod, result.get(), pool)); - ASSERT_EQ(1ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - - poly[0] = 1; - poly[1] = 2; - poly[2] = 3; - ASSERT_TRUE(try_invert_poly_coeffmod(poly.get(), polymod.get(), 4, mod, result.get(), pool)); - ASSERT_EQ(4ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(2ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - } - - TEST(PolyArithSmallMod, PolyInftyNormCoeffSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly(allocate_zero_poly(4, 1, pool)); - SmallModulus mod(10); - - poly[0] = 0; - poly[1] = 1; - poly[2] = 2; - poly[3] = 3; - ASSERT_EQ(0x3ULL, poly_infty_norm_coeffmod(poly.get(), 4, mod)); - - poly[0] = 0; - poly[1] = 1; - poly[2] = 2; - poly[3] = 8; - ASSERT_EQ(0x2ULL, poly_infty_norm_coeffmod(poly.get(), 4, mod)); - } - - TEST(PolyArithSmallMod, NegacyclicShiftPolyCoeffSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly(allocate_zero_poly(4, 1, pool)); - auto result(allocate_zero_poly(4, 1, pool)); - - SmallModulus mod(10); - size_t coeff_count = 4; - - negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 0, mod, result.get()); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 1, mod, result.get()); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 2, mod, result.get()); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 3, mod, result.get()); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - - poly[0] = 1; - poly[1] = 2; - poly[2] = 3; - poly[3] = 4; - negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 0, mod, result.get()); - ASSERT_EQ(1ULL, result[0]); - ASSERT_EQ(2ULL, result[1]); - ASSERT_EQ(3ULL, result[2]); - ASSERT_EQ(4ULL, result[3]); - negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 1, mod, result.get()); - ASSERT_EQ(6ULL, result[0]); - ASSERT_EQ(1ULL, result[1]); - ASSERT_EQ(2ULL, result[2]); - ASSERT_EQ(3ULL, result[3]); - negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 2, mod, result.get()); - ASSERT_EQ(7ULL, result[0]); - ASSERT_EQ(6ULL, result[1]); - ASSERT_EQ(1ULL, result[2]); - ASSERT_EQ(2ULL, result[3]); - negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 3, mod, result.get()); - ASSERT_EQ(8ULL, result[0]); - ASSERT_EQ(7ULL, result[1]); - ASSERT_EQ(6ULL, result[2]); - ASSERT_EQ(1ULL, result[3]); - - poly[0] = 1; - poly[1] = 2; - poly[2] = 3; - poly[3] = 4; - coeff_count = 2; - negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 1, mod, result.get()); - negacyclic_shift_poly_coeffmod(poly.get() + 2, coeff_count, 1, mod, result.get() + 2); - ASSERT_EQ(8ULL, result[0]); - ASSERT_EQ(1ULL, result[1]); - ASSERT_EQ(6ULL, result[2]); - ASSERT_EQ(3ULL, result[3]); - } - } -} diff --git a/SEAL/native/tests/seal/util/polycore.cpp b/SEAL/native/tests/seal/util/polycore.cpp deleted file mode 100644 index 55e595a..0000000 --- a/SEAL/native/tests/seal/util/polycore.cpp +++ /dev/null @@ -1,288 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/polycore.h" -#include "seal/util/uintarith.h" -#include - -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(PolyCore, AllocatePoly) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_poly(0, 0, pool)); - ASSERT_TRUE(nullptr == ptr.get()); - - ptr = allocate_poly(1, 0, pool); - ASSERT_TRUE(nullptr == ptr.get()); - - ptr = allocate_poly(0, 1, pool); - ASSERT_TRUE(nullptr == ptr.get()); - - ptr = allocate_poly(1, 1, pool); - ASSERT_TRUE(nullptr != ptr.get()); - - ptr = allocate_poly(2, 1, pool); - ASSERT_TRUE(nullptr != ptr.get()); - } - - TEST(PolyCore, SetZeroPoly) - { - set_zero_poly(0, 0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_poly(1, 1, pool)); - ptr[0] = 0x1234567812345678; - set_zero_poly(1, 1, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - - ptr = allocate_poly(2, 3, pool); - for (size_t i = 0; i < 6; ++i) - { - ptr[i] = 0x1234567812345678; - } - set_zero_poly(2, 3, ptr.get()); - for (size_t i = 0; i < 6; ++i) - { - ASSERT_EQ(static_cast(0), ptr[i]); - } - } - - TEST(PolyCore, AllocateZeroPoly) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_zero_poly(0, 0, pool)); - ASSERT_TRUE(nullptr == ptr.get()); - - ptr = allocate_zero_poly(1, 1, pool); - ASSERT_TRUE(nullptr != ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - - ptr = allocate_zero_poly(2, 3, pool); - ASSERT_TRUE(nullptr != ptr.get()); - for (size_t i = 0; i < 6; ++i) - { - ASSERT_EQ(static_cast(0), ptr[i]); - } - } - - TEST(PolyCore, GetPolyCoeff) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_zero_poly(2, 3, pool)); - *get_poly_coeff(ptr.get(), 0, 3) = 1; - *get_poly_coeff(ptr.get(), 1, 3) = 2; - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(static_cast(2), ptr[3]); - ASSERT_EQ(1ULL, *get_poly_coeff(ptr.get(), 0, 3)); - ASSERT_EQ(static_cast(2), *get_poly_coeff(ptr.get(), 1, 3)); - } - - TEST(PolyCore, SetPolyPoly) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr1(allocate_poly(2, 3, pool)); - auto ptr2(allocate_zero_poly(2, 3, pool)); - for (size_t i = 0; i < 6; ++i) - { - ptr1[i] = static_cast(i + 1); - } - set_poly_poly(ptr1.get(), 2, 3, ptr2.get()); - for (size_t i = 0; i < 6; ++i) - { - ASSERT_EQ(static_cast(i + 1), ptr2[i]); - } - - set_poly_poly(ptr1.get(), 2, 3, ptr1.get()); - for (size_t i = 0; i < 6; ++i) - { - ASSERT_EQ(static_cast(i + 1), ptr2[i]); - } - - ptr2 = allocate_poly(3, 4, pool); - for (size_t i = 0; i < 12; ++i) - { - ptr2[i] = 1ULL; - } - set_poly_poly(ptr1.get(), 2, 3, 3, 4, ptr2.get()); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(static_cast(2), ptr2[1]); - ASSERT_EQ(static_cast(3), ptr2[2]); - ASSERT_EQ(static_cast(0), ptr2[3]); - ASSERT_EQ(static_cast(4), ptr2[4]); - ASSERT_EQ(static_cast(5), ptr2[5]); - ASSERT_EQ(static_cast(6), ptr2[6]); - ASSERT_EQ(static_cast(0), ptr2[7]); - ASSERT_EQ(static_cast(0), ptr2[8]); - ASSERT_EQ(static_cast(0), ptr2[9]); - ASSERT_EQ(static_cast(0), ptr2[10]); - ASSERT_EQ(static_cast(0), ptr2[11]); - - ptr2 = allocate_poly(1, 2, pool); - ptr2[0] = 1; - ptr2[1] = 1; - set_poly_poly(ptr1.get(), 2, 3, 1, 2, ptr2.get()); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(static_cast(2), ptr2[1]); - } - - TEST(PolyCore, IsZeroPoly) - { - ASSERT_TRUE(is_zero_poly(nullptr, 0, 0)); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_poly(2, 3, pool)); - for (size_t i = 0; i < 6; ++i) - { - ptr[i] = 0; - } - ASSERT_TRUE(is_zero_poly(ptr.get(), 2, 3)); - for (size_t i = 0; i < 6; ++i) - { - ptr[i] = 1; - ASSERT_FALSE(is_zero_poly(ptr.get(), 2, 3)); - ptr[i] = 0; - } - } - - TEST(PolyCore, IsEqualPolyPoly) - { - ASSERT_TRUE(is_equal_poly_poly(nullptr, nullptr, 0, 0)); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr1(allocate_poly(2, 3, pool)); - auto ptr2(allocate_poly(2, 3, pool)); - for (size_t i = 0; i < 6; ++i) - { - ptr2[i] = ptr1[i] = static_cast(i + 1); - } - ASSERT_TRUE(is_equal_poly_poly(ptr1.get(), ptr2.get(), 2, 3)); - for (size_t i = 0; i < 6; ++i) - { - ptr2[i]--; - ASSERT_FALSE(is_equal_poly_poly(ptr1.get(), ptr2.get(), 2, 3)); - ptr2[i]++; - } - } - - TEST(PolyCore, IsOneZeroOnePoly) - { - ASSERT_FALSE(is_one_zero_one_poly(nullptr, 0, 0)); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly(allocate_zero_poly(4, 2, pool)); - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 0, 2)); - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 1, 2)); - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 2, 2)); - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 3, 2)); - - poly[0] = 2; - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 1, 2)); - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 2, 2)); - - poly[0] = 1; - ASSERT_TRUE(is_one_zero_one_poly(poly.get(), 1, 2)); - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 2, 2)); - - poly[2] = 2; - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 2, 2)); - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 3, 2)); - - poly[2] = 1; - ASSERT_TRUE(is_one_zero_one_poly(poly.get(), 2, 2)); - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 3, 2)); - - poly[4] = 1; - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 3, 2)); - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 4, 2)); - - poly[2] = 0; - ASSERT_TRUE(is_one_zero_one_poly(poly.get(), 3, 2)); - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 4, 2)); - - poly[6] = 2; - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 4, 2)); - - poly[6] = 1; - ASSERT_FALSE(is_one_zero_one_poly(poly.get(), 4, 2)); - - poly[4] = 0; - ASSERT_TRUE(is_one_zero_one_poly(poly.get(), 4, 2)); - } - - TEST(PolyCore, GetSignificantCoeffCountPoly) - { - ASSERT_EQ(0ULL, get_significant_coeff_count_poly(nullptr, 0, 0)); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_zero_poly(3, 2, pool)); - ASSERT_EQ(0ULL, get_significant_coeff_count_poly(ptr.get(), 3, 2)); - ptr[0] = 1; - ASSERT_EQ(1ULL, get_significant_coeff_count_poly(ptr.get(), 3, 2)); - ptr[1] = 1; - ASSERT_EQ(1ULL, get_significant_coeff_count_poly(ptr.get(), 3, 2)); - ptr[4] = 1; - ASSERT_EQ(3ULL, get_significant_coeff_count_poly(ptr.get(), 3, 2)); - ptr[4] = 0; - ptr[5] = 1; - ASSERT_EQ(3ULL, get_significant_coeff_count_poly(ptr.get(), 3, 2)); - } - - TEST(PolyCore, DuplicatePolyIfNeeded) - { - ASSERT_EQ(0ULL, get_significant_coeff_count_poly(nullptr, 0, 0)); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly(allocate_poly(3, 2, pool)); - for (size_t i = 0; i < 6; i++) - { - poly[i] = i + 1; - } - - auto ptr = duplicate_poly_if_needed(poly.get(), 3, 2, 3, 2, false, pool); - ASSERT_TRUE(ptr.get() == poly.get()); - ptr = duplicate_poly_if_needed(poly.get(), 3, 2, 2, 2, false, pool); - ASSERT_TRUE(ptr.get() == poly.get()); - ptr = duplicate_poly_if_needed(poly.get(), 3, 2, 2, 3, false, pool); - ASSERT_TRUE(ptr.get() != poly.get()); - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(static_cast(2), ptr[1]); - ASSERT_EQ(static_cast(0), ptr[2]); - ASSERT_EQ(static_cast(3), ptr[3]); - ASSERT_EQ(static_cast(4), ptr[4]); - ASSERT_EQ(static_cast(0), ptr[5]); - - ptr = duplicate_poly_if_needed(poly.get(), 3, 2, 3, 2, true, pool); - ASSERT_TRUE(ptr.get() != poly.get()); - for (size_t i = 0; i < 6; i++) - { - ASSERT_EQ(static_cast(i + 1), ptr[i]); - } - } - - TEST(PolyCore, ArePolyCoeffsLessThan) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto poly(allocate_zero_poly(3, 2, pool)); - poly[0] = 3; - poly[2] = 5; - poly[4] = 4; - - auto max(allocate_uint(1, pool)); - max[0] = 1; - ASSERT_FALSE(are_poly_coefficients_less_than(poly.get(), 3, 2, max.get(), 1)); - max[0] = 5; - ASSERT_FALSE(are_poly_coefficients_less_than(poly.get(), 3, 2, max.get(), 1)); - max[0] = 6; - ASSERT_TRUE(are_poly_coefficients_less_than(poly.get(), 3, 2, max.get(), 1)); - max[0] = 10; - ASSERT_TRUE(are_poly_coefficients_less_than(poly.get(), 3, 2, max.get(), 1)); - } - } -} diff --git a/SEAL/native/tests/seal/util/smallntt.cpp b/SEAL/native/tests/seal/util/smallntt.cpp deleted file mode 100644 index 21ee985..0000000 --- a/SEAL/native/tests/seal/util/smallntt.cpp +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/mempool.h" -#include "seal/util/uintcore.h" -#include "seal/util/polycore.h" -#include "seal/util/smallntt.h" -#include "seal/util/numth.h" -#include -#include -#include - -using namespace seal; -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(SmallNTTTablesTest, SmallNTTBasics) - { - MemoryPoolHandle pool = MemoryPoolHandle::Global(); - SmallNTTTables tables; - int coeff_count_power = 1; - SmallModulus modulus(get_prime(uint64_t(1) << coeff_count_power, 60)); - ASSERT_TRUE(tables.generate(coeff_count_power, modulus)); - ASSERT_EQ(2ULL, tables.coeff_count()); - ASSERT_EQ(1, tables.coeff_count_power()); - - coeff_count_power = 2; - modulus = get_prime(uint64_t(1) << coeff_count_power, 50); - ASSERT_TRUE(tables.generate(coeff_count_power, modulus)); - ASSERT_EQ(4ULL, tables.coeff_count()); - ASSERT_EQ(2, tables.coeff_count_power()); - - coeff_count_power = 10; - modulus = get_prime(uint64_t(1) << coeff_count_power, 40); - ASSERT_TRUE(tables.generate(coeff_count_power, modulus)); - ASSERT_EQ(1024ULL, tables.coeff_count()); - ASSERT_EQ(10, tables.coeff_count_power()); - } - - TEST(SmallNTTTablesTest, SmallNTTPrimitiveRootsTest) - { - MemoryPoolHandle pool = MemoryPoolHandle::Global(); - SmallNTTTables tables; - - int coeff_count_power = 1; - SmallModulus modulus(0xffffffffffc0001ULL); - ASSERT_TRUE(tables.generate(coeff_count_power, modulus)); - ASSERT_EQ(1ULL, tables.get_from_root_powers(0)); - ASSERT_EQ(288794978602139552ULL, tables.get_from_root_powers(1)); - uint64_t inv; - try_mod_inverse(tables.get_from_root_powers(1), modulus.value(), inv); - ASSERT_EQ(inv, tables.get_from_inv_root_powers(1)); - - coeff_count_power = 2; - ASSERT_TRUE(tables.generate(coeff_count_power, modulus)); - ASSERT_EQ(1ULL, tables.get_from_root_powers(0)); - ASSERT_EQ(288794978602139552ULL, tables.get_from_root_powers(1)); - ASSERT_EQ(178930308976060547ULL, tables.get_from_root_powers(2)); - ASSERT_EQ(748001537669050592ULL, tables.get_from_root_powers(3)); - } - - TEST(SmallNTTTablesTest, NegacyclicSmallNTTTest) - { - MemoryPoolHandle pool = MemoryPoolHandle::Global(); - SmallNTTTables tables; - - int coeff_count_power = 1; - SmallModulus modulus(0xffffffffffc0001ULL); - ASSERT_TRUE(tables.generate(coeff_count_power, modulus)); - auto poly(allocate_poly(2, 1, pool)); - poly[0] = 0; - poly[1] = 0; - ntt_negacyclic_harvey(poly.get(), tables); - ASSERT_EQ(0ULL, poly[0]); - ASSERT_EQ(0ULL, poly[1]); - - poly[0] = 1; - poly[1] = 0; - ntt_negacyclic_harvey(poly.get(), tables); - ASSERT_EQ(1ULL, poly[0]); - ASSERT_EQ(1ULL, poly[1]); - - poly[0] = 1; - poly[1] = 1; - ntt_negacyclic_harvey(poly.get(), tables); - ASSERT_EQ(288794978602139553ULL, poly[0]); - ASSERT_EQ(864126526004445282ULL, poly[1]); - } - - TEST(SmallNTTTablesTest, InverseNegacyclicSmallNTTTest) - { - MemoryPoolHandle pool = MemoryPoolHandle::Global(); - SmallNTTTables tables; - - int coeff_count_power = 3; - SmallModulus modulus(0xffffffffffc0001ULL); - ASSERT_TRUE(tables.generate(coeff_count_power, modulus)); - auto poly(allocate_zero_poly(800, 1, pool)); - auto temp(allocate_zero_poly(800, 1, pool)); - - inverse_ntt_negacyclic_harvey(poly.get(), tables); - for (size_t i = 0; i < 800; i++) - { - ASSERT_EQ(0ULL, poly[i]); - } - - random_device rd; - for (size_t i = 0; i < 800; i++) - { - poly[i] = static_cast(rd()) % modulus.value(); - temp[i] = poly[i]; - } - - ntt_negacyclic_harvey(poly.get(), tables); - inverse_ntt_negacyclic_harvey(poly.get(), tables); - for (size_t i = 0; i < 800; i++) - { - ASSERT_EQ(temp[i], poly[i]); - } - } - } -} diff --git a/SEAL/native/tests/seal/util/stringtouint64.cpp b/SEAL/native/tests/seal/util/stringtouint64.cpp deleted file mode 100644 index 0999f65..0000000 --- a/SEAL/native/tests/seal/util/stringtouint64.cpp +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include -#include - -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(StringToUInt64, IsHexCharTest) - { - ASSERT_TRUE(is_hex_char('0')); - ASSERT_TRUE(is_hex_char('1')); - ASSERT_TRUE(is_hex_char('2')); - ASSERT_TRUE(is_hex_char('3')); - ASSERT_TRUE(is_hex_char('4')); - ASSERT_TRUE(is_hex_char('5')); - ASSERT_TRUE(is_hex_char('6')); - ASSERT_TRUE(is_hex_char('7')); - ASSERT_TRUE(is_hex_char('8')); - ASSERT_TRUE(is_hex_char('9')); - ASSERT_TRUE(is_hex_char('A')); - ASSERT_TRUE(is_hex_char('B')); - ASSERT_TRUE(is_hex_char('C')); - ASSERT_TRUE(is_hex_char('D')); - ASSERT_TRUE(is_hex_char('E')); - ASSERT_TRUE(is_hex_char('F')); - ASSERT_TRUE(is_hex_char('a')); - ASSERT_TRUE(is_hex_char('b')); - ASSERT_TRUE(is_hex_char('c')); - ASSERT_TRUE(is_hex_char('d')); - ASSERT_TRUE(is_hex_char('e')); - ASSERT_TRUE(is_hex_char('f')); - - ASSERT_FALSE(is_hex_char('/')); - ASSERT_FALSE(is_hex_char(' ')); - ASSERT_FALSE(is_hex_char('+')); - ASSERT_FALSE(is_hex_char('\\')); - ASSERT_FALSE(is_hex_char('G')); - ASSERT_FALSE(is_hex_char('g')); - ASSERT_FALSE(is_hex_char('Z')); - ASSERT_FALSE(is_hex_char('Z')); - } - - TEST(StringToUInt64, HexToNibbleTest) - { - ASSERT_EQ(0, hex_to_nibble('0')); - ASSERT_EQ(1, hex_to_nibble('1')); - ASSERT_EQ(2, hex_to_nibble('2')); - ASSERT_EQ(3, hex_to_nibble('3')); - ASSERT_EQ(4, hex_to_nibble('4')); - ASSERT_EQ(5, hex_to_nibble('5')); - ASSERT_EQ(6, hex_to_nibble('6')); - ASSERT_EQ(7, hex_to_nibble('7')); - ASSERT_EQ(8, hex_to_nibble('8')); - ASSERT_EQ(9, hex_to_nibble('9')); - ASSERT_EQ(10, hex_to_nibble('A')); - ASSERT_EQ(11, hex_to_nibble('B')); - ASSERT_EQ(12, hex_to_nibble('C')); - ASSERT_EQ(13, hex_to_nibble('D')); - ASSERT_EQ(14, hex_to_nibble('E')); - ASSERT_EQ(15, hex_to_nibble('F')); - ASSERT_EQ(10, hex_to_nibble('a')); - ASSERT_EQ(11, hex_to_nibble('b')); - ASSERT_EQ(12, hex_to_nibble('c')); - ASSERT_EQ(13, hex_to_nibble('d')); - ASSERT_EQ(14, hex_to_nibble('e')); - ASSERT_EQ(15, hex_to_nibble('f')); - } - - TEST(StringToUInt64, GetHexStringBitCount) - { - ASSERT_EQ(0, get_hex_string_bit_count(nullptr, 0)); - ASSERT_EQ(0, get_hex_string_bit_count("0", 1)); - ASSERT_EQ(0, get_hex_string_bit_count("000000000", 9)); - ASSERT_EQ(1, get_hex_string_bit_count("1", 1)); - ASSERT_EQ(1, get_hex_string_bit_count("00001", 5)); - ASSERT_EQ(2, get_hex_string_bit_count("2", 1)); - ASSERT_EQ(2, get_hex_string_bit_count("00002", 5)); - ASSERT_EQ(2, get_hex_string_bit_count("3", 1)); - ASSERT_EQ(2, get_hex_string_bit_count("0003", 4)); - ASSERT_EQ(3, get_hex_string_bit_count("4", 1)); - ASSERT_EQ(3, get_hex_string_bit_count("5", 1)); - ASSERT_EQ(3, get_hex_string_bit_count("6", 1)); - ASSERT_EQ(3, get_hex_string_bit_count("7", 1)); - ASSERT_EQ(4, get_hex_string_bit_count("8", 1)); - ASSERT_EQ(4, get_hex_string_bit_count("9", 1)); - ASSERT_EQ(4, get_hex_string_bit_count("A", 1)); - ASSERT_EQ(4, get_hex_string_bit_count("B", 1)); - ASSERT_EQ(4, get_hex_string_bit_count("C", 1)); - ASSERT_EQ(4, get_hex_string_bit_count("D", 1)); - ASSERT_EQ(4, get_hex_string_bit_count("E", 1)); - ASSERT_EQ(4, get_hex_string_bit_count("F", 1)); - ASSERT_EQ(5, get_hex_string_bit_count("10", 2)); - ASSERT_EQ(5, get_hex_string_bit_count("00010", 5)); - ASSERT_EQ(5, get_hex_string_bit_count("11", 2)); - ASSERT_EQ(5, get_hex_string_bit_count("1F", 2)); - ASSERT_EQ(6, get_hex_string_bit_count("20", 2)); - ASSERT_EQ(6, get_hex_string_bit_count("2F", 2)); - ASSERT_EQ(7, get_hex_string_bit_count("7F", 2)); - ASSERT_EQ(7, get_hex_string_bit_count("0007F", 5)); - ASSERT_EQ(8, get_hex_string_bit_count("80", 2)); - ASSERT_EQ(8, get_hex_string_bit_count("FF", 2)); - ASSERT_EQ(8, get_hex_string_bit_count("00FF", 4)); - ASSERT_EQ(9, get_hex_string_bit_count("100", 3)); - ASSERT_EQ(9, get_hex_string_bit_count("000100", 6)); - ASSERT_EQ(22, get_hex_string_bit_count("200000", 6)); - ASSERT_EQ(35, get_hex_string_bit_count("7FFF30001", 9)); - - ASSERT_EQ(15, get_hex_string_bit_count("7FFF30001", 4)); - ASSERT_EQ(3, get_hex_string_bit_count("7FFF30001", 1)); - ASSERT_EQ(0, get_hex_string_bit_count("7FFF30001", 0)); - } - - TEST(StringToUInt64, HexStringToUInt64) - { - uint64_t correct[3]; - uint64_t parsed[3]; - - correct[0] = 0; - correct[1] = 0; - correct[2] = 0; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("0", 1, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("0", 1, 1, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 1 * sizeof(uint64_t))); - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint(nullptr, 0, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 1; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("1", 1, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("01", 2, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("001", 3, 1, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 1 * sizeof(uint64_t))); - - correct[0] = 0xF; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("F", 1, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 0x10; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("10", 2, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("010", 3, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 0x100; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("100", 3, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 0x123; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("123", 3, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("00000123", 8, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 0; - correct[1] = 1; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("10000000000000000", 17, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 0x1123456789ABCDEF; - correct[1] = 0x1; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("11123456789ABCDEF", 17, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("000011123456789ABCDEF", 21, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 0x3456789ABCDEF123; - correct[1] = 0x23456789ABCDEF12; - correct[2] = 0x123456789ABCDEF1; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("123456789ABCDEF123456789ABCDEF123456789ABCDEF123", 48, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 0xFFFFFFFFFFFFFFFF; - correct[1] = 0xFFFFFFFFFFFFFFFF; - correct[2] = 0xFFFFFFFFFFFFFFFF; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", 48, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 0x100; - correct[1] = 0; - correct[2] = 0; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("100", 3, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 0x10; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("100", 2, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 0x1; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("100", 1, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - - correct[0] = 0; - parsed[0] = 0x123; - parsed[1] = 0x123; - parsed[2] = 0x123; - hex_string_to_uint("100", 0, 3, parsed); - ASSERT_EQ(0, memcmp(correct, parsed, 3 * sizeof(uint64_t))); - } - } -} diff --git a/SEAL/native/tests/seal/util/uint64tostring.cpp b/SEAL/native/tests/seal/util/uint64tostring.cpp deleted file mode 100644 index 1ce1337..0000000 --- a/SEAL/native/tests/seal/util/uint64tostring.cpp +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/common.h" -#include "seal/util/uintcore.h" -#include "seal/util/polycore.h" -#include "seal/util/mempool.h" -#include - -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(UInt64ToString, NibbleToUpperHexTest) - { - ASSERT_EQ('0', nibble_to_upper_hex(0)); - ASSERT_EQ('1', nibble_to_upper_hex(1)); - ASSERT_EQ('2', nibble_to_upper_hex(2)); - ASSERT_EQ('3', nibble_to_upper_hex(3)); - ASSERT_EQ('4', nibble_to_upper_hex(4)); - ASSERT_EQ('5', nibble_to_upper_hex(5)); - ASSERT_EQ('6', nibble_to_upper_hex(6)); - ASSERT_EQ('7', nibble_to_upper_hex(7)); - ASSERT_EQ('8', nibble_to_upper_hex(8)); - ASSERT_EQ('9', nibble_to_upper_hex(9)); - ASSERT_EQ('A', nibble_to_upper_hex(10)); - ASSERT_EQ('B', nibble_to_upper_hex(11)); - ASSERT_EQ('C', nibble_to_upper_hex(12)); - ASSERT_EQ('D', nibble_to_upper_hex(13)); - ASSERT_EQ('E', nibble_to_upper_hex(14)); - ASSERT_EQ('F', nibble_to_upper_hex(15)); - } - - TEST(UInt64ToString, UInt64ToHexString) - { - uint64_t number[] = { 0, 0, 0 }; - string correct = "0"; - ASSERT_EQ(correct, uint_to_hex_string(number, 3)); - ASSERT_EQ(correct, uint_to_hex_string(number, 1)); - ASSERT_EQ(correct, uint_to_hex_string(number, 0)); - ASSERT_EQ(correct, uint_to_hex_string(nullptr, 0)); - - number[0] = 1; - correct = "1"; - ASSERT_EQ(correct, uint_to_hex_string(number, 3)); - ASSERT_EQ(correct, uint_to_hex_string(number, 1)); - - number[0] = 0xF; - correct = "F"; - ASSERT_EQ(correct, uint_to_hex_string(number, 3)); - - number[0] = 0x10; - correct = "10"; - ASSERT_EQ(correct, uint_to_hex_string(number, 3)); - - number[0] = 0x100; - correct = "100"; - ASSERT_EQ(correct, uint_to_hex_string(number, 3)); - - number[0] = 0x123; - correct = "123"; - ASSERT_EQ(correct, uint_to_hex_string(number, 3)); - - number[0] = 0; - number[1] = 1; - correct = "10000000000000000"; - ASSERT_EQ(correct, uint_to_hex_string(number, 3)); - - number[0] = 0x1123456789ABCDEF; - number[1] = 0x1; - correct = "11123456789ABCDEF"; - ASSERT_EQ(correct, uint_to_hex_string(number, 3)); - - number[0] = 0x3456789ABCDEF123; - number[1] = 0x23456789ABCDEF12; - number[2] = 0x123456789ABCDEF1; - correct = "123456789ABCDEF123456789ABCDEF123456789ABCDEF123"; - ASSERT_EQ(correct, uint_to_hex_string(number, 3)); - - number[0] = 0xFFFFFFFFFFFFFFFF; - number[1] = 0xFFFFFFFFFFFFFFFF; - number[2] = 0xFFFFFFFFFFFFFFFF; - correct = "FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF"; - ASSERT_EQ(correct, uint_to_hex_string(number, 3)); - } - - TEST(UInt64ToString, UInt64ToDecString) - { - uint64_t number[] = { 0, 0, 0 }; - string correct = "0"; - MemoryPool &pool = *global_variables::global_memory_pool; - ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); - ASSERT_EQ(correct, uint_to_dec_string(number, 1, pool)); - ASSERT_EQ(correct, uint_to_dec_string(number, 0, pool)); - ASSERT_EQ(correct, uint_to_dec_string(nullptr, 0, pool)); - - number[0] = 1; - correct = "1"; - ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); - ASSERT_EQ(correct, uint_to_dec_string(number, 1, pool)); - - number[0] = 9; - correct = "9"; - ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); - - number[0] = 10; - correct = "10"; - ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); - - number[0] = 123; - correct = "123"; - ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); - - number[0] = 987654321; - correct = "987654321"; - ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); - - number[0] = 0; - number[1] = 1; - correct = "18446744073709551616"; - ASSERT_EQ(correct, uint_to_dec_string(number, 3, pool)); - } - - TEST(UInt64ToString, PolyToHexString) - { - uint64_t number[] = { 0, 0, 0, 0 }; - string correct = "0"; - ASSERT_EQ(correct, poly_to_hex_string(number, 0, 1)); - ASSERT_EQ(correct, poly_to_hex_string(number, 4, 0)); - ASSERT_EQ(correct, poly_to_hex_string(number, 1, 1)); - ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); - ASSERT_EQ(correct, poly_to_hex_string(number, 2, 2)); - ASSERT_EQ(correct, poly_to_hex_string(number, 1, 4)); - ASSERT_EQ(correct, poly_to_hex_string(nullptr, 0, 0)); - - number[0] = 1; - correct = "1"; - ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); - ASSERT_EQ(correct, poly_to_hex_string(number, 2, 2)); - ASSERT_EQ(correct, poly_to_hex_string(number, 1, 4)); - - number[0] = 0; - number[1] = 1; - correct = "1x^1"; - ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); - correct = "10000000000000000"; - ASSERT_EQ(correct, poly_to_hex_string(number, 2, 2)); - ASSERT_EQ(correct, poly_to_hex_string(number, 1, 4)); - - number[0] = 1; - number[1] = 0; - number[2] = 0; - number[3] = 1; - correct = "1x^3 + 1"; - ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); - correct = "10000000000000000x^1 + 1"; - ASSERT_EQ(correct, poly_to_hex_string(number, 2, 2)); - correct = "1000000000000000000000000000000000000000000000001"; - ASSERT_EQ(correct, poly_to_hex_string(number, 1, 4)); - - number[0] = 0xF00000000000000F; - number[1] = 0xF0F0F0F0F0F0F0F0; - number[2] = 0; - number[3] = 0; - correct = "F0F0F0F0F0F0F0F0x^1 + F00000000000000F"; - ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); - correct = "F0F0F0F0F0F0F0F0F00000000000000F"; - - number[2] = 0xF0FF0F0FF0F0FF0F; - number[3] = 0xBABABABABABABABA; - correct = "BABABABABABABABAF0FF0F0FF0F0FF0Fx^1 + F0F0F0F0F0F0F0F0F00000000000000F"; - ASSERT_EQ(correct, poly_to_hex_string(number, 2, 2)); - correct = "BABABABABABABABAx^3 + F0FF0F0FF0F0FF0Fx^2 + F0F0F0F0F0F0F0F0x^1 + F00000000000000F"; - ASSERT_EQ(correct, poly_to_hex_string(number, 4, 1)); - } - } -} diff --git a/SEAL/native/tests/seal/util/uintarith.cpp b/SEAL/native/tests/seal/util/uintarith.cpp deleted file mode 100644 index 21ecac7..0000000 --- a/SEAL/native/tests/seal/util/uintarith.cpp +++ /dev/null @@ -1,1654 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/uintarith.h" -#include - -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(UIntArith, AddUInt64Generic) - { - unsigned long long result; - ASSERT_FALSE(add_uint64_generic(0ULL, 0ULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_FALSE(add_uint64_generic(1ULL, 1ULL, 0, &result)); - ASSERT_EQ(2ULL, result); - ASSERT_FALSE(add_uint64_generic(1ULL, 0ULL, 1, &result)); - ASSERT_EQ(2ULL, result); - ASSERT_FALSE(add_uint64_generic(0ULL, 1ULL, 1, &result)); - ASSERT_EQ(2ULL, result); - ASSERT_FALSE(add_uint64_generic(1ULL, 1ULL, 1, &result)); - ASSERT_EQ(3ULL, result); - ASSERT_TRUE(add_uint64_generic(0xFFFFFFFFFFFFFFFFULL, 1ULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_TRUE(add_uint64_generic(1ULL, 0xFFFFFFFFFFFFFFFFULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_TRUE(add_uint64_generic(1ULL, 0xFFFFFFFFFFFFFFFFULL, 1, &result)); - ASSERT_EQ(1ULL, result); - ASSERT_TRUE(add_uint64_generic(2ULL, 0xFFFFFFFFFFFFFFFEULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_TRUE(add_uint64_generic(2ULL, 0xFFFFFFFFFFFFFFFEULL, 1, &result)); - ASSERT_EQ(1ULL, result); - ASSERT_FALSE(add_uint64_generic(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 0, &result)); - ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, result); - ASSERT_TRUE(add_uint64_generic(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 1, &result)); - ASSERT_EQ(0x0ULL, result); - } - -#if SEAL_COMPILER == SEAL_COMPILER_MSVC -#pragma optimize ("", off) -#elif SEAL_COMPILER == SEAL_COMPILER_GCC -#pragma GCC push_options -#pragma GCC optimize ("O0") -#elif SEAL_COMPILER == SEAL_COMPILER_CLANG -#pragma clang optimize off -#endif - - TEST(UIntArith, AddUInt64) - { - unsigned long long result; - ASSERT_FALSE(add_uint64(0ULL, 0ULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_FALSE(add_uint64(1ULL, 1ULL, 0, &result)); - ASSERT_EQ(2ULL, result); - ASSERT_FALSE(add_uint64(1ULL, 0ULL, 1, &result)); - ASSERT_EQ(2ULL, result); - ASSERT_FALSE(add_uint64(0ULL, 1ULL, 1, &result)); - ASSERT_EQ(2ULL, result); - ASSERT_FALSE(add_uint64(1ULL, 1ULL, 1, &result)); - ASSERT_EQ(3ULL, result); - ASSERT_TRUE(add_uint64(0xFFFFFFFFFFFFFFFFULL, 1ULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_TRUE(add_uint64(1ULL, 0xFFFFFFFFFFFFFFFFULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_TRUE(add_uint64(1ULL, 0xFFFFFFFFFFFFFFFFULL, 1, &result)); - ASSERT_EQ(1ULL, result); - ASSERT_TRUE(add_uint64(2ULL, 0xFFFFFFFFFFFFFFFEULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_TRUE(add_uint64(2ULL, 0xFFFFFFFFFFFFFFFEULL, 1, &result)); - ASSERT_EQ(1ULL, result); - ASSERT_FALSE(add_uint64(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 0, &result)); - ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, result); - ASSERT_TRUE(add_uint64(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 1, &result)); - ASSERT_EQ(0x0ULL, result); - } - -#if SEAL_COMPILER == SEAL_COMPILER_MSVC -#pragma optimize ("", on) -#elif SEAL_COMPILER == SEAL_COMPILER_GCC -#pragma GCC pop_options -#elif SEAL_COMPILER == SEAL_COMPILER_CLANG -#pragma clang optimize on -#endif - - TEST(UIntArith, SubUInt64Generic) - { - unsigned long long result; - ASSERT_FALSE(sub_uint64_generic(0ULL, 0ULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_FALSE(sub_uint64_generic(1ULL, 1ULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_FALSE(sub_uint64_generic(1ULL, 0ULL, 1, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_TRUE(sub_uint64_generic(0ULL, 1ULL, 1, &result)); - ASSERT_EQ(0xFFFFFFFFFFFFFFFEULL, result); - ASSERT_TRUE(sub_uint64_generic(1ULL, 1ULL, 1, &result)); - ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, result); - ASSERT_FALSE(sub_uint64_generic(0xFFFFFFFFFFFFFFFFULL, 1ULL, 0, &result)); - ASSERT_EQ(0xFFFFFFFFFFFFFFFEULL, result); - ASSERT_TRUE(sub_uint64_generic(1ULL, 0xFFFFFFFFFFFFFFFFULL, 0, &result)); - ASSERT_EQ(2ULL, result); - ASSERT_TRUE(sub_uint64_generic(1ULL, 0xFFFFFFFFFFFFFFFFULL, 1, &result)); - ASSERT_EQ(1ULL, result); - ASSERT_TRUE(sub_uint64_generic(2ULL, 0xFFFFFFFFFFFFFFFEULL, 0, &result)); - ASSERT_EQ(4ULL, result); - ASSERT_TRUE(sub_uint64_generic(2ULL, 0xFFFFFFFFFFFFFFFEULL, 1, &result)); - ASSERT_EQ(3ULL, result); - ASSERT_FALSE(sub_uint64_generic(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 0, &result)); - ASSERT_EQ(0xE01E01E01E01E01FULL, result); - ASSERT_FALSE(sub_uint64_generic(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 1, &result)); - ASSERT_EQ(0xE01E01E01E01E01EULL, result); - } - - TEST(UIntArith, SubUInt64) - { - unsigned long long result; - ASSERT_FALSE(sub_uint64(0ULL, 0ULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_FALSE(sub_uint64(1ULL, 1ULL, 0, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_FALSE(sub_uint64(1ULL, 0ULL, 1, &result)); - ASSERT_EQ(0ULL, result); - ASSERT_TRUE(sub_uint64(0ULL, 1ULL, 1, &result)); - ASSERT_EQ(0xFFFFFFFFFFFFFFFEULL, result); - ASSERT_TRUE(sub_uint64(1ULL, 1ULL, 1, &result)); - ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, result); - ASSERT_FALSE(sub_uint64(0xFFFFFFFFFFFFFFFFULL, 1ULL, 0, &result)); - ASSERT_EQ(0xFFFFFFFFFFFFFFFEULL, result); - ASSERT_TRUE(sub_uint64(1ULL, 0xFFFFFFFFFFFFFFFFULL, 0, &result)); - ASSERT_EQ(2ULL, result); - ASSERT_TRUE(sub_uint64(1ULL, 0xFFFFFFFFFFFFFFFFULL, 1, &result)); - ASSERT_EQ(1ULL, result); - ASSERT_TRUE(sub_uint64(2ULL, 0xFFFFFFFFFFFFFFFEULL, 0, &result)); - ASSERT_EQ(4ULL, result); - ASSERT_TRUE(sub_uint64(2ULL, 0xFFFFFFFFFFFFFFFEULL, 1, &result)); - ASSERT_EQ(3ULL, result); - ASSERT_FALSE(sub_uint64(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 0, &result)); - ASSERT_EQ(0xE01E01E01E01E01FULL, result); - ASSERT_FALSE(sub_uint64(0xF00F00F00F00F00FULL, 0x0FF0FF0FF0FF0FF0ULL, 1, &result)); - ASSERT_EQ(0xE01E01E01E01E01EULL, result); - } - - TEST(UIntArith, AddUIntUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - auto ptr3(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - ptr2[0] = 0; - ptr2[1] = 0; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_FALSE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0; - ptr2[1] = 0; - ptr3[0] = 0; - ptr3[1] = 0; - ASSERT_FALSE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFE; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 1; - ptr2[1] = 0; - ptr3[0] = 0; - ptr3[1] = 0; - ASSERT_FALSE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 1; - ptr2[1] = 0; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_TRUE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0; - ptr3[1] = 0; - - ASSERT_TRUE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - ASSERT_TRUE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0; - ptr2[0] = 1; - ptr2[1] = 0; - ptr3[0] = 0; - ptr3[1] = 0; - ASSERT_FALSE(add_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(1ULL, ptr3[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 5; - ptr2[0] = 1; - ptr2[1] = 0; - ptr3[0] = 0; - ptr3[1] = 0; - ASSERT_FALSE(add_uint_uint(ptr.get(), 2, ptr2.get(), 1, false, 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(6), ptr3[1]); - ASSERT_FALSE(add_uint_uint(ptr.get(), 2, ptr2.get(), 1, true, 2, ptr3.get()) != 0); - ASSERT_EQ(1ULL, ptr3[0]); - ASSERT_EQ(static_cast(6), ptr3[1]); - } - - TEST(UIntArith, SubUIntUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - auto ptr3(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - ptr2[0] = 0; - ptr2[1] = 0; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0; - ptr2[1] = 0; - ptr3[0] = 0; - ptr3[1] = 0; - ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 1; - ptr2[1] = 0; - ptr3[0] = 0; - ptr3[1] = 0; - ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - - ptr[0] = 0; - ptr[1] = 0; - ptr2[0] = 1; - ptr2[1] = 0; - ptr3[0] = 0; - ptr3[1] = 0; - ASSERT_TRUE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - ASSERT_TRUE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0; - ptr3[1] = 0; - ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()) != 0); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFE; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0; - ptr3[1] = 0; - ASSERT_TRUE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - - ptr[0] = 0; - ptr[1] = 1; - ptr2[0] = 1; - ptr2[1] = 0; - ptr3[0] = 0; - ptr3[1] = 0; - ASSERT_FALSE(sub_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - - ptr[0] = 0; - ptr[1] = 1; - ptr2[0] = 1; - ptr2[1] = 0; - ptr3[0] = 0; - ptr3[1] = 0; - ASSERT_FALSE(sub_uint_uint(ptr.get(), 2, ptr2.get(), 1, false, 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - ASSERT_FALSE(sub_uint_uint(ptr.get(), 2, ptr2.get(), 1, true, 2, ptr3.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - } - - TEST(UIntArith, AddUIntUInt64) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - - ptr[0] = 0ULL; - ptr[1] = 0ULL; - ASSERT_FALSE(add_uint_uint64(ptr.get(), 0ULL, 2, ptr2.get())); - ASSERT_EQ(0ULL, ptr2[0]); - ASSERT_EQ(0ULL, ptr2[1]); - - ptr[0] = 0xFFFFFFFF00000000ULL; - ptr[1] = 0ULL; - ASSERT_FALSE(add_uint_uint64(ptr.get(), 0xFFFFFFFFULL, 2, ptr2.get())); - ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[0]); - ASSERT_EQ(0ULL, ptr2[1]); - - ptr[0] = 0xFFFFFFFF00000000ULL; - ptr[1] = 0xFFFFFFFF00000000ULL; - ASSERT_FALSE(add_uint_uint64(ptr.get(), 0x100000000ULL, 2, ptr2.get())); - ASSERT_EQ(0ULL, ptr2[0]); - ASSERT_EQ(0xFFFFFFFF00000001ULL, ptr2[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFFULL; - ptr[1] = 0xFFFFFFFFFFFFFFFFULL; - ASSERT_TRUE(add_uint_uint64(ptr.get(), 1ULL, 2, ptr2.get())); - ASSERT_EQ(0ULL, ptr2[0]); - ASSERT_EQ(0ULL, ptr2[1]); - } - - TEST(UIntArith, SubUIntUInt64) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - - ptr[0] = 0ULL; - ptr[1] = 0ULL; - ASSERT_FALSE(sub_uint_uint64(ptr.get(), 0ULL, 2, ptr2.get())); - ASSERT_EQ(0ULL, ptr2[0]); - ASSERT_EQ(0ULL, ptr2[1]); - - ptr[0] = 0ULL; - ptr[1] = 0ULL; - ASSERT_TRUE(sub_uint_uint64(ptr.get(), 1ULL, 2, ptr2.get())); - ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[0]); - ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[1]); - - ptr[0] = 1ULL; - ptr[1] = 0ULL; - ASSERT_TRUE(sub_uint_uint64(ptr.get(), 2ULL, 2, ptr2.get())); - ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[0]); - ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[1]); - - ptr[0] = 0xFFFFFFFF00000000ULL; - ptr[1] = 0ULL; - ASSERT_FALSE(sub_uint_uint64(ptr.get(), 0xFFFFFFFFULL, 2, ptr2.get())); - ASSERT_EQ(0xFFFFFFFE00000001ULL, ptr2[0]); - ASSERT_EQ(0ULL, ptr2[1]); - - ptr[0] = 0xFFFFFFFF00000000ULL; - ptr[1] = 0xFFFFFFFF00000000ULL; - ASSERT_FALSE(sub_uint_uint64(ptr.get(), 0x100000000ULL, 2, ptr2.get())); - ASSERT_EQ(0xFFFFFFFE00000000ULL, ptr2[0]); - ASSERT_EQ(0xFFFFFFFF00000000ULL, ptr2[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFFULL; - ptr[1] = 0xFFFFFFFFFFFFFFFFULL; - ASSERT_FALSE(sub_uint_uint64(ptr.get(), 1ULL, 2, ptr2.get())); - ASSERT_EQ(0xFFFFFFFFFFFFFFFEULL, ptr2[0]); - ASSERT_EQ(0xFFFFFFFFFFFFFFFFULL, ptr2[1]); - } - - TEST(UIntArith, IncrementUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr1(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - ptr1[0] = 0; - ptr1[1] = 0; - ASSERT_FALSE(increment_uint(ptr1.get(), 2, ptr2.get()) != 0); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ASSERT_FALSE(increment_uint(ptr2.get(), 2, ptr1.get()) != 0); - ASSERT_EQ(static_cast(2), ptr1[0]); - ASSERT_EQ(static_cast(0), ptr1[1]); - - ptr1[0] = 0xFFFFFFFFFFFFFFFF; - ptr1[1] = 0; - ASSERT_FALSE(increment_uint(ptr1.get(), 2, ptr2.get()) != 0); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(1ULL, ptr2[1]); - ASSERT_FALSE(increment_uint(ptr2.get(), 2, ptr1.get()) != 0); - ASSERT_EQ(1ULL, ptr1[0]); - ASSERT_EQ(1ULL, ptr1[1]); - - ptr1[0] = 0xFFFFFFFFFFFFFFFF; - ptr1[1] = 1; - ASSERT_FALSE(increment_uint(ptr1.get(), 2, ptr2.get()) != 0); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(2), ptr2[1]); - ASSERT_FALSE(increment_uint(ptr2.get(), 2, ptr1.get()) != 0); - ASSERT_EQ(1ULL, ptr1[0]); - ASSERT_EQ(static_cast(2), ptr1[1]); - - ptr1[0] = 0xFFFFFFFFFFFFFFFE; - ptr1[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_FALSE(increment_uint(ptr1.get(), 2, ptr2.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[1]); - ASSERT_TRUE(increment_uint(ptr2.get(), 2, ptr1.get()) != 0); - ASSERT_EQ(static_cast(0), ptr1[0]); - ASSERT_EQ(static_cast(0), ptr1[1]); - ASSERT_FALSE(increment_uint(ptr1.get(), 2, ptr2.get()) != 0); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - } - - TEST(UIntArith, DecrementUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr1(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - ptr1[0] = 2; - ptr1[1] = 2; - ASSERT_FALSE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(static_cast(2), ptr2[1]); - ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); - ASSERT_EQ(static_cast(0), ptr1[0]); - ASSERT_EQ(static_cast(2), ptr1[1]); - ASSERT_FALSE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[0]); - ASSERT_EQ(1ULL, ptr2[1]); - ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr1[0]); - ASSERT_EQ(1ULL, ptr1[1]); - - ptr1[0] = 2; - ptr1[1] = 1; - ASSERT_FALSE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(1ULL, ptr2[1]); - ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); - ASSERT_EQ(static_cast(0), ptr1[0]); - ASSERT_EQ(1ULL, ptr1[1]); - ASSERT_FALSE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr1[0]); - ASSERT_EQ(static_cast(0), ptr1[1]); - - ptr1[0] = 2; - ptr1[1] = 0; - ASSERT_FALSE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); - ASSERT_EQ(static_cast(0), ptr1[0]); - ASSERT_EQ(static_cast(0), ptr1[1]); - ASSERT_TRUE(decrement_uint(ptr1.get(), 2, ptr2.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr2[1]); - ASSERT_FALSE(decrement_uint(ptr2.get(), 2, ptr1.get()) != 0); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr1[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr1[1]); - } - - TEST(UIntArith, NegateUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - negate_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 1; - ptr[1] = 0; - negate_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); - negate_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 2; - ptr[1] = 0; - negate_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); - negate_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(2), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 0; - ptr[1] = 1; - negate_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); - negate_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(1ULL, ptr[1]); - - ptr[0] = 0; - ptr[1] = 2; - negate_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[1]); - negate_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(2), ptr[1]); - - ptr[0] = 1; - ptr[1] = 1; - negate_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[1]); - negate_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(1ULL, ptr[1]); - } - - TEST(UIntArith, LeftShiftUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - left_shift_uint(ptr.get(), 0, 2, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - left_shift_uint(ptr.get(), 10, 2, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - left_shift_uint(ptr.get(), 10, 2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 0x5555555555555555; - ptr[1] = 0xAAAAAAAAAAAAAAAA; - left_shift_uint(ptr.get(), 0, 2, ptr2.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); - left_shift_uint(ptr.get(), 0, 2, ptr.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr[1]); - left_shift_uint(ptr.get(), 1, 2, ptr2.get()); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[0]); - ASSERT_EQ(static_cast(0x5555555555555554), ptr2[1]); - left_shift_uint(ptr.get(), 2, 2, ptr2.get()); - ASSERT_EQ(static_cast(0x5555555555555554), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAA9), ptr2[1]); - left_shift_uint(ptr.get(), 64, 2, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[1]); - left_shift_uint(ptr.get(), 65, 2, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); - left_shift_uint(ptr.get(), 127, 2, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0x8000000000000000), ptr2[1]); - - left_shift_uint(ptr.get(), 2, 2, ptr.get()); - ASSERT_EQ(static_cast(0x5555555555555554), ptr[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAA9), ptr[1]); - left_shift_uint(ptr.get(), 64, 2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0x5555555555555554), ptr[1]); - } - - TEST(UIntArith, LeftShiftUInt128) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - left_shift_uint128(ptr.get(), 0, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - left_shift_uint128(ptr.get(), 10, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - left_shift_uint128(ptr.get(), 10, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 0x5555555555555555; - ptr[1] = 0xAAAAAAAAAAAAAAAA; - left_shift_uint128(ptr.get(), 0, ptr2.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); - left_shift_uint128(ptr.get(), 0, ptr.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr[1]); - left_shift_uint128(ptr.get(), 1, ptr2.get()); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[0]); - ASSERT_EQ(static_cast(0x5555555555555554), ptr2[1]); - left_shift_uint128(ptr.get(), 2, ptr2.get()); - ASSERT_EQ(static_cast(0x5555555555555554), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAA9), ptr2[1]); - left_shift_uint128(ptr.get(), 64, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[1]); - left_shift_uint128(ptr.get(), 65, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); - left_shift_uint128(ptr.get(), 127, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0x8000000000000000), ptr2[1]); - - left_shift_uint128(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0x5555555555555554), ptr[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAA9), ptr[1]); - left_shift_uint128(ptr.get(), 64, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0x5555555555555554), ptr[1]); - } - - TEST(UIntArith, LeftShiftUInt192) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(3, pool)); - auto ptr2(allocate_uint(3, pool)); - ptr[0] = 0; - ptr[1] = 0; - ptr[2] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[2] = 0xFFFFFFFFFFFFFFFF; - left_shift_uint192(ptr.get(), 0, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ASSERT_EQ(static_cast(0), ptr2[2]); - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[2] = 0xFFFFFFFFFFFFFFFF; - left_shift_uint192(ptr.get(), 10, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ASSERT_EQ(static_cast(0), ptr2[2]); - left_shift_uint192(ptr.get(), 10, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - ASSERT_EQ(static_cast(0), ptr[2]); - - ptr[0] = 0x5555555555555555; - ptr[1] = 0xAAAAAAAAAAAAAAAA; - ptr[2] = 0xCDCDCDCDCDCDCDCD; - left_shift_uint192(ptr.get(), 0, ptr2.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); - ASSERT_EQ(static_cast(0xCDCDCDCDCDCDCDCD), ptr2[2]); - left_shift_uint192(ptr.get(), 0, ptr.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr[1]); - ASSERT_EQ(static_cast(0xCDCDCDCDCDCDCDCD), ptr[2]); - left_shift_uint192(ptr.get(), 1, ptr2.get()); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[0]); - ASSERT_EQ(static_cast(0x5555555555555554), ptr2[1]); - ASSERT_EQ(static_cast(0x9B9B9B9B9B9B9B9B), ptr2[2]); - left_shift_uint192(ptr.get(), 2, ptr2.get()); - ASSERT_EQ(static_cast(0x5555555555555554), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAA9), ptr2[1]); - ASSERT_EQ(static_cast(0x3737373737373736), ptr2[2]); - left_shift_uint192(ptr.get(), 64, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[1]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[2]); - left_shift_uint192(ptr.get(), 65, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); - ASSERT_EQ(static_cast(0x5555555555555554), ptr2[2]); - left_shift_uint192(ptr.get(), 191, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ASSERT_EQ(static_cast(0x8000000000000000), ptr2[2]); - - left_shift_uint192(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0x5555555555555554), ptr[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAA9), ptr[1]); - ASSERT_EQ(static_cast(0x3737373737373736), ptr[2]); - - left_shift_uint192(ptr.get(), 64, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0x5555555555555554), ptr[1]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAA9), ptr[2]); - } - - TEST(UIntArith, RightShiftUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - right_shift_uint(ptr.get(), 0, 2, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - right_shift_uint(ptr.get(), 10, 2, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - right_shift_uint(ptr.get(), 10, 2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 0x5555555555555555; - ptr[1] = 0xAAAAAAAAAAAAAAAA; - right_shift_uint(ptr.get(), 0, 2, ptr2.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); - right_shift_uint(ptr.get(), 0, 2, ptr.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr[1]); - right_shift_uint(ptr.get(), 1, 2, ptr2.get()); - ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr2[0]); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[1]); - right_shift_uint(ptr.get(), 2, 2, ptr2.get()); - ASSERT_EQ(static_cast(0x9555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr2[1]); - right_shift_uint(ptr.get(), 64, 2, ptr2.get()); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - right_shift_uint(ptr.get(), 65, 2, ptr2.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - right_shift_uint(ptr.get(), 127, 2, ptr2.get()); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - - right_shift_uint(ptr.get(), 2, 2, ptr.get()); - ASSERT_EQ(static_cast(0x9555555555555555), ptr[0]); - ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr[1]); - right_shift_uint(ptr.get(), 64, 2, ptr.get()); - ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - } - - TEST(UIntArith, RightShiftUInt128) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - right_shift_uint128(ptr.get(), 0, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - right_shift_uint128(ptr.get(), 10, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - right_shift_uint128(ptr.get(), 10, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 0x5555555555555555; - ptr[1] = 0xAAAAAAAAAAAAAAAA; - right_shift_uint128(ptr.get(), 0, ptr2.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); - right_shift_uint128(ptr.get(), 0, ptr.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr[1]); - right_shift_uint128(ptr.get(), 1, ptr2.get()); - ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr2[0]); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[1]); - right_shift_uint128(ptr.get(), 2, ptr2.get()); - ASSERT_EQ(static_cast(0x9555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr2[1]); - right_shift_uint128(ptr.get(), 64, ptr2.get()); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - right_shift_uint128(ptr.get(), 65, ptr2.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - right_shift_uint128(ptr.get(), 127, ptr2.get()); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - - right_shift_uint128(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0x9555555555555555), ptr[0]); - ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr[1]); - right_shift_uint128(ptr.get(), 64, ptr.get()); - ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - } - - TEST(UIntArith, RightShiftUInt192) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(3, pool)); - auto ptr2(allocate_uint(3, pool)); - ptr[0] = 0; - ptr[1] = 0; - ptr[2] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[2] = 0xFFFFFFFFFFFFFFFF; - right_shift_uint192(ptr.get(), 0, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ASSERT_EQ(static_cast(0), ptr2[2]); - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[2] = 0xFFFFFFFFFFFFFFFF; - right_shift_uint192(ptr.get(), 10, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ASSERT_EQ(static_cast(0), ptr2[2]); - right_shift_uint192(ptr.get(), 10, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - ASSERT_EQ(static_cast(0), ptr[2]); - - ptr[0] = 0x5555555555555555; - ptr[1] = 0xAAAAAAAAAAAAAAAA; - ptr[2] = 0xCDCDCDCDCDCDCDCD; - - right_shift_uint192(ptr.get(), 0, ptr2.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[1]); - ASSERT_EQ(static_cast(0xCDCDCDCDCDCDCDCD), ptr2[2]); - right_shift_uint192(ptr.get(), 0, ptr.get()); - ASSERT_EQ(static_cast(0x5555555555555555), ptr[0]); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr[1]); - ASSERT_EQ(static_cast(0xCDCDCDCDCDCDCDCD), ptr[2]); - right_shift_uint192(ptr.get(), 1, ptr2.get()); - ASSERT_EQ(static_cast(0x2AAAAAAAAAAAAAAA), ptr2[0]); - ASSERT_EQ(static_cast(0xD555555555555555), ptr2[1]); - ASSERT_EQ(static_cast(0x66E6E6E6E6E6E6E6), ptr2[2]); - right_shift_uint192(ptr.get(), 2, ptr2.get()); - ASSERT_EQ(static_cast(0x9555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0x6AAAAAAAAAAAAAAA), ptr2[1]); - ASSERT_EQ(static_cast(0x3373737373737373), ptr2[2]); - right_shift_uint192(ptr.get(), 64, ptr2.get()); - ASSERT_EQ(static_cast(0xAAAAAAAAAAAAAAAA), ptr2[0]); - ASSERT_EQ(static_cast(0xCDCDCDCDCDCDCDCD), ptr2[1]); - ASSERT_EQ(static_cast(0), ptr2[2]); - right_shift_uint192(ptr.get(), 65, ptr2.get()); - ASSERT_EQ(static_cast(0xD555555555555555), ptr2[0]); - ASSERT_EQ(static_cast(0x66E6E6E6E6E6E6E6), ptr2[1]); - ASSERT_EQ(static_cast(0), ptr2[2]); - right_shift_uint192(ptr.get(), 191, ptr2.get()); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - ASSERT_EQ(static_cast(0), ptr2[2]); - - right_shift_uint192(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0x9555555555555555), ptr[0]); - ASSERT_EQ(static_cast(0x6AAAAAAAAAAAAAAA), ptr[1]); - ASSERT_EQ(static_cast(0x3373737373737373), ptr[2]); - right_shift_uint192(ptr.get(), 64, ptr.get()); - ASSERT_EQ(static_cast(0x6AAAAAAAAAAAAAAA), ptr[0]); - ASSERT_EQ(static_cast(0x3373737373737373), ptr[1]); - ASSERT_EQ(static_cast(0), ptr[2]); - } - - TEST(UIntArith, HalfRoundUpUInt) - { - half_round_up_uint(nullptr, 0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - half_round_up_uint(ptr.get(), 2, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - half_round_up_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 1; - ptr[1] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - half_round_up_uint(ptr.get(), 2, ptr2.get()); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - half_round_up_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 2; - ptr[1] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - half_round_up_uint(ptr.get(), 2, ptr2.get()); - ASSERT_EQ(1ULL, ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - half_round_up_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 3; - ptr[1] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - half_round_up_uint(ptr.get(), 2, ptr2.get()); - ASSERT_EQ(static_cast(2), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - - ptr[0] = 4; - ptr[1] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - half_round_up_uint(ptr.get(), 2, ptr2.get()); - ASSERT_EQ(static_cast(2), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - half_round_up_uint(ptr.get(), 2, ptr2.get()); - ASSERT_EQ(static_cast(0), ptr2[0]); - ASSERT_EQ(static_cast(0x8000000000000000), ptr2[1]); - half_round_up_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0x8000000000000000), ptr[1]); - } - - TEST(UIntArith, NotUInt) - { - not_uint(nullptr, 0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0; - not_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); - - ptr[0] = 0xFFFFFFFF00000000; - ptr[1] = 0xFFFF0000FFFF0000; - not_uint(ptr.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0x00000000FFFFFFFF), ptr[0]); - ASSERT_EQ(static_cast(0x0000FFFF0000FFFF), ptr[1]); - } - - TEST(UIntArith, AndUIntUInt) - { - and_uint_uint(nullptr, nullptr, 0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - auto ptr3(allocate_uint(2, pool)); - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0; - ptr2[0] = 0; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - and_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - - ptr[0] = 0xFFFFFFFF00000000; - ptr[1] = 0xFFFF0000FFFF0000; - ptr2[0] = 0x0000FFFF0000FFFF; - ptr2[1] = 0xFF00FF00FF00FF00; - ptr3[0] = 0; - ptr3[1] = 0; - and_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(static_cast(0x0000FFFF00000000), ptr3[0]); - ASSERT_EQ(static_cast(0xFF000000FF000000), ptr3[1]); - and_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0x0000FFFF00000000), ptr[0]); - ASSERT_EQ(static_cast(0xFF000000FF000000), ptr[1]); - } - - TEST(UIntArith, OrUIntUInt) - { - or_uint_uint(nullptr, nullptr, 0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - auto ptr3(allocate_uint(2, pool)); - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0; - ptr2[0] = 0; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0; - ptr3[1] = 0; - or_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - - ptr[0] = 0xFFFFFFFF00000000; - ptr[1] = 0xFFFF0000FFFF0000; - ptr2[0] = 0x0000FFFF0000FFFF; - ptr2[1] = 0xFF00FF00FF00FF00; - ptr3[0] = 0; - ptr3[1] = 0; - or_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(static_cast(0xFFFFFFFF0000FFFF), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFF00FFFFFF00), ptr3[1]); - or_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0xFFFFFFFF0000FFFF), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFF00FFFFFF00), ptr[1]); - } - - TEST(UIntArith, XorUIntUInt) - { - xor_uint_uint(nullptr, nullptr, 0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - auto ptr3(allocate_uint(2, pool)); - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0; - ptr2[0] = 0; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0; - ptr3[1] = 0; - xor_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - - ptr[0] = 0xFFFFFFFF00000000; - ptr[1] = 0xFFFF0000FFFF0000; - ptr2[0] = 0x0000FFFF0000FFFF; - ptr2[1] = 0xFF00FF00FF00FF00; - ptr3[0] = 0; - ptr3[1] = 0; - xor_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(static_cast(0xFFFF00000000FFFF), ptr3[0]); - ASSERT_EQ(static_cast(0x00FFFF0000FFFF00), ptr3[1]); - xor_uint_uint(ptr.get(), ptr2.get(), 2, ptr.get()); - ASSERT_EQ(static_cast(0xFFFF00000000FFFF), ptr[0]); - ASSERT_EQ(static_cast(0x00FFFF0000FFFF00), ptr[1]); - } - - TEST(UIntArith, MultiplyUInt64Generic) - { - unsigned long long result[2]; - - multiply_uint64_generic(0ULL, 0ULL, result); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - multiply_uint64_generic(0ULL, 1ULL, result); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - multiply_uint64_generic(1ULL, 0ULL, result); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - multiply_uint64_generic(1ULL, 1ULL, result); - ASSERT_EQ(1ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - multiply_uint64_generic(0x100000000ULL, 0xFAFABABAULL, result); - ASSERT_EQ(0xFAFABABA00000000ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - multiply_uint64_generic(0x1000000000ULL, 0xFAFABABAULL, result); - ASSERT_EQ(0xAFABABA000000000ULL, result[0]); - ASSERT_EQ(0xFULL, result[1]); - multiply_uint64_generic(1111222233334444ULL, 5555666677778888ULL, result); - ASSERT_EQ(4140785562324247136ULL, result[0]); - ASSERT_EQ(334670460471ULL, result[1]); - } - - TEST(UIntArith, MultiplyUInt64) - { - unsigned long long result[2]; - - multiply_uint64(0ULL, 0ULL, result); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - multiply_uint64(0ULL, 1ULL, result); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - multiply_uint64(1ULL, 0ULL, result); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - multiply_uint64(1ULL, 1ULL, result); - ASSERT_EQ(1ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - multiply_uint64(0x100000000ULL, 0xFAFABABAULL, result); - ASSERT_EQ(0xFAFABABA00000000ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - multiply_uint64(0x1000000000ULL, 0xFAFABABAULL, result); - ASSERT_EQ(0xAFABABA000000000ULL, result[0]); - ASSERT_EQ(0xFULL, result[1]); - multiply_uint64(1111222233334444ULL, 5555666677778888ULL, result); - ASSERT_EQ(4140785562324247136ULL, result[0]); - ASSERT_EQ(334670460471ULL, result[1]); - } - - TEST(UIntArith, MultiplyUInt64HW64Generic) - { - unsigned long long result; - - multiply_uint64_hw64_generic(0ULL, 0ULL, &result); - ASSERT_EQ(0ULL, result); - multiply_uint64_hw64_generic(0ULL, 1ULL, &result); - ASSERT_EQ(0ULL, result); - multiply_uint64_hw64_generic(1ULL, 0ULL, &result); - ASSERT_EQ(0ULL, result); - multiply_uint64_hw64_generic(1ULL, 1ULL, &result); - ASSERT_EQ(0ULL, result); - multiply_uint64_hw64_generic(0x100000000ULL, 0xFAFABABAULL, &result); - ASSERT_EQ(0ULL, result); - multiply_uint64_hw64_generic(0x1000000000ULL, 0xFAFABABAULL, &result); - ASSERT_EQ(0xFULL, result); - multiply_uint64_hw64_generic(1111222233334444ULL, 5555666677778888ULL, &result); - ASSERT_EQ(334670460471ULL, result); - } - - TEST(UIntArith, MultiplyUInt64HW64) - { - unsigned long long result; - - multiply_uint64_hw64(0ULL, 0ULL, &result); - ASSERT_EQ(0ULL, result); - multiply_uint64_hw64(0ULL, 1ULL, &result); - ASSERT_EQ(0ULL, result); - multiply_uint64_hw64(1ULL, 0ULL, &result); - ASSERT_EQ(0ULL, result); - multiply_uint64_hw64(1ULL, 1ULL, &result); - ASSERT_EQ(0ULL, result); - multiply_uint64_hw64(0x100000000ULL, 0xFAFABABAULL, &result); - ASSERT_EQ(0ULL, result); - multiply_uint64_hw64(0x1000000000ULL, 0xFAFABABAULL, &result); - ASSERT_EQ(0xFULL, result); - multiply_uint64_hw64(1111222233334444ULL, 5555666677778888ULL, &result); - ASSERT_EQ(334670460471ULL, result); - } - - TEST(UIntArith, MultiplyUIntUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - auto ptr3(allocate_uint(4, pool)); - ptr[0] = 0; - ptr[1] = 0; - ptr2[0] = 0; - ptr2[1] = 0; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[2] = 0xFFFFFFFFFFFFFFFF; - ptr3[3] = 0xFFFFFFFFFFFFFFFF; - multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - ASSERT_EQ(static_cast(0), ptr3[2]); - ASSERT_EQ(static_cast(0), ptr3[3]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0; - ptr2[1] = 0; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[2] = 0xFFFFFFFFFFFFFFFF; - ptr3[3] = 0xFFFFFFFFFFFFFFFF; - multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - ASSERT_EQ(static_cast(0), ptr3[2]); - ASSERT_EQ(static_cast(0), ptr3[3]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 1; - ptr2[1] = 0; - ptr3[0] = 0; - ptr3[1] = 0; - ptr3[2] = 0; - ptr3[3] = 0; - multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - ASSERT_EQ(static_cast(0), ptr3[2]); - ASSERT_EQ(static_cast(0), ptr3[3]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0; - ptr2[1] = 1; - ptr3[0] = 0; - ptr3[1] = 0; - ptr3[2] = 0; - ptr3[3] = 0; - multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[2]); - ASSERT_EQ(static_cast(0), ptr3[3]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0; - ptr3[1] = 0; - ptr3[2] = 0; - ptr3[3] = 0; - multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(1ULL, ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr3[2]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[3]); - - ptr[0] = 9756571004902751654ul; - ptr[1] = 731952007397389984; - ptr2[0] = 701538366196406307; - ptr2[1] = 1699883529753102283; - ptr3[0] = 0; - ptr3[1] = 0; - ptr3[2] = 0; - ptr3[3] = 0; - multiply_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(static_cast(9585656442714717618ul), ptr3[0]); - ASSERT_EQ(static_cast(1817697005049051848), ptr3[1]); - ASSERT_EQ(static_cast(14447416709120365380ul), ptr3[2]); - ASSERT_EQ(static_cast(67450014862939159), ptr3[3]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0; - ptr3[1] = 0; - ptr3[2] = 0; - ptr3[3] = 0; - multiply_uint_uint(ptr.get(), 2, ptr2.get(), 1, 2, ptr3.get()); - ASSERT_EQ(1ULL, ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - ASSERT_EQ(static_cast(0), ptr3[2]); - ASSERT_EQ(static_cast(0), ptr3[3]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0; - ptr3[1] = 0; - ptr3[2] = 0; - ptr3[3] = 0; - multiply_uint_uint(ptr.get(), 2, ptr2.get(), 1, 3, ptr3.get()); - ASSERT_EQ(1ULL, ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr3[2]); - ASSERT_EQ(static_cast(0), ptr3[3]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0; - ptr3[1] = 0; - ptr3[2] = 0; - ptr3[3] = 0; - multiply_truncate_uint_uint(ptr.get(), ptr2.get(), 2, ptr3.get()); - ASSERT_EQ(1ULL, ptr3[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr3[1]); - ASSERT_EQ(static_cast(0), ptr3[2]); - ASSERT_EQ(static_cast(0), ptr3[3]); - } - - TEST(UIntArith, MultiplyUIntUInt64) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(3, pool)); - auto result(allocate_uint(4, pool)); - - ptr[0] = 0; - ptr[1] = 0; - ptr[2] = 0; - multiply_uint_uint64(ptr.get(), 3, 0ULL, 4, result.get()); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - - ptr[0] = 0xFFFFFFFFF; - ptr[1] = 0xAAAAAAAAA; - ptr[2] = 0x111111111; - multiply_uint_uint64(ptr.get(), 3, 0ULL, 4, result.get()); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - - ptr[0] = 0xFFFFFFFFF; - ptr[1] = 0xAAAAAAAAA; - ptr[2] = 0x111111111; - multiply_uint_uint64(ptr.get(), 3, 1ULL, 4, result.get()); - ASSERT_EQ(0xFFFFFFFFFULL, result[0]); - ASSERT_EQ(0xAAAAAAAAAULL, result[1]); - ASSERT_EQ(0x111111111ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - - ptr[0] = 0xFFFFFFFFF; - ptr[1] = 0xAAAAAAAAA; - ptr[2] = 0x111111111; - multiply_uint_uint64(ptr.get(), 3, 0x10000ULL, 4, result.get()); - ASSERT_EQ(0xFFFFFFFFF0000ULL, result[0]); - ASSERT_EQ(0xAAAAAAAAA0000ULL, result[1]); - ASSERT_EQ(0x1111111110000ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - - ptr[0] = 0xFFFFFFFFF; - ptr[1] = 0xAAAAAAAAA; - ptr[2] = 0x111111111; - multiply_uint_uint64(ptr.get(), 3, 0x100000000ULL, 4, result.get()); - ASSERT_EQ(0xFFFFFFFF00000000ULL, result[0]); - ASSERT_EQ(0xAAAAAAAA0000000FULL, result[1]); - ASSERT_EQ(0x111111110000000AULL, result[2]); - ASSERT_EQ(1ULL, result[3]); - - ptr[0] = 5656565656565656ULL; - ptr[1] = 3434343434343434ULL; - ptr[2] = 1212121212121212ULL; - multiply_uint_uint64(ptr.get(), 3, 7878787878787878ULL, 4, result.get()); - ASSERT_EQ(8891370032116156560ULL, result[0]); - ASSERT_EQ(127835914414679452ULL, result[1]); - ASSERT_EQ(9811042505314082702ULL, result[2]); - ASSERT_EQ(517709026347ULL, result[3]); - } - - TEST(UIntArith, DivideUIntUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - divide_uint_uint_inplace(nullptr, nullptr, 0, nullptr, pool); - divide_uint_uint(nullptr, nullptr, 0, nullptr, nullptr, pool); - - auto ptr(allocate_uint(4, pool)); - auto ptr2(allocate_uint(4, pool)); - auto ptr3(allocate_uint(4, pool)); - auto ptr4(allocate_uint(4, pool)); - ptr[0] = 0; - ptr[1] = 0; - ptr2[0] = 0; - ptr2[1] = 1; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - divide_uint_uint_inplace(ptr.get(), ptr2.get(), 2, ptr3.get(), pool); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - - ptr[0] = 0; - ptr[1] = 0; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - divide_uint_uint_inplace(ptr.get(), ptr2.get(), 2, ptr3.get(), pool); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFE; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - divide_uint_uint_inplace(ptr.get(), ptr2.get(), 2, ptr3.get(), pool); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); - ASSERT_EQ(static_cast(0), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - divide_uint_uint_inplace(ptr.get(), ptr2.get(), 2, ptr3.get(), pool); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - ASSERT_EQ(1ULL, ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - - ptr[0] = 14; - ptr[1] = 0; - ptr2[0] = 3; - ptr2[1] = 0; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - divide_uint_uint_inplace(ptr.get(), ptr2.get(), 2, ptr3.get(), pool); - ASSERT_EQ(static_cast(2), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - ASSERT_EQ(static_cast(4), ptr3[0]); - ASSERT_EQ(static_cast(0), ptr3[1]); - - ptr[0] = 9585656442714717620ul; - ptr[1] = 1817697005049051848; - ptr[2] = 14447416709120365380ul; - ptr[3] = 67450014862939159; - ptr2[0] = 701538366196406307; - ptr2[1] = 1699883529753102283; - ptr2[2] = 0; - ptr2[3] = 0; - ptr3[0] = 0xFFFFFFFFFFFFFFFF; - ptr3[1] = 0xFFFFFFFFFFFFFFFF; - ptr3[2] = 0xFFFFFFFFFFFFFFFF; - ptr3[3] = 0xFFFFFFFFFFFFFFFF; - ptr4[0] = 0xFFFFFFFFFFFFFFFF; - ptr4[1] = 0xFFFFFFFFFFFFFFFF; - ptr4[2] = 0xFFFFFFFFFFFFFFFF; - ptr4[3] = 0xFFFFFFFFFFFFFFFF; - divide_uint_uint(ptr.get(), ptr2.get(), 4, ptr3.get(), ptr4.get(), pool); - ASSERT_EQ(static_cast(2), ptr4[0]); - ASSERT_EQ(static_cast(0), ptr4[1]); - ASSERT_EQ(static_cast(0), ptr4[2]); - ASSERT_EQ(static_cast(0), ptr4[3]); - ASSERT_EQ(static_cast(9756571004902751654ul), ptr3[0]); - ASSERT_EQ(static_cast(731952007397389984), ptr3[1]); - ASSERT_EQ(static_cast(0), ptr3[2]); - ASSERT_EQ(static_cast(0), ptr3[3]); - - divide_uint_uint_inplace(ptr.get(), ptr2.get(), 4, ptr3.get(), pool); - ASSERT_EQ(static_cast(2), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - ASSERT_EQ(static_cast(0), ptr[2]); - ASSERT_EQ(static_cast(0), ptr[3]); - ASSERT_EQ(static_cast(9756571004902751654ul), ptr3[0]); - ASSERT_EQ(static_cast(731952007397389984), ptr3[1]); - ASSERT_EQ(static_cast(0), ptr3[2]); - ASSERT_EQ(static_cast(0), ptr3[3]); - } - - TEST(UIntArith, DivideUInt128UInt64) - { - uint64_t input[2]; - uint64_t quotient[2]; - - input[0] = 0; - input[1] = 0; - divide_uint128_uint64_inplace(input, 1ULL, quotient); - ASSERT_EQ(0ULL, input[0]); - ASSERT_EQ(0ULL, input[1]); - ASSERT_EQ(0ULL, quotient[0]); - ASSERT_EQ(0ULL, quotient[1]); - - input[0] = 1; - input[1] = 0; - divide_uint128_uint64_inplace(input, 1ULL, quotient); - ASSERT_EQ(0ULL, input[0]); - ASSERT_EQ(0ULL, input[1]); - ASSERT_EQ(1ULL, quotient[0]); - ASSERT_EQ(0ULL, quotient[1]); - - input[0] = 0x10101010; - input[1] = 0x2B2B2B2B; - divide_uint128_uint64_inplace(input, 0x1000ULL, quotient); - ASSERT_EQ(0x10ULL, input[0]); - ASSERT_EQ(0ULL, input[1]); - ASSERT_EQ(0xB2B0000000010101ULL, quotient[0]); - ASSERT_EQ(0x2B2B2ULL, quotient[1]); - - input[0] = 1212121212121212ULL; - input[1] = 3434343434343434ULL; - divide_uint128_uint64_inplace(input, 5656565656565656ULL, quotient); - ASSERT_EQ(5252525252525252ULL, input[0]); - ASSERT_EQ(0ULL, input[1]); - ASSERT_EQ(11199808901895084909ULL, quotient[0]); - ASSERT_EQ(0ULL, quotient[1]); - } - - TEST(UIntArith, DivideUInt192UInt64) - { - uint64_t input[3]; - uint64_t quotient[3]; - - input[0] = 0; - input[1] = 0; - input[2] = 0; - divide_uint192_uint64_inplace(input, 1ULL, quotient); - ASSERT_EQ(0ULL, input[0]); - ASSERT_EQ(0ULL, input[1]); - ASSERT_EQ(0ULL, input[2]); - ASSERT_EQ(0ULL, quotient[0]); - ASSERT_EQ(0ULL, quotient[1]); - ASSERT_EQ(0ULL, quotient[2]); - - input[0] = 1; - input[1] = 0; - input[2] = 0; - divide_uint192_uint64_inplace(input, 1ULL, quotient); - ASSERT_EQ(0ULL, input[0]); - ASSERT_EQ(0ULL, input[1]); - ASSERT_EQ(0ULL, input[2]); - ASSERT_EQ(1ULL, quotient[0]); - ASSERT_EQ(0ULL, quotient[1]); - ASSERT_EQ(0ULL, quotient[2]); - - input[0] = 0x10101010; - input[1] = 0x2B2B2B2B; - input[2] = 0xF1F1F1F1; - divide_uint192_uint64_inplace(input, 0x1000ULL, quotient); - ASSERT_EQ(0x10ULL, input[0]); - ASSERT_EQ(0ULL, input[1]); - ASSERT_EQ(0ULL, input[2]); - ASSERT_EQ(0xB2B0000000010101ULL, quotient[0]); - ASSERT_EQ(0x1F1000000002B2B2ULL, quotient[1]); - ASSERT_EQ(0xF1F1FULL, quotient[2]); - - input[0] = 1212121212121212ULL; - input[1] = 3434343434343434ULL; - input[2] = 5656565656565656ULL; - divide_uint192_uint64_inplace(input, 7878787878787878ULL, quotient); - ASSERT_EQ(7272727272727272ULL, input[0]); - ASSERT_EQ(0ULL, input[1]); - ASSERT_EQ(0ULL, input[2]); - ASSERT_EQ(17027763760347278414ULL, quotient[0]); - ASSERT_EQ(13243816258047883211ULL, quotient[1]); - ASSERT_EQ(0ULL, quotient[2]); - } - - TEST(UIntArith, ExponentiateUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto input(allocate_zero_uint(2, pool)); - auto result(allocate_zero_uint(8, pool)); - - result[0] = 1, result[1] = 2, result[2] = 3, result[3] = 4; - result[4] = 5, result[5] = 6, result[6] = 7, result[7] = 8; - - uint64_t exponent[2]{ 0, 0 }; - - input[0] = 0xFFF; - input[1] = 0; - exponentiate_uint(input.get(), 2, exponent, 1, 1, result.get(), pool); - ASSERT_EQ(1ULL, result[0]); - ASSERT_EQ(2ULL, result[1]); - - exponentiate_uint(input.get(), 2, exponent, 1, 2, result.get(), pool); - ASSERT_EQ(1ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - - exponentiate_uint(input.get(), 1, exponent, 1, 4, result.get(), pool); - ASSERT_EQ(1ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - - input[0] = 123; - exponent[0] = 5; - exponentiate_uint(input.get(), 1, exponent, 2, 2, result.get(), pool); - ASSERT_EQ(28153056843ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - - input[0] = 1; - exponent[0] = 1; - exponent[1] = 1; - exponentiate_uint(input.get(), 1, exponent, 2, 2, result.get(), pool); - ASSERT_EQ(1ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - - input[0] = 0; - input[1] = 1; - exponent[0] = 7; - exponent[1] = 0; - exponentiate_uint(input.get(), 2, exponent, 2, 8, result.get(), pool); - ASSERT_EQ(0ULL, result[0]); - ASSERT_EQ(0ULL, result[1]); - ASSERT_EQ(0ULL, result[2]); - ASSERT_EQ(0ULL, result[3]); - ASSERT_EQ(0ULL, result[4]); - ASSERT_EQ(0ULL, result[5]); - ASSERT_EQ(0ULL, result[6]); - ASSERT_EQ(1ULL, result[7]); - - input[0] = 121212; - input[1] = 343434; - exponent[0] = 3; - exponent[1] = 0; - exponentiate_uint(input.get(), 2, exponent, 2, 8, result.get(), pool); - ASSERT_EQ(1780889000200128ULL, result[0]); - ASSERT_EQ(15137556501701088ULL, result[1]); - ASSERT_EQ(42889743421486416ULL, result[2]); - ASSERT_EQ(40506979898070504ULL, result[3]); - ASSERT_EQ(0ULL, result[4]); - ASSERT_EQ(0ULL, result[5]); - ASSERT_EQ(0ULL, result[6]); - ASSERT_EQ(0ULL, result[7]); - } - - TEST(UIntArith, ExponentiateUInt64) - { - ASSERT_EQ(0ULL, exponentiate_uint64(0ULL, 1ULL)); - ASSERT_EQ(1ULL, exponentiate_uint64(1ULL, 0ULL)); - ASSERT_EQ(0ULL, exponentiate_uint64(0ULL, 0xFFFFFFFFFFFFFFFFULL)); - ASSERT_EQ(1ULL, exponentiate_uint64(0xFFFFFFFFFFFFFFFFULL, 0ULL)); - ASSERT_EQ(25ULL, exponentiate_uint64(5ULL, 2ULL)); - ASSERT_EQ(32ULL, exponentiate_uint64(2ULL, 5ULL)); - ASSERT_EQ(0x1000000000000000ULL, exponentiate_uint64(0x10ULL, 15ULL)); - ASSERT_EQ(0ULL, exponentiate_uint64(0x10ULL, 16ULL)); - ASSERT_EQ(12389286314587456613ULL, exponentiate_uint64(123456789ULL, 13ULL)); - } - } -} diff --git a/SEAL/native/tests/seal/util/uintarithmod.cpp b/SEAL/native/tests/seal/util/uintarithmod.cpp deleted file mode 100644 index 5dc6d2c..0000000 --- a/SEAL/native/tests/seal/util/uintarithmod.cpp +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarithmod.h" -#include -#include - -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(UIntArithMod, IncrementUIntMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto value(allocate_uint(2, pool)); - auto modulus(allocate_uint(2, pool)); - value[0] = 0; - value[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - increment_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(1ULL, value[0]); - ASSERT_EQ(static_cast(0), value[1]); - increment_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(2), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - increment_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(0), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - - value[0] = 0xFFFFFFFFFFFFFFFD; - value[1] = 0xFFFFFFFFFFFFFFFF; - modulus[0] = 0xFFFFFFFFFFFFFFFF; - modulus[1] = 0xFFFFFFFFFFFFFFFF; - increment_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), value[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), value[1]); - increment_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(0), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - increment_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(1ULL, value[0]); - ASSERT_EQ(static_cast(0), value[1]); - } - - TEST(UIntArithMod, DecrementUIntMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto value(allocate_uint(2, pool)); - auto modulus(allocate_uint(2, pool)); - value[0] = 2; - value[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(1ULL, value[0]); - ASSERT_EQ(static_cast(0), value[1]); - decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(0), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(2), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - - value[0] = 1; - value[1] = 0; - modulus[0] = 0xFFFFFFFFFFFFFFFF; - modulus[1] = 0xFFFFFFFFFFFFFFFF; - decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(0), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFE), value[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), value[1]); - decrement_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFD), value[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), value[1]); - } - - TEST(UIntArithMod, NegateUIntMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto value(allocate_uint(2, pool)); - auto modulus(allocate_uint(2, pool)); - value[0] = 0; - value[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - negate_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(0), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - - value[0] = 1; - value[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - negate_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(2), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - negate_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(1ULL, value[0]); - ASSERT_EQ(static_cast(0), value[1]); - - value[0] = 2; - value[1] = 0; - modulus[0] = 0xFFFFFFFFFFFFFFFF; - modulus[1] = 0xFFFFFFFFFFFFFFFF; - negate_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFD), value[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), value[1]); - negate_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(static_cast(2), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - } - - TEST(UIntArithMod, Div2UIntMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto value(allocate_uint(2, pool)); - auto modulus(allocate_uint(2, pool)); - value[0] = 0; - value[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - div2_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(0ULL, value[0]); - ASSERT_EQ(0ULL, value[1]); - - value[0] = 1; - value[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - div2_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(2ULL, value[0]); - ASSERT_EQ(0ULL, value[1]); - - value[0] = 8; - value[1] = 0; - modulus[0] = 17; - modulus[1] = 0; - div2_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(4ULL, value[0]); - ASSERT_EQ(0ULL, value[1]); - - value[0] = 5; - value[1] = 0; - modulus[0] = 17; - modulus[1] = 0; - div2_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(11ULL, value[0]); - ASSERT_EQ(0ULL, value[1]); - - value[0] = 1; - value[1] = 0; - modulus[0] = 0xFFFFFFFFFFFFFFFFULL; - modulus[1] = 0xFFFFFFFFFFFFFFFFULL; - div2_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(0ULL, value[0]); - ASSERT_EQ(0x8000000000000000ULL, value[1]); - - value[0] = 3; - value[1] = 0; - modulus[0] = 0xFFFFFFFFFFFFFFFFULL; - modulus[1] = 0xFFFFFFFFFFFFFFFFULL; - div2_uint_mod(value.get(), modulus.get(), 2, value.get()); - ASSERT_EQ(1ULL, value[0]); - ASSERT_EQ(0x8000000000000000ULL, value[1]); - } - - TEST(UIntArithMod, AddUIntUIntMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto value1(allocate_uint(2, pool)); - auto value2(allocate_uint(2, pool)); - auto modulus(allocate_uint(2, pool)); - value1[0] = 0; - value1[1] = 0; - value2[0] = 0; - value2[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - add_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); - ASSERT_EQ(static_cast(0), value1[0]); - ASSERT_EQ(static_cast(0), value1[1]); - - value1[0] = 1; - value1[1] = 0; - value2[0] = 1; - value2[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - add_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); - ASSERT_EQ(static_cast(2), value1[0]); - ASSERT_EQ(static_cast(0), value1[1]); - - value1[0] = 1; - value1[1] = 0; - value2[0] = 2; - value2[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - add_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); - ASSERT_EQ(static_cast(0), value1[0]); - ASSERT_EQ(static_cast(0), value1[1]); - - value1[0] = 2; - value1[1] = 0; - value2[0] = 2; - value2[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - add_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); - ASSERT_EQ(1ULL, value1[0]); - ASSERT_EQ(static_cast(0), value1[1]); - - value1[0] = 0xFFFFFFFFFFFFFFFE; - value1[1] = 0xFFFFFFFFFFFFFFFF; - value2[0] = 0xFFFFFFFFFFFFFFFE; - value2[1] = 0xFFFFFFFFFFFFFFFF; - modulus[0] = 0xFFFFFFFFFFFFFFFF; - modulus[1] = 0xFFFFFFFFFFFFFFFF; - add_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFD), value1[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), value1[1]); - } - - TEST(UIntArithMod, SubUIntUIntMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto value1(allocate_uint(2, pool)); - auto value2(allocate_uint(2, pool)); - auto modulus(allocate_uint(2, pool)); - value1[0] = 0; - value1[1] = 0; - value2[0] = 0; - value2[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - sub_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); - ASSERT_EQ(static_cast(0), value1[0]); - ASSERT_EQ(static_cast(0), value1[1]); - - value1[0] = 2; - value1[1] = 0; - value2[0] = 1; - value2[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - sub_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); - ASSERT_EQ(1ULL, value1[0]); - ASSERT_EQ(static_cast(0), value1[1]); - - value1[0] = 1; - value1[1] = 0; - value2[0] = 2; - value2[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - sub_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); - ASSERT_EQ(static_cast(2), value1[0]); - ASSERT_EQ(static_cast(0), value1[1]); - - value1[0] = 2; - value1[1] = 0; - value2[0] = 2; - value2[1] = 0; - modulus[0] = 3; - modulus[1] = 0; - sub_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); - ASSERT_EQ(static_cast(0), value1[0]); - ASSERT_EQ(static_cast(0), value1[1]); - - value1[0] = 1; - value1[1] = 0; - value2[0] = 0xFFFFFFFFFFFFFFFE; - value2[1] = 0xFFFFFFFFFFFFFFFF; - modulus[0] = 0xFFFFFFFFFFFFFFFF; - modulus[1] = 0xFFFFFFFFFFFFFFFF; - sub_uint_uint_mod(value1.get(), value2.get(), modulus.get(), 2, value1.get()); - ASSERT_EQ(static_cast(2), value1[0]); - ASSERT_EQ(static_cast(0), value1[1]); - } - - TEST(UIntArithMod, TryInvertUIntMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto value(allocate_uint(2, pool)); - auto modulus(allocate_uint(2, pool)); - value[0] = 0; - value[1] = 0; - modulus[0] = 5; - modulus[1] = 0; - ASSERT_FALSE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); - - value[0] = 1; - value[1] = 0; - modulus[0] = 5; - modulus[1] = 0; - ASSERT_TRUE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); - ASSERT_EQ(1ULL, value[0]); - ASSERT_EQ(static_cast(0), value[1]); - - value[0] = 2; - value[1] = 0; - modulus[0] = 5; - modulus[1] = 0; - ASSERT_TRUE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); - ASSERT_EQ(static_cast(3), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - - value[0] = 3; - value[1] = 0; - modulus[0] = 5; - modulus[1] = 0; - ASSERT_TRUE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); - ASSERT_EQ(static_cast(2), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - - value[0] = 4; - value[1] = 0; - modulus[0] = 5; - modulus[1] = 0; - ASSERT_TRUE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); - ASSERT_EQ(static_cast(4), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - - value[0] = 2; - value[1] = 0; - modulus[0] = 6; - modulus[1] = 0; - ASSERT_FALSE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); - - value[0] = 3; - value[1] = 0; - modulus[0] = 6; - modulus[1] = 0; - ASSERT_FALSE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); - - value[0] = 331975426; - value[1] = 0; - modulus[0] = 1351315121; - modulus[1] = 0; - ASSERT_TRUE(try_invert_uint_mod(value.get(), modulus.get(), 2, value.get(), pool)); - ASSERT_EQ(static_cast(1052541512), value[0]); - ASSERT_EQ(static_cast(0), value[1]); - } - } -} diff --git a/SEAL/native/tests/seal/util/uintarithsmallmod.cpp b/SEAL/native/tests/seal/util/uintarithsmallmod.cpp deleted file mode 100644 index 13ac0a2..0000000 --- a/SEAL/native/tests/seal/util/uintarithsmallmod.cpp +++ /dev/null @@ -1,390 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/uintcore.h" -#include "seal/util/uintarithsmallmod.h" -#include "seal/smallmodulus.h" -#include "seal/memorymanager.h" - -using namespace seal::util; -using namespace seal; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(UIntArithSmallMod, IncrementUIntSmallMod) - { - SmallModulus mod(2); - ASSERT_EQ(1ULL, increment_uint_mod(0, mod)); - ASSERT_EQ(0ULL, increment_uint_mod(1ULL, mod)); - - mod = 0x10000; - ASSERT_EQ(1ULL, increment_uint_mod(0, mod)); - ASSERT_EQ(2ULL, increment_uint_mod(1ULL, mod)); - ASSERT_EQ(0ULL, increment_uint_mod(0xFFFFULL, mod)); - - mod = 4611686018427289601ULL; - ASSERT_EQ(1ULL, increment_uint_mod(0, mod)); - ASSERT_EQ(0ULL, increment_uint_mod(4611686018427289600ULL, mod)); - ASSERT_EQ(1ULL, increment_uint_mod(0, mod)); - } - - TEST(UIntArithSmallMod, DecrementUIntSmallMod) - { - SmallModulus mod(2); - ASSERT_EQ(0ULL, decrement_uint_mod(1, mod)); - ASSERT_EQ(1ULL, decrement_uint_mod(0ULL, mod)); - - mod = 0x10000; - ASSERT_EQ(0ULL, decrement_uint_mod(1, mod)); - ASSERT_EQ(1ULL, decrement_uint_mod(2ULL, mod)); - ASSERT_EQ(0xFFFFULL, decrement_uint_mod(0ULL, mod)); - - mod = 4611686018427289601ULL; - ASSERT_EQ(0ULL, decrement_uint_mod(1, mod)); - ASSERT_EQ(4611686018427289600ULL, decrement_uint_mod(0ULL, mod)); - ASSERT_EQ(0ULL, decrement_uint_mod(1, mod)); - } - - TEST(UIntArithSmallMod, NegateUIntSmallMod) - { - SmallModulus mod(2); - ASSERT_EQ(0ULL, negate_uint_mod(0, mod)); - ASSERT_EQ(1ULL, negate_uint_mod(1, mod)); - - mod = 0xFFFFULL; - ASSERT_EQ(0ULL, negate_uint_mod(0, mod)); - ASSERT_EQ(0xFFFEULL, negate_uint_mod(1, mod)); - ASSERT_EQ(0x1ULL, negate_uint_mod(0xFFFEULL, mod)); - - mod = 0x10000ULL; - ASSERT_EQ(0ULL, negate_uint_mod(0, mod)); - ASSERT_EQ(0xFFFFULL, negate_uint_mod(1, mod)); - ASSERT_EQ(0x1ULL, negate_uint_mod(0xFFFFULL, mod)); - - mod = 4611686018427289601ULL; - ASSERT_EQ(0ULL, negate_uint_mod(0, mod)); - ASSERT_EQ(4611686018427289600ULL, negate_uint_mod(1, mod)); - } - - TEST(UIntArithSmallMod, Div2UIntSmallMod) - { - SmallModulus mod(3); - ASSERT_EQ(0ULL, div2_uint_mod(0ULL, mod)); - ASSERT_EQ(2ULL, div2_uint_mod(1ULL, mod)); - - mod = 17; - ASSERT_EQ(11ULL, div2_uint_mod(5ULL, mod)); - ASSERT_EQ(4ULL, div2_uint_mod(8ULL, mod)); - - mod = 0xFFFFFFFFFFFFFFFULL; - ASSERT_EQ(0x800000000000000ULL, div2_uint_mod(1ULL, mod)); - ASSERT_EQ(0x800000000000001ULL, div2_uint_mod(3ULL, mod)); - } - - TEST(UIntArithSmallMod, AddUIntSmallMod) - { - SmallModulus mod(2); - ASSERT_EQ(0ULL, add_uint_uint_mod(0, 0, mod)); - ASSERT_EQ(1ULL, add_uint_uint_mod(0, 1, mod)); - ASSERT_EQ(1ULL, add_uint_uint_mod(1, 0, mod)); - ASSERT_EQ(0ULL, add_uint_uint_mod(1, 1, mod)); - - mod = 10; - ASSERT_EQ(0ULL, add_uint_uint_mod(0, 0, mod)); - ASSERT_EQ(1ULL, add_uint_uint_mod(0, 1, mod)); - ASSERT_EQ(1ULL, add_uint_uint_mod(1, 0, mod)); - ASSERT_EQ(2ULL, add_uint_uint_mod(1, 1, mod)); - ASSERT_EQ(4ULL, add_uint_uint_mod(7, 7, mod)); - ASSERT_EQ(3ULL, add_uint_uint_mod(6, 7, mod)); - - mod = 4611686018427289601; - ASSERT_EQ(0ULL, add_uint_uint_mod(0, 0, mod)); - ASSERT_EQ(1ULL, add_uint_uint_mod(0, 1, mod)); - ASSERT_EQ(1ULL, add_uint_uint_mod(1, 0, mod)); - ASSERT_EQ(2ULL, add_uint_uint_mod(1, 1, mod)); - ASSERT_EQ(0ULL, add_uint_uint_mod(2305843009213644800ULL, 2305843009213644801ULL, mod)); - ASSERT_EQ(1ULL, add_uint_uint_mod(2305843009213644801ULL, 2305843009213644801ULL, mod)); - ASSERT_EQ(4611686018427289599ULL, add_uint_uint_mod(4611686018427289600ULL, 4611686018427289600ULL, mod)); - } - - TEST(UIntArithSmallMod, SubUIntSmallMod) - { - SmallModulus mod(2); - ASSERT_EQ(0ULL, sub_uint_uint_mod(0, 0, mod)); - ASSERT_EQ(1ULL, sub_uint_uint_mod(0, 1, mod)); - ASSERT_EQ(1ULL, sub_uint_uint_mod(1, 0, mod)); - ASSERT_EQ(0ULL, sub_uint_uint_mod(1, 1, mod)); - - mod = 10; - ASSERT_EQ(0ULL, sub_uint_uint_mod(0, 0, mod)); - ASSERT_EQ(9ULL, sub_uint_uint_mod(0, 1, mod)); - ASSERT_EQ(1ULL, sub_uint_uint_mod(1, 0, mod)); - ASSERT_EQ(0ULL, sub_uint_uint_mod(1, 1, mod)); - ASSERT_EQ(0ULL, sub_uint_uint_mod(7, 7, mod)); - ASSERT_EQ(9ULL, sub_uint_uint_mod(6, 7, mod)); - ASSERT_EQ(1ULL, sub_uint_uint_mod(7, 6, mod)); - - mod = 4611686018427289601ULL; - ASSERT_EQ(0ULL, sub_uint_uint_mod(0, 0, mod)); - ASSERT_EQ(4611686018427289600ULL, sub_uint_uint_mod(0, 1, mod)); - ASSERT_EQ(1ULL, sub_uint_uint_mod(1, 0, mod)); - ASSERT_EQ(0ULL, sub_uint_uint_mod(1, 1, mod)); - ASSERT_EQ(4611686018427289600ULL, sub_uint_uint_mod(2305843009213644800ULL, 2305843009213644801ULL, mod)); - ASSERT_EQ(1ULL, sub_uint_uint_mod(2305843009213644801ULL, 2305843009213644800ULL, mod)); - ASSERT_EQ(0ULL, sub_uint_uint_mod(2305843009213644801ULL, 2305843009213644801ULL, mod)); - ASSERT_EQ(0ULL, sub_uint_uint_mod(4611686018427289600ULL, 4611686018427289600ULL, mod)); - } - - TEST(UIntArithSmallMod, BarrettReduce128) - { - uint64_t input[2]; - - SmallModulus mod(2); - input[0] = 0; - input[1] = 0; - ASSERT_EQ(0ULL, barrett_reduce_128(input, mod)); - input[0] = 1; - input[1] = 0; - ASSERT_EQ(1ULL, barrett_reduce_128(input, mod)); - input[0] = 0xFFFFFFFFFFFFFFFFULL; - input[1] = 0xFFFFFFFFFFFFFFFFULL; - ASSERT_EQ(1ULL, barrett_reduce_128(input, mod)); - - mod = 3; - input[0] = 0; - input[1] = 0; - ASSERT_EQ(0ULL, barrett_reduce_128(input, mod)); - input[0] = 1; - input[1] = 0; - ASSERT_EQ(1ULL, barrett_reduce_128(input, mod)); - input[0] = 123; - input[1] = 456; - ASSERT_EQ(0ULL, barrett_reduce_128(input, mod)); - input[0] = 0xFFFFFFFFFFFFFFFFULL; - input[1] = 0xFFFFFFFFFFFFFFFFULL; - ASSERT_EQ(0ULL, barrett_reduce_128(input, mod)); - - mod = 13131313131313ULL; - input[0] = 0; - input[1] = 0; - ASSERT_EQ(0ULL, barrett_reduce_128(input, mod)); - input[0] = 1; - input[1] = 0; - ASSERT_EQ(1ULL, barrett_reduce_128(input, mod)); - input[0] = 123; - input[1] = 456; - ASSERT_EQ(8722750765283ULL, barrett_reduce_128(input, mod)); - input[0] = 24242424242424; - input[1] = 79797979797979; - ASSERT_EQ(1010101010101ULL, barrett_reduce_128(input, mod)); - } - - TEST(UIntArithSmallMod, MultiplyUIntUIntSmallMod) - { - SmallModulus mod(2); - ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 0, mod)); - ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 1, mod)); - ASSERT_EQ(0ULL, multiply_uint_uint_mod(1, 0, mod)); - ASSERT_EQ(1ULL, multiply_uint_uint_mod(1, 1, mod)); - - mod = 10; - ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 0, mod)); - ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 1, mod)); - ASSERT_EQ(0ULL, multiply_uint_uint_mod(1, 0, mod)); - ASSERT_EQ(1ULL, multiply_uint_uint_mod(1, 1, mod)); - ASSERT_EQ(9ULL, multiply_uint_uint_mod(7, 7, mod)); - ASSERT_EQ(2ULL, multiply_uint_uint_mod(6, 7, mod)); - ASSERT_EQ(2ULL, multiply_uint_uint_mod(7, 6, mod)); - - mod = 4611686018427289601ULL; - ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 0, mod)); - ASSERT_EQ(0ULL, multiply_uint_uint_mod(0, 1, mod)); - ASSERT_EQ(0ULL, multiply_uint_uint_mod(1, 0, mod)); - ASSERT_EQ(1ULL, multiply_uint_uint_mod(1, 1, mod)); - ASSERT_EQ(1152921504606822400ULL, multiply_uint_uint_mod(2305843009213644800ULL, 2305843009213644801ULL, mod)); - ASSERT_EQ(1152921504606822400ULL, multiply_uint_uint_mod(2305843009213644801ULL, 2305843009213644800ULL, mod)); - ASSERT_EQ(3458764513820467201ULL, multiply_uint_uint_mod(2305843009213644801ULL, 2305843009213644801ULL, mod)); - ASSERT_EQ(1ULL, multiply_uint_uint_mod(4611686018427289600ULL, 4611686018427289600ULL, mod)); - } - - TEST(UIntArithSmallMod, ModuloUIntSmallMod) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto value(allocate_uint(4, pool)); - - SmallModulus mod(2); - value[0] = 0; - value[1] = 0; - value[2] = 0; - modulo_uint_inplace(value.get(), 3, mod); - ASSERT_EQ(0ULL, value[0]); - ASSERT_EQ(0ULL, value[1]); - ASSERT_EQ(0ULL, value[2]); - - value[0] = 1; - value[1] = 0; - value[2] = 0; - modulo_uint_inplace(value.get(), 3, mod); - ASSERT_EQ(1ULL, value[0]); - ASSERT_EQ(0ULL, value[1]); - ASSERT_EQ(0ULL, value[2]); - - value[0] = 2; - value[1] = 0; - value[2] = 0; - modulo_uint_inplace(value.get(), 3, mod); - ASSERT_EQ(0ULL, value[0]); - ASSERT_EQ(0ULL, value[1]); - ASSERT_EQ(0ULL, value[2]); - - value[0] = 3; - value[1] = 0; - value[2] = 0; - modulo_uint_inplace(value.get(), 3, mod); - ASSERT_EQ(1ULL, value[0]); - ASSERT_EQ(0ULL, value[1]); - ASSERT_EQ(0ULL, value[2]); - - mod = 0xFFFF; - value[0] = 9585656442714717620ul; - value[1] = 1817697005049051848; - value[2] = 0; - modulo_uint_inplace(value.get(), 3, mod); - ASSERT_EQ(65143ULL, value[0]); - ASSERT_EQ(0ULL, value[1]); - ASSERT_EQ(0ULL, value[2]); - - mod = 0x1000; - value[0] = 9585656442714717620ul; - value[1] = 1817697005049051848; - value[2] = 0; - modulo_uint_inplace(value.get(), 3, mod); - ASSERT_EQ(0xDB4ULL, value[0]); - ASSERT_EQ(0ULL, value[1]); - ASSERT_EQ(0ULL, value[2]); - - mod = 0xFFFFFFFFC001ULL; - value[0] = 9585656442714717620ul; - value[1] = 1817697005049051848; - value[2] = 14447416709120365380ul; - value[3] = 67450014862939159; - modulo_uint_inplace(value.get(), 4, mod); - ASSERT_EQ(124510066632001ULL, value[0]); - ASSERT_EQ(0ULL, value[1]); - ASSERT_EQ(0ULL, value[2]); - ASSERT_EQ(0ULL, value[3]); - } - - TEST(UIntArithSmallMod, TryInvertUIntSmallMod) - { - uint64_t result; - SmallModulus mod(5); - ASSERT_FALSE(try_invert_uint_mod(0, mod, result)); - ASSERT_TRUE(try_invert_uint_mod(1, mod, result)); - ASSERT_EQ(1ULL, result); - ASSERT_TRUE(try_invert_uint_mod(2, mod, result)); - ASSERT_EQ(3ULL, result); - ASSERT_TRUE(try_invert_uint_mod(3, mod, result)); - ASSERT_EQ(2ULL, result); - ASSERT_TRUE(try_invert_uint_mod(4, mod, result)); - ASSERT_EQ(4ULL, result); - - mod = 6; - ASSERT_FALSE(try_invert_uint_mod(2, mod, result)); - ASSERT_FALSE(try_invert_uint_mod(3, mod, result)); - ASSERT_TRUE(try_invert_uint_mod(5, mod, result)); - ASSERT_EQ(5ULL, result); - - mod = 1351315121; - ASSERT_TRUE(try_invert_uint_mod(331975426, mod, result)); - ASSERT_EQ(1052541512ULL, result); - } - - TEST(UIntArithSmallMod, TryPrimitiveRootSmallMod) - { - uint64_t result; - SmallModulus mod(11); - - ASSERT_TRUE(try_primitive_root(2, mod, result)); - ASSERT_EQ(10ULL, result); - - mod = 29; - ASSERT_TRUE(try_primitive_root(2, mod, result)); - ASSERT_EQ(28ULL, result); - - vector corrects{ 12, 17 }; - ASSERT_TRUE(try_primitive_root(4, mod, result)); - ASSERT_TRUE(std::find(corrects.begin(), corrects.end(), result) != corrects.end()); - - mod = 1234565441; - ASSERT_TRUE(try_primitive_root(2, mod, result)); - ASSERT_EQ(1234565440ULL, result); - corrects = { 984839708, 273658408, 249725733, 960907033 }; - ASSERT_TRUE(try_primitive_root(8, mod, result)); - ASSERT_TRUE(std::find(corrects.begin(), corrects.end(), result) != corrects.end()); - } - - TEST(UIntArithSmallMod, IsPrimitiveRootSmallMod) - { - SmallModulus mod(11); - ASSERT_TRUE(is_primitive_root(10, 2, mod)); - ASSERT_FALSE(is_primitive_root(9, 2, mod)); - ASSERT_FALSE(is_primitive_root(10, 4, mod)); - - mod = 29; - ASSERT_TRUE(is_primitive_root(28, 2, mod)); - ASSERT_TRUE(is_primitive_root(12, 4, mod)); - ASSERT_FALSE(is_primitive_root(12, 2, mod)); - ASSERT_FALSE(is_primitive_root(12, 8, mod)); - - - mod = 1234565441ULL; - ASSERT_TRUE(is_primitive_root(1234565440ULL, 2, mod)); - ASSERT_TRUE(is_primitive_root(960907033ULL, 8, mod)); - ASSERT_TRUE(is_primitive_root(1180581915ULL, 16, mod)); - ASSERT_FALSE(is_primitive_root(1180581915ULL, 32, mod)); - ASSERT_FALSE(is_primitive_root(1180581915ULL, 8, mod)); - ASSERT_FALSE(is_primitive_root(1180581915ULL, 2, mod)); - } - - TEST(UIntArithSmallMod, TryMinimalPrimitiveRootSmallMod) - { - SmallModulus mod(11); - - uint64_t result; - ASSERT_TRUE(try_minimal_primitive_root(2, mod, result)); - ASSERT_EQ(10ULL, result); - - mod = 29; - ASSERT_TRUE(try_minimal_primitive_root(2, mod, result)); - ASSERT_EQ(28ULL, result); - ASSERT_TRUE(try_minimal_primitive_root(4, mod, result)); - ASSERT_EQ(12ULL, result); - - mod = 1234565441; - ASSERT_TRUE(try_minimal_primitive_root(2, mod, result)); - ASSERT_EQ(1234565440ULL, result); - ASSERT_TRUE(try_minimal_primitive_root(8, mod, result)); - ASSERT_EQ(249725733ULL, result); - } - - TEST(UIntArithSmallMod, ExponentiateUIntSmallMod) - { - SmallModulus mod(5); - ASSERT_EQ(1ULL, exponentiate_uint_mod(1, 0, mod)); - ASSERT_EQ(1ULL, exponentiate_uint_mod(1, 0xFFFFFFFFFFFFFFFFULL, mod)); - ASSERT_EQ(3ULL, exponentiate_uint_mod(2, 0xFFFFFFFFFFFFFFFFULL, mod)); - - mod = 0x1000000000000000ULL; - ASSERT_EQ(0ULL, exponentiate_uint_mod(2, 60, mod)); - ASSERT_EQ(0x800000000000000ULL, exponentiate_uint_mod(2, 59, mod)); - - mod = 131313131313; - ASSERT_EQ(39418477653ULL, exponentiate_uint_mod(2424242424, 16, mod)); - } - } -} diff --git a/SEAL/native/tests/seal/util/uintcore.cpp b/SEAL/native/tests/seal/util/uintcore.cpp deleted file mode 100644 index 6d8874e..0000000 --- a/SEAL/native/tests/seal/util/uintcore.cpp +++ /dev/null @@ -1,795 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT license. - -#include "gtest/gtest.h" -#include "seal/util/uintcore.h" -#include - -using namespace seal::util; -using namespace std; - -namespace SEALTest -{ - namespace util - { - TEST(UIntCore, AllocateUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(0, pool)); - ASSERT_TRUE(nullptr == ptr.get()); - - ptr = allocate_uint(1, pool); - ASSERT_TRUE(nullptr != ptr.get()); - - ptr = allocate_uint(2, pool); - ASSERT_TRUE(nullptr != ptr.get()); - } - - TEST(UIntCore, SetZeroUInt) - { - set_zero_uint(0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(1, pool)); - ptr[0] = 0x1234567812345678; - set_zero_uint(1, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - - ptr = allocate_uint(2, pool); - ptr[0] = 0x1234567812345678; - ptr[1] = 0x1234567812345678; - set_zero_uint(2, ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - } - - TEST(UIntCore, AllocateZeroUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_zero_uint(0, pool)); - ASSERT_TRUE(nullptr == ptr.get()); - - ptr = allocate_zero_uint(1, pool); - ASSERT_TRUE(nullptr != ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - - ptr = allocate_zero_uint(2, pool); - ASSERT_TRUE(nullptr != ptr.get()); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - } - - TEST(UIntCore, SetUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(1, pool)); - ptr[0] = 0xFFFFFFFFFFFFFFFF; - set_uint(1, 1, ptr.get()); - ASSERT_EQ(1ULL, ptr[0]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - set_uint(0x1234567812345678, 1, ptr.get()); - ASSERT_EQ(static_cast(0x1234567812345678), ptr[0]); - - ptr = allocate_uint(2, pool); - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - set_uint(1, 2, ptr.get()); - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - set_uint(0x1234567812345678, 2, ptr.get()); - ASSERT_EQ(static_cast(0x1234567812345678), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - } - - TEST(UIntCore, SetUIntUInt) - { - set_uint_uint(nullptr, 0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr1(allocate_uint(1, pool)); - ptr1[0] = 0x1234567887654321; - auto ptr2(allocate_uint(1, pool)); - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - set_uint_uint(ptr1.get(), 1, ptr2.get()); - ASSERT_EQ(static_cast(0x1234567887654321), ptr2[0]); - - ptr1[0] = 0x1231231231231231; - set_uint_uint(ptr1.get(), 1, ptr1.get()); - ASSERT_EQ(static_cast(0x1231231231231231), ptr1[0]); - - ptr1 = allocate_uint(2, pool); - ptr2 = allocate_uint(2, pool); - ptr1[0] = 0x1234567887654321; - ptr1[1] = 0x8765432112345678; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - set_uint_uint(ptr1.get(), 2, ptr2.get()); - ASSERT_EQ(static_cast(0x1234567887654321), ptr2[0]); - ASSERT_EQ(static_cast(0x8765432112345678), ptr2[1]); - - ptr1[0] = 0x1231231231231321; - ptr1[1] = 0x3213213213213211; - set_uint_uint(ptr1.get(), 2, ptr1.get()); - ASSERT_EQ(static_cast(0x1231231231231321), ptr1[0]); - ASSERT_EQ(static_cast(0x3213213213213211), ptr1[1]); - } - - TEST(UIntCore, SetUIntUInt2) - { - set_uint_uint(nullptr, 0, 0, nullptr); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr1(allocate_uint(1, pool)); - ptr1[0] = 0x1234567887654321; - set_uint_uint(nullptr, 0, 1, ptr1.get()); - ASSERT_EQ(static_cast(0), ptr1[0]); - - auto ptr2(allocate_uint(1, pool)); - ptr1[0] = 0x1234567887654321; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - set_uint_uint(ptr1.get(), 1, 1, ptr2.get()); - ASSERT_EQ(static_cast(0x1234567887654321), ptr2[0]); - - ptr1[0] = 0x1231231231231231; - set_uint_uint(ptr1.get(), 1, 1, ptr1.get()); - ASSERT_EQ(static_cast(0x1231231231231231), ptr1[0]); - - ptr1 = allocate_uint(2, pool); - ptr2 = allocate_uint(2, pool); - ptr1[0] = 0x1234567887654321; - ptr1[1] = 0x8765432112345678; - set_uint_uint(nullptr, 0, 2, ptr1.get()); - ASSERT_EQ(static_cast(0), ptr1[0]); - ASSERT_EQ(static_cast(0), ptr1[1]); - - ptr1[0] = 0x1234567887654321; - ptr1[1] = 0x8765432112345678; - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - set_uint_uint(ptr1.get(), 1, 2, ptr2.get()); - ASSERT_EQ(static_cast(0x1234567887654321), ptr2[0]); - ASSERT_EQ(static_cast(0), ptr2[1]); - - ptr2[0] = 0xFFFFFFFFFFFFFFFF; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - set_uint_uint(ptr1.get(), 2, 2, ptr2.get()); - ASSERT_EQ(static_cast(0x1234567887654321), ptr2[0]); - ASSERT_EQ(static_cast(0x8765432112345678), ptr2[1]); - - ptr1[0] = 0x1231231231231321; - ptr1[1] = 0x3213213213213211; - set_uint_uint(ptr1.get(), 2, 2, ptr1.get()); - ASSERT_EQ(static_cast(0x1231231231231321), ptr1[0]); - ASSERT_EQ(static_cast(0x3213213213213211), ptr1[1]); - - set_uint_uint(ptr1.get(), 1, 2, ptr1.get()); - ASSERT_EQ(static_cast(0x1231231231231321), ptr1[0]); - ASSERT_EQ(static_cast(0), ptr1[1]); - } - - TEST(UIntCore, IsZeroUInt) - { - ASSERT_TRUE(is_zero_uint(nullptr, 0)); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(1, pool)); - ptr[0] = 1; - ASSERT_FALSE(is_zero_uint(ptr.get(), 1)); - ptr[0] = 0; - ASSERT_TRUE(is_zero_uint(ptr.get(), 1)); - - ptr = allocate_uint(2, pool); - ptr[0] = 0x8000000000000000; - ptr[1] = 0x8000000000000000; - ASSERT_FALSE(is_zero_uint(ptr.get(), 2)); - ptr[0] = 0; - ptr[1] = 0x8000000000000000; - ASSERT_FALSE(is_zero_uint(ptr.get(), 2)); - ptr[0] = 0x8000000000000000; - ptr[1] = 0; - ASSERT_FALSE(is_zero_uint(ptr.get(), 2)); - ptr[0] = 0; - ptr[1] = 0; - ASSERT_TRUE(is_zero_uint(ptr.get(), 2)); - } - - TEST(UIntCore, IsEqualUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(1, pool)); - ptr[0] = 1; - ASSERT_TRUE(is_equal_uint(ptr.get(), 1, 1)); - ASSERT_FALSE(is_equal_uint(ptr.get(), 1, 0)); - ASSERT_FALSE(is_equal_uint(ptr.get(), 1, 2)); - - ptr = allocate_uint(2, pool); - ptr[0] = 1; - ptr[1] = 1; - ASSERT_FALSE(is_equal_uint(ptr.get(), 2, 1)); - ptr[0] = 1; - ptr[1] = 0; - ASSERT_TRUE(is_equal_uint(ptr.get(), 2, 1)); - ptr[0] = 0x1234567887654321; - ptr[1] = 0; - ASSERT_TRUE(is_equal_uint(ptr.get(), 2, 0x1234567887654321)); - ASSERT_FALSE(is_equal_uint(ptr.get(), 2, 0x2234567887654321)); - } - - TEST(UIntCore, IsBitSetUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - for (int i = 0; i < 128; ++i) - { - ASSERT_FALSE(is_bit_set_uint(ptr.get(), 2, i)); - } - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - for (int i = 0; i < 128; ++i) - { - ASSERT_TRUE(is_bit_set_uint(ptr.get(), 2, i)); - } - - ptr[0] = 0x0000000000000001; - ptr[1] = 0x8000000000000000; - for (int i = 0; i < 128; ++i) - { - if (i == 0 || i == 127) - { - ASSERT_TRUE(is_bit_set_uint(ptr.get(), 2, i)); - } - else - { - ASSERT_FALSE(is_bit_set_uint(ptr.get(), 2, i)); - } - } - } - - TEST(UIntCore, IsHighBitSetUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - ASSERT_FALSE(is_high_bit_set_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_TRUE(is_high_bit_set_uint(ptr.get(), 2)); - - ptr[0] = 0; - ptr[1] = 0x8000000000000000; - ASSERT_TRUE(is_high_bit_set_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0x7FFFFFFFFFFFFFFF; - ASSERT_FALSE(is_high_bit_set_uint(ptr.get(), 2)); - } - - TEST(UIntCore, SetBitUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - set_bit_uint(ptr.get(), 2, 0); - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - set_bit_uint(ptr.get(), 2, 127); - ASSERT_EQ(1ULL, ptr[0]); - ASSERT_EQ(static_cast(0x8000000000000000), ptr[1]); - - set_bit_uint(ptr.get(), 2, 63); - ASSERT_EQ(static_cast(0x8000000000000001), ptr[0]); - ASSERT_EQ(static_cast(0x8000000000000000), ptr[1]); - - set_bit_uint(ptr.get(), 2, 64); - ASSERT_EQ(static_cast(0x8000000000000001), ptr[0]); - ASSERT_EQ(static_cast(0x8000000000000001), ptr[1]); - - set_bit_uint(ptr.get(), 2, 3); - ASSERT_EQ(static_cast(0x8000000000000009), ptr[0]); - ASSERT_EQ(static_cast(0x8000000000000001), ptr[1]); - } - - TEST(UIntCore, GetSignificantBitCountUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - ASSERT_EQ(0, get_significant_bit_count_uint(ptr.get(), 2)); - - ptr[0] = 1; - ptr[1] = 0; - ASSERT_EQ(1, get_significant_bit_count_uint(ptr.get(), 2)); - - ptr[0] = 2; - ptr[1] = 0; - ASSERT_EQ(2, get_significant_bit_count_uint(ptr.get(), 2)); - - ptr[0] = 3; - ptr[1] = 0; - ASSERT_EQ(2, get_significant_bit_count_uint(ptr.get(), 2)); - - ptr[0] = 29; - ptr[1] = 0; - ASSERT_EQ(5, get_significant_bit_count_uint(ptr.get(), 2)); - - ptr[0] = 4; - ptr[1] = 0; - ASSERT_EQ(3, get_significant_bit_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0; - ASSERT_EQ(64, get_significant_bit_count_uint(ptr.get(), 2)); - - ptr[0] = 0; - ptr[1] = 1; - ASSERT_EQ(65, get_significant_bit_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 1; - ASSERT_EQ(65, get_significant_bit_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0x7000000000000000; - ASSERT_EQ(127, get_significant_bit_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0x8000000000000000; - ASSERT_EQ(128, get_significant_bit_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_EQ(128, get_significant_bit_count_uint(ptr.get(), 2)); - } - - TEST(UIntCore, GetSignificantUInt64CountUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - ASSERT_EQ(0ULL, get_significant_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 1; - ptr[1] = 0; - ASSERT_EQ(1ULL, get_significant_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 2; - ptr[1] = 0; - ASSERT_EQ(1ULL, get_significant_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0; - ASSERT_EQ(1ULL, get_significant_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 0; - ptr[1] = 1; - ASSERT_EQ(2ULL, get_significant_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 1; - ASSERT_EQ(2ULL, get_significant_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0x8000000000000000; - ASSERT_EQ(2ULL, get_significant_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_EQ(2ULL, get_significant_uint64_count_uint(ptr.get(), 2)); - } - - TEST(UIntCore, GetNonzeroUInt64CountUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - ptr[0] = 0; - ptr[1] = 0; - ASSERT_EQ(0ULL, get_nonzero_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 1; - ptr[1] = 0; - ASSERT_EQ(1ULL, get_nonzero_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 2; - ptr[1] = 0; - ASSERT_EQ(1ULL, get_nonzero_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0; - ASSERT_EQ(1ULL, get_nonzero_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 0; - ptr[1] = 1; - ASSERT_EQ(1ULL, get_nonzero_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 1; - ASSERT_EQ(2ULL, get_nonzero_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0x8000000000000000; - ASSERT_EQ(2ULL, get_nonzero_uint64_count_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_EQ(2ULL, get_nonzero_uint64_count_uint(ptr.get(), 2)); - } - - TEST(UIntCore, GetPowerOfTwoUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_zero_uint(2, pool)); - ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 1)); - ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 1)); - ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 2)); - - ptr[0] = 0x0000000000000001; - ptr[1] = 0x0000000000000000; - ASSERT_EQ(0, get_power_of_two_uint(ptr.get(), 1)); - ASSERT_EQ(0, get_power_of_two_uint(ptr.get(), 2)); - - ptr[0] = 0x0000000000000001; - ptr[1] = 0x8000000000000000; - ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 2)); - - ptr[0] = 0x0000000000000000; - ptr[1] = 0x8000000000000000; - ASSERT_EQ(127, get_power_of_two_uint(ptr.get(), 2)); - - ptr[0] = 0x8000000000000000; - ptr[1] = 0x0000000000000000; - ASSERT_EQ(63, get_power_of_two_uint(ptr.get(), 2)); - - ptr[0] = 0x9000000000000000; - ptr[1] = 0x0000000000000000; - ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 2)); - - ptr[0] = 0x8000000000000001; - ptr[1] = 0x0000000000000000; - ASSERT_EQ(-1, get_power_of_two_uint(ptr.get(), 2)); - - ptr[0] = 0x0000000000000000; - ptr[1] = 0x0000000000000001; - ASSERT_EQ(64, get_power_of_two_uint(ptr.get(), 2)); - } - - TEST(UIntCore, GetPowerOfTwoMinusOneUInt) - { - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_zero_uint(2, pool)); - ASSERT_EQ(0, get_power_of_two_minus_one_uint(ptr.get(), 1)); - ASSERT_EQ(0, get_power_of_two_minus_one_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_EQ(64, get_power_of_two_minus_one_uint(ptr.get(), 1)); - ASSERT_EQ(128, get_power_of_two_minus_one_uint(ptr.get(), 2)); - - ptr[0] = 0x0000000000000001; - ptr[1] = 0x0000000000000000; - ASSERT_EQ(1, get_power_of_two_minus_one_uint(ptr.get(), 1)); - ASSERT_EQ(1, get_power_of_two_minus_one_uint(ptr.get(), 2)); - - ptr[0] = 0x0000000000000001; - ptr[1] = 0x8000000000000000; - ASSERT_EQ(-1, get_power_of_two_minus_one_uint(ptr.get(), 2)); - - ptr[0] = 0x0000000000000000; - ptr[1] = 0x8000000000000000; - ASSERT_EQ(-1, get_power_of_two_minus_one_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0x7FFFFFFFFFFFFFFF; - ASSERT_EQ(127, get_power_of_two_minus_one_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFE; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_EQ(-1, get_power_of_two_minus_one_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0x0000000000000000; - ASSERT_EQ(64, get_power_of_two_minus_one_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFE; - ptr[1] = 0x0000000000000000; - ASSERT_EQ(-1, get_power_of_two_minus_one_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0x0000000000000001; - ASSERT_EQ(65, get_power_of_two_minus_one_uint(ptr.get(), 2)); - - ptr[0] = 0xFFFFFFFFFFFFFFFE; - ptr[1] = 0x0000000000000001; - ASSERT_EQ(-1, get_power_of_two_minus_one_uint(ptr.get(), 2)); - } - - TEST(UIntCore, FilterHighBitsUInt) - { - filter_highbits_uint(nullptr, 0, 0); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr(allocate_uint(2, pool)); - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - filter_highbits_uint(ptr.get(), 2, 0); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - ptr[0] = 0xFFFFFFFFFFFFFFFF; - ptr[1] = 0xFFFFFFFFFFFFFFFF; - filter_highbits_uint(ptr.get(), 2, 128); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[1]); - filter_highbits_uint(ptr.get(), 2, 127); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); - ASSERT_EQ(static_cast(0x7FFFFFFFFFFFFFFF), ptr[1]); - filter_highbits_uint(ptr.get(), 2, 126); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); - ASSERT_EQ(static_cast(0x3FFFFFFFFFFFFFFF), ptr[1]); - filter_highbits_uint(ptr.get(), 2, 64); - ASSERT_EQ(static_cast(0xFFFFFFFFFFFFFFFF), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - filter_highbits_uint(ptr.get(), 2, 63); - ASSERT_EQ(static_cast(0x7FFFFFFFFFFFFFFF), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - filter_highbits_uint(ptr.get(), 2, 2); - ASSERT_EQ(static_cast(0x3), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - filter_highbits_uint(ptr.get(), 2, 1); - ASSERT_EQ(static_cast(0x1), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - filter_highbits_uint(ptr.get(), 2, 0); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - - filter_highbits_uint(ptr.get(), 2, 128); - ASSERT_EQ(static_cast(0), ptr[0]); - ASSERT_EQ(static_cast(0), ptr[1]); - } - - TEST(UIntCore, CompareUIntUInt) - { - ASSERT_EQ(0, compare_uint_uint(nullptr, nullptr, 0)); - ASSERT_TRUE(is_equal_uint_uint(nullptr, nullptr, 0)); - ASSERT_FALSE(is_not_equal_uint_uint(nullptr, nullptr, 0)); - ASSERT_FALSE(is_greater_than_uint_uint(nullptr, nullptr, 0)); - ASSERT_FALSE(is_less_than_uint_uint(nullptr, nullptr, 0)); - ASSERT_TRUE(is_greater_than_or_equal_uint_uint(nullptr, nullptr, 0)); - ASSERT_TRUE(is_less_than_or_equal_uint_uint(nullptr, nullptr, 0)); - - MemoryPool &pool = *global_variables::global_memory_pool; - auto ptr1(allocate_uint(2, pool)); - auto ptr2(allocate_uint(2, pool)); - ptr1[0] = 0; - ptr1[1] = 0; - ptr2[0] = 0; - ptr2[1] = 0; - ASSERT_EQ(0, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - - ptr1[0] = 0x1234567887654321; - ptr1[1] = 0x8765432112345678; - ptr2[0] = 0x1234567887654321; - ptr2[1] = 0x8765432112345678; - ASSERT_EQ(0, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - - ptr1[0] = 1; - ptr1[1] = 0; - ptr2[0] = 2; - ptr2[1] = 0; - ASSERT_EQ(-1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - - ptr1[0] = 1; - ptr1[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 2; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_EQ(-1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - - ptr1[0] = 0xFFFFFFFFFFFFFFFF; - ptr1[1] = 0x0000000000000001; - ptr2[0] = 0x0000000000000000; - ptr2[1] = 0x0000000000000002; - ASSERT_EQ(-1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - - ptr1[0] = 2; - ptr1[1] = 0; - ptr2[0] = 1; - ptr2[1] = 0; - ASSERT_EQ(1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - - ptr1[0] = 2; - ptr1[1] = 0xFFFFFFFFFFFFFFFF; - ptr2[0] = 1; - ptr2[1] = 0xFFFFFFFFFFFFFFFF; - ASSERT_EQ(1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - - ptr1[0] = 0xFFFFFFFFFFFFFFFF; - ptr1[1] = 0x0000000000000003; - ptr2[0] = 0x0000000000000000; - ptr2[1] = 0x0000000000000002; - ASSERT_EQ(1, compare_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_not_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_greater_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_less_than_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_TRUE(is_greater_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - ASSERT_FALSE(is_less_than_or_equal_uint_uint(ptr1.get(), ptr2.get(), 2)); - } - - TEST(UIntCore, GetPowerOfTwo) - { - ASSERT_EQ(-1, get_power_of_two(0)); - ASSERT_EQ(0, get_power_of_two(1)); - ASSERT_EQ(1, get_power_of_two(2)); - ASSERT_EQ(-1, get_power_of_two(3)); - ASSERT_EQ(2, get_power_of_two(4)); - ASSERT_EQ(-1, get_power_of_two(5)); - ASSERT_EQ(-1, get_power_of_two(6)); - ASSERT_EQ(-1, get_power_of_two(7)); - ASSERT_EQ(3, get_power_of_two(8)); - ASSERT_EQ(-1, get_power_of_two(15)); - ASSERT_EQ(4, get_power_of_two(16)); - ASSERT_EQ(-1, get_power_of_two(17)); - ASSERT_EQ(-1, get_power_of_two(255)); - ASSERT_EQ(8, get_power_of_two(256)); - ASSERT_EQ(-1, get_power_of_two(257)); - ASSERT_EQ(10, get_power_of_two(1 << 10)); - ASSERT_EQ(30, get_power_of_two(1 << 30)); - ASSERT_EQ(32, get_power_of_two(1ULL << 32)); - ASSERT_EQ(62, get_power_of_two(1ULL << 62)); - ASSERT_EQ(63, get_power_of_two(1ULL << 63)); - } - - TEST(UIntCore, GetPowerOfTwoMinusOne) - { - ASSERT_EQ(0, get_power_of_two_minus_one(0)); - ASSERT_EQ(1, get_power_of_two_minus_one(1)); - ASSERT_EQ(-1, get_power_of_two_minus_one(2)); - ASSERT_EQ(2, get_power_of_two_minus_one(3)); - ASSERT_EQ(-1, get_power_of_two_minus_one(4)); - ASSERT_EQ(-1, get_power_of_two_minus_one(5)); - ASSERT_EQ(-1, get_power_of_two_minus_one(6)); - ASSERT_EQ(3, get_power_of_two_minus_one(7)); - ASSERT_EQ(-1, get_power_of_two_minus_one(8)); - ASSERT_EQ(-1, get_power_of_two_minus_one(14)); - ASSERT_EQ(4, get_power_of_two_minus_one(15)); - ASSERT_EQ(-1, get_power_of_two_minus_one(16)); - ASSERT_EQ(8, get_power_of_two_minus_one(255)); - ASSERT_EQ(10, get_power_of_two_minus_one((1 << 10) - 1)); - ASSERT_EQ(30, get_power_of_two_minus_one((1 << 30) - 1)); - ASSERT_EQ(32, get_power_of_two_minus_one((1ULL << 32) - 1)); - ASSERT_EQ(63, get_power_of_two_minus_one((1ULL << 63) - 1)); - ASSERT_EQ(64, get_power_of_two_minus_one(~static_cast(0))); - } - - TEST(UIntCore, DuplicateUIntIfNeeded) - { - //MemoryPool &pool = *global_variables::global_memory_pool; - MemoryPoolST pool; - auto ptr(allocate_uint(2, pool)); - ptr[0] = 0xF0F0F0F0F0; - ptr[1] = 0xABABABABAB; - auto ptr2 = duplicate_uint_if_needed(ptr.get(), 0, 0, false, pool); - // No forcing and sizes are same (although zero) so just alias - ASSERT_TRUE(ptr2.get() == ptr.get()); - - ptr2 = duplicate_uint_if_needed(ptr.get(), 0, 0, true, pool); - // Forcing and size is zero so return nullptr - ASSERT_TRUE(ptr2.get() == nullptr); - - ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 0, false, pool); - ASSERT_TRUE(ptr2.get() == ptr.get()); - - ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 0, true, pool); - ASSERT_TRUE(ptr2.get() == nullptr); - - ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 1, false, pool); - ASSERT_TRUE(ptr2.get() == ptr.get()); - - ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 1, true, pool); - ASSERT_TRUE(ptr2.get() != ptr.get()); - ASSERT_EQ(ptr[0], ptr2[0]); - - ptr2 = duplicate_uint_if_needed(ptr.get(), 2, 2, true, pool); - ASSERT_TRUE(ptr2.get() != ptr.get()); - ASSERT_EQ(ptr[0], ptr2[0]); - ASSERT_EQ(ptr[1], ptr2[1]); - - ptr2 = duplicate_uint_if_needed(ptr.get(), 2, 2, false, pool); - ASSERT_TRUE(ptr2.get() == ptr.get()); - - ptr2 = duplicate_uint_if_needed(ptr.get(), 2, 1, false, pool); - ASSERT_TRUE(ptr2.get() == ptr.get()); - - ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 2, false, pool); - ASSERT_TRUE(ptr2.get() != ptr.get()); - ASSERT_EQ(ptr[0], ptr2[0]); - ASSERT_EQ(0ULL, ptr2[1]); - - ptr2 = duplicate_uint_if_needed(ptr.get(), 1, 2, true, pool); - ASSERT_TRUE(ptr2.get() != ptr.get()); - ASSERT_EQ(ptr[0], ptr2[0]); - ASSERT_EQ(0ULL, ptr2[1]); - } - - TEST(UIntCore, HammingWeight) - { - ASSERT_EQ(0ULL, hamming_weight(0ULL)); - ASSERT_EQ(1ULL, hamming_weight(1ULL)); - ASSERT_EQ(1ULL, hamming_weight(0x10000ULL)); - ASSERT_EQ(2ULL, hamming_weight(0x10001ULL)); - ASSERT_EQ(32ULL, hamming_weight(0xFFFFFFFFULL)); - ASSERT_EQ(64ULL, hamming_weight(0xFFFFFFFFFFFFFFFFULL)); - ASSERT_EQ(32ULL, hamming_weight(0xF0F0F0F0F0F0F0F0ULL)); - ASSERT_EQ(16ULL, hamming_weight(0xA0A0A0A0A0A0A0A0ULL)); - } - - TEST(UIntCore, HammingWeightSplit) - { - ASSERT_EQ(0ULL, hamming_weight_split(0ULL)); - ASSERT_EQ(1ULL, hamming_weight_split(1ULL)); - ASSERT_EQ(0x10000ULL, hamming_weight_split(0x10000ULL)); - ASSERT_EQ(1ULL, hamming_weight_split(0x10001ULL)); - ASSERT_EQ(0xFFFFULL, hamming_weight_split(0xFFFFFFFFULL)); - ASSERT_EQ(0xFFFFFFFFULL, hamming_weight_split(0xFFFFFFFFFFFFFFFFULL)); - ASSERT_EQ(0xF0F0F00ULL, hamming_weight_split(0xF0F0F0000F0F0F00ULL)); - ASSERT_EQ(0xA0A0A0A0ULL, hamming_weight_split(0xA0A0A0A0A0A0A0A0ULL)); - } - } -} diff --git a/build_docker.sh b/build_docker.sh deleted file mode 100755 index e71fabe..0000000 --- a/build_docker.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -docker build -t py_seal . diff --git a/examples/4_bgv_basics.py b/examples/4_bgv_basics.py new file mode 100644 index 0000000..ec81cf7 --- /dev/null +++ b/examples/4_bgv_basics.py @@ -0,0 +1,113 @@ +from seal import * +import numpy as np + +def print_vector(vector): + print('[ ', end='') + for i in range(0, 8): + print(vector[i], end=', ') + print('... ]') + + +def example_bgv_basics(): + parms = EncryptionParameters (scheme_type.bgv) + poly_modulus_degree = 8192 + parms.set_poly_modulus_degree(poly_modulus_degree) + parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) + parms.set_plain_modulus(PlainModulus.Batching(poly_modulus_degree, 20)) + context = SEALContext(parms) + + keygen = KeyGenerator(context) + secret_key = keygen.secret_key() + public_key = keygen.create_public_key() + relin_keys = keygen.create_relin_keys() + + encryptor = Encryptor(context, public_key) + evaluator = Evaluator(context) + decryptor = Decryptor(context, secret_key) + + batch_encoder = BatchEncoder(context) + slot_count = batch_encoder.slot_count() + row_size = slot_count / 2 + print(f'Plaintext matrix row size: {row_size}') + + pod_matrix = [0] * slot_count + pod_matrix[0] = 1 + pod_matrix[1] = 2 + pod_matrix[2] = 3 + pod_matrix[3] = 4 + + x_plain = batch_encoder.encode(pod_matrix) + + x_encrypted = encryptor.encrypt(x_plain) + print(f'noise budget in freshly encrypted x: {decryptor.invariant_noise_budget(x_encrypted)}') + print('-'*50) + + x_squared = evaluator.square(x_encrypted) + print(f'size of x_squared: {x_squared.size()}') + evaluator.relinearize_inplace(x_squared, relin_keys) + print(f'size of x_squared (after relinearization): {x_squared.size()}') + print(f'noise budget in x_squared: {decryptor.invariant_noise_budget(x_squared)} bits') + decrypted_result = decryptor.decrypt(x_squared) + pod_result = batch_encoder.decode(decrypted_result) + print_vector(pod_result) + print('-'*50) + + x_4th = evaluator.square(x_squared) + print(f'size of x_4th: {x_4th.size()}') + evaluator.relinearize_inplace(x_4th, relin_keys) + print(f'size of x_4th (after relinearization): { x_4th.size()}') + print(f'noise budget in x_4th: {decryptor.invariant_noise_budget(x_4th)} bits') + decrypted_result = decryptor.decrypt(x_4th) + pod_result = batch_encoder.decode(decrypted_result) + print_vector(pod_result) + print('-'*50) + + x_8th = evaluator.square(x_4th) + print(f'size of x_8th: {x_8th.size()}') + evaluator.relinearize_inplace(x_8th, relin_keys) + print(f'size of x_8th (after relinearization): { x_8th.size()}') + print(f'noise budget in x_8th: {decryptor.invariant_noise_budget(x_8th)} bits') + decrypted_result = decryptor.decrypt(x_8th) + pod_result = batch_encoder.decode(decrypted_result) + print_vector(pod_result) + print('run out of noise budget') + print('-'*100) + + x_encrypted = encryptor.encrypt(x_plain) + print(f'noise budget in freshly encrypted x: {decryptor.invariant_noise_budget(x_encrypted)}') + print('-'*50) + + x_squared = evaluator.square(x_encrypted) + print(f'size of x_squared: {x_squared.size()}') + evaluator.relinearize_inplace(x_squared, relin_keys) + evaluator.mod_switch_to_next_inplace(x_squared) + print(f'noise budget in x_squared (with modulus switching): {decryptor.invariant_noise_budget(x_squared)} bits') + decrypted_result = decryptor.decrypt(x_squared) + pod_result = batch_encoder.decode(decrypted_result) + print_vector(pod_result) + print('-'*50) + + x_4th = evaluator.square(x_squared) + print(f'size of x_4th: {x_4th.size()}') + evaluator.relinearize_inplace(x_4th, relin_keys) + evaluator.mod_switch_to_next_inplace(x_4th) + print(f'size of x_4th (after relinearization): { x_4th.size()}') + print(f'noise budget in x_4th (with modulus switching): {decryptor.invariant_noise_budget(x_4th)} bits') + decrypted_result = decryptor.decrypt(x_4th) + pod_result = batch_encoder.decode(decrypted_result) + print_vector(pod_result) + print('-'*50) + + x_8th = evaluator.square(x_4th) + print(f'size of x_8th: {x_8th.size()}') + evaluator.relinearize_inplace(x_8th, relin_keys) + evaluator.mod_switch_to_next_inplace(x_8th) + print(f'size of x_8th (after relinearization): { x_8th.size()}') + print(f'noise budget in x_8th (with modulus switching): {decryptor.invariant_noise_budget(x_8th)} bits') + decrypted_result = decryptor.decrypt(x_8th) + pod_result = batch_encoder.decode(decrypted_result) + print_vector(pod_result) + + +if __name__ == "__main__": + example_bgv_basics() diff --git a/examples/7_serialization.py b/examples/7_serialization.py new file mode 100644 index 0000000..07bc5d2 --- /dev/null +++ b/examples/7_serialization.py @@ -0,0 +1,73 @@ +from seal import * +import pickle +import time + + +def get_seal(): + parms = EncryptionParameters(scheme_type.ckks) + poly_modulus_degree = 8192 + parms.set_poly_modulus_degree(poly_modulus_degree) + parms.set_coeff_modulus(CoeffModulus.Create(poly_modulus_degree, [60, 40, 40, 60])) + scale = 2.0 ** 40 + + context = SEALContext(parms) + ckks_encoder = CKKSEncoder(context) + slot_count = ckks_encoder.slot_count() + + keygen = KeyGenerator(context) + public_key = keygen.create_public_key() + secret_key = keygen.secret_key() + + encryptor = Encryptor(context, public_key) + # evaluator = Evaluator(context) + decryptor = Decryptor(context, secret_key) + + data = [3.1415926] * slot_count + plain = ckks_encoder.encode(data, scale) + cipher = encryptor.encrypt(plain) + + return cipher, context, ckks_encoder, decryptor + + +def serialization_example(): + print('serialization example') + print('-' * 70) + cipher2, context2, ckks_encoder2, decryptor2 = get_seal() + cipher2.save('cipher2.bin') + print('save cipher2 data success') + + time.sleep(.5) + + cipher3 = Ciphertext() + cipher3.load(context2, 'cipher2.bin') + print('load cipher2 data success') + plain3 = decryptor2.decrypt(cipher3) + data3 = ckks_encoder2.decode(plain3) + print(data3) + print('-' * 70) + + +def pickle_example(): + print('pickle example') + print('-' * 70) + cipher1, context1, ckks_encoder1, decryptor1 = get_seal() + with open('cipher1.bin', 'wb') as f: + pickle.dump(cipher1.to_string(), f) + print('write cipher1 data success') + + time.sleep(.5) + + with open('cipher1.bin', 'rb') as f: + temp = pickle.load(f) + cipher2 = context1.from_cipher_str(temp) + plain2 = decryptor1.decrypt(cipher2) + data = ckks_encoder1.decode(plain2) + print('read cipher1 data success') + print(data) + + print('-' * 70) + + +if __name__ == "__main__": + serialization_example() + pickle_example() diff --git a/examples/matrix_operations.py b/examples/matrix_operations.py new file mode 100644 index 0000000..f1c4567 --- /dev/null +++ b/examples/matrix_operations.py @@ -0,0 +1,334 @@ +import sys +import time +import math +import numpy as np +from seal import * +from seal_helper import * + + +def get_diagonal(position, matrix): + n = matrix.shape[0] + diagonal = np.zeros(n) + + k = 0 + i = 0 + j = position + while i < n-position and j < n: + diagonal[k] = matrix[i][j] + i += 1 + j += 1 + k += 1 + + i = n - position + j = 0 + while i < n and j < position: + diagonal[k] = matrix[i][j] + i += 1 + j += 1 + k += 1 + + return diagonal + + +def get_all_diagonals(matrix): + matrix_diagonals = [] + for i in range(matrix.shape[0]): + matrix_diagonals.append(get_diagonal(i, matrix)) + + return np.array(matrix_diagonals) + + +def get_u_transpose(shape): + u_transpose = np.zeros((shape[0]**2, shape[1]**2)) + n = shape[0] + k = 0 + i = 0 + for row in u_transpose: + row[k+i] = 1 + k += n + if k >= n*n: + k = 0 + i += 1 + + return u_transpose + + +def get_transposed_diagonals(u_transposed): + transposed_diagonals = np.zeros(u_transposed.shape) + for i in range(u_transposed.shape[0]): + a = np.diagonal(u_transposed, offset=i) + b = np.diagonal(u_transposed, offset=u_transposed.shape[0]-i) + transposed_diagonals[i] = np.concatenate([a, b]) + + return transposed_diagonals + + +def linear_transform_plain(cipher_matrix, plain_diags, galois_keys, evaluator): + cipher_rot = evaluator.rotate_vector(cipher_matrix, -len(plain_diags), galois_keys) + cipher_temp = evaluator.add(cipher_matrix, cipher_rot) + cipher_results = [] + temp = evaluator.multiply_plain(cipher_temp, plain_diags[0]) + cipher_results.append(temp) + + i = 1 + while i < len(plain_diags): + temp_rot = evaluator.rotate_vector(cipher_temp, i, galois_keys) + temp = evaluator.multiply_plain(temp_rot, plain_diags[i]) + cipher_results.append(temp) + i += 1 + + cipher_prime = evaluator.add_many(cipher_results) + + return cipher_prime + + +def get_u_sigma(shape): + u_sigma_ = np.zeros(shape) + indices_diagonal = np.diag_indices(shape[0]) + u_sigma_[indices_diagonal] = 1. + + for i in range(shape[0]-1): + u_sigma_ = np.pad(u_sigma_, (0, shape[0]), 'constant') + temp = np.zeros(shape) + j = np.arange(0, shape[0]) + temp[j, j-(shape[0]-1-i)] = 1. + temp = np.pad(temp, ((i+1)*shape[0], 0), 'constant') + u_sigma_ += temp + + return u_sigma_ + + +def get_u_tau(shape): + u_tau_ = np.zeros((shape[0], shape[0]**2)) + index = np.arange(shape[0]) + for i in range(shape[0], 0, -1): + idx = np.concatenate([index[i:], index[:i]], axis=0) + row = np.zeros(shape) + for j in range(shape[0]): + temp = np.zeros(shape) + temp[idx[j], idx[j]] = 1. + if j == 0: + row += temp + else: + row = np.concatenate([row, temp], axis=1) + + if i == shape[0]: + u_tau_ += row + else: + u_tau_ = np.concatenate([u_tau_, row], axis=0) + + return u_tau_ + + +def get_v_k(shape): + v_k_ = [] + index = np.arange(0, shape[0]) + for j in range(1, shape[0]): + temp = np.zeros(shape) + temp[index, index-(shape[0]-j)] = 1. + mat = temp + for i in range(shape[0]-1): + mat = np.pad(mat, (0, shape[0]), 'constant') + temp2 = np.pad(temp, ((i+1)*shape[0], 0), 'constant') + mat += temp2 + + v_k_.append(mat) + + return v_k_ + + +def get_w_k(shape): + w_k_ = [] + index = np.arange(shape[0]**2) + for i in range(shape[0]-1): + temp = np.zeros((shape[0]**2, shape[1]**2)) + temp[index-(i+1)*shape[0], index] = 1. + w_k_.append(temp) + + return w_k_ + + +def matrix_multiplication(n, cm1, cm2, sigma, tau, v, w, galois_keys, evaluator): + cipher_result1 = [] + cipher_result2 = [] + + cipher_result1.append(linear_transform_plain(cm1, sigma, galois_keys, evaluator)) + cipher_result2.append(linear_transform_plain(cm2, tau, galois_keys, evaluator)) + + for i in range(1, n): + cipher_result1.append(linear_transform_plain(cipher_result1[0], v[i-1], galois_keys, evaluator)) + cipher_result2.append(linear_transform_plain(cipher_result2[0], w[i-1], galois_keys, evaluator)) + + for i in range(1, n): + evaluator.rescale_to_next_inplace(cipher_result1[i]) + evaluator.rescale_to_next_inplace(cipher_result2[i]) + + cipher_mult = evaluator.multiply(cipher_result1[0], cipher_result2[0]) + evaluator.mod_switch_to_next_inplace(cipher_mult) + + for i in range(1, n): + cipher_result1[i].scale(2**int(math.log2(cipher_result1[i].scale()))) + cipher_result2[i].scale(2**int(math.log2(cipher_result2[i].scale()))) + + for i in range(1, n): + temp = evaluator.multiply(cipher_result1[i], cipher_result2[i]) + evaluator.add_inplace(cipher_mult, temp) + + return cipher_mult + + +def matrix_mult_test(n=4): + parms = EncryptionParameters(scheme_type.ckks) + poly_modulus_degree = 16384 + parms.set_poly_modulus_degree(poly_modulus_degree) + parms.set_coeff_modulus(CoeffModulus.Create( + poly_modulus_degree, [60, 40, 40, 40, 40, 60])) + scale = 2.0**40 + context = SEALContext(parms) + print_parameters(context) + + ckks_encoder = CKKSEncoder(context) + slot_count = ckks_encoder.slot_count() + print(f'Number of slots: {slot_count}') + + keygen = KeyGenerator(context) + public_key = keygen.create_public_key() + secret_key = keygen.secret_key() + galois_keys = keygen.create_galois_keys() + + encryptor = Encryptor(context, public_key) + evaluator = Evaluator(context) + decryptor = Decryptor(context, secret_key) + + # --------------------------------------------------------- + u_sigma = get_u_sigma((n,n)) + u_tau = get_u_tau((n,n)) + v_k = get_v_k((n, n)) + w_k = get_w_k((n, n)) + + u_sigma_diagonals = get_all_diagonals(u_sigma) + u_sigma_diagonals += 0.00000001 # prevent is_transparent + + u_tau_diagonals = get_all_diagonals(u_tau) + u_tau_diagonals += 0.00000001 + + v_k_diagonals = [] + for v in v_k: + diags = get_all_diagonals(v) + diags += 0.00000001 + v_k_diagonals.append(diags) + + w_k_diagonals = [] + for w in w_k: + diags = get_all_diagonals(w) + diags += 0.00000001 + w_k_diagonals.append(diags) + + plain_u_sigma_diagonals = [] + plain_u_tau_diagonals = [] + plain_v_k_diagonals = [] + plain_w_k_diagonals = [] + + # --------------------------------------------------------- + for i in range(n**2): + plain_u_sigma_diagonals.append(ckks_encoder.encode(u_sigma_diagonals[i], scale)) + plain_u_tau_diagonals.append(ckks_encoder.encode(u_tau_diagonals[i], scale)) + + for i in range(n-1): + temp1 = [] + temp2 = [] + for j in range(n**2): + temp1.append(ckks_encoder.encode(v_k_diagonals[i][j], scale)) + temp2.append(ckks_encoder.encode(w_k_diagonals[i][j], scale)) + + plain_v_k_diagonals.append(temp1) + plain_w_k_diagonals.append(temp2) + + # matrix1 = np.random.rand(n, n) + matrix1 = np.arange(1, n*n+1).reshape(n, n) + matrix2 = matrix1 + print('Plaintext result:') + print(np.dot(matrix1, matrix2)) + + plain_matrix1 = ckks_encoder.encode(matrix1.flatten(), scale) + plain_matrix2 = ckks_encoder.encode(matrix2.flatten(), scale) + cipher_matrix1 = encryptor.encrypt(plain_matrix1) + cipher_matrix2 = encryptor.encrypt(plain_matrix2) + + # --------------------------------------------------------- + start = time.time() + cipher_result = matrix_multiplication(n, cipher_matrix1, cipher_matrix2, plain_u_sigma_diagonals, plain_u_tau_diagonals, plain_v_k_diagonals, plain_w_k_diagonals, galois_keys, evaluator) + end = time.time() + + # --------------------------------------------------------- + plain = decryptor.decrypt(cipher_result) + vec = ckks_encoder.decode(plain) + print('Ciphertext result:') + print(vec[:n**2].reshape(n, n)) + print(f'Mult Time: {(end-start):.3f}s') + + +def matrix_transpose_test(n=4): + parms = EncryptionParameters(scheme_type.ckks) + poly_modulus_degree = 8192 + parms.set_poly_modulus_degree(poly_modulus_degree) + parms.set_coeff_modulus(CoeffModulus.Create( + poly_modulus_degree, [60, 40, 40, 60])) + scale = 2.0**40 + context = SEALContext(parms) + print_parameters(context) + + ckks_encoder = CKKSEncoder(context) + slot_count = ckks_encoder.slot_count() + print(f'Number of slots: {slot_count}') + + keygen = KeyGenerator(context) + public_key = keygen.create_public_key() + secret_key = keygen.secret_key() + galois_keys = keygen.create_galois_keys() + + encryptor = Encryptor(context, public_key) + evaluator = Evaluator(context) + decryptor = Decryptor(context, secret_key) + + # --------------------------------------------------------- + # matrix = np.random.rand(n, n) + matrix = np.arange(1, n*n+1).reshape(n, n) + print('Plaintext result:') + print(matrix) + + u_transposed = get_u_transpose(matrix.shape) + u_transposed_diagonals = get_transposed_diagonals(u_transposed) + u_transposed_diagonals += 0.00000001 # Prevent is_transparent + + # --------------------------------------------------------- + plain_u_diag = [] + for row in u_transposed_diagonals: + plain_u_diag.append(ckks_encoder.encode(row, scale)) + + plain_matrix = ckks_encoder.encode(matrix.flatten(), scale) + cipher_matrix = encryptor.encrypt(plain_matrix) + + # --------------------------------------------------------- + start = time.time() + cipher_result = linear_transform_plain( + cipher_matrix, plain_u_diag, galois_keys, evaluator) + end = time.time() + + # --------------------------------------------------------- + p1 = decryptor.decrypt(cipher_result) + vec = ckks_encoder.decode(p1) + print('Ciphertext result:') + print(vec[:n**2].reshape(n, n)) + print(f'Trans Time: {(end-start):.3f}s') + + +if __name__ == "__main__": + args = sys.argv[1:] + n = int(args[0]) if args else 4 + print(f'n: {n}') + print('-'*18 + 'Matrix Transpose:' + '-'*18) + matrix_transpose_test(n) + + print('-'*18 + 'Matrix Multiplication:' + '-'*18) + matrix_mult_test(n) diff --git a/examples/seal_helper.py b/examples/seal_helper.py new file mode 100644 index 0000000..a3d7980 --- /dev/null +++ b/examples/seal_helper.py @@ -0,0 +1,54 @@ +from seal import scheme_type + + +def print_example_banner(title): + title_length = len(title) + banner_length = title_length + 2 * 10 + banner_top = '+' + '-' * (banner_length - 2) + '+' + banner_middle = '|' + ' ' * 9 + title + ' ' * 9 + '|' + print(banner_top) + print(banner_middle) + print(banner_top) + + +def print_parameters(context): + context_data = context.key_context_data() + if context_data.parms().scheme() == scheme_type.bfv: + scheme_name = 'bfv' + elif context_data.parms().scheme() == scheme_type.ckks: + scheme_name = 'ckks' + else: + scheme_name = 'none' + print('/') + print('| Encryption parameters') + print('| scheme: ' + scheme_name) + print(f'| poly_modulus_degree: {context_data.parms().poly_modulus_degree()}') + coeff_modulus = context_data.parms().coeff_modulus() + coeff_modulus_sum = 0 + for j in coeff_modulus: + coeff_modulus_sum += j.bit_count() + print(f'| coeff_modulus size: {coeff_modulus_sum}(', end='') + for i in range(len(coeff_modulus) - 1): + print(f'{coeff_modulus[i].bit_count()} + ', end='') + print(f'{coeff_modulus[-1].bit_count()}) bits') + if context_data.parms().scheme() == scheme_type.bfv: + print(f'| plain_modulus: {context_data.parms().plain_modulus().value()}') + print('\\') + + +def print_vector(vec, print_size=4, prec=3): + slot_count = len(vec) + print() + if slot_count <= 2*print_size: + print(' [', end='') + for i in range(slot_count): + print(f' {vec[i]:.{prec}f}' + (',' if (i != slot_count - 1) else ' ]\n'), end='') + else: + print(' [', end='') + for i in range(print_size): + print(f' {vec[i]:.{prec}f},', end='') + if slot_count > 2*print_size: + print(' ...,', end='') + for i in range(slot_count - print_size, slot_count): + print(f' {vec[i]:.{prec}f}' + (',' if (i != slot_count - 1) else ' ]\n'), end='') + print() diff --git a/pybind11 b/pybind11 new file mode 160000 index 0000000..ffa3468 --- /dev/null +++ b/pybind11 @@ -0,0 +1 @@ +Subproject commit ffa346860b306c9bbfb341aed9c14c067751feb8 diff --git a/pybind11/.gitignore b/pybind11/.gitignore deleted file mode 100644 index 979fd44..0000000 --- a/pybind11/.gitignore +++ /dev/null @@ -1,38 +0,0 @@ -CMakeCache.txt -CMakeFiles -Makefile -cmake_install.cmake -.DS_Store -*.so -*.pyd -*.dll -*.sln -*.sdf -*.opensdf -*.vcxproj -*.filters -example.dir -Win32 -x64 -Release -Debug -.vs -CTestTestfile.cmake -Testing -autogen -MANIFEST -/.ninja_* -/*.ninja -/docs/.build -*.py[co] -*.egg-info -*~ -.*.swp -.DS_Store -/dist -/build -/cmake/ -.cache/ -sosize-*.txt -pybind11Config*.cmake -pybind11Targets.cmake diff --git a/pybind11/.gitmodules b/pybind11/.gitmodules deleted file mode 100644 index d063a8e..0000000 --- a/pybind11/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "tools/clang"] - path = tools/clang - url = ../../wjakob/clang-cindex-python3 diff --git a/pybind11/CMakeLists.txt b/pybind11/CMakeLists.txt deleted file mode 100644 index 85ecd90..0000000 --- a/pybind11/CMakeLists.txt +++ /dev/null @@ -1,157 +0,0 @@ -# CMakeLists.txt -- Build system for the pybind11 modules -# -# Copyright (c) 2015 Wenzel Jakob -# -# All rights reserved. Use of this source code is governed by a -# BSD-style license that can be found in the LICENSE file. - -cmake_minimum_required(VERSION 2.8.12) - -if (POLICY CMP0048) - # cmake warns if loaded from a min-3.0-required parent dir, so silence the warning: - cmake_policy(SET CMP0048 NEW) -endif() - -# CMake versions < 3.4.0 do not support try_compile/pthread checks without C as active language. -if(CMAKE_VERSION VERSION_LESS 3.4.0) - project(pybind11) -else() - project(pybind11 CXX) -endif() - -# Check if pybind11 is being used directly or via add_subdirectory -set(PYBIND11_MASTER_PROJECT OFF) -if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) - set(PYBIND11_MASTER_PROJECT ON) -endif() - -option(PYBIND11_INSTALL "Install pybind11 header files?" ${PYBIND11_MASTER_PROJECT}) -option(PYBIND11_TEST "Build pybind11 test suite?" ${PYBIND11_MASTER_PROJECT}) - -list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}/tools") - -include(pybind11Tools) - -# Cache variables so pybind11_add_module can be used in parent projects -set(PYBIND11_INCLUDE_DIR "${CMAKE_CURRENT_LIST_DIR}/include" CACHE INTERNAL "") -set(PYTHON_INCLUDE_DIRS ${PYTHON_INCLUDE_DIRS} CACHE INTERNAL "") -set(PYTHON_LIBRARIES ${PYTHON_LIBRARIES} CACHE INTERNAL "") -set(PYTHON_MODULE_PREFIX ${PYTHON_MODULE_PREFIX} CACHE INTERNAL "") -set(PYTHON_MODULE_EXTENSION ${PYTHON_MODULE_EXTENSION} CACHE INTERNAL "") -set(PYTHON_VERSION_MAJOR ${PYTHON_VERSION_MAJOR} CACHE INTERNAL "") -set(PYTHON_VERSION_MINOR ${PYTHON_VERSION_MINOR} CACHE INTERNAL "") - -# NB: when adding a header don't forget to also add it to setup.py -set(PYBIND11_HEADERS - include/pybind11/detail/class.h - include/pybind11/detail/common.h - include/pybind11/detail/descr.h - include/pybind11/detail/init.h - include/pybind11/detail/internals.h - include/pybind11/detail/typeid.h - include/pybind11/attr.h - include/pybind11/buffer_info.h - include/pybind11/cast.h - include/pybind11/chrono.h - include/pybind11/common.h - include/pybind11/complex.h - include/pybind11/options.h - include/pybind11/eigen.h - include/pybind11/embed.h - include/pybind11/eval.h - include/pybind11/functional.h - include/pybind11/numpy.h - include/pybind11/operators.h - include/pybind11/pybind11.h - include/pybind11/pytypes.h - include/pybind11/stl.h - include/pybind11/stl_bind.h -) -string(REPLACE "include/" "${CMAKE_CURRENT_SOURCE_DIR}/include/" - PYBIND11_HEADERS "${PYBIND11_HEADERS}") - -if (PYBIND11_TEST) - add_subdirectory(tests) -endif() - -include(GNUInstallDirs) -include(CMakePackageConfigHelpers) - -# extract project version from source -file(STRINGS "${PYBIND11_INCLUDE_DIR}/pybind11/detail/common.h" pybind11_version_defines - REGEX "#define PYBIND11_VERSION_(MAJOR|MINOR|PATCH) ") -foreach(ver ${pybind11_version_defines}) - if (ver MATCHES "#define PYBIND11_VERSION_(MAJOR|MINOR|PATCH) +([^ ]+)$") - set(PYBIND11_VERSION_${CMAKE_MATCH_1} "${CMAKE_MATCH_2}" CACHE INTERNAL "") - endif() -endforeach() -set(${PROJECT_NAME}_VERSION ${PYBIND11_VERSION_MAJOR}.${PYBIND11_VERSION_MINOR}.${PYBIND11_VERSION_PATCH}) -message(STATUS "pybind11 v${${PROJECT_NAME}_VERSION}") - -option (USE_PYTHON_INCLUDE_DIR "Install pybind11 headers in Python include directory instead of default installation prefix" OFF) -if (USE_PYTHON_INCLUDE_DIR) - file(RELATIVE_PATH CMAKE_INSTALL_INCLUDEDIR ${CMAKE_INSTALL_PREFIX} ${PYTHON_INCLUDE_DIRS}) -endif() - -if(NOT (CMAKE_VERSION VERSION_LESS 3.0)) # CMake >= 3.0 - # Build an interface library target: - add_library(pybind11 INTERFACE) - add_library(pybind11::pybind11 ALIAS pybind11) # to match exported target - target_include_directories(pybind11 INTERFACE $ - $ - $) - target_compile_options(pybind11 INTERFACE $) - - add_library(module INTERFACE) - add_library(pybind11::module ALIAS module) - if(NOT MSVC) - target_compile_options(module INTERFACE -fvisibility=hidden) - endif() - target_link_libraries(module INTERFACE pybind11::pybind11) - if(WIN32 OR CYGWIN) - target_link_libraries(module INTERFACE $) - elseif(APPLE) - target_link_libraries(module INTERFACE "-undefined dynamic_lookup") - endif() - - add_library(embed INTERFACE) - add_library(pybind11::embed ALIAS embed) - target_link_libraries(embed INTERFACE pybind11::pybind11 $) -endif() - -if (PYBIND11_INSTALL) - install(DIRECTORY ${PYBIND11_INCLUDE_DIR}/pybind11 DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) - # GNUInstallDirs "DATADIR" wrong here; CMake search path wants "share". - set(PYBIND11_CMAKECONFIG_INSTALL_DIR "share/cmake/${PROJECT_NAME}" CACHE STRING "install path for pybind11Config.cmake") - - configure_package_config_file(tools/${PROJECT_NAME}Config.cmake.in - "${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}Config.cmake" - INSTALL_DESTINATION ${PYBIND11_CMAKECONFIG_INSTALL_DIR}) - # Remove CMAKE_SIZEOF_VOID_P from ConfigVersion.cmake since the library does - # not depend on architecture specific settings or libraries. - set(_PYBIND11_CMAKE_SIZEOF_VOID_P ${CMAKE_SIZEOF_VOID_P}) - unset(CMAKE_SIZEOF_VOID_P) - write_basic_package_version_file(${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake - VERSION ${${PROJECT_NAME}_VERSION} - COMPATIBILITY AnyNewerVersion) - set(CMAKE_SIZEOF_VOID_P ${_PYBIND11_CMAKE_SIZEOF_VOID_P}) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}Config.cmake - ${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake - tools/FindPythonLibsNew.cmake - tools/pybind11Tools.cmake - DESTINATION ${PYBIND11_CMAKECONFIG_INSTALL_DIR}) - - if(NOT (CMAKE_VERSION VERSION_LESS 3.0)) - if(NOT PYBIND11_EXPORT_NAME) - set(PYBIND11_EXPORT_NAME "${PROJECT_NAME}Targets") - endif() - - install(TARGETS pybind11 module embed - EXPORT "${PYBIND11_EXPORT_NAME}") - if(PYBIND11_MASTER_PROJECT) - install(EXPORT "${PYBIND11_EXPORT_NAME}" - NAMESPACE "${PROJECT_NAME}::" - DESTINATION ${PYBIND11_CMAKECONFIG_INSTALL_DIR}) - endif() - endif() -endif() diff --git a/pybind11/LICENSE b/pybind11/LICENSE deleted file mode 100644 index 6f15578..0000000 --- a/pybind11/LICENSE +++ /dev/null @@ -1,29 +0,0 @@ -Copyright (c) 2016 Wenzel Jakob , All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its contributors - may be used to endorse or promote products derived from this software - without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -Please also refer to the file CONTRIBUTING.md, which clarifies licensing of -external contributions to this project including patches, pull requests, etc. diff --git a/pybind11/README.md b/pybind11/README.md deleted file mode 100644 index 35d2d76..0000000 --- a/pybind11/README.md +++ /dev/null @@ -1,129 +0,0 @@ -![pybind11 logo](https://github.com/pybind/pybind11/raw/master/docs/pybind11-logo.png) - -# pybind11 — Seamless operability between C++11 and Python - -[![Documentation Status](https://readthedocs.org/projects/pybind11/badge/?version=master)](http://pybind11.readthedocs.org/en/master/?badge=master) -[![Documentation Status](https://readthedocs.org/projects/pybind11/badge/?version=stable)](http://pybind11.readthedocs.org/en/stable/?badge=stable) -[![Gitter chat](https://img.shields.io/gitter/room/gitterHQ/gitter.svg)](https://gitter.im/pybind/Lobby) -[![Build Status](https://travis-ci.org/pybind/pybind11.svg?branch=master)](https://travis-ci.org/pybind/pybind11) -[![Build status](https://ci.appveyor.com/api/projects/status/riaj54pn4h08xy40?svg=true)](https://ci.appveyor.com/project/wjakob/pybind11) - -**pybind11** is a lightweight header-only library that exposes C++ types in Python -and vice versa, mainly to create Python bindings of existing C++ code. Its -goals and syntax are similar to the excellent -[Boost.Python](http://www.boost.org/doc/libs/1_58_0/libs/python/doc/) library -by David Abrahams: to minimize boilerplate code in traditional extension -modules by inferring type information using compile-time introspection. - -The main issue with Boost.Python—and the reason for creating such a similar -project—is Boost. Boost is an enormously large and complex suite of utility -libraries that works with almost every C++ compiler in existence. This -compatibility has its cost: arcane template tricks and workarounds are -necessary to support the oldest and buggiest of compiler specimens. Now that -C++11-compatible compilers are widely available, this heavy machinery has -become an excessively large and unnecessary dependency. - -Think of this library as a tiny self-contained version of Boost.Python with -everything stripped away that isn't relevant for binding generation. Without -comments, the core header files only require ~4K lines of code and depend on -Python (2.7 or 3.x, or PyPy2.7 >= 5.7) and the C++ standard library. This -compact implementation was possible thanks to some of the new C++11 language -features (specifically: tuples, lambda functions and variadic templates). Since -its creation, this library has grown beyond Boost.Python in many ways, leading -to dramatically simpler binding code in many common situations. - -Tutorial and reference documentation is provided at -[http://pybind11.readthedocs.org/en/master](http://pybind11.readthedocs.org/en/master). -A PDF version of the manual is available -[here](https://media.readthedocs.org/pdf/pybind11/master/pybind11.pdf). - -## Core features -pybind11 can map the following core C++ features to Python - -- Functions accepting and returning custom data structures per value, reference, or pointer -- Instance methods and static methods -- Overloaded functions -- Instance attributes and static attributes -- Arbitrary exception types -- Enumerations -- Callbacks -- Iterators and ranges -- Custom operators -- Single and multiple inheritance -- STL data structures -- Smart pointers with reference counting like ``std::shared_ptr`` -- Internal references with correct reference counting -- C++ classes with virtual (and pure virtual) methods can be extended in Python - -## Goodies -In addition to the core functionality, pybind11 provides some extra goodies: - -- Python 2.7, 3.x, and PyPy (PyPy2.7 >= 5.7) are supported with an - implementation-agnostic interface. - -- It is possible to bind C++11 lambda functions with captured variables. The - lambda capture data is stored inside the resulting Python function object. - -- pybind11 uses C++11 move constructors and move assignment operators whenever - possible to efficiently transfer custom data types. - -- It's easy to expose the internal storage of custom data types through - Pythons' buffer protocols. This is handy e.g. for fast conversion between - C++ matrix classes like Eigen and NumPy without expensive copy operations. - -- pybind11 can automatically vectorize functions so that they are transparently - applied to all entries of one or more NumPy array arguments. - -- Python's slice-based access and assignment operations can be supported with - just a few lines of code. - -- Everything is contained in just a few header files; there is no need to link - against any additional libraries. - -- Binaries are generally smaller by a factor of at least 2 compared to - equivalent bindings generated by Boost.Python. A recent pybind11 conversion - of PyRosetta, an enormous Boost.Python binding project, - [reported](http://graylab.jhu.edu/RosettaCon2016/PyRosetta-4.pdf) a binary - size reduction of **5.4x** and compile time reduction by **5.8x**. - -- Function signatures are precomputed at compile time (using ``constexpr``), - leading to smaller binaries. - -- With little extra effort, C++ types can be pickled and unpickled similar to - regular Python objects. - -## Supported compilers - -1. Clang/LLVM 3.3 or newer (for Apple Xcode's clang, this is 5.0.0 or newer) -2. GCC 4.8 or newer -3. Microsoft Visual Studio 2015 Update 3 or newer -4. Intel C++ compiler 17 or newer (16 with pybind11 v2.0 and 15 with pybind11 v2.0 and a [workaround](https://github.com/pybind/pybind11/issues/276)) -5. Cygwin/GCC (tested on 2.5.1) - -## About - -This project was created by [Wenzel Jakob](http://rgl.epfl.ch/people/wjakob). -Significant features and/or improvements to the code were contributed by -Jonas Adler, -Lori A. Burns, -Sylvain Corlay, -Trent Houliston, -Axel Huebl, -@hulucc, -Sergey Lyskov -Johan Mabille, -Tomasz Miąsko, -Dean Moldovan, -Ben Pritchard, -Jason Rhinelander, -Boris Schäling, -Pim Schellart, -Henry Schreiner, -Ivan Smirnov, and -Patrick Stewart. - -### License - -pybind11 is provided under a BSD-style license that can be found in the -``LICENSE`` file. By using, distributing, or contributing to this project, -you agree to the terms and conditions of this license. diff --git a/pybind11/include/pybind11/attr.h b/pybind11/include/pybind11/attr.h deleted file mode 100644 index 6962d6f..0000000 --- a/pybind11/include/pybind11/attr.h +++ /dev/null @@ -1,493 +0,0 @@ -/* - pybind11/attr.h: Infrastructure for processing custom - type and function attributes - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "cast.h" - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) - -/// \addtogroup annotations -/// @{ - -/// Annotation for methods -struct is_method { handle class_; is_method(const handle &c) : class_(c) { } }; - -/// Annotation for operators -struct is_operator { }; - -/// Annotation for parent scope -struct scope { handle value; scope(const handle &s) : value(s) { } }; - -/// Annotation for documentation -struct doc { const char *value; doc(const char *value) : value(value) { } }; - -/// Annotation for function names -struct name { const char *value; name(const char *value) : value(value) { } }; - -/// Annotation indicating that a function is an overload associated with a given "sibling" -struct sibling { handle value; sibling(const handle &value) : value(value.ptr()) { } }; - -/// Annotation indicating that a class derives from another given type -template struct base { - PYBIND11_DEPRECATED("base() was deprecated in favor of specifying 'T' as a template argument to class_") - base() { } -}; - -/// Keep patient alive while nurse lives -template struct keep_alive { }; - -/// Annotation indicating that a class is involved in a multiple inheritance relationship -struct multiple_inheritance { }; - -/// Annotation which enables dynamic attributes, i.e. adds `__dict__` to a class -struct dynamic_attr { }; - -/// Annotation which enables the buffer protocol for a type -struct buffer_protocol { }; - -/// Annotation which requests that a special metaclass is created for a type -struct metaclass { - handle value; - - PYBIND11_DEPRECATED("py::metaclass() is no longer required. It's turned on by default now.") - metaclass() {} - - /// Override pybind11's default metaclass - explicit metaclass(handle value) : value(value) { } -}; - -/// Annotation that marks a class as local to the module: -struct module_local { const bool value; constexpr module_local(bool v = true) : value(v) { } }; - -/// Annotation to mark enums as an arithmetic type -struct arithmetic { }; - -/** \rst - A call policy which places one or more guard variables (``Ts...``) around the function call. - - For example, this definition: - - .. code-block:: cpp - - m.def("foo", foo, py::call_guard()); - - is equivalent to the following pseudocode: - - .. code-block:: cpp - - m.def("foo", [](args...) { - T scope_guard; - return foo(args...); // forwarded arguments - }); - \endrst */ -template struct call_guard; - -template <> struct call_guard<> { using type = detail::void_type; }; - -template -struct call_guard { - static_assert(std::is_default_constructible::value, - "The guard type must be default constructible"); - - using type = T; -}; - -template -struct call_guard { - struct type { - T guard{}; // Compose multiple guard types with left-to-right default-constructor order - typename call_guard::type next{}; - }; -}; - -/// @} annotations - -NAMESPACE_BEGIN(detail) -/* Forward declarations */ -enum op_id : int; -enum op_type : int; -struct undefined_t; -template struct op_; -inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret); - -/// Internal data structure which holds metadata about a keyword argument -struct argument_record { - const char *name; ///< Argument name - const char *descr; ///< Human-readable version of the argument value - handle value; ///< Associated Python object - bool convert : 1; ///< True if the argument is allowed to convert when loading - bool none : 1; ///< True if None is allowed when loading - - argument_record(const char *name, const char *descr, handle value, bool convert, bool none) - : name(name), descr(descr), value(value), convert(convert), none(none) { } -}; - -/// Internal data structure which holds metadata about a bound function (signature, overloads, etc.) -struct function_record { - function_record() - : is_constructor(false), is_new_style_constructor(false), is_stateless(false), - is_operator(false), has_args(false), has_kwargs(false), is_method(false) { } - - /// Function name - char *name = nullptr; /* why no C++ strings? They generate heavier code.. */ - - // User-specified documentation string - char *doc = nullptr; - - /// Human-readable version of the function signature - char *signature = nullptr; - - /// List of registered keyword arguments - std::vector args; - - /// Pointer to lambda function which converts arguments and performs the actual call - handle (*impl) (function_call &) = nullptr; - - /// Storage for the wrapped function pointer and captured data, if any - void *data[3] = { }; - - /// Pointer to custom destructor for 'data' (if needed) - void (*free_data) (function_record *ptr) = nullptr; - - /// Return value policy associated with this function - return_value_policy policy = return_value_policy::automatic; - - /// True if name == '__init__' - bool is_constructor : 1; - - /// True if this is a new-style `__init__` defined in `detail/init.h` - bool is_new_style_constructor : 1; - - /// True if this is a stateless function pointer - bool is_stateless : 1; - - /// True if this is an operator (__add__), etc. - bool is_operator : 1; - - /// True if the function has a '*args' argument - bool has_args : 1; - - /// True if the function has a '**kwargs' argument - bool has_kwargs : 1; - - /// True if this is a method - bool is_method : 1; - - /// Number of arguments (including py::args and/or py::kwargs, if present) - std::uint16_t nargs; - - /// Python method object - PyMethodDef *def = nullptr; - - /// Python handle to the parent scope (a class or a module) - handle scope; - - /// Python handle to the sibling function representing an overload chain - handle sibling; - - /// Pointer to next overload - function_record *next = nullptr; -}; - -/// Special data structure which (temporarily) holds metadata about a bound class -struct type_record { - PYBIND11_NOINLINE type_record() - : multiple_inheritance(false), dynamic_attr(false), buffer_protocol(false), - default_holder(true), module_local(false) { } - - /// Handle to the parent scope - handle scope; - - /// Name of the class - const char *name = nullptr; - - // Pointer to RTTI type_info data structure - const std::type_info *type = nullptr; - - /// How large is the underlying C++ type? - size_t type_size = 0; - - /// What is the alignment of the underlying C++ type? - size_t type_align = 0; - - /// How large is the type's holder? - size_t holder_size = 0; - - /// The global operator new can be overridden with a class-specific variant - void *(*operator_new)(size_t) = nullptr; - - /// Function pointer to class_<..>::init_instance - void (*init_instance)(instance *, const void *) = nullptr; - - /// Function pointer to class_<..>::dealloc - void (*dealloc)(detail::value_and_holder &) = nullptr; - - /// List of base classes of the newly created type - list bases; - - /// Optional docstring - const char *doc = nullptr; - - /// Custom metaclass (optional) - handle metaclass; - - /// Multiple inheritance marker - bool multiple_inheritance : 1; - - /// Does the class manage a __dict__? - bool dynamic_attr : 1; - - /// Does the class implement the buffer protocol? - bool buffer_protocol : 1; - - /// Is the default (unique_ptr) holder type used? - bool default_holder : 1; - - /// Is the class definition local to the module shared object? - bool module_local : 1; - - PYBIND11_NOINLINE void add_base(const std::type_info &base, void *(*caster)(void *)) { - auto base_info = detail::get_type_info(base, false); - if (!base_info) { - std::string tname(base.name()); - detail::clean_type_id(tname); - pybind11_fail("generic_type: type \"" + std::string(name) + - "\" referenced unknown base type \"" + tname + "\""); - } - - if (default_holder != base_info->default_holder) { - std::string tname(base.name()); - detail::clean_type_id(tname); - pybind11_fail("generic_type: type \"" + std::string(name) + "\" " + - (default_holder ? "does not have" : "has") + - " a non-default holder type while its base \"" + tname + "\" " + - (base_info->default_holder ? "does not" : "does")); - } - - bases.append((PyObject *) base_info->type); - - if (base_info->type->tp_dictoffset != 0) - dynamic_attr = true; - - if (caster) - base_info->implicit_casts.emplace_back(type, caster); - } -}; - -inline function_call::function_call(const function_record &f, handle p) : - func(f), parent(p) { - args.reserve(f.nargs); - args_convert.reserve(f.nargs); -} - -/// Tag for a new-style `__init__` defined in `detail/init.h` -struct is_new_style_constructor { }; - -/** - * Partial template specializations to process custom attributes provided to - * cpp_function_ and class_. These are either used to initialize the respective - * fields in the type_record and function_record data structures or executed at - * runtime to deal with custom call policies (e.g. keep_alive). - */ -template struct process_attribute; - -template struct process_attribute_default { - /// Default implementation: do nothing - static void init(const T &, function_record *) { } - static void init(const T &, type_record *) { } - static void precall(function_call &) { } - static void postcall(function_call &, handle) { } -}; - -/// Process an attribute specifying the function's name -template <> struct process_attribute : process_attribute_default { - static void init(const name &n, function_record *r) { r->name = const_cast(n.value); } -}; - -/// Process an attribute specifying the function's docstring -template <> struct process_attribute : process_attribute_default { - static void init(const doc &n, function_record *r) { r->doc = const_cast(n.value); } -}; - -/// Process an attribute specifying the function's docstring (provided as a C-style string) -template <> struct process_attribute : process_attribute_default { - static void init(const char *d, function_record *r) { r->doc = const_cast(d); } - static void init(const char *d, type_record *r) { r->doc = const_cast(d); } -}; -template <> struct process_attribute : process_attribute { }; - -/// Process an attribute indicating the function's return value policy -template <> struct process_attribute : process_attribute_default { - static void init(const return_value_policy &p, function_record *r) { r->policy = p; } -}; - -/// Process an attribute which indicates that this is an overloaded function associated with a given sibling -template <> struct process_attribute : process_attribute_default { - static void init(const sibling &s, function_record *r) { r->sibling = s.value; } -}; - -/// Process an attribute which indicates that this function is a method -template <> struct process_attribute : process_attribute_default { - static void init(const is_method &s, function_record *r) { r->is_method = true; r->scope = s.class_; } -}; - -/// Process an attribute which indicates the parent scope of a method -template <> struct process_attribute : process_attribute_default { - static void init(const scope &s, function_record *r) { r->scope = s.value; } -}; - -/// Process an attribute which indicates that this function is an operator -template <> struct process_attribute : process_attribute_default { - static void init(const is_operator &, function_record *r) { r->is_operator = true; } -}; - -template <> struct process_attribute : process_attribute_default { - static void init(const is_new_style_constructor &, function_record *r) { r->is_new_style_constructor = true; } -}; - -/// Process a keyword argument attribute (*without* a default value) -template <> struct process_attribute : process_attribute_default { - static void init(const arg &a, function_record *r) { - if (r->is_method && r->args.empty()) - r->args.emplace_back("self", nullptr, handle(), true /*convert*/, false /*none not allowed*/); - r->args.emplace_back(a.name, nullptr, handle(), !a.flag_noconvert, a.flag_none); - } -}; - -/// Process a keyword argument attribute (*with* a default value) -template <> struct process_attribute : process_attribute_default { - static void init(const arg_v &a, function_record *r) { - if (r->is_method && r->args.empty()) - r->args.emplace_back("self", nullptr /*descr*/, handle() /*parent*/, true /*convert*/, false /*none not allowed*/); - - if (!a.value) { -#if !defined(NDEBUG) - std::string descr("'"); - if (a.name) descr += std::string(a.name) + ": "; - descr += a.type + "'"; - if (r->is_method) { - if (r->name) - descr += " in method '" + (std::string) str(r->scope) + "." + (std::string) r->name + "'"; - else - descr += " in method of '" + (std::string) str(r->scope) + "'"; - } else if (r->name) { - descr += " in function '" + (std::string) r->name + "'"; - } - pybind11_fail("arg(): could not convert default argument " - + descr + " into a Python object (type not registered yet?)"); -#else - pybind11_fail("arg(): could not convert default argument " - "into a Python object (type not registered yet?). " - "Compile in debug mode for more information."); -#endif - } - r->args.emplace_back(a.name, a.descr, a.value.inc_ref(), !a.flag_noconvert, a.flag_none); - } -}; - -/// Process a parent class attribute. Single inheritance only (class_ itself already guarantees that) -template -struct process_attribute::value>> : process_attribute_default { - static void init(const handle &h, type_record *r) { r->bases.append(h); } -}; - -/// Process a parent class attribute (deprecated, does not support multiple inheritance) -template -struct process_attribute> : process_attribute_default> { - static void init(const base &, type_record *r) { r->add_base(typeid(T), nullptr); } -}; - -/// Process a multiple inheritance attribute -template <> -struct process_attribute : process_attribute_default { - static void init(const multiple_inheritance &, type_record *r) { r->multiple_inheritance = true; } -}; - -template <> -struct process_attribute : process_attribute_default { - static void init(const dynamic_attr &, type_record *r) { r->dynamic_attr = true; } -}; - -template <> -struct process_attribute : process_attribute_default { - static void init(const buffer_protocol &, type_record *r) { r->buffer_protocol = true; } -}; - -template <> -struct process_attribute : process_attribute_default { - static void init(const metaclass &m, type_record *r) { r->metaclass = m.value; } -}; - -template <> -struct process_attribute : process_attribute_default { - static void init(const module_local &l, type_record *r) { r->module_local = l.value; } -}; - -/// Process an 'arithmetic' attribute for enums (does nothing here) -template <> -struct process_attribute : process_attribute_default {}; - -template -struct process_attribute> : process_attribute_default> { }; - -/** - * Process a keep_alive call policy -- invokes keep_alive_impl during the - * pre-call handler if both Nurse, Patient != 0 and use the post-call handler - * otherwise - */ -template struct process_attribute> : public process_attribute_default> { - template = 0> - static void precall(function_call &call) { keep_alive_impl(Nurse, Patient, call, handle()); } - template = 0> - static void postcall(function_call &, handle) { } - template = 0> - static void precall(function_call &) { } - template = 0> - static void postcall(function_call &call, handle ret) { keep_alive_impl(Nurse, Patient, call, ret); } -}; - -/// Recursively iterate over variadic template arguments -template struct process_attributes { - static void init(const Args&... args, function_record *r) { - int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; - ignore_unused(unused); - } - static void init(const Args&... args, type_record *r) { - int unused[] = { 0, (process_attribute::type>::init(args, r), 0) ... }; - ignore_unused(unused); - } - static void precall(function_call &call) { - int unused[] = { 0, (process_attribute::type>::precall(call), 0) ... }; - ignore_unused(unused); - } - static void postcall(function_call &call, handle fn_ret) { - int unused[] = { 0, (process_attribute::type>::postcall(call, fn_ret), 0) ... }; - ignore_unused(unused); - } -}; - -template -using is_call_guard = is_instantiation; - -/// Extract the ``type`` from the first `call_guard` in `Extras...` (or `void_type` if none found) -template -using extract_guard_t = typename exactly_one_t, Extra...>::type; - -/// Check the number of named arguments at compile time -template ::value...), - size_t self = constexpr_sum(std::is_same::value...)> -constexpr bool expected_num_args(size_t nargs, bool has_args, bool has_kwargs) { - return named == 0 || (self + named + has_args + has_kwargs) == nargs; -} - -NAMESPACE_END(detail) -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/buffer_info.h b/pybind11/include/pybind11/buffer_info.h deleted file mode 100644 index 9f072fa..0000000 --- a/pybind11/include/pybind11/buffer_info.h +++ /dev/null @@ -1,108 +0,0 @@ -/* - pybind11/buffer_info.h: Python buffer object interface - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "detail/common.h" - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) - -/// Information record describing a Python buffer object -struct buffer_info { - void *ptr = nullptr; // Pointer to the underlying storage - ssize_t itemsize = 0; // Size of individual items in bytes - ssize_t size = 0; // Total number of entries - std::string format; // For homogeneous buffers, this should be set to format_descriptor::format() - ssize_t ndim = 0; // Number of dimensions - std::vector shape; // Shape of the tensor (1 entry per dimension) - std::vector strides; // Number of entries between adjacent entries (for each per dimension) - - buffer_info() { } - - buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, - detail::any_container shape_in, detail::any_container strides_in) - : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), - shape(std::move(shape_in)), strides(std::move(strides_in)) { - if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size()) - pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); - for (size_t i = 0; i < (size_t) ndim; ++i) - size *= shape[i]; - } - - template - buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in) - : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { } - - buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size) - : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { } - - template - buffer_info(T *ptr, ssize_t size) - : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { } - - explicit buffer_info(Py_buffer *view, bool ownview = true) - : buffer_info(view->buf, view->itemsize, view->format, view->ndim, - {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { - this->view = view; - this->ownview = ownview; - } - - buffer_info(const buffer_info &) = delete; - buffer_info& operator=(const buffer_info &) = delete; - - buffer_info(buffer_info &&other) { - (*this) = std::move(other); - } - - buffer_info& operator=(buffer_info &&rhs) { - ptr = rhs.ptr; - itemsize = rhs.itemsize; - size = rhs.size; - format = std::move(rhs.format); - ndim = rhs.ndim; - shape = std::move(rhs.shape); - strides = std::move(rhs.strides); - std::swap(view, rhs.view); - std::swap(ownview, rhs.ownview); - return *this; - } - - ~buffer_info() { - if (view && ownview) { PyBuffer_Release(view); delete view; } - } - -private: - struct private_ctr_tag { }; - - buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim, - detail::any_container &&shape_in, detail::any_container &&strides_in) - : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { } - - Py_buffer *view = nullptr; - bool ownview = false; -}; - -NAMESPACE_BEGIN(detail) - -template struct compare_buffer_info { - static bool compare(const buffer_info& b) { - return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T); - } -}; - -template struct compare_buffer_info::value>> { - static bool compare(const buffer_info& b) { - return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value || - ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) || - ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n"))); - } -}; - -NAMESPACE_END(detail) -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/cast.h b/pybind11/include/pybind11/cast.h deleted file mode 100644 index 605acb3..0000000 --- a/pybind11/include/pybind11/cast.h +++ /dev/null @@ -1,2132 +0,0 @@ -/* - pybind11/cast.h: Partial template specializations to cast between - C++ and Python types - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "pytypes.h" -#include "detail/typeid.h" -#include "detail/descr.h" -#include "detail/internals.h" -#include -#include -#include -#include - -#if defined(PYBIND11_CPP17) -# if defined(__has_include) -# if __has_include() -# define PYBIND11_HAS_STRING_VIEW -# endif -# elif defined(_MSC_VER) -# define PYBIND11_HAS_STRING_VIEW -# endif -#endif -#ifdef PYBIND11_HAS_STRING_VIEW -#include -#endif - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -/// A life support system for temporary objects created by `type_caster::load()`. -/// Adding a patient will keep it alive up until the enclosing function returns. -class loader_life_support { -public: - /// A new patient frame is created when a function is entered - loader_life_support() { - get_internals().loader_patient_stack.push_back(nullptr); - } - - /// ... and destroyed after it returns - ~loader_life_support() { - auto &stack = get_internals().loader_patient_stack; - if (stack.empty()) - pybind11_fail("loader_life_support: internal error"); - - auto ptr = stack.back(); - stack.pop_back(); - Py_CLEAR(ptr); - - // A heuristic to reduce the stack's capacity (e.g. after long recursive calls) - if (stack.capacity() > 16 && stack.size() != 0 && stack.capacity() / stack.size() > 2) - stack.shrink_to_fit(); - } - - /// This can only be used inside a pybind11-bound function, either by `argument_loader` - /// at argument preparation time or by `py::cast()` at execution time. - PYBIND11_NOINLINE static void add_patient(handle h) { - auto &stack = get_internals().loader_patient_stack; - if (stack.empty()) - throw cast_error("When called outside a bound function, py::cast() cannot " - "do Python -> C++ conversions which require the creation " - "of temporary values"); - - auto &list_ptr = stack.back(); - if (list_ptr == nullptr) { - list_ptr = PyList_New(1); - if (!list_ptr) - pybind11_fail("loader_life_support: error allocating list"); - PyList_SET_ITEM(list_ptr, 0, h.inc_ref().ptr()); - } else { - auto result = PyList_Append(list_ptr, h.ptr()); - if (result == -1) - pybind11_fail("loader_life_support: error adding patient"); - } - } -}; - -// Gets the cache entry for the given type, creating it if necessary. The return value is the pair -// returned by emplace, i.e. an iterator for the entry and a bool set to `true` if the entry was -// just created. -inline std::pair all_type_info_get_cache(PyTypeObject *type); - -// Populates a just-created cache entry. -PYBIND11_NOINLINE inline void all_type_info_populate(PyTypeObject *t, std::vector &bases) { - std::vector check; - for (handle parent : reinterpret_borrow(t->tp_bases)) - check.push_back((PyTypeObject *) parent.ptr()); - - auto const &type_dict = get_internals().registered_types_py; - for (size_t i = 0; i < check.size(); i++) { - auto type = check[i]; - // Ignore Python2 old-style class super type: - if (!PyType_Check((PyObject *) type)) continue; - - // Check `type` in the current set of registered python types: - auto it = type_dict.find(type); - if (it != type_dict.end()) { - // We found a cache entry for it, so it's either pybind-registered or has pre-computed - // pybind bases, but we have to make sure we haven't already seen the type(s) before: we - // want to follow Python/virtual C++ rules that there should only be one instance of a - // common base. - for (auto *tinfo : it->second) { - // NB: Could use a second set here, rather than doing a linear search, but since - // having a large number of immediate pybind11-registered types seems fairly - // unlikely, that probably isn't worthwhile. - bool found = false; - for (auto *known : bases) { - if (known == tinfo) { found = true; break; } - } - if (!found) bases.push_back(tinfo); - } - } - else if (type->tp_bases) { - // It's some python type, so keep follow its bases classes to look for one or more - // registered types - if (i + 1 == check.size()) { - // When we're at the end, we can pop off the current element to avoid growing - // `check` when adding just one base (which is typical--i.e. when there is no - // multiple inheritance) - check.pop_back(); - i--; - } - for (handle parent : reinterpret_borrow(type->tp_bases)) - check.push_back((PyTypeObject *) parent.ptr()); - } - } -} - -/** - * Extracts vector of type_info pointers of pybind-registered roots of the given Python type. Will - * be just 1 pybind type for the Python type of a pybind-registered class, or for any Python-side - * derived class that uses single inheritance. Will contain as many types as required for a Python - * class that uses multiple inheritance to inherit (directly or indirectly) from multiple - * pybind-registered classes. Will be empty if neither the type nor any base classes are - * pybind-registered. - * - * The value is cached for the lifetime of the Python type. - */ -inline const std::vector &all_type_info(PyTypeObject *type) { - auto ins = all_type_info_get_cache(type); - if (ins.second) - // New cache entry: populate it - all_type_info_populate(type, ins.first->second); - - return ins.first->second; -} - -/** - * Gets a single pybind11 type info for a python type. Returns nullptr if neither the type nor any - * ancestors are pybind11-registered. Throws an exception if there are multiple bases--use - * `all_type_info` instead if you want to support multiple bases. - */ -PYBIND11_NOINLINE inline detail::type_info* get_type_info(PyTypeObject *type) { - auto &bases = all_type_info(type); - if (bases.size() == 0) - return nullptr; - if (bases.size() > 1) - pybind11_fail("pybind11::detail::get_type_info: type has multiple pybind11-registered bases"); - return bases.front(); -} - -inline detail::type_info *get_local_type_info(const std::type_index &tp) { - auto &locals = registered_local_types_cpp(); - auto it = locals.find(tp); - if (it != locals.end()) - return it->second; - return nullptr; -} - -inline detail::type_info *get_global_type_info(const std::type_index &tp) { - auto &types = get_internals().registered_types_cpp; - auto it = types.find(tp); - if (it != types.end()) - return it->second; - return nullptr; -} - -/// Return the type info for a given C++ type; on lookup failure can either throw or return nullptr. -PYBIND11_NOINLINE inline detail::type_info *get_type_info(const std::type_index &tp, - bool throw_if_missing = false) { - if (auto ltype = get_local_type_info(tp)) - return ltype; - if (auto gtype = get_global_type_info(tp)) - return gtype; - - if (throw_if_missing) { - std::string tname = tp.name(); - detail::clean_type_id(tname); - pybind11_fail("pybind11::detail::get_type_info: unable to find type info for \"" + tname + "\""); - } - return nullptr; -} - -PYBIND11_NOINLINE inline handle get_type_handle(const std::type_info &tp, bool throw_if_missing) { - detail::type_info *type_info = get_type_info(tp, throw_if_missing); - return handle(type_info ? ((PyObject *) type_info->type) : nullptr); -} - -struct value_and_holder { - instance *inst = nullptr; - size_t index = 0u; - const detail::type_info *type = nullptr; - void **vh = nullptr; - - // Main constructor for a found value/holder: - value_and_holder(instance *i, const detail::type_info *type, size_t vpos, size_t index) : - inst{i}, index{index}, type{type}, - vh{inst->simple_layout ? inst->simple_value_holder : &inst->nonsimple.values_and_holders[vpos]} - {} - - // Default constructor (used to signal a value-and-holder not found by get_value_and_holder()) - value_and_holder() {} - - // Used for past-the-end iterator - value_and_holder(size_t index) : index{index} {} - - template V *&value_ptr() const { - return reinterpret_cast(vh[0]); - } - // True if this `value_and_holder` has a non-null value pointer - explicit operator bool() const { return value_ptr(); } - - template H &holder() const { - return reinterpret_cast(vh[1]); - } - bool holder_constructed() const { - return inst->simple_layout - ? inst->simple_holder_constructed - : inst->nonsimple.status[index] & instance::status_holder_constructed; - } - void set_holder_constructed(bool v = true) { - if (inst->simple_layout) - inst->simple_holder_constructed = v; - else if (v) - inst->nonsimple.status[index] |= instance::status_holder_constructed; - else - inst->nonsimple.status[index] &= (uint8_t) ~instance::status_holder_constructed; - } - bool instance_registered() const { - return inst->simple_layout - ? inst->simple_instance_registered - : inst->nonsimple.status[index] & instance::status_instance_registered; - } - void set_instance_registered(bool v = true) { - if (inst->simple_layout) - inst->simple_instance_registered = v; - else if (v) - inst->nonsimple.status[index] |= instance::status_instance_registered; - else - inst->nonsimple.status[index] &= (uint8_t) ~instance::status_instance_registered; - } -}; - -// Container for accessing and iterating over an instance's values/holders -struct values_and_holders { -private: - instance *inst; - using type_vec = std::vector; - const type_vec &tinfo; - -public: - values_and_holders(instance *inst) : inst{inst}, tinfo(all_type_info(Py_TYPE(inst))) {} - - struct iterator { - private: - instance *inst = nullptr; - const type_vec *types = nullptr; - value_and_holder curr; - friend struct values_and_holders; - iterator(instance *inst, const type_vec *tinfo) - : inst{inst}, types{tinfo}, - curr(inst /* instance */, - types->empty() ? nullptr : (*types)[0] /* type info */, - 0, /* vpos: (non-simple types only): the first vptr comes first */ - 0 /* index */) - {} - // Past-the-end iterator: - iterator(size_t end) : curr(end) {} - public: - bool operator==(const iterator &other) { return curr.index == other.curr.index; } - bool operator!=(const iterator &other) { return curr.index != other.curr.index; } - iterator &operator++() { - if (!inst->simple_layout) - curr.vh += 1 + (*types)[curr.index]->holder_size_in_ptrs; - ++curr.index; - curr.type = curr.index < types->size() ? (*types)[curr.index] : nullptr; - return *this; - } - value_and_holder &operator*() { return curr; } - value_and_holder *operator->() { return &curr; } - }; - - iterator begin() { return iterator(inst, &tinfo); } - iterator end() { return iterator(tinfo.size()); } - - iterator find(const type_info *find_type) { - auto it = begin(), endit = end(); - while (it != endit && it->type != find_type) ++it; - return it; - } - - size_t size() { return tinfo.size(); } -}; - -/** - * Extracts C++ value and holder pointer references from an instance (which may contain multiple - * values/holders for python-side multiple inheritance) that match the given type. Throws an error - * if the given type (or ValueType, if omitted) is not a pybind11 base of the given instance. If - * `find_type` is omitted (or explicitly specified as nullptr) the first value/holder are returned, - * regardless of type (and the resulting .type will be nullptr). - * - * The returned object should be short-lived: in particular, it must not outlive the called-upon - * instance. - */ -PYBIND11_NOINLINE inline value_and_holder instance::get_value_and_holder(const type_info *find_type /*= nullptr default in common.h*/, bool throw_if_missing /*= true in common.h*/) { - // Optimize common case: - if (!find_type || Py_TYPE(this) == find_type->type) - return value_and_holder(this, find_type, 0, 0); - - detail::values_and_holders vhs(this); - auto it = vhs.find(find_type); - if (it != vhs.end()) - return *it; - - if (!throw_if_missing) - return value_and_holder(); - -#if defined(NDEBUG) - pybind11_fail("pybind11::detail::instance::get_value_and_holder: " - "type is not a pybind11 base of the given instance " - "(compile in debug mode for type details)"); -#else - pybind11_fail("pybind11::detail::instance::get_value_and_holder: `" + - std::string(find_type->type->tp_name) + "' is not a pybind11 base of the given `" + - std::string(Py_TYPE(this)->tp_name) + "' instance"); -#endif -} - -PYBIND11_NOINLINE inline void instance::allocate_layout() { - auto &tinfo = all_type_info(Py_TYPE(this)); - - const size_t n_types = tinfo.size(); - - if (n_types == 0) - pybind11_fail("instance allocation failed: new instance has no pybind11-registered base types"); - - simple_layout = - n_types == 1 && tinfo.front()->holder_size_in_ptrs <= instance_simple_holder_in_ptrs(); - - // Simple path: no python-side multiple inheritance, and a small-enough holder - if (simple_layout) { - simple_value_holder[0] = nullptr; - simple_holder_constructed = false; - simple_instance_registered = false; - } - else { // multiple base types or a too-large holder - // Allocate space to hold: [v1*][h1][v2*][h2]...[bb...] where [vN*] is a value pointer, - // [hN] is the (uninitialized) holder instance for value N, and [bb...] is a set of bool - // values that tracks whether each associated holder has been initialized. Each [block] is - // padded, if necessary, to an integer multiple of sizeof(void *). - size_t space = 0; - for (auto t : tinfo) { - space += 1; // value pointer - space += t->holder_size_in_ptrs; // holder instance - } - size_t flags_at = space; - space += size_in_ptrs(n_types); // status bytes (holder_constructed and instance_registered) - - // Allocate space for flags, values, and holders, and initialize it to 0 (flags and values, - // in particular, need to be 0). Use Python's memory allocation functions: in Python 3.6 - // they default to using pymalloc, which is designed to be efficient for small allocations - // like the one we're doing here; in earlier versions (and for larger allocations) they are - // just wrappers around malloc. -#if PY_VERSION_HEX >= 0x03050000 - nonsimple.values_and_holders = (void **) PyMem_Calloc(space, sizeof(void *)); - if (!nonsimple.values_and_holders) throw std::bad_alloc(); -#else - nonsimple.values_and_holders = (void **) PyMem_New(void *, space); - if (!nonsimple.values_and_holders) throw std::bad_alloc(); - std::memset(nonsimple.values_and_holders, 0, space * sizeof(void *)); -#endif - nonsimple.status = reinterpret_cast(&nonsimple.values_and_holders[flags_at]); - } - owned = true; -} - -PYBIND11_NOINLINE inline void instance::deallocate_layout() { - if (!simple_layout) - PyMem_Free(nonsimple.values_and_holders); -} - -PYBIND11_NOINLINE inline bool isinstance_generic(handle obj, const std::type_info &tp) { - handle type = detail::get_type_handle(tp, false); - if (!type) - return false; - return isinstance(obj, type); -} - -PYBIND11_NOINLINE inline std::string error_string() { - if (!PyErr_Occurred()) { - PyErr_SetString(PyExc_RuntimeError, "Unknown internal error occurred"); - return "Unknown internal error occurred"; - } - - error_scope scope; // Preserve error state - - std::string errorString; - if (scope.type) { - errorString += handle(scope.type).attr("__name__").cast(); - errorString += ": "; - } - if (scope.value) - errorString += (std::string) str(scope.value); - - PyErr_NormalizeException(&scope.type, &scope.value, &scope.trace); - -#if PY_MAJOR_VERSION >= 3 - if (scope.trace != nullptr) - PyException_SetTraceback(scope.value, scope.trace); -#endif - -#if !defined(PYPY_VERSION) - if (scope.trace) { - PyTracebackObject *trace = (PyTracebackObject *) scope.trace; - - /* Get the deepest trace possible */ - while (trace->tb_next) - trace = trace->tb_next; - - PyFrameObject *frame = trace->tb_frame; - errorString += "\n\nAt:\n"; - while (frame) { - int lineno = PyFrame_GetLineNumber(frame); - errorString += - " " + handle(frame->f_code->co_filename).cast() + - "(" + std::to_string(lineno) + "): " + - handle(frame->f_code->co_name).cast() + "\n"; - frame = frame->f_back; - } - } -#endif - - return errorString; -} - -PYBIND11_NOINLINE inline handle get_object_handle(const void *ptr, const detail::type_info *type ) { - auto &instances = get_internals().registered_instances; - auto range = instances.equal_range(ptr); - for (auto it = range.first; it != range.second; ++it) { - for (auto vh : values_and_holders(it->second)) { - if (vh.type == type) - return handle((PyObject *) it->second); - } - } - return handle(); -} - -inline PyThreadState *get_thread_state_unchecked() { -#if defined(PYPY_VERSION) - return PyThreadState_GET(); -#elif PY_VERSION_HEX < 0x03000000 - return _PyThreadState_Current; -#elif PY_VERSION_HEX < 0x03050000 - return (PyThreadState*) _Py_atomic_load_relaxed(&_PyThreadState_Current); -#elif PY_VERSION_HEX < 0x03050200 - return (PyThreadState*) _PyThreadState_Current.value; -#else - return _PyThreadState_UncheckedGet(); -#endif -} - -// Forward declarations -inline void keep_alive_impl(handle nurse, handle patient); -inline PyObject *make_new_instance(PyTypeObject *type); - -class type_caster_generic { -public: - PYBIND11_NOINLINE type_caster_generic(const std::type_info &type_info) - : typeinfo(get_type_info(type_info)), cpptype(&type_info) { } - - type_caster_generic(const type_info *typeinfo) - : typeinfo(typeinfo), cpptype(typeinfo ? typeinfo->cpptype : nullptr) { } - - bool load(handle src, bool convert) { - return load_impl(src, convert); - } - - PYBIND11_NOINLINE static handle cast(const void *_src, return_value_policy policy, handle parent, - const detail::type_info *tinfo, - void *(*copy_constructor)(const void *), - void *(*move_constructor)(const void *), - const void *existing_holder = nullptr) { - if (!tinfo) // no type info: error will be set already - return handle(); - - void *src = const_cast(_src); - if (src == nullptr) - return none().release(); - - auto it_instances = get_internals().registered_instances.equal_range(src); - for (auto it_i = it_instances.first; it_i != it_instances.second; ++it_i) { - for (auto instance_type : detail::all_type_info(Py_TYPE(it_i->second))) { - if (instance_type && same_type(*instance_type->cpptype, *tinfo->cpptype)) - return handle((PyObject *) it_i->second).inc_ref(); - } - } - - auto inst = reinterpret_steal(make_new_instance(tinfo->type)); - auto wrapper = reinterpret_cast(inst.ptr()); - wrapper->owned = false; - void *&valueptr = values_and_holders(wrapper).begin()->value_ptr(); - - switch (policy) { - case return_value_policy::automatic: - case return_value_policy::take_ownership: - valueptr = src; - wrapper->owned = true; - break; - - case return_value_policy::automatic_reference: - case return_value_policy::reference: - valueptr = src; - wrapper->owned = false; - break; - - case return_value_policy::copy: - if (copy_constructor) - valueptr = copy_constructor(src); - else - throw cast_error("return_value_policy = copy, but the " - "object is non-copyable!"); - wrapper->owned = true; - break; - - case return_value_policy::move: - if (move_constructor) - valueptr = move_constructor(src); - else if (copy_constructor) - valueptr = copy_constructor(src); - else - throw cast_error("return_value_policy = move, but the " - "object is neither movable nor copyable!"); - wrapper->owned = true; - break; - - case return_value_policy::reference_internal: - valueptr = src; - wrapper->owned = false; - keep_alive_impl(inst, parent); - break; - - default: - throw cast_error("unhandled return_value_policy: should not happen!"); - } - - tinfo->init_instance(wrapper, existing_holder); - - return inst.release(); - } - - // Base methods for generic caster; there are overridden in copyable_holder_caster - void load_value(value_and_holder &&v_h) { - auto *&vptr = v_h.value_ptr(); - // Lazy allocation for unallocated values: - if (vptr == nullptr) { - auto *type = v_h.type ? v_h.type : typeinfo; - if (type->operator_new) { - vptr = type->operator_new(type->type_size); - } else { - #if defined(PYBIND11_CPP17) - if (type->type_align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) - vptr = ::operator new(type->type_size, - (std::align_val_t) type->type_align); - else - #endif - vptr = ::operator new(type->type_size); - } - } - value = vptr; - } - bool try_implicit_casts(handle src, bool convert) { - for (auto &cast : typeinfo->implicit_casts) { - type_caster_generic sub_caster(*cast.first); - if (sub_caster.load(src, convert)) { - value = cast.second(sub_caster.value); - return true; - } - } - return false; - } - bool try_direct_conversions(handle src) { - for (auto &converter : *typeinfo->direct_conversions) { - if (converter(src.ptr(), value)) - return true; - } - return false; - } - void check_holder_compat() {} - - PYBIND11_NOINLINE static void *local_load(PyObject *src, const type_info *ti) { - auto caster = type_caster_generic(ti); - if (caster.load(src, false)) - return caster.value; - return nullptr; - } - - /// Try to load with foreign typeinfo, if available. Used when there is no - /// native typeinfo, or when the native one wasn't able to produce a value. - PYBIND11_NOINLINE bool try_load_foreign_module_local(handle src) { - constexpr auto *local_key = PYBIND11_MODULE_LOCAL_ID; - const auto pytype = src.get_type(); - if (!hasattr(pytype, local_key)) - return false; - - type_info *foreign_typeinfo = reinterpret_borrow(getattr(pytype, local_key)); - // Only consider this foreign loader if actually foreign and is a loader of the correct cpp type - if (foreign_typeinfo->module_local_load == &local_load - || (cpptype && !same_type(*cpptype, *foreign_typeinfo->cpptype))) - return false; - - if (auto result = foreign_typeinfo->module_local_load(src.ptr(), foreign_typeinfo)) { - value = result; - return true; - } - return false; - } - - // Implementation of `load`; this takes the type of `this` so that it can dispatch the relevant - // bits of code between here and copyable_holder_caster where the two classes need different - // logic (without having to resort to virtual inheritance). - template - PYBIND11_NOINLINE bool load_impl(handle src, bool convert) { - if (!src) return false; - if (!typeinfo) return try_load_foreign_module_local(src); - if (src.is_none()) { - // Defer accepting None to other overloads (if we aren't in convert mode): - if (!convert) return false; - value = nullptr; - return true; - } - - auto &this_ = static_cast(*this); - this_.check_holder_compat(); - - PyTypeObject *srctype = Py_TYPE(src.ptr()); - - // Case 1: If src is an exact type match for the target type then we can reinterpret_cast - // the instance's value pointer to the target type: - if (srctype == typeinfo->type) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); - return true; - } - // Case 2: We have a derived class - else if (PyType_IsSubtype(srctype, typeinfo->type)) { - auto &bases = all_type_info(srctype); - bool no_cpp_mi = typeinfo->simple_type; - - // Case 2a: the python type is a Python-inherited derived class that inherits from just - // one simple (no MI) pybind11 class, or is an exact match, so the C++ instance is of - // the right type and we can use reinterpret_cast. - // (This is essentially the same as case 2b, but because not using multiple inheritance - // is extremely common, we handle it specially to avoid the loop iterator and type - // pointer lookup overhead) - if (bases.size() == 1 && (no_cpp_mi || bases.front()->type == typeinfo->type)) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder()); - return true; - } - // Case 2b: the python type inherits from multiple C++ bases. Check the bases to see if - // we can find an exact match (or, for a simple C++ type, an inherited match); if so, we - // can safely reinterpret_cast to the relevant pointer. - else if (bases.size() > 1) { - for (auto base : bases) { - if (no_cpp_mi ? PyType_IsSubtype(base->type, typeinfo->type) : base->type == typeinfo->type) { - this_.load_value(reinterpret_cast(src.ptr())->get_value_and_holder(base)); - return true; - } - } - } - - // Case 2c: C++ multiple inheritance is involved and we couldn't find an exact type match - // in the registered bases, above, so try implicit casting (needed for proper C++ casting - // when MI is involved). - if (this_.try_implicit_casts(src, convert)) - return true; - } - - // Perform an implicit conversion - if (convert) { - for (auto &converter : typeinfo->implicit_conversions) { - auto temp = reinterpret_steal(converter(src.ptr(), typeinfo->type)); - if (load_impl(temp, false)) { - loader_life_support::add_patient(temp); - return true; - } - } - if (this_.try_direct_conversions(src)) - return true; - } - - // Failed to match local typeinfo. Try again with global. - if (typeinfo->module_local) { - if (auto gtype = get_global_type_info(*typeinfo->cpptype)) { - typeinfo = gtype; - return load(src, false); - } - } - - // Global typeinfo has precedence over foreign module_local - return try_load_foreign_module_local(src); - } - - - // Called to do type lookup and wrap the pointer and type in a pair when a dynamic_cast - // isn't needed or can't be used. If the type is unknown, sets the error and returns a pair - // with .second = nullptr. (p.first = nullptr is not an error: it becomes None). - PYBIND11_NOINLINE static std::pair src_and_type( - const void *src, const std::type_info &cast_type, const std::type_info *rtti_type = nullptr) { - if (auto *tpi = get_type_info(cast_type)) - return {src, const_cast(tpi)}; - - // Not found, set error: - std::string tname = rtti_type ? rtti_type->name() : cast_type.name(); - detail::clean_type_id(tname); - std::string msg = "Unregistered type : " + tname; - PyErr_SetString(PyExc_TypeError, msg.c_str()); - return {nullptr, nullptr}; - } - - const type_info *typeinfo = nullptr; - const std::type_info *cpptype = nullptr; - void *value = nullptr; -}; - -/** - * Determine suitable casting operator for pointer-or-lvalue-casting type casters. The type caster - * needs to provide `operator T*()` and `operator T&()` operators. - * - * If the type supports moving the value away via an `operator T&&() &&` method, it should use - * `movable_cast_op_type` instead. - */ -template -using cast_op_type = - conditional_t>::value, - typename std::add_pointer>::type, - typename std::add_lvalue_reference>::type>; - -/** - * Determine suitable casting operator for a type caster with a movable value. Such a type caster - * needs to provide `operator T*()`, `operator T&()`, and `operator T&&() &&`. The latter will be - * called in appropriate contexts where the value can be moved rather than copied. - * - * These operator are automatically provided when using the PYBIND11_TYPE_CASTER macro. - */ -template -using movable_cast_op_type = - conditional_t::type>::value, - typename std::add_pointer>::type, - conditional_t::value, - typename std::add_rvalue_reference>::type, - typename std::add_lvalue_reference>::type>>; - -// std::is_copy_constructible isn't quite enough: it lets std::vector (and similar) through when -// T is non-copyable, but code containing such a copy constructor fails to actually compile. -template struct is_copy_constructible : std::is_copy_constructible {}; - -// Specialization for types that appear to be copy constructible but also look like stl containers -// (we specifically check for: has `value_type` and `reference` with `reference = value_type&`): if -// so, copy constructability depends on whether the value_type is copy constructible. -template struct is_copy_constructible, - std::is_same, - // Avoid infinite recursion - negation> - >::value>> : is_copy_constructible {}; - -#if !defined(PYBIND11_CPP17) -// Likewise for std::pair before C++17 (which mandates that the copy constructor not exist when the -// two types aren't themselves copy constructible). -template struct is_copy_constructible> - : all_of, is_copy_constructible> {}; -#endif - -NAMESPACE_END(detail) - -// polymorphic_type_hook::get(src, tinfo) determines whether the object pointed -// to by `src` actually is an instance of some class derived from `itype`. -// If so, it sets `tinfo` to point to the std::type_info representing that derived -// type, and returns a pointer to the start of the most-derived object of that type -// (in which `src` is a subobject; this will be the same address as `src` in most -// single inheritance cases). If not, or if `src` is nullptr, it simply returns `src` -// and leaves `tinfo` at its default value of nullptr. -// -// The default polymorphic_type_hook just returns src. A specialization for polymorphic -// types determines the runtime type of the passed object and adjusts the this-pointer -// appropriately via dynamic_cast. This is what enables a C++ Animal* to appear -// to Python as a Dog (if Dog inherits from Animal, Animal is polymorphic, Dog is -// registered with pybind11, and this Animal is in fact a Dog). -// -// You may specialize polymorphic_type_hook yourself for types that want to appear -// polymorphic to Python but do not use C++ RTTI. (This is a not uncommon pattern -// in performance-sensitive applications, used most notably in LLVM.) -template -struct polymorphic_type_hook -{ - static const void *get(const itype *src, const std::type_info*&) { return src; } -}; -template -struct polymorphic_type_hook::value>> -{ - static const void *get(const itype *src, const std::type_info*& type) { - type = src ? &typeid(*src) : nullptr; - return dynamic_cast(src); - } -}; - -NAMESPACE_BEGIN(detail) - -/// Generic type caster for objects stored on the heap -template class type_caster_base : public type_caster_generic { - using itype = intrinsic_t; - -public: - static constexpr auto name = _(); - - type_caster_base() : type_caster_base(typeid(type)) { } - explicit type_caster_base(const std::type_info &info) : type_caster_generic(info) { } - - static handle cast(const itype &src, return_value_policy policy, handle parent) { - if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) - policy = return_value_policy::copy; - return cast(&src, policy, parent); - } - - static handle cast(itype &&src, return_value_policy, handle parent) { - return cast(&src, return_value_policy::move, parent); - } - - // Returns a (pointer, type_info) pair taking care of necessary type lookup for a - // polymorphic type (using RTTI by default, but can be overridden by specializing - // polymorphic_type_hook). If the instance isn't derived, returns the base version. - static std::pair src_and_type(const itype *src) { - auto &cast_type = typeid(itype); - const std::type_info *instance_type = nullptr; - const void *vsrc = polymorphic_type_hook::get(src, instance_type); - if (instance_type && !same_type(cast_type, *instance_type)) { - // This is a base pointer to a derived type. If the derived type is registered - // with pybind11, we want to make the full derived object available. - // In the typical case where itype is polymorphic, we get the correct - // derived pointer (which may be != base pointer) by a dynamic_cast to - // most derived type. If itype is not polymorphic, we won't get here - // except via a user-provided specialization of polymorphic_type_hook, - // and the user has promised that no this-pointer adjustment is - // required in that case, so it's OK to use static_cast. - if (const auto *tpi = get_type_info(*instance_type)) - return {vsrc, tpi}; - } - // Otherwise we have either a nullptr, an `itype` pointer, or an unknown derived pointer, so - // don't do a cast - return type_caster_generic::src_and_type(src, cast_type, instance_type); - } - - static handle cast(const itype *src, return_value_policy policy, handle parent) { - auto st = src_and_type(src); - return type_caster_generic::cast( - st.first, policy, parent, st.second, - make_copy_constructor(src), make_move_constructor(src)); - } - - static handle cast_holder(const itype *src, const void *holder) { - auto st = src_and_type(src); - return type_caster_generic::cast( - st.first, return_value_policy::take_ownership, {}, st.second, - nullptr, nullptr, holder); - } - - template using cast_op_type = detail::cast_op_type; - - operator itype*() { return (type *) value; } - operator itype&() { if (!value) throw reference_cast_error(); return *((itype *) value); } - -protected: - using Constructor = void *(*)(const void *); - - /* Only enabled when the types are {copy,move}-constructible *and* when the type - does not have a private operator new implementation. */ - template ::value>> - static auto make_copy_constructor(const T *x) -> decltype(new T(*x), Constructor{}) { - return [](const void *arg) -> void * { - return new T(*reinterpret_cast(arg)); - }; - } - - template ::value>> - static auto make_move_constructor(const T *x) -> decltype(new T(std::move(*const_cast(x))), Constructor{}) { - return [](const void *arg) -> void * { - return new T(std::move(*const_cast(reinterpret_cast(arg)))); - }; - } - - static Constructor make_copy_constructor(...) { return nullptr; } - static Constructor make_move_constructor(...) { return nullptr; } -}; - -template class type_caster : public type_caster_base { }; -template using make_caster = type_caster>; - -// Shortcut for calling a caster's `cast_op_type` cast operator for casting a type_caster to a T -template typename make_caster::template cast_op_type cast_op(make_caster &caster) { - return caster.operator typename make_caster::template cast_op_type(); -} -template typename make_caster::template cast_op_type::type> -cast_op(make_caster &&caster) { - return std::move(caster).operator - typename make_caster::template cast_op_type::type>(); -} - -template class type_caster> { -private: - using caster_t = make_caster; - caster_t subcaster; - using subcaster_cast_op_type = typename caster_t::template cast_op_type; - static_assert(std::is_same::type &, subcaster_cast_op_type>::value, - "std::reference_wrapper caster requires T to have a caster with an `T &` operator"); -public: - bool load(handle src, bool convert) { return subcaster.load(src, convert); } - static constexpr auto name = caster_t::name; - static handle cast(const std::reference_wrapper &src, return_value_policy policy, handle parent) { - // It is definitely wrong to take ownership of this pointer, so mask that rvp - if (policy == return_value_policy::take_ownership || policy == return_value_policy::automatic) - policy = return_value_policy::automatic_reference; - return caster_t::cast(&src.get(), policy, parent); - } - template using cast_op_type = std::reference_wrapper; - operator std::reference_wrapper() { return subcaster.operator subcaster_cast_op_type&(); } -}; - -#define PYBIND11_TYPE_CASTER(type, py_name) \ - protected: \ - type value; \ - public: \ - static constexpr auto name = py_name; \ - template >::value, int> = 0> \ - static handle cast(T_ *src, return_value_policy policy, handle parent) { \ - if (!src) return none().release(); \ - if (policy == return_value_policy::take_ownership) { \ - auto h = cast(std::move(*src), policy, parent); delete src; return h; \ - } else { \ - return cast(*src, policy, parent); \ - } \ - } \ - operator type*() { return &value; } \ - operator type&() { return value; } \ - operator type&&() && { return std::move(value); } \ - template using cast_op_type = pybind11::detail::movable_cast_op_type - - -template using is_std_char_type = any_of< - std::is_same, /* std::string */ - std::is_same, /* std::u16string */ - std::is_same, /* std::u32string */ - std::is_same /* std::wstring */ ->; - -template -struct type_caster::value && !is_std_char_type::value>> { - using _py_type_0 = conditional_t; - using _py_type_1 = conditional_t::value, _py_type_0, typename std::make_unsigned<_py_type_0>::type>; - using py_type = conditional_t::value, double, _py_type_1>; -public: - - bool load(handle src, bool convert) { - py_type py_value; - - if (!src) - return false; - - if (std::is_floating_point::value) { - if (convert || PyFloat_Check(src.ptr())) - py_value = (py_type) PyFloat_AsDouble(src.ptr()); - else - return false; - } else if (PyFloat_Check(src.ptr())) { - return false; - } else if (std::is_unsigned::value) { - py_value = as_unsigned(src.ptr()); - } else { // signed integer: - py_value = sizeof(T) <= sizeof(long) - ? (py_type) PyLong_AsLong(src.ptr()) - : (py_type) PYBIND11_LONG_AS_LONGLONG(src.ptr()); - } - - bool py_err = py_value == (py_type) -1 && PyErr_Occurred(); - - // Protect std::numeric_limits::min/max with parentheses - if (py_err || (std::is_integral::value && sizeof(py_type) != sizeof(T) && - (py_value < (py_type) (std::numeric_limits::min)() || - py_value > (py_type) (std::numeric_limits::max)()))) { - bool type_error = py_err && PyErr_ExceptionMatches( -#if PY_VERSION_HEX < 0x03000000 && !defined(PYPY_VERSION) - PyExc_SystemError -#else - PyExc_TypeError -#endif - ); - PyErr_Clear(); - if (type_error && convert && PyNumber_Check(src.ptr())) { - auto tmp = reinterpret_steal(std::is_floating_point::value - ? PyNumber_Float(src.ptr()) - : PyNumber_Long(src.ptr())); - PyErr_Clear(); - return load(tmp, false); - } - return false; - } - - value = (T) py_value; - return true; - } - - template - static typename std::enable_if::value, handle>::type - cast(U src, return_value_policy /* policy */, handle /* parent */) { - return PyFloat_FromDouble((double) src); - } - - template - static typename std::enable_if::value && std::is_signed::value && (sizeof(U) <= sizeof(long)), handle>::type - cast(U src, return_value_policy /* policy */, handle /* parent */) { - return PYBIND11_LONG_FROM_SIGNED((long) src); - } - - template - static typename std::enable_if::value && std::is_unsigned::value && (sizeof(U) <= sizeof(unsigned long)), handle>::type - cast(U src, return_value_policy /* policy */, handle /* parent */) { - return PYBIND11_LONG_FROM_UNSIGNED((unsigned long) src); - } - - template - static typename std::enable_if::value && std::is_signed::value && (sizeof(U) > sizeof(long)), handle>::type - cast(U src, return_value_policy /* policy */, handle /* parent */) { - return PyLong_FromLongLong((long long) src); - } - - template - static typename std::enable_if::value && std::is_unsigned::value && (sizeof(U) > sizeof(unsigned long)), handle>::type - cast(U src, return_value_policy /* policy */, handle /* parent */) { - return PyLong_FromUnsignedLongLong((unsigned long long) src); - } - - PYBIND11_TYPE_CASTER(T, _::value>("int", "float")); -}; - -template struct void_caster { -public: - bool load(handle src, bool) { - if (src && src.is_none()) - return true; - return false; - } - static handle cast(T, return_value_policy /* policy */, handle /* parent */) { - return none().inc_ref(); - } - PYBIND11_TYPE_CASTER(T, _("None")); -}; - -template <> class type_caster : public void_caster {}; - -template <> class type_caster : public type_caster { -public: - using type_caster::cast; - - bool load(handle h, bool) { - if (!h) { - return false; - } else if (h.is_none()) { - value = nullptr; - return true; - } - - /* Check if this is a capsule */ - if (isinstance(h)) { - value = reinterpret_borrow(h); - return true; - } - - /* Check if this is a C++ type */ - auto &bases = all_type_info((PyTypeObject *) h.get_type().ptr()); - if (bases.size() == 1) { // Only allowing loading from a single-value type - value = values_and_holders(reinterpret_cast(h.ptr())).begin()->value_ptr(); - return true; - } - - /* Fail */ - return false; - } - - static handle cast(const void *ptr, return_value_policy /* policy */, handle /* parent */) { - if (ptr) - return capsule(ptr).release(); - else - return none().inc_ref(); - } - - template using cast_op_type = void*&; - operator void *&() { return value; } - static constexpr auto name = _("capsule"); -private: - void *value = nullptr; -}; - -template <> class type_caster : public void_caster { }; - -template <> class type_caster { -public: - bool load(handle src, bool convert) { - if (!src) return false; - else if (src.ptr() == Py_True) { value = true; return true; } - else if (src.ptr() == Py_False) { value = false; return true; } - else if (convert || !strcmp("numpy.bool_", Py_TYPE(src.ptr())->tp_name)) { - // (allow non-implicit conversion for numpy booleans) - - Py_ssize_t res = -1; - if (src.is_none()) { - res = 0; // None is implicitly converted to False - } - #if defined(PYPY_VERSION) - // On PyPy, check that "__bool__" (or "__nonzero__" on Python 2.7) attr exists - else if (hasattr(src, PYBIND11_BOOL_ATTR)) { - res = PyObject_IsTrue(src.ptr()); - } - #else - // Alternate approach for CPython: this does the same as the above, but optimized - // using the CPython API so as to avoid an unneeded attribute lookup. - else if (auto tp_as_number = src.ptr()->ob_type->tp_as_number) { - if (PYBIND11_NB_BOOL(tp_as_number)) { - res = (*PYBIND11_NB_BOOL(tp_as_number))(src.ptr()); - } - } - #endif - if (res == 0 || res == 1) { - value = (bool) res; - return true; - } - } - return false; - } - static handle cast(bool src, return_value_policy /* policy */, handle /* parent */) { - return handle(src ? Py_True : Py_False).inc_ref(); - } - PYBIND11_TYPE_CASTER(bool, _("bool")); -}; - -// Helper class for UTF-{8,16,32} C++ stl strings: -template struct string_caster { - using CharT = typename StringType::value_type; - - // Simplify life by being able to assume standard char sizes (the standard only guarantees - // minimums, but Python requires exact sizes) - static_assert(!std::is_same::value || sizeof(CharT) == 1, "Unsupported char size != 1"); - static_assert(!std::is_same::value || sizeof(CharT) == 2, "Unsupported char16_t size != 2"); - static_assert(!std::is_same::value || sizeof(CharT) == 4, "Unsupported char32_t size != 4"); - // wchar_t can be either 16 bits (Windows) or 32 (everywhere else) - static_assert(!std::is_same::value || sizeof(CharT) == 2 || sizeof(CharT) == 4, - "Unsupported wchar_t size != 2/4"); - static constexpr size_t UTF_N = 8 * sizeof(CharT); - - bool load(handle src, bool) { -#if PY_MAJOR_VERSION < 3 - object temp; -#endif - handle load_src = src; - if (!src) { - return false; - } else if (!PyUnicode_Check(load_src.ptr())) { -#if PY_MAJOR_VERSION >= 3 - return load_bytes(load_src); -#else - if (sizeof(CharT) == 1) { - return load_bytes(load_src); - } - - // The below is a guaranteed failure in Python 3 when PyUnicode_Check returns false - if (!PYBIND11_BYTES_CHECK(load_src.ptr())) - return false; - - temp = reinterpret_steal(PyUnicode_FromObject(load_src.ptr())); - if (!temp) { PyErr_Clear(); return false; } - load_src = temp; -#endif - } - - object utfNbytes = reinterpret_steal(PyUnicode_AsEncodedString( - load_src.ptr(), UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr)); - if (!utfNbytes) { PyErr_Clear(); return false; } - - const CharT *buffer = reinterpret_cast(PYBIND11_BYTES_AS_STRING(utfNbytes.ptr())); - size_t length = (size_t) PYBIND11_BYTES_SIZE(utfNbytes.ptr()) / sizeof(CharT); - if (UTF_N > 8) { buffer++; length--; } // Skip BOM for UTF-16/32 - value = StringType(buffer, length); - - // If we're loading a string_view we need to keep the encoded Python object alive: - if (IsView) - loader_life_support::add_patient(utfNbytes); - - return true; - } - - static handle cast(const StringType &src, return_value_policy /* policy */, handle /* parent */) { - const char *buffer = reinterpret_cast(src.data()); - ssize_t nbytes = ssize_t(src.size() * sizeof(CharT)); - handle s = decode_utfN(buffer, nbytes); - if (!s) throw error_already_set(); - return s; - } - - PYBIND11_TYPE_CASTER(StringType, _(PYBIND11_STRING_NAME)); - -private: - static handle decode_utfN(const char *buffer, ssize_t nbytes) { -#if !defined(PYPY_VERSION) - return - UTF_N == 8 ? PyUnicode_DecodeUTF8(buffer, nbytes, nullptr) : - UTF_N == 16 ? PyUnicode_DecodeUTF16(buffer, nbytes, nullptr, nullptr) : - PyUnicode_DecodeUTF32(buffer, nbytes, nullptr, nullptr); -#else - // PyPy seems to have multiple problems related to PyUnicode_UTF*: the UTF8 version - // sometimes segfaults for unknown reasons, while the UTF16 and 32 versions require a - // non-const char * arguments, which is also a nuisance, so bypass the whole thing by just - // passing the encoding as a string value, which works properly: - return PyUnicode_Decode(buffer, nbytes, UTF_N == 8 ? "utf-8" : UTF_N == 16 ? "utf-16" : "utf-32", nullptr); -#endif - } - - // When loading into a std::string or char*, accept a bytes object as-is (i.e. - // without any encoding/decoding attempt). For other C++ char sizes this is a no-op. - // which supports loading a unicode from a str, doesn't take this path. - template - bool load_bytes(enable_if_t src) { - if (PYBIND11_BYTES_CHECK(src.ptr())) { - // We were passed a Python 3 raw bytes; accept it into a std::string or char* - // without any encoding attempt. - const char *bytes = PYBIND11_BYTES_AS_STRING(src.ptr()); - if (bytes) { - value = StringType(bytes, (size_t) PYBIND11_BYTES_SIZE(src.ptr())); - return true; - } - } - - return false; - } - - template - bool load_bytes(enable_if_t) { return false; } -}; - -template -struct type_caster, enable_if_t::value>> - : string_caster> {}; - -#ifdef PYBIND11_HAS_STRING_VIEW -template -struct type_caster, enable_if_t::value>> - : string_caster, true> {}; -#endif - -// Type caster for C-style strings. We basically use a std::string type caster, but also add the -// ability to use None as a nullptr char* (which the string caster doesn't allow). -template struct type_caster::value>> { - using StringType = std::basic_string; - using StringCaster = type_caster; - StringCaster str_caster; - bool none = false; - CharT one_char = 0; -public: - bool load(handle src, bool convert) { - if (!src) return false; - if (src.is_none()) { - // Defer accepting None to other overloads (if we aren't in convert mode): - if (!convert) return false; - none = true; - return true; - } - return str_caster.load(src, convert); - } - - static handle cast(const CharT *src, return_value_policy policy, handle parent) { - if (src == nullptr) return pybind11::none().inc_ref(); - return StringCaster::cast(StringType(src), policy, parent); - } - - static handle cast(CharT src, return_value_policy policy, handle parent) { - if (std::is_same::value) { - handle s = PyUnicode_DecodeLatin1((const char *) &src, 1, nullptr); - if (!s) throw error_already_set(); - return s; - } - return StringCaster::cast(StringType(1, src), policy, parent); - } - - operator CharT*() { return none ? nullptr : const_cast(static_cast(str_caster).c_str()); } - operator CharT&() { - if (none) - throw value_error("Cannot convert None to a character"); - - auto &value = static_cast(str_caster); - size_t str_len = value.size(); - if (str_len == 0) - throw value_error("Cannot convert empty string to a character"); - - // If we're in UTF-8 mode, we have two possible failures: one for a unicode character that - // is too high, and one for multiple unicode characters (caught later), so we need to figure - // out how long the first encoded character is in bytes to distinguish between these two - // errors. We also allow want to allow unicode characters U+0080 through U+00FF, as those - // can fit into a single char value. - if (StringCaster::UTF_N == 8 && str_len > 1 && str_len <= 4) { - unsigned char v0 = static_cast(value[0]); - size_t char0_bytes = !(v0 & 0x80) ? 1 : // low bits only: 0-127 - (v0 & 0xE0) == 0xC0 ? 2 : // 0b110xxxxx - start of 2-byte sequence - (v0 & 0xF0) == 0xE0 ? 3 : // 0b1110xxxx - start of 3-byte sequence - 4; // 0b11110xxx - start of 4-byte sequence - - if (char0_bytes == str_len) { - // If we have a 128-255 value, we can decode it into a single char: - if (char0_bytes == 2 && (v0 & 0xFC) == 0xC0) { // 0x110000xx 0x10xxxxxx - one_char = static_cast(((v0 & 3) << 6) + (static_cast(value[1]) & 0x3F)); - return one_char; - } - // Otherwise we have a single character, but it's > U+00FF - throw value_error("Character code point not in range(0x100)"); - } - } - - // UTF-16 is much easier: we can only have a surrogate pair for values above U+FFFF, thus a - // surrogate pair with total length 2 instantly indicates a range error (but not a "your - // string was too long" error). - else if (StringCaster::UTF_N == 16 && str_len == 2) { - one_char = static_cast(value[0]); - if (one_char >= 0xD800 && one_char < 0xE000) - throw value_error("Character code point not in range(0x10000)"); - } - - if (str_len != 1) - throw value_error("Expected a character, but multi-character string found"); - - one_char = value[0]; - return one_char; - } - - static constexpr auto name = _(PYBIND11_STRING_NAME); - template using cast_op_type = pybind11::detail::cast_op_type<_T>; -}; - -// Base implementation for std::tuple and std::pair -template class Tuple, typename... Ts> class tuple_caster { - using type = Tuple; - static constexpr auto size = sizeof...(Ts); - using indices = make_index_sequence; -public: - - bool load(handle src, bool convert) { - if (!isinstance(src)) - return false; - const auto seq = reinterpret_borrow(src); - if (seq.size() != size) - return false; - return load_impl(seq, convert, indices{}); - } - - template - static handle cast(T &&src, return_value_policy policy, handle parent) { - return cast_impl(std::forward(src), policy, parent, indices{}); - } - - static constexpr auto name = _("Tuple[") + concat(make_caster::name...) + _("]"); - - template using cast_op_type = type; - - operator type() & { return implicit_cast(indices{}); } - operator type() && { return std::move(*this).implicit_cast(indices{}); } - -protected: - template - type implicit_cast(index_sequence) & { return type(cast_op(std::get(subcasters))...); } - template - type implicit_cast(index_sequence) && { return type(cast_op(std::move(std::get(subcasters)))...); } - - static constexpr bool load_impl(const sequence &, bool, index_sequence<>) { return true; } - - template - bool load_impl(const sequence &seq, bool convert, index_sequence) { - for (bool r : {std::get(subcasters).load(seq[Is], convert)...}) - if (!r) - return false; - return true; - } - - /* Implementation: Convert a C++ tuple into a Python tuple */ - template - static handle cast_impl(T &&src, return_value_policy policy, handle parent, index_sequence) { - std::array entries{{ - reinterpret_steal(make_caster::cast(std::get(std::forward(src)), policy, parent))... - }}; - for (const auto &entry: entries) - if (!entry) - return handle(); - tuple result(size); - int counter = 0; - for (auto & entry: entries) - PyTuple_SET_ITEM(result.ptr(), counter++, entry.release().ptr()); - return result.release(); - } - - Tuple...> subcasters; -}; - -template class type_caster> - : public tuple_caster {}; - -template class type_caster> - : public tuple_caster {}; - -/// Helper class which abstracts away certain actions. Users can provide specializations for -/// custom holders, but it's only necessary if the type has a non-standard interface. -template -struct holder_helper { - static auto get(const T &p) -> decltype(p.get()) { return p.get(); } -}; - -/// Type caster for holder types like std::shared_ptr, etc. -template -struct copyable_holder_caster : public type_caster_base { -public: - using base = type_caster_base; - static_assert(std::is_base_of>::value, - "Holder classes are only supported for custom types"); - using base::base; - using base::cast; - using base::typeinfo; - using base::value; - - bool load(handle src, bool convert) { - return base::template load_impl>(src, convert); - } - - explicit operator type*() { return this->value; } - explicit operator type&() { return *(this->value); } - explicit operator holder_type*() { return std::addressof(holder); } - - // Workaround for Intel compiler bug - // see pybind11 issue 94 - #if defined(__ICC) || defined(__INTEL_COMPILER) - operator holder_type&() { return holder; } - #else - explicit operator holder_type&() { return holder; } - #endif - - static handle cast(const holder_type &src, return_value_policy, handle) { - const auto *ptr = holder_helper::get(src); - return type_caster_base::cast_holder(ptr, &src); - } - -protected: - friend class type_caster_generic; - void check_holder_compat() { - if (typeinfo->default_holder) - throw cast_error("Unable to load a custom holder type from a default-holder instance"); - } - - bool load_value(value_and_holder &&v_h) { - if (v_h.holder_constructed()) { - value = v_h.value_ptr(); - holder = v_h.template holder(); - return true; - } else { - throw cast_error("Unable to cast from non-held to held instance (T& to Holder) " -#if defined(NDEBUG) - "(compile in debug mode for type information)"); -#else - "of type '" + type_id() + "''"); -#endif - } - } - - template ::value, int> = 0> - bool try_implicit_casts(handle, bool) { return false; } - - template ::value, int> = 0> - bool try_implicit_casts(handle src, bool convert) { - for (auto &cast : typeinfo->implicit_casts) { - copyable_holder_caster sub_caster(*cast.first); - if (sub_caster.load(src, convert)) { - value = cast.second(sub_caster.value); - holder = holder_type(sub_caster.holder, (type *) value); - return true; - } - } - return false; - } - - static bool try_direct_conversions(handle) { return false; } - - - holder_type holder; -}; - -/// Specialize for the common std::shared_ptr, so users don't need to -template -class type_caster> : public copyable_holder_caster> { }; - -template -struct move_only_holder_caster { - static_assert(std::is_base_of, type_caster>::value, - "Holder classes are only supported for custom types"); - - static handle cast(holder_type &&src, return_value_policy, handle) { - auto *ptr = holder_helper::get(src); - return type_caster_base::cast_holder(ptr, std::addressof(src)); - } - static constexpr auto name = type_caster_base::name; -}; - -template -class type_caster> - : public move_only_holder_caster> { }; - -template -using type_caster_holder = conditional_t::value, - copyable_holder_caster, - move_only_holder_caster>; - -template struct always_construct_holder { static constexpr bool value = Value; }; - -/// Create a specialization for custom holder types (silently ignores std::shared_ptr) -#define PYBIND11_DECLARE_HOLDER_TYPE(type, holder_type, ...) \ - namespace pybind11 { namespace detail { \ - template \ - struct always_construct_holder : always_construct_holder { }; \ - template \ - class type_caster::value>> \ - : public type_caster_holder { }; \ - }} - -// PYBIND11_DECLARE_HOLDER_TYPE holder types: -template struct is_holder_type : - std::is_base_of, detail::type_caster> {}; -// Specialization for always-supported unique_ptr holders: -template struct is_holder_type> : - std::true_type {}; - -template struct handle_type_name { static constexpr auto name = _(); }; -template <> struct handle_type_name { static constexpr auto name = _(PYBIND11_BYTES_NAME); }; -template <> struct handle_type_name { static constexpr auto name = _("*args"); }; -template <> struct handle_type_name { static constexpr auto name = _("**kwargs"); }; - -template -struct pyobject_caster { - template ::value, int> = 0> - bool load(handle src, bool /* convert */) { value = src; return static_cast(value); } - - template ::value, int> = 0> - bool load(handle src, bool /* convert */) { - if (!isinstance(src)) - return false; - value = reinterpret_borrow(src); - return true; - } - - static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) { - return src.inc_ref(); - } - PYBIND11_TYPE_CASTER(type, handle_type_name::name); -}; - -template -class type_caster::value>> : public pyobject_caster { }; - -// Our conditions for enabling moving are quite restrictive: -// At compile time: -// - T needs to be a non-const, non-pointer, non-reference type -// - type_caster::operator T&() must exist -// - the type must be move constructible (obviously) -// At run-time: -// - if the type is non-copy-constructible, the object must be the sole owner of the type (i.e. it -// must have ref_count() == 1)h -// If any of the above are not satisfied, we fall back to copying. -template using move_is_plain_type = satisfies_none_of; -template struct move_always : std::false_type {}; -template struct move_always, - negation>, - std::is_move_constructible, - std::is_same>().operator T&()), T&> ->::value>> : std::true_type {}; -template struct move_if_unreferenced : std::false_type {}; -template struct move_if_unreferenced, - negation>, - std::is_move_constructible, - std::is_same>().operator T&()), T&> ->::value>> : std::true_type {}; -template using move_never = none_of, move_if_unreferenced>; - -// Detect whether returning a `type` from a cast on type's type_caster is going to result in a -// reference or pointer to a local variable of the type_caster. Basically, only -// non-reference/pointer `type`s and reference/pointers from a type_caster_generic are safe; -// everything else returns a reference/pointer to a local variable. -template using cast_is_temporary_value_reference = bool_constant< - (std::is_reference::value || std::is_pointer::value) && - !std::is_base_of>::value && - !std::is_same, void>::value ->; - -// When a value returned from a C++ function is being cast back to Python, we almost always want to -// force `policy = move`, regardless of the return value policy the function/method was declared -// with. -template struct return_value_policy_override { - static return_value_policy policy(return_value_policy p) { return p; } -}; - -template struct return_value_policy_override>::value, void>> { - static return_value_policy policy(return_value_policy p) { - return !std::is_lvalue_reference::value && - !std::is_pointer::value - ? return_value_policy::move : p; - } -}; - -// Basic python -> C++ casting; throws if casting fails -template type_caster &load_type(type_caster &conv, const handle &handle) { - if (!conv.load(handle, true)) { -#if defined(NDEBUG) - throw cast_error("Unable to cast Python instance to C++ type (compile in debug mode for details)"); -#else - throw cast_error("Unable to cast Python instance of type " + - (std::string) str(handle.get_type()) + " to C++ type '" + type_id() + "'"); -#endif - } - return conv; -} -// Wrapper around the above that also constructs and returns a type_caster -template make_caster load_type(const handle &handle) { - make_caster conv; - load_type(conv, handle); - return conv; -} - -NAMESPACE_END(detail) - -// pytype -> C++ type -template ::value, int> = 0> -T cast(const handle &handle) { - using namespace detail; - static_assert(!cast_is_temporary_value_reference::value, - "Unable to cast type to reference: value is local to type caster"); - return cast_op(load_type(handle)); -} - -// pytype -> pytype (calls converting constructor) -template ::value, int> = 0> -T cast(const handle &handle) { return T(reinterpret_borrow(handle)); } - -// C++ type -> py::object -template ::value, int> = 0> -object cast(const T &value, return_value_policy policy = return_value_policy::automatic_reference, - handle parent = handle()) { - if (policy == return_value_policy::automatic) - policy = std::is_pointer::value ? return_value_policy::take_ownership : return_value_policy::copy; - else if (policy == return_value_policy::automatic_reference) - policy = std::is_pointer::value ? return_value_policy::reference : return_value_policy::copy; - return reinterpret_steal(detail::make_caster::cast(value, policy, parent)); -} - -template T handle::cast() const { return pybind11::cast(*this); } -template <> inline void handle::cast() const { return; } - -template -detail::enable_if_t::value, T> move(object &&obj) { - if (obj.ref_count() > 1) -#if defined(NDEBUG) - throw cast_error("Unable to cast Python instance to C++ rvalue: instance has multiple references" - " (compile in debug mode for details)"); -#else - throw cast_error("Unable to move from Python " + (std::string) str(obj.get_type()) + - " instance to C++ " + type_id() + " instance: instance has multiple references"); -#endif - - // Move into a temporary and return that, because the reference may be a local value of `conv` - T ret = std::move(detail::load_type(obj).operator T&()); - return ret; -} - -// Calling cast() on an rvalue calls pybind::cast with the object rvalue, which does: -// - If we have to move (because T has no copy constructor), do it. This will fail if the moved -// object has multiple references, but trying to copy will fail to compile. -// - If both movable and copyable, check ref count: if 1, move; otherwise copy -// - Otherwise (not movable), copy. -template detail::enable_if_t::value, T> cast(object &&object) { - return move(std::move(object)); -} -template detail::enable_if_t::value, T> cast(object &&object) { - if (object.ref_count() > 1) - return cast(object); - else - return move(std::move(object)); -} -template detail::enable_if_t::value, T> cast(object &&object) { - return cast(object); -} - -template T object::cast() const & { return pybind11::cast(*this); } -template T object::cast() && { return pybind11::cast(std::move(*this)); } -template <> inline void object::cast() const & { return; } -template <> inline void object::cast() && { return; } - -NAMESPACE_BEGIN(detail) - -// Declared in pytypes.h: -template ::value, int>> -object object_or_cast(T &&o) { return pybind11::cast(std::forward(o)); } - -struct overload_unused {}; // Placeholder type for the unneeded (and dead code) static variable in the OVERLOAD_INT macro -template using overload_caster_t = conditional_t< - cast_is_temporary_value_reference::value, make_caster, overload_unused>; - -// Trampoline use: for reference/pointer types to value-converted values, we do a value cast, then -// store the result in the given variable. For other types, this is a no-op. -template enable_if_t::value, T> cast_ref(object &&o, make_caster &caster) { - return cast_op(load_type(caster, o)); -} -template enable_if_t::value, T> cast_ref(object &&, overload_unused &) { - pybind11_fail("Internal error: cast_ref fallback invoked"); } - -// Trampoline use: Having a pybind11::cast with an invalid reference type is going to static_assert, even -// though if it's in dead code, so we provide a "trampoline" to pybind11::cast that only does anything in -// cases where pybind11::cast is valid. -template enable_if_t::value, T> cast_safe(object &&o) { - return pybind11::cast(std::move(o)); } -template enable_if_t::value, T> cast_safe(object &&) { - pybind11_fail("Internal error: cast_safe fallback invoked"); } -template <> inline void cast_safe(object &&) {} - -NAMESPACE_END(detail) - -template -tuple make_tuple() { return tuple(0); } - -template tuple make_tuple(Args&&... args_) { - constexpr size_t size = sizeof...(Args); - std::array args { - { reinterpret_steal(detail::make_caster::cast( - std::forward(args_), policy, nullptr))... } - }; - for (size_t i = 0; i < args.size(); i++) { - if (!args[i]) { -#if defined(NDEBUG) - throw cast_error("make_tuple(): unable to convert arguments to Python object (compile in debug mode for details)"); -#else - std::array argtypes { {type_id()...} }; - throw cast_error("make_tuple(): unable to convert argument of type '" + - argtypes[i] + "' to Python object"); -#endif - } - } - tuple result(size); - int counter = 0; - for (auto &arg_value : args) - PyTuple_SET_ITEM(result.ptr(), counter++, arg_value.release().ptr()); - return result; -} - -/// \ingroup annotations -/// Annotation for arguments -struct arg { - /// Constructs an argument with the name of the argument; if null or omitted, this is a positional argument. - constexpr explicit arg(const char *name = nullptr) : name(name), flag_noconvert(false), flag_none(true) { } - /// Assign a value to this argument - template arg_v operator=(T &&value) const; - /// Indicate that the type should not be converted in the type caster - arg &noconvert(bool flag = true) { flag_noconvert = flag; return *this; } - /// Indicates that the argument should/shouldn't allow None (e.g. for nullable pointer args) - arg &none(bool flag = true) { flag_none = flag; return *this; } - - const char *name; ///< If non-null, this is a named kwargs argument - bool flag_noconvert : 1; ///< If set, do not allow conversion (requires a supporting type caster!) - bool flag_none : 1; ///< If set (the default), allow None to be passed to this argument -}; - -/// \ingroup annotations -/// Annotation for arguments with values -struct arg_v : arg { -private: - template - arg_v(arg &&base, T &&x, const char *descr = nullptr) - : arg(base), - value(reinterpret_steal( - detail::make_caster::cast(x, return_value_policy::automatic, {}) - )), - descr(descr) -#if !defined(NDEBUG) - , type(type_id()) -#endif - { } - -public: - /// Direct construction with name, default, and description - template - arg_v(const char *name, T &&x, const char *descr = nullptr) - : arg_v(arg(name), std::forward(x), descr) { } - - /// Called internally when invoking `py::arg("a") = value` - template - arg_v(const arg &base, T &&x, const char *descr = nullptr) - : arg_v(arg(base), std::forward(x), descr) { } - - /// Same as `arg::noconvert()`, but returns *this as arg_v&, not arg& - arg_v &noconvert(bool flag = true) { arg::noconvert(flag); return *this; } - - /// Same as `arg::nonone()`, but returns *this as arg_v&, not arg& - arg_v &none(bool flag = true) { arg::none(flag); return *this; } - - /// The default value - object value; - /// The (optional) description of the default value - const char *descr; -#if !defined(NDEBUG) - /// The C++ type name of the default value (only available when compiled in debug mode) - std::string type; -#endif -}; - -template -arg_v arg::operator=(T &&value) const { return {std::move(*this), std::forward(value)}; } - -/// Alias for backward compatibility -- to be removed in version 2.0 -template using arg_t = arg_v; - -inline namespace literals { -/** \rst - String literal version of `arg` - \endrst */ -constexpr arg operator"" _a(const char *name, size_t) { return arg(name); } -} - -NAMESPACE_BEGIN(detail) - -// forward declaration (definition in attr.h) -struct function_record; - -/// Internal data associated with a single function call -struct function_call { - function_call(const function_record &f, handle p); // Implementation in attr.h - - /// The function data: - const function_record &func; - - /// Arguments passed to the function: - std::vector args; - - /// The `convert` value the arguments should be loaded with - std::vector args_convert; - - /// Extra references for the optional `py::args` and/or `py::kwargs` arguments (which, if - /// present, are also in `args` but without a reference). - object args_ref, kwargs_ref; - - /// The parent, if any - handle parent; - - /// If this is a call to an initializer, this argument contains `self` - handle init_self; -}; - - -/// Helper class which loads arguments for C++ functions called from Python -template -class argument_loader { - using indices = make_index_sequence; - - template using argument_is_args = std::is_same, args>; - template using argument_is_kwargs = std::is_same, kwargs>; - // Get args/kwargs argument positions relative to the end of the argument list: - static constexpr auto args_pos = constexpr_first() - (int) sizeof...(Args), - kwargs_pos = constexpr_first() - (int) sizeof...(Args); - - static constexpr bool args_kwargs_are_last = kwargs_pos >= - 1 && args_pos >= kwargs_pos - 1; - - static_assert(args_kwargs_are_last, "py::args/py::kwargs are only permitted as the last argument(s) of a function"); - -public: - static constexpr bool has_kwargs = kwargs_pos < 0; - static constexpr bool has_args = args_pos < 0; - - static constexpr auto arg_names = concat(type_descr(make_caster::name)...); - - bool load_args(function_call &call) { - return load_impl_sequence(call, indices{}); - } - - template - enable_if_t::value, Return> call(Func &&f) && { - return std::move(*this).template call_impl(std::forward(f), indices{}, Guard{}); - } - - template - enable_if_t::value, void_type> call(Func &&f) && { - std::move(*this).template call_impl(std::forward(f), indices{}, Guard{}); - return void_type(); - } - -private: - - static bool load_impl_sequence(function_call &, index_sequence<>) { return true; } - - template - bool load_impl_sequence(function_call &call, index_sequence) { - for (bool r : {std::get(argcasters).load(call.args[Is], call.args_convert[Is])...}) - if (!r) - return false; - return true; - } - - template - Return call_impl(Func &&f, index_sequence, Guard &&) { - return std::forward(f)(cast_op(std::move(std::get(argcasters)))...); - } - - std::tuple...> argcasters; -}; - -/// Helper class which collects only positional arguments for a Python function call. -/// A fancier version below can collect any argument, but this one is optimal for simple calls. -template -class simple_collector { -public: - template - explicit simple_collector(Ts &&...values) - : m_args(pybind11::make_tuple(std::forward(values)...)) { } - - const tuple &args() const & { return m_args; } - dict kwargs() const { return {}; } - - tuple args() && { return std::move(m_args); } - - /// Call a Python function and pass the collected arguments - object call(PyObject *ptr) const { - PyObject *result = PyObject_CallObject(ptr, m_args.ptr()); - if (!result) - throw error_already_set(); - return reinterpret_steal(result); - } - -private: - tuple m_args; -}; - -/// Helper class which collects positional, keyword, * and ** arguments for a Python function call -template -class unpacking_collector { -public: - template - explicit unpacking_collector(Ts &&...values) { - // Tuples aren't (easily) resizable so a list is needed for collection, - // but the actual function call strictly requires a tuple. - auto args_list = list(); - int _[] = { 0, (process(args_list, std::forward(values)), 0)... }; - ignore_unused(_); - - m_args = std::move(args_list); - } - - const tuple &args() const & { return m_args; } - const dict &kwargs() const & { return m_kwargs; } - - tuple args() && { return std::move(m_args); } - dict kwargs() && { return std::move(m_kwargs); } - - /// Call a Python function and pass the collected arguments - object call(PyObject *ptr) const { - PyObject *result = PyObject_Call(ptr, m_args.ptr(), m_kwargs.ptr()); - if (!result) - throw error_already_set(); - return reinterpret_steal(result); - } - -private: - template - void process(list &args_list, T &&x) { - auto o = reinterpret_steal(detail::make_caster::cast(std::forward(x), policy, {})); - if (!o) { -#if defined(NDEBUG) - argument_cast_error(); -#else - argument_cast_error(std::to_string(args_list.size()), type_id()); -#endif - } - args_list.append(o); - } - - void process(list &args_list, detail::args_proxy ap) { - for (const auto &a : ap) - args_list.append(a); - } - - void process(list &/*args_list*/, arg_v a) { - if (!a.name) -#if defined(NDEBUG) - nameless_argument_error(); -#else - nameless_argument_error(a.type); -#endif - - if (m_kwargs.contains(a.name)) { -#if defined(NDEBUG) - multiple_values_error(); -#else - multiple_values_error(a.name); -#endif - } - if (!a.value) { -#if defined(NDEBUG) - argument_cast_error(); -#else - argument_cast_error(a.name, a.type); -#endif - } - m_kwargs[a.name] = a.value; - } - - void process(list &/*args_list*/, detail::kwargs_proxy kp) { - if (!kp) - return; - for (const auto &k : reinterpret_borrow(kp)) { - if (m_kwargs.contains(k.first)) { -#if defined(NDEBUG) - multiple_values_error(); -#else - multiple_values_error(str(k.first)); -#endif - } - m_kwargs[k.first] = k.second; - } - } - - [[noreturn]] static void nameless_argument_error() { - throw type_error("Got kwargs without a name; only named arguments " - "may be passed via py::arg() to a python function call. " - "(compile in debug mode for details)"); - } - [[noreturn]] static void nameless_argument_error(std::string type) { - throw type_error("Got kwargs without a name of type '" + type + "'; only named " - "arguments may be passed via py::arg() to a python function call. "); - } - [[noreturn]] static void multiple_values_error() { - throw type_error("Got multiple values for keyword argument " - "(compile in debug mode for details)"); - } - - [[noreturn]] static void multiple_values_error(std::string name) { - throw type_error("Got multiple values for keyword argument '" + name + "'"); - } - - [[noreturn]] static void argument_cast_error() { - throw cast_error("Unable to convert call argument to Python object " - "(compile in debug mode for details)"); - } - - [[noreturn]] static void argument_cast_error(std::string name, std::string type) { - throw cast_error("Unable to convert call argument '" + name - + "' of type '" + type + "' to Python object"); - } - -private: - tuple m_args; - dict m_kwargs; -}; - -/// Collect only positional arguments for a Python function call -template ...>::value>> -simple_collector collect_arguments(Args &&...args) { - return simple_collector(std::forward(args)...); -} - -/// Collect all arguments, including keywords and unpacking (only instantiated when needed) -template ...>::value>> -unpacking_collector collect_arguments(Args &&...args) { - // Following argument order rules for generalized unpacking according to PEP 448 - static_assert( - constexpr_last() < constexpr_first() - && constexpr_last() < constexpr_first(), - "Invalid function call: positional args must precede keywords and ** unpacking; " - "* unpacking must precede ** unpacking" - ); - return unpacking_collector(std::forward(args)...); -} - -template -template -object object_api::operator()(Args &&...args) const { - return detail::collect_arguments(std::forward(args)...).call(derived().ptr()); -} - -template -template -object object_api::call(Args &&...args) const { - return operator()(std::forward(args)...); -} - -NAMESPACE_END(detail) - -#define PYBIND11_MAKE_OPAQUE(...) \ - namespace pybind11 { namespace detail { \ - template<> class type_caster<__VA_ARGS__> : public type_caster_base<__VA_ARGS__> { }; \ - }} - -/// Lets you pass a type containing a `,` through a macro parameter without needing a separate -/// typedef, e.g.: `PYBIND11_OVERLOAD(PYBIND11_TYPE(ReturnType), PYBIND11_TYPE(Parent), f, arg)` -#define PYBIND11_TYPE(...) __VA_ARGS__ - -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/chrono.h b/pybind11/include/pybind11/chrono.h deleted file mode 100644 index ea777e6..0000000 --- a/pybind11/include/pybind11/chrono.h +++ /dev/null @@ -1,184 +0,0 @@ -/* - pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime - - Copyright (c) 2016 Trent Houliston and - Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "pybind11.h" -#include -#include -#include -#include - -// Backport the PyDateTime_DELTA functions from Python3.3 if required -#ifndef PyDateTime_DELTA_GET_DAYS -#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) -#endif -#ifndef PyDateTime_DELTA_GET_SECONDS -#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) -#endif -#ifndef PyDateTime_DELTA_GET_MICROSECONDS -#define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) -#endif - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -template class duration_caster { -public: - typedef typename type::rep rep; - typedef typename type::period period; - - typedef std::chrono::duration> days; - - bool load(handle src, bool) { - using namespace std::chrono; - - // Lazy initialise the PyDateTime import - if (!PyDateTimeAPI) { PyDateTime_IMPORT; } - - if (!src) return false; - // If invoked with datetime.delta object - if (PyDelta_Check(src.ptr())) { - value = type(duration_cast>( - days(PyDateTime_DELTA_GET_DAYS(src.ptr())) - + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr())) - + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr())))); - return true; - } - // If invoked with a float we assume it is seconds and convert - else if (PyFloat_Check(src.ptr())) { - value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr())))); - return true; - } - else return false; - } - - // If this is a duration just return it back - static const std::chrono::duration& get_duration(const std::chrono::duration &src) { - return src; - } - - // If this is a time_point get the time_since_epoch - template static std::chrono::duration get_duration(const std::chrono::time_point> &src) { - return src.time_since_epoch(); - } - - static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) { - using namespace std::chrono; - - // Use overloaded function to get our duration from our source - // Works out if it is a duration or time_point and get the duration - auto d = get_duration(src); - - // Lazy initialise the PyDateTime import - if (!PyDateTimeAPI) { PyDateTime_IMPORT; } - - // Declare these special duration types so the conversions happen with the correct primitive types (int) - using dd_t = duration>; - using ss_t = duration>; - using us_t = duration; - - auto dd = duration_cast(d); - auto subd = d - dd; - auto ss = duration_cast(subd); - auto us = duration_cast(subd - ss); - return PyDelta_FromDSU(dd.count(), ss.count(), us.count()); - } - - PYBIND11_TYPE_CASTER(type, _("datetime.timedelta")); -}; - -// This is for casting times on the system clock into datetime.datetime instances -template class type_caster> { -public: - typedef std::chrono::time_point type; - bool load(handle src, bool) { - using namespace std::chrono; - - // Lazy initialise the PyDateTime import - if (!PyDateTimeAPI) { PyDateTime_IMPORT; } - - if (!src) return false; - - std::tm cal; - microseconds msecs; - - if (PyDateTime_Check(src.ptr())) { - cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr()); - cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr()); - cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr()); - cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); - cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; - cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; - cal.tm_isdst = -1; - msecs = microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr())); - } else if (PyDate_Check(src.ptr())) { - cal.tm_sec = 0; - cal.tm_min = 0; - cal.tm_hour = 0; - cal.tm_mday = PyDateTime_GET_DAY(src.ptr()); - cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1; - cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900; - cal.tm_isdst = -1; - msecs = microseconds(0); - } else if (PyTime_Check(src.ptr())) { - cal.tm_sec = PyDateTime_TIME_GET_SECOND(src.ptr()); - cal.tm_min = PyDateTime_TIME_GET_MINUTE(src.ptr()); - cal.tm_hour = PyDateTime_TIME_GET_HOUR(src.ptr()); - cal.tm_mday = 1; // This date (day, month, year) = (1, 0, 70) - cal.tm_mon = 0; // represents 1-Jan-1970, which is the first - cal.tm_year = 70; // earliest available date for Python's datetime - cal.tm_isdst = -1; - msecs = microseconds(PyDateTime_TIME_GET_MICROSECOND(src.ptr())); - } - else return false; - - value = system_clock::from_time_t(std::mktime(&cal)) + msecs; - return true; - } - - static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) { - using namespace std::chrono; - - // Lazy initialise the PyDateTime import - if (!PyDateTimeAPI) { PyDateTime_IMPORT; } - - std::time_t tt = system_clock::to_time_t(time_point_cast(src)); - // this function uses static memory so it's best to copy it out asap just in case - // otherwise other code that is using localtime may break this (not just python code) - std::tm localtime = *std::localtime(&tt); - - // Declare these special duration types so the conversions happen with the correct primitive types (int) - using us_t = duration; - - return PyDateTime_FromDateAndTime(localtime.tm_year + 1900, - localtime.tm_mon + 1, - localtime.tm_mday, - localtime.tm_hour, - localtime.tm_min, - localtime.tm_sec, - (duration_cast(src.time_since_epoch() % seconds(1))).count()); - } - PYBIND11_TYPE_CASTER(type, _("datetime.datetime")); -}; - -// Other clocks that are not the system clock are not measured as datetime.datetime objects -// since they are not measured on calendar time. So instead we just make them timedeltas -// Or if they have passed us a time as a float we convert that -template class type_caster> -: public duration_caster> { -}; - -template class type_caster> -: public duration_caster> { -}; - -NAMESPACE_END(detail) -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/common.h b/pybind11/include/pybind11/common.h deleted file mode 100644 index 6c8a4f1..0000000 --- a/pybind11/include/pybind11/common.h +++ /dev/null @@ -1,2 +0,0 @@ -#include "detail/common.h" -#warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'." diff --git a/pybind11/include/pybind11/complex.h b/pybind11/include/pybind11/complex.h deleted file mode 100644 index 3f89638..0000000 --- a/pybind11/include/pybind11/complex.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - pybind11/complex.h: Complex number support - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "pybind11.h" -#include - -/// glibc defines I as a macro which breaks things, e.g., boost template names -#ifdef I -# undef I -#endif - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) - -template struct format_descriptor, detail::enable_if_t::value>> { - static constexpr const char c = format_descriptor::c; - static constexpr const char value[3] = { 'Z', c, '\0' }; - static std::string format() { return std::string(value); } -}; - -#ifndef PYBIND11_CPP17 - -template constexpr const char format_descriptor< - std::complex, detail::enable_if_t::value>>::value[3]; - -#endif - -NAMESPACE_BEGIN(detail) - -template struct is_fmt_numeric, detail::enable_if_t::value>> { - static constexpr bool value = true; - static constexpr int index = is_fmt_numeric::index + 3; -}; - -template class type_caster> { -public: - bool load(handle src, bool convert) { - if (!src) - return false; - if (!convert && !PyComplex_Check(src.ptr())) - return false; - Py_complex result = PyComplex_AsCComplex(src.ptr()); - if (result.real == -1.0 && PyErr_Occurred()) { - PyErr_Clear(); - return false; - } - value = std::complex((T) result.real, (T) result.imag); - return true; - } - - static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) { - return PyComplex_FromDoubles((double) src.real(), (double) src.imag()); - } - - PYBIND11_TYPE_CASTER(std::complex, _("complex")); -}; -NAMESPACE_END(detail) -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/detail/class.h b/pybind11/include/pybind11/detail/class.h deleted file mode 100644 index 230ae81..0000000 --- a/pybind11/include/pybind11/detail/class.h +++ /dev/null @@ -1,632 +0,0 @@ -/* - pybind11/detail/class.h: Python C API implementation details for py::class_ - - Copyright (c) 2017 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "../attr.h" -#include "../options.h" - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -#if PY_VERSION_HEX >= 0x03030000 -# define PYBIND11_BUILTIN_QUALNAME -# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) -#else -// In pre-3.3 Python, we still set __qualname__ so that we can produce reliable function type -// signatures; in 3.3+ this macro expands to nothing: -# define PYBIND11_SET_OLDPY_QUALNAME(obj, nameobj) setattr((PyObject *) obj, "__qualname__", nameobj) -#endif - -inline PyTypeObject *type_incref(PyTypeObject *type) { - Py_INCREF(type); - return type; -} - -#if !defined(PYPY_VERSION) - -/// `pybind11_static_property.__get__()`: Always pass the class instead of the instance. -extern "C" inline PyObject *pybind11_static_get(PyObject *self, PyObject * /*ob*/, PyObject *cls) { - return PyProperty_Type.tp_descr_get(self, cls, cls); -} - -/// `pybind11_static_property.__set__()`: Just like the above `__get__()`. -extern "C" inline int pybind11_static_set(PyObject *self, PyObject *obj, PyObject *value) { - PyObject *cls = PyType_Check(obj) ? obj : (PyObject *) Py_TYPE(obj); - return PyProperty_Type.tp_descr_set(self, cls, value); -} - -/** A `static_property` is the same as a `property` but the `__get__()` and `__set__()` - methods are modified to always use the object type instead of a concrete instance. - Return value: New reference. */ -inline PyTypeObject *make_static_property_type() { - constexpr auto *name = "pybind11_static_property"; - auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); - - /* Danger zone: from now (and until PyType_Ready), make sure to - issue no Python C API calls which could potentially invoke the - garbage collector (the GC will call type_traverse(), which will in - turn find the newly constructed type in an invalid state) */ - auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); - if (!heap_type) - pybind11_fail("make_static_property_type(): error allocating type!"); - - heap_type->ht_name = name_obj.inc_ref().ptr(); -#ifdef PYBIND11_BUILTIN_QUALNAME - heap_type->ht_qualname = name_obj.inc_ref().ptr(); -#endif - - auto type = &heap_type->ht_type; - type->tp_name = name; - type->tp_base = type_incref(&PyProperty_Type); - type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; - type->tp_descr_get = pybind11_static_get; - type->tp_descr_set = pybind11_static_set; - - if (PyType_Ready(type) < 0) - pybind11_fail("make_static_property_type(): failure in PyType_Ready()!"); - - setattr((PyObject *) type, "__module__", str("pybind11_builtins")); - PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); - - return type; -} - -#else // PYPY - -/** PyPy has some issues with the above C API, so we evaluate Python code instead. - This function will only be called once so performance isn't really a concern. - Return value: New reference. */ -inline PyTypeObject *make_static_property_type() { - auto d = dict(); - PyObject *result = PyRun_String(R"(\ - class pybind11_static_property(property): - def __get__(self, obj, cls): - return property.__get__(self, cls, cls) - - def __set__(self, obj, value): - cls = obj if isinstance(obj, type) else type(obj) - property.__set__(self, cls, value) - )", Py_file_input, d.ptr(), d.ptr() - ); - if (result == nullptr) - throw error_already_set(); - Py_DECREF(result); - return (PyTypeObject *) d["pybind11_static_property"].cast().release().ptr(); -} - -#endif // PYPY - -/** Types with static properties need to handle `Type.static_prop = x` in a specific way. - By default, Python replaces the `static_property` itself, but for wrapped C++ types - we need to call `static_property.__set__()` in order to propagate the new value to - the underlying C++ data structure. */ -extern "C" inline int pybind11_meta_setattro(PyObject* obj, PyObject* name, PyObject* value) { - // Use `_PyType_Lookup()` instead of `PyObject_GetAttr()` in order to get the raw - // descriptor (`property`) instead of calling `tp_descr_get` (`property.__get__()`). - PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); - - // The following assignment combinations are possible: - // 1. `Type.static_prop = value` --> descr_set: `Type.static_prop.__set__(value)` - // 2. `Type.static_prop = other_static_prop` --> setattro: replace existing `static_prop` - // 3. `Type.regular_attribute = value` --> setattro: regular attribute assignment - const auto static_prop = (PyObject *) get_internals().static_property_type; - const auto call_descr_set = descr && PyObject_IsInstance(descr, static_prop) - && !PyObject_IsInstance(value, static_prop); - if (call_descr_set) { - // Call `static_property.__set__()` instead of replacing the `static_property`. -#if !defined(PYPY_VERSION) - return Py_TYPE(descr)->tp_descr_set(descr, obj, value); -#else - if (PyObject *result = PyObject_CallMethod(descr, "__set__", "OO", obj, value)) { - Py_DECREF(result); - return 0; - } else { - return -1; - } -#endif - } else { - // Replace existing attribute. - return PyType_Type.tp_setattro(obj, name, value); - } -} - -#if PY_MAJOR_VERSION >= 3 -/** - * Python 3's PyInstanceMethod_Type hides itself via its tp_descr_get, which prevents aliasing - * methods via cls.attr("m2") = cls.attr("m1"): instead the tp_descr_get returns a plain function, - * when called on a class, or a PyMethod, when called on an instance. Override that behaviour here - * to do a special case bypass for PyInstanceMethod_Types. - */ -extern "C" inline PyObject *pybind11_meta_getattro(PyObject *obj, PyObject *name) { - PyObject *descr = _PyType_Lookup((PyTypeObject *) obj, name); - if (descr && PyInstanceMethod_Check(descr)) { - Py_INCREF(descr); - return descr; - } - else { - return PyType_Type.tp_getattro(obj, name); - } -} -#endif - -/** This metaclass is assigned by default to all pybind11 types and is required in order - for static properties to function correctly. Users may override this using `py::metaclass`. - Return value: New reference. */ -inline PyTypeObject* make_default_metaclass() { - constexpr auto *name = "pybind11_type"; - auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); - - /* Danger zone: from now (and until PyType_Ready), make sure to - issue no Python C API calls which could potentially invoke the - garbage collector (the GC will call type_traverse(), which will in - turn find the newly constructed type in an invalid state) */ - auto heap_type = (PyHeapTypeObject *) PyType_Type.tp_alloc(&PyType_Type, 0); - if (!heap_type) - pybind11_fail("make_default_metaclass(): error allocating metaclass!"); - - heap_type->ht_name = name_obj.inc_ref().ptr(); -#ifdef PYBIND11_BUILTIN_QUALNAME - heap_type->ht_qualname = name_obj.inc_ref().ptr(); -#endif - - auto type = &heap_type->ht_type; - type->tp_name = name; - type->tp_base = type_incref(&PyType_Type); - type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; - - type->tp_setattro = pybind11_meta_setattro; -#if PY_MAJOR_VERSION >= 3 - type->tp_getattro = pybind11_meta_getattro; -#endif - - if (PyType_Ready(type) < 0) - pybind11_fail("make_default_metaclass(): failure in PyType_Ready()!"); - - setattr((PyObject *) type, "__module__", str("pybind11_builtins")); - PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); - - return type; -} - -/// For multiple inheritance types we need to recursively register/deregister base pointers for any -/// base classes with pointers that are difference from the instance value pointer so that we can -/// correctly recognize an offset base class pointer. This calls a function with any offset base ptrs. -inline void traverse_offset_bases(void *valueptr, const detail::type_info *tinfo, instance *self, - bool (*f)(void * /*parentptr*/, instance * /*self*/)) { - for (handle h : reinterpret_borrow(tinfo->type->tp_bases)) { - if (auto parent_tinfo = get_type_info((PyTypeObject *) h.ptr())) { - for (auto &c : parent_tinfo->implicit_casts) { - if (c.first == tinfo->cpptype) { - auto *parentptr = c.second(valueptr); - if (parentptr != valueptr) - f(parentptr, self); - traverse_offset_bases(parentptr, parent_tinfo, self, f); - break; - } - } - } - } -} - -inline bool register_instance_impl(void *ptr, instance *self) { - get_internals().registered_instances.emplace(ptr, self); - return true; // unused, but gives the same signature as the deregister func -} -inline bool deregister_instance_impl(void *ptr, instance *self) { - auto ®istered_instances = get_internals().registered_instances; - auto range = registered_instances.equal_range(ptr); - for (auto it = range.first; it != range.second; ++it) { - if (Py_TYPE(self) == Py_TYPE(it->second)) { - registered_instances.erase(it); - return true; - } - } - return false; -} - -inline void register_instance(instance *self, void *valptr, const type_info *tinfo) { - register_instance_impl(valptr, self); - if (!tinfo->simple_ancestors) - traverse_offset_bases(valptr, tinfo, self, register_instance_impl); -} - -inline bool deregister_instance(instance *self, void *valptr, const type_info *tinfo) { - bool ret = deregister_instance_impl(valptr, self); - if (!tinfo->simple_ancestors) - traverse_offset_bases(valptr, tinfo, self, deregister_instance_impl); - return ret; -} - -/// Instance creation function for all pybind11 types. It allocates the internal instance layout for -/// holding C++ objects and holders. Allocation is done lazily (the first time the instance is cast -/// to a reference or pointer), and initialization is done by an `__init__` function. -inline PyObject *make_new_instance(PyTypeObject *type) { -#if defined(PYPY_VERSION) - // PyPy gets tp_basicsize wrong (issue 2482) under multiple inheritance when the first inherited - // object is a a plain Python type (i.e. not derived from an extension type). Fix it. - ssize_t instance_size = static_cast(sizeof(instance)); - if (type->tp_basicsize < instance_size) { - type->tp_basicsize = instance_size; - } -#endif - PyObject *self = type->tp_alloc(type, 0); - auto inst = reinterpret_cast(self); - // Allocate the value/holder internals: - inst->allocate_layout(); - - inst->owned = true; - - return self; -} - -/// Instance creation function for all pybind11 types. It only allocates space for the -/// C++ object, but doesn't call the constructor -- an `__init__` function must do that. -extern "C" inline PyObject *pybind11_object_new(PyTypeObject *type, PyObject *, PyObject *) { - return make_new_instance(type); -} - -/// An `__init__` function constructs the C++ object. Users should provide at least one -/// of these using `py::init` or directly with `.def(__init__, ...)`. Otherwise, the -/// following default function will be used which simply throws an exception. -extern "C" inline int pybind11_object_init(PyObject *self, PyObject *, PyObject *) { - PyTypeObject *type = Py_TYPE(self); - std::string msg; -#if defined(PYPY_VERSION) - msg += handle((PyObject *) type).attr("__module__").cast() + "."; -#endif - msg += type->tp_name; - msg += ": No constructor defined!"; - PyErr_SetString(PyExc_TypeError, msg.c_str()); - return -1; -} - -inline void add_patient(PyObject *nurse, PyObject *patient) { - auto &internals = get_internals(); - auto instance = reinterpret_cast(nurse); - instance->has_patients = true; - Py_INCREF(patient); - internals.patients[nurse].push_back(patient); -} - -inline void clear_patients(PyObject *self) { - auto instance = reinterpret_cast(self); - auto &internals = get_internals(); - auto pos = internals.patients.find(self); - assert(pos != internals.patients.end()); - // Clearing the patients can cause more Python code to run, which - // can invalidate the iterator. Extract the vector of patients - // from the unordered_map first. - auto patients = std::move(pos->second); - internals.patients.erase(pos); - instance->has_patients = false; - for (PyObject *&patient : patients) - Py_CLEAR(patient); -} - -/// Clears all internal data from the instance and removes it from registered instances in -/// preparation for deallocation. -inline void clear_instance(PyObject *self) { - auto instance = reinterpret_cast(self); - - // Deallocate any values/holders, if present: - for (auto &v_h : values_and_holders(instance)) { - if (v_h) { - - // We have to deregister before we call dealloc because, for virtual MI types, we still - // need to be able to get the parent pointers. - if (v_h.instance_registered() && !deregister_instance(instance, v_h.value_ptr(), v_h.type)) - pybind11_fail("pybind11_object_dealloc(): Tried to deallocate unregistered instance!"); - - if (instance->owned || v_h.holder_constructed()) - v_h.type->dealloc(v_h); - } - } - // Deallocate the value/holder layout internals: - instance->deallocate_layout(); - - if (instance->weakrefs) - PyObject_ClearWeakRefs(self); - - PyObject **dict_ptr = _PyObject_GetDictPtr(self); - if (dict_ptr) - Py_CLEAR(*dict_ptr); - - if (instance->has_patients) - clear_patients(self); -} - -/// Instance destructor function for all pybind11 types. It calls `type_info.dealloc` -/// to destroy the C++ object itself, while the rest is Python bookkeeping. -extern "C" inline void pybind11_object_dealloc(PyObject *self) { - clear_instance(self); - - auto type = Py_TYPE(self); - type->tp_free(self); - -#if PY_VERSION_HEX < 0x03080000 - // `type->tp_dealloc != pybind11_object_dealloc` means that we're being called - // as part of a derived type's dealloc, in which case we're not allowed to decref - // the type here. For cross-module compatibility, we shouldn't compare directly - // with `pybind11_object_dealloc`, but with the common one stashed in internals. - auto pybind11_object_type = (PyTypeObject *) get_internals().instance_base; - if (type->tp_dealloc == pybind11_object_type->tp_dealloc) - Py_DECREF(type); -#else - // This was not needed before Python 3.8 (Python issue 35810) - // https://github.com/pybind/pybind11/issues/1946 - Py_DECREF(type); -#endif -} - -/** Create the type which can be used as a common base for all classes. This is - needed in order to satisfy Python's requirements for multiple inheritance. - Return value: New reference. */ -inline PyObject *make_object_base_type(PyTypeObject *metaclass) { - constexpr auto *name = "pybind11_object"; - auto name_obj = reinterpret_steal(PYBIND11_FROM_STRING(name)); - - /* Danger zone: from now (and until PyType_Ready), make sure to - issue no Python C API calls which could potentially invoke the - garbage collector (the GC will call type_traverse(), which will in - turn find the newly constructed type in an invalid state) */ - auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); - if (!heap_type) - pybind11_fail("make_object_base_type(): error allocating type!"); - - heap_type->ht_name = name_obj.inc_ref().ptr(); -#ifdef PYBIND11_BUILTIN_QUALNAME - heap_type->ht_qualname = name_obj.inc_ref().ptr(); -#endif - - auto type = &heap_type->ht_type; - type->tp_name = name; - type->tp_base = type_incref(&PyBaseObject_Type); - type->tp_basicsize = static_cast(sizeof(instance)); - type->tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; - - type->tp_new = pybind11_object_new; - type->tp_init = pybind11_object_init; - type->tp_dealloc = pybind11_object_dealloc; - - /* Support weak references (needed for the keep_alive feature) */ - type->tp_weaklistoffset = offsetof(instance, weakrefs); - - if (PyType_Ready(type) < 0) - pybind11_fail("PyType_Ready failed in make_object_base_type():" + error_string()); - - setattr((PyObject *) type, "__module__", str("pybind11_builtins")); - PYBIND11_SET_OLDPY_QUALNAME(type, name_obj); - - assert(!PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); - return (PyObject *) heap_type; -} - -/// dynamic_attr: Support for `d = instance.__dict__`. -extern "C" inline PyObject *pybind11_get_dict(PyObject *self, void *) { - PyObject *&dict = *_PyObject_GetDictPtr(self); - if (!dict) - dict = PyDict_New(); - Py_XINCREF(dict); - return dict; -} - -/// dynamic_attr: Support for `instance.__dict__ = dict()`. -extern "C" inline int pybind11_set_dict(PyObject *self, PyObject *new_dict, void *) { - if (!PyDict_Check(new_dict)) { - PyErr_Format(PyExc_TypeError, "__dict__ must be set to a dictionary, not a '%.200s'", - Py_TYPE(new_dict)->tp_name); - return -1; - } - PyObject *&dict = *_PyObject_GetDictPtr(self); - Py_INCREF(new_dict); - Py_CLEAR(dict); - dict = new_dict; - return 0; -} - -/// dynamic_attr: Allow the garbage collector to traverse the internal instance `__dict__`. -extern "C" inline int pybind11_traverse(PyObject *self, visitproc visit, void *arg) { - PyObject *&dict = *_PyObject_GetDictPtr(self); - Py_VISIT(dict); - return 0; -} - -/// dynamic_attr: Allow the GC to clear the dictionary. -extern "C" inline int pybind11_clear(PyObject *self) { - PyObject *&dict = *_PyObject_GetDictPtr(self); - Py_CLEAR(dict); - return 0; -} - -/// Give instances of this type a `__dict__` and opt into garbage collection. -inline void enable_dynamic_attributes(PyHeapTypeObject *heap_type) { - auto type = &heap_type->ht_type; -#if defined(PYPY_VERSION) - pybind11_fail(std::string(type->tp_name) + ": dynamic attributes are " - "currently not supported in " - "conjunction with PyPy!"); -#endif - type->tp_flags |= Py_TPFLAGS_HAVE_GC; - type->tp_dictoffset = type->tp_basicsize; // place dict at the end - type->tp_basicsize += (ssize_t)sizeof(PyObject *); // and allocate enough space for it - type->tp_traverse = pybind11_traverse; - type->tp_clear = pybind11_clear; - - static PyGetSetDef getset[] = { - {const_cast("__dict__"), pybind11_get_dict, pybind11_set_dict, nullptr, nullptr}, - {nullptr, nullptr, nullptr, nullptr, nullptr} - }; - type->tp_getset = getset; -} - -/// buffer_protocol: Fill in the view as specified by flags. -extern "C" inline int pybind11_getbuffer(PyObject *obj, Py_buffer *view, int flags) { - // Look for a `get_buffer` implementation in this type's info or any bases (following MRO). - type_info *tinfo = nullptr; - for (auto type : reinterpret_borrow(Py_TYPE(obj)->tp_mro)) { - tinfo = get_type_info((PyTypeObject *) type.ptr()); - if (tinfo && tinfo->get_buffer) - break; - } - if (view == nullptr || !tinfo || !tinfo->get_buffer) { - if (view) - view->obj = nullptr; - PyErr_SetString(PyExc_BufferError, "pybind11_getbuffer(): Internal error"); - return -1; - } - std::memset(view, 0, sizeof(Py_buffer)); - buffer_info *info = tinfo->get_buffer(obj, tinfo->get_buffer_data); - view->obj = obj; - view->ndim = 1; - view->internal = info; - view->buf = info->ptr; - view->itemsize = info->itemsize; - view->len = view->itemsize; - for (auto s : info->shape) - view->len *= s; - if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) - view->format = const_cast(info->format.c_str()); - if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { - view->ndim = (int) info->ndim; - view->strides = &info->strides[0]; - view->shape = &info->shape[0]; - } - Py_INCREF(view->obj); - return 0; -} - -/// buffer_protocol: Release the resources of the buffer. -extern "C" inline void pybind11_releasebuffer(PyObject *, Py_buffer *view) { - delete (buffer_info *) view->internal; -} - -/// Give this type a buffer interface. -inline void enable_buffer_protocol(PyHeapTypeObject *heap_type) { - heap_type->ht_type.tp_as_buffer = &heap_type->as_buffer; -#if PY_MAJOR_VERSION < 3 - heap_type->ht_type.tp_flags |= Py_TPFLAGS_HAVE_NEWBUFFER; -#endif - - heap_type->as_buffer.bf_getbuffer = pybind11_getbuffer; - heap_type->as_buffer.bf_releasebuffer = pybind11_releasebuffer; -} - -/** Create a brand new Python type according to the `type_record` specification. - Return value: New reference. */ -inline PyObject* make_new_python_type(const type_record &rec) { - auto name = reinterpret_steal(PYBIND11_FROM_STRING(rec.name)); - - auto qualname = name; - if (rec.scope && !PyModule_Check(rec.scope.ptr()) && hasattr(rec.scope, "__qualname__")) { -#if PY_MAJOR_VERSION >= 3 - qualname = reinterpret_steal( - PyUnicode_FromFormat("%U.%U", rec.scope.attr("__qualname__").ptr(), name.ptr())); -#else - qualname = str(rec.scope.attr("__qualname__").cast() + "." + rec.name); -#endif - } - - object module; - if (rec.scope) { - if (hasattr(rec.scope, "__module__")) - module = rec.scope.attr("__module__"); - else if (hasattr(rec.scope, "__name__")) - module = rec.scope.attr("__name__"); - } - - auto full_name = c_str( -#if !defined(PYPY_VERSION) - module ? str(module).cast() + "." + rec.name : -#endif - rec.name); - - char *tp_doc = nullptr; - if (rec.doc && options::show_user_defined_docstrings()) { - /* Allocate memory for docstring (using PyObject_MALLOC, since - Python will free this later on) */ - size_t size = strlen(rec.doc) + 1; - tp_doc = (char *) PyObject_MALLOC(size); - memcpy((void *) tp_doc, rec.doc, size); - } - - auto &internals = get_internals(); - auto bases = tuple(rec.bases); - auto base = (bases.size() == 0) ? internals.instance_base - : bases[0].ptr(); - - /* Danger zone: from now (and until PyType_Ready), make sure to - issue no Python C API calls which could potentially invoke the - garbage collector (the GC will call type_traverse(), which will in - turn find the newly constructed type in an invalid state) */ - auto metaclass = rec.metaclass.ptr() ? (PyTypeObject *) rec.metaclass.ptr() - : internals.default_metaclass; - - auto heap_type = (PyHeapTypeObject *) metaclass->tp_alloc(metaclass, 0); - if (!heap_type) - pybind11_fail(std::string(rec.name) + ": Unable to create type object!"); - - heap_type->ht_name = name.release().ptr(); -#ifdef PYBIND11_BUILTIN_QUALNAME - heap_type->ht_qualname = qualname.inc_ref().ptr(); -#endif - - auto type = &heap_type->ht_type; - type->tp_name = full_name; - type->tp_doc = tp_doc; - type->tp_base = type_incref((PyTypeObject *)base); - type->tp_basicsize = static_cast(sizeof(instance)); - if (bases.size() > 0) - type->tp_bases = bases.release().ptr(); - - /* Don't inherit base __init__ */ - type->tp_init = pybind11_object_init; - - /* Supported protocols */ - type->tp_as_number = &heap_type->as_number; - type->tp_as_sequence = &heap_type->as_sequence; - type->tp_as_mapping = &heap_type->as_mapping; -#if PY_VERSION_HEX >= 0x03050000 - type->tp_as_async = &heap_type->as_async; -#endif - - /* Flags */ - type->tp_flags |= Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; -#if PY_MAJOR_VERSION < 3 - type->tp_flags |= Py_TPFLAGS_CHECKTYPES; -#endif - - if (rec.dynamic_attr) - enable_dynamic_attributes(heap_type); - - if (rec.buffer_protocol) - enable_buffer_protocol(heap_type); - - if (PyType_Ready(type) < 0) - pybind11_fail(std::string(rec.name) + ": PyType_Ready failed (" + error_string() + ")!"); - - assert(rec.dynamic_attr ? PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) - : !PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC)); - - /* Register type with the parent scope */ - if (rec.scope) - setattr(rec.scope, rec.name, (PyObject *) type); - else - Py_INCREF(type); // Keep it alive forever (reference leak) - - if (module) // Needed by pydoc - setattr((PyObject *) type, "__module__", module); - - PYBIND11_SET_OLDPY_QUALNAME(type, qualname); - - return (PyObject *) type; -} - -NAMESPACE_END(detail) -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/detail/common.h b/pybind11/include/pybind11/detail/common.h deleted file mode 100644 index 6da5470..0000000 --- a/pybind11/include/pybind11/detail/common.h +++ /dev/null @@ -1,808 +0,0 @@ -/* - pybind11/detail/common.h -- Basic macros - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#if !defined(NAMESPACE_BEGIN) -# define NAMESPACE_BEGIN(name) namespace name { -#endif -#if !defined(NAMESPACE_END) -# define NAMESPACE_END(name) } -#endif - -// Robust support for some features and loading modules compiled against different pybind versions -// requires forcing hidden visibility on pybind code, so we enforce this by setting the attribute on -// the main `pybind11` namespace. -#if !defined(PYBIND11_NAMESPACE) -# ifdef __GNUG__ -# define PYBIND11_NAMESPACE pybind11 __attribute__((visibility("hidden"))) -# else -# define PYBIND11_NAMESPACE pybind11 -# endif -#endif - -#if !(defined(_MSC_VER) && __cplusplus == 199711L) && !defined(__INTEL_COMPILER) -# if __cplusplus >= 201402L -# define PYBIND11_CPP14 -# if __cplusplus >= 201703L -# define PYBIND11_CPP17 -# endif -# endif -#elif defined(_MSC_VER) && __cplusplus == 199711L -// MSVC sets _MSVC_LANG rather than __cplusplus (supposedly until the standard is fully implemented) -// Unless you use the /Zc:__cplusplus flag on Visual Studio 2017 15.7 Preview 3 or newer -# if _MSVC_LANG >= 201402L -# define PYBIND11_CPP14 -# if _MSVC_LANG > 201402L && _MSC_VER >= 1910 -# define PYBIND11_CPP17 -# endif -# endif -#endif - -// Compiler version assertions -#if defined(__INTEL_COMPILER) -# if __INTEL_COMPILER < 1700 -# error pybind11 requires Intel C++ compiler v17 or newer -# endif -#elif defined(__clang__) && !defined(__apple_build_version__) -# if __clang_major__ < 3 || (__clang_major__ == 3 && __clang_minor__ < 3) -# error pybind11 requires clang 3.3 or newer -# endif -#elif defined(__clang__) -// Apple changes clang version macros to its Xcode version; the first Xcode release based on -// (upstream) clang 3.3 was Xcode 5: -# if __clang_major__ < 5 -# error pybind11 requires Xcode/clang 5.0 or newer -# endif -#elif defined(__GNUG__) -# if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 8) -# error pybind11 requires gcc 4.8 or newer -# endif -#elif defined(_MSC_VER) -// Pybind hits various compiler bugs in 2015u2 and earlier, and also makes use of some stl features -// (e.g. std::negation) added in 2015u3: -# if _MSC_FULL_VER < 190024210 -# error pybind11 requires MSVC 2015 update 3 or newer -# endif -#endif - -#if !defined(PYBIND11_EXPORT) -# if defined(WIN32) || defined(_WIN32) -# define PYBIND11_EXPORT __declspec(dllexport) -# else -# define PYBIND11_EXPORT __attribute__ ((visibility("default"))) -# endif -#endif - -#if defined(_MSC_VER) -# define PYBIND11_NOINLINE __declspec(noinline) -#else -# define PYBIND11_NOINLINE __attribute__ ((noinline)) -#endif - -#if defined(PYBIND11_CPP14) -# define PYBIND11_DEPRECATED(reason) [[deprecated(reason)]] -#else -# define PYBIND11_DEPRECATED(reason) __attribute__((deprecated(reason))) -#endif - -#define PYBIND11_VERSION_MAJOR 2 -#define PYBIND11_VERSION_MINOR 4 -#define PYBIND11_VERSION_PATCH 3 - -/// Include Python header, disable linking to pythonX_d.lib on Windows in debug mode -#if defined(_MSC_VER) -# if (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 4) -# define HAVE_ROUND 1 -# endif -# pragma warning(push) -# pragma warning(disable: 4510 4610 4512 4005) -# if defined(_DEBUG) -# define PYBIND11_DEBUG_MARKER -# undef _DEBUG -# endif -#endif - -#include -#include -#include - -#if defined(isalnum) -# undef isalnum -# undef isalpha -# undef islower -# undef isspace -# undef isupper -# undef tolower -# undef toupper -#endif - -#if defined(_MSC_VER) -# if defined(PYBIND11_DEBUG_MARKER) -# define _DEBUG -# undef PYBIND11_DEBUG_MARKER -# endif -# pragma warning(pop) -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if PY_MAJOR_VERSION >= 3 /// Compatibility macros for various Python versions -#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyInstanceMethod_New(ptr) -#define PYBIND11_INSTANCE_METHOD_CHECK PyInstanceMethod_Check -#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyInstanceMethod_GET_FUNCTION -#define PYBIND11_BYTES_CHECK PyBytes_Check -#define PYBIND11_BYTES_FROM_STRING PyBytes_FromString -#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyBytes_FromStringAndSize -#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyBytes_AsStringAndSize -#define PYBIND11_BYTES_AS_STRING PyBytes_AsString -#define PYBIND11_BYTES_SIZE PyBytes_Size -#define PYBIND11_LONG_CHECK(o) PyLong_Check(o) -#define PYBIND11_LONG_AS_LONGLONG(o) PyLong_AsLongLong(o) -#define PYBIND11_LONG_FROM_SIGNED(o) PyLong_FromSsize_t((ssize_t) o) -#define PYBIND11_LONG_FROM_UNSIGNED(o) PyLong_FromSize_t((size_t) o) -#define PYBIND11_BYTES_NAME "bytes" -#define PYBIND11_STRING_NAME "str" -#define PYBIND11_SLICE_OBJECT PyObject -#define PYBIND11_FROM_STRING PyUnicode_FromString -#define PYBIND11_STR_TYPE ::pybind11::str -#define PYBIND11_BOOL_ATTR "__bool__" -#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_bool) -// Providing a separate declaration to make Clang's -Wmissing-prototypes happy -#define PYBIND11_PLUGIN_IMPL(name) \ - extern "C" PYBIND11_EXPORT PyObject *PyInit_##name(); \ - extern "C" PYBIND11_EXPORT PyObject *PyInit_##name() - -#else -#define PYBIND11_INSTANCE_METHOD_NEW(ptr, class_) PyMethod_New(ptr, nullptr, class_) -#define PYBIND11_INSTANCE_METHOD_CHECK PyMethod_Check -#define PYBIND11_INSTANCE_METHOD_GET_FUNCTION PyMethod_GET_FUNCTION -#define PYBIND11_BYTES_CHECK PyString_Check -#define PYBIND11_BYTES_FROM_STRING PyString_FromString -#define PYBIND11_BYTES_FROM_STRING_AND_SIZE PyString_FromStringAndSize -#define PYBIND11_BYTES_AS_STRING_AND_SIZE PyString_AsStringAndSize -#define PYBIND11_BYTES_AS_STRING PyString_AsString -#define PYBIND11_BYTES_SIZE PyString_Size -#define PYBIND11_LONG_CHECK(o) (PyInt_Check(o) || PyLong_Check(o)) -#define PYBIND11_LONG_AS_LONGLONG(o) (PyInt_Check(o) ? (long long) PyLong_AsLong(o) : PyLong_AsLongLong(o)) -#define PYBIND11_LONG_FROM_SIGNED(o) PyInt_FromSsize_t((ssize_t) o) // Returns long if needed. -#define PYBIND11_LONG_FROM_UNSIGNED(o) PyInt_FromSize_t((size_t) o) // Returns long if needed. -#define PYBIND11_BYTES_NAME "str" -#define PYBIND11_STRING_NAME "unicode" -#define PYBIND11_SLICE_OBJECT PySliceObject -#define PYBIND11_FROM_STRING PyString_FromString -#define PYBIND11_STR_TYPE ::pybind11::bytes -#define PYBIND11_BOOL_ATTR "__nonzero__" -#define PYBIND11_NB_BOOL(ptr) ((ptr)->nb_nonzero) -// Providing a separate PyInit decl to make Clang's -Wmissing-prototypes happy -#define PYBIND11_PLUGIN_IMPL(name) \ - static PyObject *pybind11_init_wrapper(); \ - extern "C" PYBIND11_EXPORT void init##name(); \ - extern "C" PYBIND11_EXPORT void init##name() { \ - (void)pybind11_init_wrapper(); \ - } \ - PyObject *pybind11_init_wrapper() -#endif - -#if PY_VERSION_HEX >= 0x03050000 && PY_VERSION_HEX < 0x03050200 -extern "C" { - struct _Py_atomic_address { void *value; }; - PyAPI_DATA(_Py_atomic_address) _PyThreadState_Current; -} -#endif - -#define PYBIND11_TRY_NEXT_OVERLOAD ((PyObject *) 1) // special failure return code -#define PYBIND11_STRINGIFY(x) #x -#define PYBIND11_TOSTRING(x) PYBIND11_STRINGIFY(x) -#define PYBIND11_CONCAT(first, second) first##second - -#define PYBIND11_CHECK_PYTHON_VERSION \ - { \ - const char *compiled_ver = PYBIND11_TOSTRING(PY_MAJOR_VERSION) \ - "." PYBIND11_TOSTRING(PY_MINOR_VERSION); \ - const char *runtime_ver = Py_GetVersion(); \ - size_t len = std::strlen(compiled_ver); \ - if (std::strncmp(runtime_ver, compiled_ver, len) != 0 \ - || (runtime_ver[len] >= '0' && runtime_ver[len] <= '9')) { \ - PyErr_Format(PyExc_ImportError, \ - "Python version mismatch: module was compiled for Python %s, " \ - "but the interpreter version is incompatible: %s.", \ - compiled_ver, runtime_ver); \ - return nullptr; \ - } \ - } - -#define PYBIND11_CATCH_INIT_EXCEPTIONS \ - catch (pybind11::error_already_set &e) { \ - PyErr_SetString(PyExc_ImportError, e.what()); \ - return nullptr; \ - } catch (const std::exception &e) { \ - PyErr_SetString(PyExc_ImportError, e.what()); \ - return nullptr; \ - } \ - -/** \rst - ***Deprecated in favor of PYBIND11_MODULE*** - - This macro creates the entry point that will be invoked when the Python interpreter - imports a plugin library. Please create a `module` in the function body and return - the pointer to its underlying Python object at the end. - - .. code-block:: cpp - - PYBIND11_PLUGIN(example) { - pybind11::module m("example", "pybind11 example plugin"); - /// Set up bindings here - return m.ptr(); - } -\endrst */ -#define PYBIND11_PLUGIN(name) \ - PYBIND11_DEPRECATED("PYBIND11_PLUGIN is deprecated, use PYBIND11_MODULE") \ - static PyObject *pybind11_init(); \ - PYBIND11_PLUGIN_IMPL(name) { \ - PYBIND11_CHECK_PYTHON_VERSION \ - try { \ - return pybind11_init(); \ - } PYBIND11_CATCH_INIT_EXCEPTIONS \ - } \ - PyObject *pybind11_init() - -/** \rst - This macro creates the entry point that will be invoked when the Python interpreter - imports an extension module. The module name is given as the fist argument and it - should not be in quotes. The second macro argument defines a variable of type - `py::module` which can be used to initialize the module. - - .. code-block:: cpp - - PYBIND11_MODULE(example, m) { - m.doc() = "pybind11 example module"; - - // Add bindings here - m.def("foo", []() { - return "Hello, World!"; - }); - } -\endrst */ -#define PYBIND11_MODULE(name, variable) \ - static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ - PYBIND11_PLUGIN_IMPL(name) { \ - PYBIND11_CHECK_PYTHON_VERSION \ - auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ - try { \ - PYBIND11_CONCAT(pybind11_init_, name)(m); \ - return m.ptr(); \ - } PYBIND11_CATCH_INIT_EXCEPTIONS \ - } \ - void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) - - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) - -using ssize_t = Py_ssize_t; -using size_t = std::size_t; - -/// Approach used to cast a previously unknown C++ instance into a Python object -enum class return_value_policy : uint8_t { - /** This is the default return value policy, which falls back to the policy - return_value_policy::take_ownership when the return value is a pointer. - Otherwise, it uses return_value::move or return_value::copy for rvalue - and lvalue references, respectively. See below for a description of what - all of these different policies do. */ - automatic = 0, - - /** As above, but use policy return_value_policy::reference when the return - value is a pointer. This is the default conversion policy for function - arguments when calling Python functions manually from C++ code (i.e. via - handle::operator()). You probably won't need to use this. */ - automatic_reference, - - /** Reference an existing object (i.e. do not create a new copy) and take - ownership. Python will call the destructor and delete operator when the - object’s reference count reaches zero. Undefined behavior ensues when - the C++ side does the same.. */ - take_ownership, - - /** Create a new copy of the returned object, which will be owned by - Python. This policy is comparably safe because the lifetimes of the two - instances are decoupled. */ - copy, - - /** Use std::move to move the return value contents into a new instance - that will be owned by Python. This policy is comparably safe because the - lifetimes of the two instances (move source and destination) are - decoupled. */ - move, - - /** Reference an existing object, but do not take ownership. The C++ side - is responsible for managing the object’s lifetime and deallocating it - when it is no longer used. Warning: undefined behavior will ensue when - the C++ side deletes an object that is still referenced and used by - Python. */ - reference, - - /** This policy only applies to methods and properties. It references the - object without taking ownership similar to the above - return_value_policy::reference policy. In contrast to that policy, the - function or property’s implicit this argument (called the parent) is - considered to be the the owner of the return value (the child). - pybind11 then couples the lifetime of the parent to the child via a - reference relationship that ensures that the parent cannot be garbage - collected while Python is still using the child. More advanced - variations of this scheme are also possible using combinations of - return_value_policy::reference and the keep_alive call policy */ - reference_internal -}; - -NAMESPACE_BEGIN(detail) - -inline static constexpr int log2(size_t n, int k = 0) { return (n <= 1) ? k : log2(n >> 1, k + 1); } - -// Returns the size as a multiple of sizeof(void *), rounded up. -inline static constexpr size_t size_in_ptrs(size_t s) { return 1 + ((s - 1) >> log2(sizeof(void *))); } - -/** - * The space to allocate for simple layout instance holders (see below) in multiple of the size of - * a pointer (e.g. 2 means 16 bytes on 64-bit architectures). The default is the minimum required - * to holder either a std::unique_ptr or std::shared_ptr (which is almost always - * sizeof(std::shared_ptr)). - */ -constexpr size_t instance_simple_holder_in_ptrs() { - static_assert(sizeof(std::shared_ptr) >= sizeof(std::unique_ptr), - "pybind assumes std::shared_ptrs are at least as big as std::unique_ptrs"); - return size_in_ptrs(sizeof(std::shared_ptr)); -} - -// Forward declarations -struct type_info; -struct value_and_holder; - -struct nonsimple_values_and_holders { - void **values_and_holders; - uint8_t *status; -}; - -/// The 'instance' type which needs to be standard layout (need to be able to use 'offsetof') -struct instance { - PyObject_HEAD - /// Storage for pointers and holder; see simple_layout, below, for a description - union { - void *simple_value_holder[1 + instance_simple_holder_in_ptrs()]; - nonsimple_values_and_holders nonsimple; - }; - /// Weak references - PyObject *weakrefs; - /// If true, the pointer is owned which means we're free to manage it with a holder. - bool owned : 1; - /** - * An instance has two possible value/holder layouts. - * - * Simple layout (when this flag is true), means the `simple_value_holder` is set with a pointer - * and the holder object governing that pointer, i.e. [val1*][holder]. This layout is applied - * whenever there is no python-side multiple inheritance of bound C++ types *and* the type's - * holder will fit in the default space (which is large enough to hold either a std::unique_ptr - * or std::shared_ptr). - * - * Non-simple layout applies when using custom holders that require more space than `shared_ptr` - * (which is typically the size of two pointers), or when multiple inheritance is used on the - * python side. Non-simple layout allocates the required amount of memory to have multiple - * bound C++ classes as parents. Under this layout, `nonsimple.values_and_holders` is set to a - * pointer to allocated space of the required space to hold a sequence of value pointers and - * holders followed `status`, a set of bit flags (1 byte each), i.e. - * [val1*][holder1][val2*][holder2]...[bb...] where each [block] is rounded up to a multiple of - * `sizeof(void *)`. `nonsimple.status` is, for convenience, a pointer to the - * beginning of the [bb...] block (but not independently allocated). - * - * Status bits indicate whether the associated holder is constructed (& - * status_holder_constructed) and whether the value pointer is registered (& - * status_instance_registered) in `registered_instances`. - */ - bool simple_layout : 1; - /// For simple layout, tracks whether the holder has been constructed - bool simple_holder_constructed : 1; - /// For simple layout, tracks whether the instance is registered in `registered_instances` - bool simple_instance_registered : 1; - /// If true, get_internals().patients has an entry for this object - bool has_patients : 1; - - /// Initializes all of the above type/values/holders data (but not the instance values themselves) - void allocate_layout(); - - /// Destroys/deallocates all of the above - void deallocate_layout(); - - /// Returns the value_and_holder wrapper for the given type (or the first, if `find_type` - /// omitted). Returns a default-constructed (with `.inst = nullptr`) object on failure if - /// `throw_if_missing` is false. - value_and_holder get_value_and_holder(const type_info *find_type = nullptr, bool throw_if_missing = true); - - /// Bit values for the non-simple status flags - static constexpr uint8_t status_holder_constructed = 1; - static constexpr uint8_t status_instance_registered = 2; -}; - -static_assert(std::is_standard_layout::value, "Internal error: `pybind11::detail::instance` is not standard layout!"); - -/// from __cpp_future__ import (convenient aliases from C++14/17) -#if defined(PYBIND11_CPP14) && (!defined(_MSC_VER) || _MSC_VER >= 1910) -using std::enable_if_t; -using std::conditional_t; -using std::remove_cv_t; -using std::remove_reference_t; -#else -template using enable_if_t = typename std::enable_if::type; -template using conditional_t = typename std::conditional::type; -template using remove_cv_t = typename std::remove_cv::type; -template using remove_reference_t = typename std::remove_reference::type; -#endif - -/// Index sequences -#if defined(PYBIND11_CPP14) -using std::index_sequence; -using std::make_index_sequence; -#else -template struct index_sequence { }; -template struct make_index_sequence_impl : make_index_sequence_impl { }; -template struct make_index_sequence_impl <0, S...> { typedef index_sequence type; }; -template using make_index_sequence = typename make_index_sequence_impl::type; -#endif - -/// Make an index sequence of the indices of true arguments -template struct select_indices_impl { using type = ISeq; }; -template struct select_indices_impl, I, B, Bs...> - : select_indices_impl, index_sequence>, I + 1, Bs...> {}; -template using select_indices = typename select_indices_impl, 0, Bs...>::type; - -/// Backports of std::bool_constant and std::negation to accommodate older compilers -template using bool_constant = std::integral_constant; -template struct negation : bool_constant { }; - -template struct void_t_impl { using type = void; }; -template using void_t = typename void_t_impl::type; - -/// Compile-time all/any/none of that check the boolean value of all template types -#if defined(__cpp_fold_expressions) && !(defined(_MSC_VER) && (_MSC_VER < 1916)) -template using all_of = bool_constant<(Ts::value && ...)>; -template using any_of = bool_constant<(Ts::value || ...)>; -#elif !defined(_MSC_VER) -template struct bools {}; -template using all_of = std::is_same< - bools, - bools>; -template using any_of = negation...>>; -#else -// MSVC has trouble with the above, but supports std::conjunction, which we can use instead (albeit -// at a slight loss of compilation efficiency). -template using all_of = std::conjunction; -template using any_of = std::disjunction; -#endif -template using none_of = negation>; - -template class... Predicates> using satisfies_all_of = all_of...>; -template class... Predicates> using satisfies_any_of = any_of...>; -template class... Predicates> using satisfies_none_of = none_of...>; - -/// Strip the class from a method type -template struct remove_class { }; -template struct remove_class { typedef R type(A...); }; -template struct remove_class { typedef R type(A...); }; - -/// Helper template to strip away type modifiers -template struct intrinsic_type { typedef T type; }; -template struct intrinsic_type { typedef typename intrinsic_type::type type; }; -template struct intrinsic_type { typedef typename intrinsic_type::type type; }; -template struct intrinsic_type { typedef typename intrinsic_type::type type; }; -template struct intrinsic_type { typedef typename intrinsic_type::type type; }; -template struct intrinsic_type { typedef typename intrinsic_type::type type; }; -template struct intrinsic_type { typedef typename intrinsic_type::type type; }; -template using intrinsic_t = typename intrinsic_type::type; - -/// Helper type to replace 'void' in some expressions -struct void_type { }; - -/// Helper template which holds a list of types -template struct type_list { }; - -/// Compile-time integer sum -#ifdef __cpp_fold_expressions -template constexpr size_t constexpr_sum(Ts... ns) { return (0 + ... + size_t{ns}); } -#else -constexpr size_t constexpr_sum() { return 0; } -template -constexpr size_t constexpr_sum(T n, Ts... ns) { return size_t{n} + constexpr_sum(ns...); } -#endif - -NAMESPACE_BEGIN(constexpr_impl) -/// Implementation details for constexpr functions -constexpr int first(int i) { return i; } -template -constexpr int first(int i, T v, Ts... vs) { return v ? i : first(i + 1, vs...); } - -constexpr int last(int /*i*/, int result) { return result; } -template -constexpr int last(int i, int result, T v, Ts... vs) { return last(i + 1, v ? i : result, vs...); } -NAMESPACE_END(constexpr_impl) - -/// Return the index of the first type in Ts which satisfies Predicate. Returns sizeof...(Ts) if -/// none match. -template class Predicate, typename... Ts> -constexpr int constexpr_first() { return constexpr_impl::first(0, Predicate::value...); } - -/// Return the index of the last type in Ts which satisfies Predicate, or -1 if none match. -template class Predicate, typename... Ts> -constexpr int constexpr_last() { return constexpr_impl::last(0, -1, Predicate::value...); } - -/// Return the Nth element from the parameter pack -template -struct pack_element { using type = typename pack_element::type; }; -template -struct pack_element<0, T, Ts...> { using type = T; }; - -/// Return the one and only type which matches the predicate, or Default if none match. -/// If more than one type matches the predicate, fail at compile-time. -template class Predicate, typename Default, typename... Ts> -struct exactly_one { - static constexpr auto found = constexpr_sum(Predicate::value...); - static_assert(found <= 1, "Found more than one type matching the predicate"); - - static constexpr auto index = found ? constexpr_first() : 0; - using type = conditional_t::type, Default>; -}; -template class P, typename Default> -struct exactly_one { using type = Default; }; - -template class Predicate, typename Default, typename... Ts> -using exactly_one_t = typename exactly_one::type; - -/// Defer the evaluation of type T until types Us are instantiated -template struct deferred_type { using type = T; }; -template using deferred_t = typename deferred_type::type; - -/// Like is_base_of, but requires a strict base (i.e. `is_strict_base_of::value == false`, -/// unlike `std::is_base_of`) -template using is_strict_base_of = bool_constant< - std::is_base_of::value && !std::is_same::value>; - -/// Like is_base_of, but also requires that the base type is accessible (i.e. that a Derived pointer -/// can be converted to a Base pointer) -template using is_accessible_base_of = bool_constant< - std::is_base_of::value && std::is_convertible::value>; - -template class Base> -struct is_template_base_of_impl { - template static std::true_type check(Base *); - static std::false_type check(...); -}; - -/// Check if a template is the base of a type. For example: -/// `is_template_base_of` is true if `struct T : Base {}` where U can be anything -template class Base, typename T> -#if !defined(_MSC_VER) -using is_template_base_of = decltype(is_template_base_of_impl::check((intrinsic_t*)nullptr)); -#else // MSVC2015 has trouble with decltype in template aliases -struct is_template_base_of : decltype(is_template_base_of_impl::check((intrinsic_t*)nullptr)) { }; -#endif - -/// Check if T is an instantiation of the template `Class`. For example: -/// `is_instantiation` is true if `T == shared_ptr` where U can be anything. -template class Class, typename T> -struct is_instantiation : std::false_type { }; -template class Class, typename... Us> -struct is_instantiation> : std::true_type { }; - -/// Check if T is std::shared_ptr where U can be anything -template using is_shared_ptr = is_instantiation; - -/// Check if T looks like an input iterator -template struct is_input_iterator : std::false_type {}; -template -struct is_input_iterator()), decltype(++std::declval())>> - : std::true_type {}; - -template using is_function_pointer = bool_constant< - std::is_pointer::value && std::is_function::type>::value>; - -template struct strip_function_object { - using type = typename remove_class::type; -}; - -// Extracts the function signature from a function, function pointer or lambda. -template > -using function_signature_t = conditional_t< - std::is_function::value, - F, - typename conditional_t< - std::is_pointer::value || std::is_member_pointer::value, - std::remove_pointer, - strip_function_object - >::type ->; - -/// Returns true if the type looks like a lambda: that is, isn't a function, pointer or member -/// pointer. Note that this can catch all sorts of other things, too; this is intended to be used -/// in a place where passing a lambda makes sense. -template using is_lambda = satisfies_none_of, - std::is_function, std::is_pointer, std::is_member_pointer>; - -/// Ignore that a variable is unused in compiler warnings -inline void ignore_unused(const int *) { } - -/// Apply a function over each element of a parameter pack -#ifdef __cpp_fold_expressions -#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) (((PATTERN), void()), ...) -#else -using expand_side_effects = bool[]; -#define PYBIND11_EXPAND_SIDE_EFFECTS(PATTERN) pybind11::detail::expand_side_effects{ ((PATTERN), void(), false)..., false } -#endif - -NAMESPACE_END(detail) - -/// C++ bindings of builtin Python exceptions -class builtin_exception : public std::runtime_error { -public: - using std::runtime_error::runtime_error; - /// Set the error using the Python C API - virtual void set_error() const = 0; -}; - -#define PYBIND11_RUNTIME_EXCEPTION(name, type) \ - class name : public builtin_exception { public: \ - using builtin_exception::builtin_exception; \ - name() : name("") { } \ - void set_error() const override { PyErr_SetString(type, what()); } \ - }; - -PYBIND11_RUNTIME_EXCEPTION(stop_iteration, PyExc_StopIteration) -PYBIND11_RUNTIME_EXCEPTION(index_error, PyExc_IndexError) -PYBIND11_RUNTIME_EXCEPTION(key_error, PyExc_KeyError) -PYBIND11_RUNTIME_EXCEPTION(value_error, PyExc_ValueError) -PYBIND11_RUNTIME_EXCEPTION(type_error, PyExc_TypeError) -PYBIND11_RUNTIME_EXCEPTION(buffer_error, PyExc_BufferError) -PYBIND11_RUNTIME_EXCEPTION(cast_error, PyExc_RuntimeError) /// Thrown when pybind11::cast or handle::call fail due to a type casting error -PYBIND11_RUNTIME_EXCEPTION(reference_cast_error, PyExc_RuntimeError) /// Used internally - -[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const char *reason) { throw std::runtime_error(reason); } -[[noreturn]] PYBIND11_NOINLINE inline void pybind11_fail(const std::string &reason) { throw std::runtime_error(reason); } - -template struct format_descriptor { }; - -NAMESPACE_BEGIN(detail) -// Returns the index of the given type in the type char array below, and in the list in numpy.h -// The order here is: bool; 8 ints ((signed,unsigned)x(8,16,32,64)bits); float,double,long double; -// complex float,double,long double. Note that the long double types only participate when long -// double is actually longer than double (it isn't under MSVC). -// NB: not only the string below but also complex.h and numpy.h rely on this order. -template struct is_fmt_numeric { static constexpr bool value = false; }; -template struct is_fmt_numeric::value>> { - static constexpr bool value = true; - static constexpr int index = std::is_same::value ? 0 : 1 + ( - std::is_integral::value ? detail::log2(sizeof(T))*2 + std::is_unsigned::value : 8 + ( - std::is_same::value ? 1 : std::is_same::value ? 2 : 0)); -}; -NAMESPACE_END(detail) - -template struct format_descriptor::value>> { - static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric::index]; - static constexpr const char value[2] = { c, '\0' }; - static std::string format() { return std::string(1, c); } -}; - -#if !defined(PYBIND11_CPP17) - -template constexpr const char format_descriptor< - T, detail::enable_if_t::value>>::value[2]; - -#endif - -/// RAII wrapper that temporarily clears any Python error state -struct error_scope { - PyObject *type, *value, *trace; - error_scope() { PyErr_Fetch(&type, &value, &trace); } - ~error_scope() { PyErr_Restore(type, value, trace); } -}; - -/// Dummy destructor wrapper that can be used to expose classes with a private destructor -struct nodelete { template void operator()(T*) { } }; - -NAMESPACE_BEGIN(detail) -template -struct overload_cast_impl { - constexpr overload_cast_impl() {} // MSVC 2015 needs this - - template - constexpr auto operator()(Return (*pf)(Args...)) const noexcept - -> decltype(pf) { return pf; } - - template - constexpr auto operator()(Return (Class::*pmf)(Args...), std::false_type = {}) const noexcept - -> decltype(pmf) { return pmf; } - - template - constexpr auto operator()(Return (Class::*pmf)(Args...) const, std::true_type) const noexcept - -> decltype(pmf) { return pmf; } -}; -NAMESPACE_END(detail) - -// overload_cast requires variable templates: C++14 -#if defined(PYBIND11_CPP14) -#define PYBIND11_OVERLOAD_CAST 1 -/// Syntax sugar for resolving overloaded function pointers: -/// - regular: static_cast(&Class::func) -/// - sweet: overload_cast(&Class::func) -template -static constexpr detail::overload_cast_impl overload_cast = {}; -// MSVC 2015 only accepts this particular initialization syntax for this variable template. -#endif - -/// Const member function selector for overload_cast -/// - regular: static_cast(&Class::func) -/// - sweet: overload_cast(&Class::func, const_) -static constexpr auto const_ = std::true_type{}; - -#if !defined(PYBIND11_CPP14) // no overload_cast: providing something that static_assert-fails: -template struct overload_cast { - static_assert(detail::deferred_t::value, - "pybind11::overload_cast<...> requires compiling in C++14 mode"); -}; -#endif // overload_cast - -NAMESPACE_BEGIN(detail) - -// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from -// any standard container (or C-style array) supporting std::begin/std::end, any singleton -// arithmetic type (if T is arithmetic), or explicitly constructible from an iterator pair. -template -class any_container { - std::vector v; -public: - any_container() = default; - - // Can construct from a pair of iterators - template ::value>> - any_container(It first, It last) : v(first, last) { } - - // Implicit conversion constructor from any arbitrary container type with values convertible to T - template ())), T>::value>> - any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { } - - // initializer_list's aren't deducible, so don't get matched by the above template; we need this - // to explicitly allow implicit conversion from one: - template ::value>> - any_container(const std::initializer_list &c) : any_container(c.begin(), c.end()) { } - - // Avoid copying if given an rvalue vector of the correct type. - any_container(std::vector &&v) : v(std::move(v)) { } - - // Moves the vector out of an rvalue any_container - operator std::vector &&() && { return std::move(v); } - - // Dereferencing obtains a reference to the underlying vector - std::vector &operator*() { return v; } - const std::vector &operator*() const { return v; } - - // -> lets you call methods on the underlying vector - std::vector *operator->() { return &v; } - const std::vector *operator->() const { return &v; } -}; - -NAMESPACE_END(detail) - - - -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/detail/descr.h b/pybind11/include/pybind11/detail/descr.h deleted file mode 100644 index 8d404e5..0000000 --- a/pybind11/include/pybind11/detail/descr.h +++ /dev/null @@ -1,100 +0,0 @@ -/* - pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "common.h" - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -#if !defined(_MSC_VER) -# define PYBIND11_DESCR_CONSTEXPR static constexpr -#else -# define PYBIND11_DESCR_CONSTEXPR const -#endif - -/* Concatenate type signatures at compile time */ -template -struct descr { - char text[N + 1]; - - constexpr descr() : text{'\0'} { } - constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { } - - template - constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { } - - template - constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { } - - static constexpr std::array types() { - return {{&typeid(Ts)..., nullptr}}; - } -}; - -template -constexpr descr plus_impl(const descr &a, const descr &b, - index_sequence, index_sequence) { - return {a.text[Is1]..., b.text[Is2]...}; -} - -template -constexpr descr operator+(const descr &a, const descr &b) { - return plus_impl(a, b, make_index_sequence(), make_index_sequence()); -} - -template -constexpr descr _(char const(&text)[N]) { return descr(text); } -constexpr descr<0> _(char const(&)[1]) { return {}; } - -template struct int_to_str : int_to_str { }; -template struct int_to_str<0, Digits...> { - static constexpr auto digits = descr(('0' + Digits)...); -}; - -// Ternary description (like std::conditional) -template -constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) { - return _(text1); -} -template -constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) { - return _(text2); -} - -template -constexpr enable_if_t _(const T1 &d, const T2 &) { return d; } -template -constexpr enable_if_t _(const T1 &, const T2 &d) { return d; } - -template auto constexpr _() -> decltype(int_to_str::digits) { - return int_to_str::digits; -} - -template constexpr descr<1, Type> _() { return {'%'}; } - -constexpr descr<0> concat() { return {}; } - -template -constexpr descr concat(const descr &descr) { return descr; } - -template -constexpr auto concat(const descr &d, const Args &...args) - -> decltype(std::declval>() + concat(args...)) { - return d + _(", ") + concat(args...); -} - -template -constexpr descr type_descr(const descr &descr) { - return _("{") + descr + _("}"); -} - -NAMESPACE_END(detail) -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/detail/init.h b/pybind11/include/pybind11/detail/init.h deleted file mode 100644 index acfe00b..0000000 --- a/pybind11/include/pybind11/detail/init.h +++ /dev/null @@ -1,335 +0,0 @@ -/* - pybind11/detail/init.h: init factory function implementation and support code. - - Copyright (c) 2017 Jason Rhinelander - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "class.h" - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -template <> -class type_caster { -public: - bool load(handle h, bool) { - value = reinterpret_cast(h.ptr()); - return true; - } - - template using cast_op_type = value_and_holder &; - operator value_and_holder &() { return *value; } - static constexpr auto name = _(); - -private: - value_and_holder *value = nullptr; -}; - -NAMESPACE_BEGIN(initimpl) - -inline void no_nullptr(void *ptr) { - if (!ptr) throw type_error("pybind11::init(): factory function returned nullptr"); -} - -// Implementing functions for all forms of py::init<...> and py::init(...) -template using Cpp = typename Class::type; -template using Alias = typename Class::type_alias; -template using Holder = typename Class::holder_type; - -template using is_alias_constructible = std::is_constructible, Cpp &&>; - -// Takes a Cpp pointer and returns true if it actually is a polymorphic Alias instance. -template = 0> -bool is_alias(Cpp *ptr) { - return dynamic_cast *>(ptr) != nullptr; -} -// Failing fallback version of the above for a no-alias class (always returns false) -template -constexpr bool is_alias(void *) { return false; } - -// Constructs and returns a new object; if the given arguments don't map to a constructor, we fall -// back to brace aggregate initiailization so that for aggregate initialization can be used with -// py::init, e.g. `py::init` to initialize a `struct T { int a; int b; }`. For -// non-aggregate types, we need to use an ordinary T(...) constructor (invoking as `T{...}` usually -// works, but will not do the expected thing when `T` has an `initializer_list` constructor). -template ::value, int> = 0> -inline Class *construct_or_initialize(Args &&...args) { return new Class(std::forward(args)...); } -template ::value, int> = 0> -inline Class *construct_or_initialize(Args &&...args) { return new Class{std::forward(args)...}; } - -// Attempts to constructs an alias using a `Alias(Cpp &&)` constructor. This allows types with -// an alias to provide only a single Cpp factory function as long as the Alias can be -// constructed from an rvalue reference of the base Cpp type. This means that Alias classes -// can, when appropriate, simply define a `Alias(Cpp &&)` constructor rather than needing to -// inherit all the base class constructors. -template -void construct_alias_from_cpp(std::true_type /*is_alias_constructible*/, - value_and_holder &v_h, Cpp &&base) { - v_h.value_ptr() = new Alias(std::move(base)); -} -template -[[noreturn]] void construct_alias_from_cpp(std::false_type /*!is_alias_constructible*/, - value_and_holder &, Cpp &&) { - throw type_error("pybind11::init(): unable to convert returned instance to required " - "alias class: no `Alias(Class &&)` constructor available"); -} - -// Error-generating fallback for factories that don't match one of the below construction -// mechanisms. -template -void construct(...) { - static_assert(!std::is_same::value /* always false */, - "pybind11::init(): init function must return a compatible pointer, " - "holder, or value"); -} - -// Pointer return v1: the factory function returns a class pointer for a registered class. -// If we don't need an alias (because this class doesn't have one, or because the final type is -// inherited on the Python side) we can simply take over ownership. Otherwise we need to try to -// construct an Alias from the returned base instance. -template -void construct(value_and_holder &v_h, Cpp *ptr, bool need_alias) { - no_nullptr(ptr); - if (Class::has_alias && need_alias && !is_alias(ptr)) { - // We're going to try to construct an alias by moving the cpp type. Whether or not - // that succeeds, we still need to destroy the original cpp pointer (either the - // moved away leftover, if the alias construction works, or the value itself if we - // throw an error), but we can't just call `delete ptr`: it might have a special - // deleter, or might be shared_from_this. So we construct a holder around it as if - // it was a normal instance, then steal the holder away into a local variable; thus - // the holder and destruction happens when we leave the C++ scope, and the holder - // class gets to handle the destruction however it likes. - v_h.value_ptr() = ptr; - v_h.set_instance_registered(true); // To prevent init_instance from registering it - v_h.type->init_instance(v_h.inst, nullptr); // Set up the holder - Holder temp_holder(std::move(v_h.holder>())); // Steal the holder - v_h.type->dealloc(v_h); // Destroys the moved-out holder remains, resets value ptr to null - v_h.set_instance_registered(false); - - construct_alias_from_cpp(is_alias_constructible{}, v_h, std::move(*ptr)); - } else { - // Otherwise the type isn't inherited, so we don't need an Alias - v_h.value_ptr() = ptr; - } -} - -// Pointer return v2: a factory that always returns an alias instance ptr. We simply take over -// ownership of the pointer. -template = 0> -void construct(value_and_holder &v_h, Alias *alias_ptr, bool) { - no_nullptr(alias_ptr); - v_h.value_ptr() = static_cast *>(alias_ptr); -} - -// Holder return: copy its pointer, and move or copy the returned holder into the new instance's -// holder. This also handles types like std::shared_ptr and std::unique_ptr where T is a -// derived type (through those holder's implicit conversion from derived class holder constructors). -template -void construct(value_and_holder &v_h, Holder holder, bool need_alias) { - auto *ptr = holder_helper>::get(holder); - // If we need an alias, check that the held pointer is actually an alias instance - if (Class::has_alias && need_alias && !is_alias(ptr)) - throw type_error("pybind11::init(): construction failed: returned holder-wrapped instance " - "is not an alias instance"); - - v_h.value_ptr() = ptr; - v_h.type->init_instance(v_h.inst, &holder); -} - -// return-by-value version 1: returning a cpp class by value. If the class has an alias and an -// alias is required the alias must have an `Alias(Cpp &&)` constructor so that we can construct -// the alias from the base when needed (i.e. because of Python-side inheritance). When we don't -// need it, we simply move-construct the cpp value into a new instance. -template -void construct(value_and_holder &v_h, Cpp &&result, bool need_alias) { - static_assert(std::is_move_constructible>::value, - "pybind11::init() return-by-value factory function requires a movable class"); - if (Class::has_alias && need_alias) - construct_alias_from_cpp(is_alias_constructible{}, v_h, std::move(result)); - else - v_h.value_ptr() = new Cpp(std::move(result)); -} - -// return-by-value version 2: returning a value of the alias type itself. We move-construct an -// Alias instance (even if no the python-side inheritance is involved). The is intended for -// cases where Alias initialization is always desired. -template -void construct(value_and_holder &v_h, Alias &&result, bool) { - static_assert(std::is_move_constructible>::value, - "pybind11::init() return-by-alias-value factory function requires a movable alias class"); - v_h.value_ptr() = new Alias(std::move(result)); -} - -// Implementing class for py::init<...>() -template -struct constructor { - template = 0> - static void execute(Class &cl, const Extra&... extra) { - cl.def("__init__", [](value_and_holder &v_h, Args... args) { - v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); - }, is_new_style_constructor(), extra...); - } - - template , Args...>::value, int> = 0> - static void execute(Class &cl, const Extra&... extra) { - cl.def("__init__", [](value_and_holder &v_h, Args... args) { - if (Py_TYPE(v_h.inst) == v_h.type->type) - v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); - else - v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); - }, is_new_style_constructor(), extra...); - } - - template , Args...>::value, int> = 0> - static void execute(Class &cl, const Extra&... extra) { - cl.def("__init__", [](value_and_holder &v_h, Args... args) { - v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); - }, is_new_style_constructor(), extra...); - } -}; - -// Implementing class for py::init_alias<...>() -template struct alias_constructor { - template , Args...>::value, int> = 0> - static void execute(Class &cl, const Extra&... extra) { - cl.def("__init__", [](value_and_holder &v_h, Args... args) { - v_h.value_ptr() = construct_or_initialize>(std::forward(args)...); - }, is_new_style_constructor(), extra...); - } -}; - -// Implementation class for py::init(Func) and py::init(Func, AliasFunc) -template , typename = function_signature_t> -struct factory; - -// Specialization for py::init(Func) -template -struct factory { - remove_reference_t class_factory; - - factory(Func &&f) : class_factory(std::forward(f)) { } - - // The given class either has no alias or has no separate alias factory; - // this always constructs the class itself. If the class is registered with an alias - // type and an alias instance is needed (i.e. because the final type is a Python class - // inheriting from the C++ type) the returned value needs to either already be an alias - // instance, or the alias needs to be constructible from a `Class &&` argument. - template - void execute(Class &cl, const Extra &...extra) && { - #if defined(PYBIND11_CPP14) - cl.def("__init__", [func = std::move(class_factory)] - #else - auto &func = class_factory; - cl.def("__init__", [func] - #endif - (value_and_holder &v_h, Args... args) { - construct(v_h, func(std::forward(args)...), - Py_TYPE(v_h.inst) != v_h.type->type); - }, is_new_style_constructor(), extra...); - } -}; - -// Specialization for py::init(Func, AliasFunc) -template -struct factory { - static_assert(sizeof...(CArgs) == sizeof...(AArgs), - "pybind11::init(class_factory, alias_factory): class and alias factories " - "must have identical argument signatures"); - static_assert(all_of...>::value, - "pybind11::init(class_factory, alias_factory): class and alias factories " - "must have identical argument signatures"); - - remove_reference_t class_factory; - remove_reference_t alias_factory; - - factory(CFunc &&c, AFunc &&a) - : class_factory(std::forward(c)), alias_factory(std::forward(a)) { } - - // The class factory is called when the `self` type passed to `__init__` is the direct - // class (i.e. not inherited), the alias factory when `self` is a Python-side subtype. - template - void execute(Class &cl, const Extra&... extra) && { - static_assert(Class::has_alias, "The two-argument version of `py::init()` can " - "only be used if the class has an alias"); - #if defined(PYBIND11_CPP14) - cl.def("__init__", [class_func = std::move(class_factory), alias_func = std::move(alias_factory)] - #else - auto &class_func = class_factory; - auto &alias_func = alias_factory; - cl.def("__init__", [class_func, alias_func] - #endif - (value_and_holder &v_h, CArgs... args) { - if (Py_TYPE(v_h.inst) == v_h.type->type) - // If the instance type equals the registered type we don't have inheritance, so - // don't need the alias and can construct using the class function: - construct(v_h, class_func(std::forward(args)...), false); - else - construct(v_h, alias_func(std::forward(args)...), true); - }, is_new_style_constructor(), extra...); - } -}; - -/// Set just the C++ state. Same as `__init__`. -template -void setstate(value_and_holder &v_h, T &&result, bool need_alias) { - construct(v_h, std::forward(result), need_alias); -} - -/// Set both the C++ and Python states -template ::value, int> = 0> -void setstate(value_and_holder &v_h, std::pair &&result, bool need_alias) { - construct(v_h, std::move(result.first), need_alias); - setattr((PyObject *) v_h.inst, "__dict__", result.second); -} - -/// Implementation for py::pickle(GetState, SetState) -template , typename = function_signature_t> -struct pickle_factory; - -template -struct pickle_factory { - static_assert(std::is_same, intrinsic_t>::value, - "The type returned by `__getstate__` must be the same " - "as the argument accepted by `__setstate__`"); - - remove_reference_t get; - remove_reference_t set; - - pickle_factory(Get get, Set set) - : get(std::forward(get)), set(std::forward(set)) { } - - template - void execute(Class &cl, const Extra &...extra) && { - cl.def("__getstate__", std::move(get)); - -#if defined(PYBIND11_CPP14) - cl.def("__setstate__", [func = std::move(set)] -#else - auto &func = set; - cl.def("__setstate__", [func] -#endif - (value_and_holder &v_h, ArgState state) { - setstate(v_h, func(std::forward(state)), - Py_TYPE(v_h.inst) != v_h.type->type); - }, is_new_style_constructor(), extra...); - } -}; - -NAMESPACE_END(initimpl) -NAMESPACE_END(detail) -NAMESPACE_END(pybind11) diff --git a/pybind11/include/pybind11/detail/internals.h b/pybind11/include/pybind11/detail/internals.h deleted file mode 100644 index 067780c..0000000 --- a/pybind11/include/pybind11/detail/internals.h +++ /dev/null @@ -1,336 +0,0 @@ -/* - pybind11/detail/internals.h: Internal data structure and related functions - - Copyright (c) 2017 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "../pytypes.h" - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) -// Forward declarations -inline PyTypeObject *make_static_property_type(); -inline PyTypeObject *make_default_metaclass(); -inline PyObject *make_object_base_type(PyTypeObject *metaclass); - -// The old Python Thread Local Storage (TLS) API is deprecated in Python 3.7 in favor of the new -// Thread Specific Storage (TSS) API. -#if PY_VERSION_HEX >= 0x03070000 -# define PYBIND11_TLS_KEY_INIT(var) Py_tss_t *var = nullptr -# define PYBIND11_TLS_GET_VALUE(key) PyThread_tss_get((key)) -# define PYBIND11_TLS_REPLACE_VALUE(key, value) PyThread_tss_set((key), (value)) -# define PYBIND11_TLS_DELETE_VALUE(key) PyThread_tss_set((key), nullptr) -#else - // Usually an int but a long on Cygwin64 with Python 3.x -# define PYBIND11_TLS_KEY_INIT(var) decltype(PyThread_create_key()) var = 0 -# define PYBIND11_TLS_GET_VALUE(key) PyThread_get_key_value((key)) -# if PY_MAJOR_VERSION < 3 -# define PYBIND11_TLS_DELETE_VALUE(key) \ - PyThread_delete_key_value(key) -# define PYBIND11_TLS_REPLACE_VALUE(key, value) \ - do { \ - PyThread_delete_key_value((key)); \ - PyThread_set_key_value((key), (value)); \ - } while (false) -# else -# define PYBIND11_TLS_DELETE_VALUE(key) \ - PyThread_set_key_value((key), nullptr) -# define PYBIND11_TLS_REPLACE_VALUE(key, value) \ - PyThread_set_key_value((key), (value)) -# endif -#endif - -// Python loads modules by default with dlopen with the RTLD_LOCAL flag; under libc++ and possibly -// other STLs, this means `typeid(A)` from one module won't equal `typeid(A)` from another module -// even when `A` is the same, non-hidden-visibility type (e.g. from a common include). Under -// libstdc++, this doesn't happen: equality and the type_index hash are based on the type name, -// which works. If not under a known-good stl, provide our own name-based hash and equality -// functions that use the type name. -#if defined(__GLIBCXX__) -inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { return lhs == rhs; } -using type_hash = std::hash; -using type_equal_to = std::equal_to; -#else -inline bool same_type(const std::type_info &lhs, const std::type_info &rhs) { - return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; -} - -struct type_hash { - size_t operator()(const std::type_index &t) const { - size_t hash = 5381; - const char *ptr = t.name(); - while (auto c = static_cast(*ptr++)) - hash = (hash * 33) ^ c; - return hash; - } -}; - -struct type_equal_to { - bool operator()(const std::type_index &lhs, const std::type_index &rhs) const { - return lhs.name() == rhs.name() || std::strcmp(lhs.name(), rhs.name()) == 0; - } -}; -#endif - -template -using type_map = std::unordered_map; - -struct overload_hash { - inline size_t operator()(const std::pair& v) const { - size_t value = std::hash()(v.first); - value ^= std::hash()(v.second) + 0x9e3779b9 + (value<<6) + (value>>2); - return value; - } -}; - -/// Internal data structure used to track registered instances and types. -/// Whenever binary incompatible changes are made to this structure, -/// `PYBIND11_INTERNALS_VERSION` must be incremented. -struct internals { - type_map registered_types_cpp; // std::type_index -> pybind11's type information - std::unordered_map> registered_types_py; // PyTypeObject* -> base type_info(s) - std::unordered_multimap registered_instances; // void * -> instance* - std::unordered_set, overload_hash> inactive_overload_cache; - type_map> direct_conversions; - std::unordered_map> patients; - std::forward_list registered_exception_translators; - std::unordered_map shared_data; // Custom data to be shared across extensions - std::vector loader_patient_stack; // Used by `loader_life_support` - std::forward_list static_strings; // Stores the std::strings backing detail::c_str() - PyTypeObject *static_property_type; - PyTypeObject *default_metaclass; - PyObject *instance_base; -#if defined(WITH_THREAD) - PYBIND11_TLS_KEY_INIT(tstate); - PyInterpreterState *istate = nullptr; -#endif -}; - -/// Additional type information which does not fit into the PyTypeObject. -/// Changes to this struct also require bumping `PYBIND11_INTERNALS_VERSION`. -struct type_info { - PyTypeObject *type; - const std::type_info *cpptype; - size_t type_size, type_align, holder_size_in_ptrs; - void *(*operator_new)(size_t); - void (*init_instance)(instance *, const void *); - void (*dealloc)(value_and_holder &v_h); - std::vector implicit_conversions; - std::vector> implicit_casts; - std::vector *direct_conversions; - buffer_info *(*get_buffer)(PyObject *, void *) = nullptr; - void *get_buffer_data = nullptr; - void *(*module_local_load)(PyObject *, const type_info *) = nullptr; - /* A simple type never occurs as a (direct or indirect) parent - * of a class that makes use of multiple inheritance */ - bool simple_type : 1; - /* True if there is no multiple inheritance in this type's inheritance tree */ - bool simple_ancestors : 1; - /* for base vs derived holder_type checks */ - bool default_holder : 1; - /* true if this is a type registered with py::module_local */ - bool module_local : 1; -}; - -/// Tracks the `internals` and `type_info` ABI version independent of the main library version -#define PYBIND11_INTERNALS_VERSION 3 - -/// On MSVC, debug and release builds are not ABI-compatible! -#if defined(_MSC_VER) && defined(_DEBUG) -# define PYBIND11_BUILD_TYPE "_debug" -#else -# define PYBIND11_BUILD_TYPE "" -#endif - -/// Let's assume that different compilers are ABI-incompatible. -#if defined(_MSC_VER) -# define PYBIND11_COMPILER_TYPE "_msvc" -#elif defined(__INTEL_COMPILER) -# define PYBIND11_COMPILER_TYPE "_icc" -#elif defined(__clang__) -# define PYBIND11_COMPILER_TYPE "_clang" -#elif defined(__PGI) -# define PYBIND11_COMPILER_TYPE "_pgi" -#elif defined(__MINGW32__) -# define PYBIND11_COMPILER_TYPE "_mingw" -#elif defined(__CYGWIN__) -# define PYBIND11_COMPILER_TYPE "_gcc_cygwin" -#elif defined(__GNUC__) -# define PYBIND11_COMPILER_TYPE "_gcc" -#else -# define PYBIND11_COMPILER_TYPE "_unknown" -#endif - -#if defined(_LIBCPP_VERSION) -# define PYBIND11_STDLIB "_libcpp" -#elif defined(__GLIBCXX__) || defined(__GLIBCPP__) -# define PYBIND11_STDLIB "_libstdcpp" -#else -# define PYBIND11_STDLIB "" -#endif - -/// On Linux/OSX, changes in __GXX_ABI_VERSION__ indicate ABI incompatibility. -#if defined(__GXX_ABI_VERSION) -# define PYBIND11_BUILD_ABI "_cxxabi" PYBIND11_TOSTRING(__GXX_ABI_VERSION) -#else -# define PYBIND11_BUILD_ABI "" -#endif - -#if defined(WITH_THREAD) -# define PYBIND11_INTERNALS_KIND "" -#else -# define PYBIND11_INTERNALS_KIND "_without_thread" -#endif - -#define PYBIND11_INTERNALS_ID "__pybind11_internals_v" \ - PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_COMPILER_TYPE PYBIND11_STDLIB PYBIND11_BUILD_ABI PYBIND11_BUILD_TYPE "__" - -#define PYBIND11_MODULE_LOCAL_ID "__pybind11_module_local_v" \ - PYBIND11_TOSTRING(PYBIND11_INTERNALS_VERSION) PYBIND11_INTERNALS_KIND PYBIND11_COMPILER_TYPE PYBIND11_STDLIB PYBIND11_BUILD_ABI PYBIND11_BUILD_TYPE "__" - -/// Each module locally stores a pointer to the `internals` data. The data -/// itself is shared among modules with the same `PYBIND11_INTERNALS_ID`. -inline internals **&get_internals_pp() { - static internals **internals_pp = nullptr; - return internals_pp; -} - -inline void translate_exception(std::exception_ptr p) { - try { - if (p) std::rethrow_exception(p); - } catch (error_already_set &e) { e.restore(); return; - } catch (const builtin_exception &e) { e.set_error(); return; - } catch (const std::bad_alloc &e) { PyErr_SetString(PyExc_MemoryError, e.what()); return; - } catch (const std::domain_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; - } catch (const std::invalid_argument &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; - } catch (const std::length_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; - } catch (const std::out_of_range &e) { PyErr_SetString(PyExc_IndexError, e.what()); return; - } catch (const std::range_error &e) { PyErr_SetString(PyExc_ValueError, e.what()); return; - } catch (const std::exception &e) { PyErr_SetString(PyExc_RuntimeError, e.what()); return; - } catch (...) { - PyErr_SetString(PyExc_RuntimeError, "Caught an unknown exception!"); - return; - } -} - -#if !defined(__GLIBCXX__) -inline void translate_local_exception(std::exception_ptr p) { - try { - if (p) std::rethrow_exception(p); - } catch (error_already_set &e) { e.restore(); return; - } catch (const builtin_exception &e) { e.set_error(); return; - } -} -#endif - -/// Return a reference to the current `internals` data -PYBIND11_NOINLINE inline internals &get_internals() { - auto **&internals_pp = get_internals_pp(); - if (internals_pp && *internals_pp) - return **internals_pp; - - // Ensure that the GIL is held since we will need to make Python calls. - // Cannot use py::gil_scoped_acquire here since that constructor calls get_internals. - struct gil_scoped_acquire_local { - gil_scoped_acquire_local() : state (PyGILState_Ensure()) {} - ~gil_scoped_acquire_local() { PyGILState_Release(state); } - const PyGILState_STATE state; - } gil; - - constexpr auto *id = PYBIND11_INTERNALS_ID; - auto builtins = handle(PyEval_GetBuiltins()); - if (builtins.contains(id) && isinstance(builtins[id])) { - internals_pp = static_cast(capsule(builtins[id])); - - // We loaded builtins through python's builtins, which means that our `error_already_set` - // and `builtin_exception` may be different local classes than the ones set up in the - // initial exception translator, below, so add another for our local exception classes. - // - // libstdc++ doesn't require this (types there are identified only by name) -#if !defined(__GLIBCXX__) - (*internals_pp)->registered_exception_translators.push_front(&translate_local_exception); -#endif - } else { - if (!internals_pp) internals_pp = new internals*(); - auto *&internals_ptr = *internals_pp; - internals_ptr = new internals(); -#if defined(WITH_THREAD) - PyEval_InitThreads(); - PyThreadState *tstate = PyThreadState_Get(); - #if PY_VERSION_HEX >= 0x03070000 - internals_ptr->tstate = PyThread_tss_alloc(); - if (!internals_ptr->tstate || PyThread_tss_create(internals_ptr->tstate)) - pybind11_fail("get_internals: could not successfully initialize the TSS key!"); - PyThread_tss_set(internals_ptr->tstate, tstate); - #else - internals_ptr->tstate = PyThread_create_key(); - if (internals_ptr->tstate == -1) - pybind11_fail("get_internals: could not successfully initialize the TLS key!"); - PyThread_set_key_value(internals_ptr->tstate, tstate); - #endif - internals_ptr->istate = tstate->interp; -#endif - builtins[id] = capsule(internals_pp); - internals_ptr->registered_exception_translators.push_front(&translate_exception); - internals_ptr->static_property_type = make_static_property_type(); - internals_ptr->default_metaclass = make_default_metaclass(); - internals_ptr->instance_base = make_object_base_type(internals_ptr->default_metaclass); - } - return **internals_pp; -} - -/// Works like `internals.registered_types_cpp`, but for module-local registered types: -inline type_map ®istered_local_types_cpp() { - static type_map locals{}; - return locals; -} - -/// Constructs a std::string with the given arguments, stores it in `internals`, and returns its -/// `c_str()`. Such strings objects have a long storage duration -- the internal strings are only -/// cleared when the program exits or after interpreter shutdown (when embedding), and so are -/// suitable for c-style strings needed by Python internals (such as PyTypeObject's tp_name). -template -const char *c_str(Args &&...args) { - auto &strings = get_internals().static_strings; - strings.emplace_front(std::forward(args)...); - return strings.front().c_str(); -} - -NAMESPACE_END(detail) - -/// Returns a named pointer that is shared among all extension modules (using the same -/// pybind11 version) running in the current interpreter. Names starting with underscores -/// are reserved for internal usage. Returns `nullptr` if no matching entry was found. -inline PYBIND11_NOINLINE void *get_shared_data(const std::string &name) { - auto &internals = detail::get_internals(); - auto it = internals.shared_data.find(name); - return it != internals.shared_data.end() ? it->second : nullptr; -} - -/// Set the shared data that can be later recovered by `get_shared_data()`. -inline PYBIND11_NOINLINE void *set_shared_data(const std::string &name, void *data) { - detail::get_internals().shared_data[name] = data; - return data; -} - -/// Returns a typed reference to a shared data entry (by using `get_shared_data()`) if -/// such entry exists. Otherwise, a new object of default-constructible type `T` is -/// added to the shared data under the given name and a reference to it is returned. -template -T &get_or_create_shared_data(const std::string &name) { - auto &internals = detail::get_internals(); - auto it = internals.shared_data.find(name); - T *ptr = (T *) (it != internals.shared_data.end() ? it->second : nullptr); - if (!ptr) { - ptr = new T(); - internals.shared_data[name] = ptr; - } - return *ptr; -} - -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/detail/typeid.h b/pybind11/include/pybind11/detail/typeid.h deleted file mode 100644 index 9c8a4fc..0000000 --- a/pybind11/include/pybind11/detail/typeid.h +++ /dev/null @@ -1,55 +0,0 @@ -/* - pybind11/detail/typeid.h: Compiler-independent access to type identifiers - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include -#include - -#if defined(__GNUG__) -#include -#endif - -#include "common.h" - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) -/// Erase all occurrences of a substring -inline void erase_all(std::string &string, const std::string &search) { - for (size_t pos = 0;;) { - pos = string.find(search, pos); - if (pos == std::string::npos) break; - string.erase(pos, search.length()); - } -} - -PYBIND11_NOINLINE inline void clean_type_id(std::string &name) { -#if defined(__GNUG__) - int status = 0; - std::unique_ptr res { - abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free }; - if (status == 0) - name = res.get(); -#else - detail::erase_all(name, "class "); - detail::erase_all(name, "struct "); - detail::erase_all(name, "enum "); -#endif - detail::erase_all(name, "pybind11::"); -} -NAMESPACE_END(detail) - -/// Return a string representation of a C++ type -template static std::string type_id() { - std::string name(typeid(T).name()); - detail::clean_type_id(name); - return name; -} - -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/eigen.h b/pybind11/include/pybind11/eigen.h deleted file mode 100644 index d963d96..0000000 --- a/pybind11/include/pybind11/eigen.h +++ /dev/null @@ -1,607 +0,0 @@ -/* - pybind11/eigen.h: Transparent conversion for dense and sparse Eigen matrices - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "numpy.h" - -#if defined(__INTEL_COMPILER) -# pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem) -#elif defined(__GNUG__) || defined(__clang__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wconversion" -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -# ifdef __clang__ -// Eigen generates a bunch of implicit-copy-constructor-is-deprecated warnings with -Wdeprecated -// under Clang, so disable that warning here: -# pragma GCC diagnostic ignored "-Wdeprecated" -# endif -# if __GNUC__ >= 7 -# pragma GCC diagnostic ignored "-Wint-in-bool-context" -# endif -#endif - -#if defined(_MSC_VER) -# pragma warning(push) -# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant -# pragma warning(disable: 4996) // warning C4996: std::unary_negate is deprecated in C++17 -#endif - -#include -#include - -// Eigen prior to 3.2.7 doesn't have proper move constructors--but worse, some classes get implicit -// move constructors that break things. We could detect this an explicitly copy, but an extra copy -// of matrices seems highly undesirable. -static_assert(EIGEN_VERSION_AT_LEAST(3,2,7), "Eigen support in pybind11 requires Eigen >= 3.2.7"); - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) - -// Provide a convenience alias for easier pass-by-ref usage with fully dynamic strides: -using EigenDStride = Eigen::Stride; -template using EigenDRef = Eigen::Ref; -template using EigenDMap = Eigen::Map; - -NAMESPACE_BEGIN(detail) - -#if EIGEN_VERSION_AT_LEAST(3,3,0) -using EigenIndex = Eigen::Index; -#else -using EigenIndex = EIGEN_DEFAULT_DENSE_INDEX_TYPE; -#endif - -// Matches Eigen::Map, Eigen::Ref, blocks, etc: -template using is_eigen_dense_map = all_of, std::is_base_of, T>>; -template using is_eigen_mutable_map = std::is_base_of, T>; -template using is_eigen_dense_plain = all_of>, is_template_base_of>; -template using is_eigen_sparse = is_template_base_of; -// Test for objects inheriting from EigenBase that aren't captured by the above. This -// basically covers anything that can be assigned to a dense matrix but that don't have a typical -// matrix data layout that can be copied from their .data(). For example, DiagonalMatrix and -// SelfAdjointView fall into this category. -template using is_eigen_other = all_of< - is_template_base_of, - negation, is_eigen_dense_plain, is_eigen_sparse>> ->; - -// Captures numpy/eigen conformability status (returned by EigenProps::conformable()): -template struct EigenConformable { - bool conformable = false; - EigenIndex rows = 0, cols = 0; - EigenDStride stride{0, 0}; // Only valid if negativestrides is false! - bool negativestrides = false; // If true, do not use stride! - - EigenConformable(bool fits = false) : conformable{fits} {} - // Matrix type: - EigenConformable(EigenIndex r, EigenIndex c, - EigenIndex rstride, EigenIndex cstride) : - conformable{true}, rows{r}, cols{c} { - // TODO: when Eigen bug #747 is fixed, remove the tests for non-negativity. http://eigen.tuxfamily.org/bz/show_bug.cgi?id=747 - if (rstride < 0 || cstride < 0) { - negativestrides = true; - } else { - stride = {EigenRowMajor ? rstride : cstride /* outer stride */, - EigenRowMajor ? cstride : rstride /* inner stride */ }; - } - } - // Vector type: - EigenConformable(EigenIndex r, EigenIndex c, EigenIndex stride) - : EigenConformable(r, c, r == 1 ? c*stride : stride, c == 1 ? r : r*stride) {} - - template bool stride_compatible() const { - // To have compatible strides, we need (on both dimensions) one of fully dynamic strides, - // matching strides, or a dimension size of 1 (in which case the stride value is irrelevant) - return - !negativestrides && - (props::inner_stride == Eigen::Dynamic || props::inner_stride == stride.inner() || - (EigenRowMajor ? cols : rows) == 1) && - (props::outer_stride == Eigen::Dynamic || props::outer_stride == stride.outer() || - (EigenRowMajor ? rows : cols) == 1); - } - operator bool() const { return conformable; } -}; - -template struct eigen_extract_stride { using type = Type; }; -template -struct eigen_extract_stride> { using type = StrideType; }; -template -struct eigen_extract_stride> { using type = StrideType; }; - -// Helper struct for extracting information from an Eigen type -template struct EigenProps { - using Type = Type_; - using Scalar = typename Type::Scalar; - using StrideType = typename eigen_extract_stride::type; - static constexpr EigenIndex - rows = Type::RowsAtCompileTime, - cols = Type::ColsAtCompileTime, - size = Type::SizeAtCompileTime; - static constexpr bool - row_major = Type::IsRowMajor, - vector = Type::IsVectorAtCompileTime, // At least one dimension has fixed size 1 - fixed_rows = rows != Eigen::Dynamic, - fixed_cols = cols != Eigen::Dynamic, - fixed = size != Eigen::Dynamic, // Fully-fixed size - dynamic = !fixed_rows && !fixed_cols; // Fully-dynamic size - - template using if_zero = std::integral_constant; - static constexpr EigenIndex inner_stride = if_zero::value, - outer_stride = if_zero::value; - static constexpr bool dynamic_stride = inner_stride == Eigen::Dynamic && outer_stride == Eigen::Dynamic; - static constexpr bool requires_row_major = !dynamic_stride && !vector && (row_major ? inner_stride : outer_stride) == 1; - static constexpr bool requires_col_major = !dynamic_stride && !vector && (row_major ? outer_stride : inner_stride) == 1; - - // Takes an input array and determines whether we can make it fit into the Eigen type. If - // the array is a vector, we attempt to fit it into either an Eigen 1xN or Nx1 vector - // (preferring the latter if it will fit in either, i.e. for a fully dynamic matrix type). - static EigenConformable conformable(const array &a) { - const auto dims = a.ndim(); - if (dims < 1 || dims > 2) - return false; - - if (dims == 2) { // Matrix type: require exact match (or dynamic) - - EigenIndex - np_rows = a.shape(0), - np_cols = a.shape(1), - np_rstride = a.strides(0) / static_cast(sizeof(Scalar)), - np_cstride = a.strides(1) / static_cast(sizeof(Scalar)); - if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols)) - return false; - - return {np_rows, np_cols, np_rstride, np_cstride}; - } - - // Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever - // is used, we want the (single) numpy stride value. - const EigenIndex n = a.shape(0), - stride = a.strides(0) / static_cast(sizeof(Scalar)); - - if (vector) { // Eigen type is a compile-time vector - if (fixed && size != n) - return false; // Vector size mismatch - return {rows == 1 ? 1 : n, cols == 1 ? 1 : n, stride}; - } - else if (fixed) { - // The type has a fixed size, but is not a vector: abort - return false; - } - else if (fixed_cols) { - // Since this isn't a vector, cols must be != 1. We allow this only if it exactly - // equals the number of elements (rows is Dynamic, and so 1 row is allowed). - if (cols != n) return false; - return {1, n, stride}; - } - else { - // Otherwise it's either fully dynamic, or column dynamic; both become a column vector - if (fixed_rows && rows != n) return false; - return {n, 1, stride}; - } - } - - static constexpr bool show_writeable = is_eigen_dense_map::value && is_eigen_mutable_map::value; - static constexpr bool show_order = is_eigen_dense_map::value; - static constexpr bool show_c_contiguous = show_order && requires_row_major; - static constexpr bool show_f_contiguous = !show_c_contiguous && show_order && requires_col_major; - - static constexpr auto descriptor = - _("numpy.ndarray[") + npy_format_descriptor::name + - _("[") + _(_<(size_t) rows>(), _("m")) + - _(", ") + _(_<(size_t) cols>(), _("n")) + - _("]") + - // For a reference type (e.g. Ref) we have other constraints that might need to be - // satisfied: writeable=True (for a mutable reference), and, depending on the map's stride - // options, possibly f_contiguous or c_contiguous. We include them in the descriptor output - // to provide some hint as to why a TypeError is occurring (otherwise it can be confusing to - // see that a function accepts a 'numpy.ndarray[float64[3,2]]' and an error message that you - // *gave* a numpy.ndarray of the right type and dimensions. - _(", flags.writeable", "") + - _(", flags.c_contiguous", "") + - _(", flags.f_contiguous", "") + - _("]"); -}; - -// Casts an Eigen type to numpy array. If given a base, the numpy array references the src data, -// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array. -template handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) { - constexpr ssize_t elem_size = sizeof(typename props::Scalar); - array a; - if (props::vector) - a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base); - else - a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() }, - src.data(), base); - - if (!writeable) - array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_; - - return a.release(); -} - -// Takes an lvalue ref to some Eigen type and a (python) base object, creating a numpy array that -// reference the Eigen object's data with `base` as the python-registered base class (if omitted, -// the base will be set to None, and lifetime management is up to the caller). The numpy array is -// non-writeable if the given type is const. -template -handle eigen_ref_array(Type &src, handle parent = none()) { - // none here is to get past array's should-we-copy detection, which currently always - // copies when there is no base. Setting the base to None should be harmless. - return eigen_array_cast(src, parent, !std::is_const::value); -} - -// Takes a pointer to some dense, plain Eigen type, builds a capsule around it, then returns a numpy -// array that references the encapsulated data with a python-side reference to the capsule to tie -// its destruction to that of any dependent python objects. Const-ness is determined by whether or -// not the Type of the pointer given is const. -template ::value>> -handle eigen_encapsulate(Type *src) { - capsule base(src, [](void *o) { delete static_cast(o); }); - return eigen_ref_array(*src, base); -} - -// Type caster for regular, dense matrix types (e.g. MatrixXd), but not maps/refs/etc. of dense -// types. -template -struct type_caster::value>> { - using Scalar = typename Type::Scalar; - using props = EigenProps; - - bool load(handle src, bool convert) { - // If we're in no-convert mode, only load if given an array of the correct type - if (!convert && !isinstance>(src)) - return false; - - // Coerce into an array, but don't do type conversion yet; the copy below handles it. - auto buf = array::ensure(src); - - if (!buf) - return false; - - auto dims = buf.ndim(); - if (dims < 1 || dims > 2) - return false; - - auto fits = props::conformable(buf); - if (!fits) - return false; - - // Allocate the new type, then build a numpy reference into it - value = Type(fits.rows, fits.cols); - auto ref = reinterpret_steal(eigen_ref_array(value)); - if (dims == 1) ref = ref.squeeze(); - else if (ref.ndim() == 1) buf = buf.squeeze(); - - int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr()); - - if (result < 0) { // Copy failed! - PyErr_Clear(); - return false; - } - - return true; - } - -private: - - // Cast implementation - template - static handle cast_impl(CType *src, return_value_policy policy, handle parent) { - switch (policy) { - case return_value_policy::take_ownership: - case return_value_policy::automatic: - return eigen_encapsulate(src); - case return_value_policy::move: - return eigen_encapsulate(new CType(std::move(*src))); - case return_value_policy::copy: - return eigen_array_cast(*src); - case return_value_policy::reference: - case return_value_policy::automatic_reference: - return eigen_ref_array(*src); - case return_value_policy::reference_internal: - return eigen_ref_array(*src, parent); - default: - throw cast_error("unhandled return_value_policy: should not happen!"); - }; - } - -public: - - // Normal returned non-reference, non-const value: - static handle cast(Type &&src, return_value_policy /* policy */, handle parent) { - return cast_impl(&src, return_value_policy::move, parent); - } - // If you return a non-reference const, we mark the numpy array readonly: - static handle cast(const Type &&src, return_value_policy /* policy */, handle parent) { - return cast_impl(&src, return_value_policy::move, parent); - } - // lvalue reference return; default (automatic) becomes copy - static handle cast(Type &src, return_value_policy policy, handle parent) { - if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) - policy = return_value_policy::copy; - return cast_impl(&src, policy, parent); - } - // const lvalue reference return; default (automatic) becomes copy - static handle cast(const Type &src, return_value_policy policy, handle parent) { - if (policy == return_value_policy::automatic || policy == return_value_policy::automatic_reference) - policy = return_value_policy::copy; - return cast(&src, policy, parent); - } - // non-const pointer return - static handle cast(Type *src, return_value_policy policy, handle parent) { - return cast_impl(src, policy, parent); - } - // const pointer return - static handle cast(const Type *src, return_value_policy policy, handle parent) { - return cast_impl(src, policy, parent); - } - - static constexpr auto name = props::descriptor; - - operator Type*() { return &value; } - operator Type&() { return value; } - operator Type&&() && { return std::move(value); } - template using cast_op_type = movable_cast_op_type; - -private: - Type value; -}; - -// Base class for casting reference/map/block/etc. objects back to python. -template struct eigen_map_caster { -private: - using props = EigenProps; - -public: - - // Directly referencing a ref/map's data is a bit dangerous (whatever the map/ref points to has - // to stay around), but we'll allow it under the assumption that you know what you're doing (and - // have an appropriate keep_alive in place). We return a numpy array pointing directly at the - // ref's data (The numpy array ends up read-only if the ref was to a const matrix type.) Note - // that this means you need to ensure you don't destroy the object in some other way (e.g. with - // an appropriate keep_alive, or with a reference to a statically allocated matrix). - static handle cast(const MapType &src, return_value_policy policy, handle parent) { - switch (policy) { - case return_value_policy::copy: - return eigen_array_cast(src); - case return_value_policy::reference_internal: - return eigen_array_cast(src, parent, is_eigen_mutable_map::value); - case return_value_policy::reference: - case return_value_policy::automatic: - case return_value_policy::automatic_reference: - return eigen_array_cast(src, none(), is_eigen_mutable_map::value); - default: - // move, take_ownership don't make any sense for a ref/map: - pybind11_fail("Invalid return_value_policy for Eigen Map/Ref/Block type"); - } - } - - static constexpr auto name = props::descriptor; - - // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return - // types but not bound arguments). We still provide them (with an explicitly delete) so that - // you end up here if you try anyway. - bool load(handle, bool) = delete; - operator MapType() = delete; - template using cast_op_type = MapType; -}; - -// We can return any map-like object (but can only load Refs, specialized next): -template struct type_caster::value>> - : eigen_map_caster {}; - -// Loader for Ref<...> arguments. See the documentation for info on how to make this work without -// copying (it requires some extra effort in many cases). -template -struct type_caster< - Eigen::Ref, - enable_if_t>::value> -> : public eigen_map_caster> { -private: - using Type = Eigen::Ref; - using props = EigenProps; - using Scalar = typename props::Scalar; - using MapType = Eigen::Map; - using Array = array_t; - static constexpr bool need_writeable = is_eigen_mutable_map::value; - // Delay construction (these have no default constructor) - std::unique_ptr map; - std::unique_ptr ref; - // Our array. When possible, this is just a numpy array pointing to the source data, but - // sometimes we can't avoid copying (e.g. input is not a numpy array at all, has an incompatible - // layout, or is an array of a type that needs to be converted). Using a numpy temporary - // (rather than an Eigen temporary) saves an extra copy when we need both type conversion and - // storage order conversion. (Note that we refuse to use this temporary copy when loading an - // argument for a Ref with M non-const, i.e. a read-write reference). - Array copy_or_ref; -public: - bool load(handle src, bool convert) { - // First check whether what we have is already an array of the right type. If not, we can't - // avoid a copy (because the copy is also going to do type conversion). - bool need_copy = !isinstance(src); - - EigenConformable fits; - if (!need_copy) { - // We don't need a converting copy, but we also need to check whether the strides are - // compatible with the Ref's stride requirements - Array aref = reinterpret_borrow(src); - - if (aref && (!need_writeable || aref.writeable())) { - fits = props::conformable(aref); - if (!fits) return false; // Incompatible dimensions - if (!fits.template stride_compatible()) - need_copy = true; - else - copy_or_ref = std::move(aref); - } - else { - need_copy = true; - } - } - - if (need_copy) { - // We need to copy: If we need a mutable reference, or we're not supposed to convert - // (either because we're in the no-convert overload pass, or because we're explicitly - // instructed not to copy (via `py::arg().noconvert()`) we have to fail loading. - if (!convert || need_writeable) return false; - - Array copy = Array::ensure(src); - if (!copy) return false; - fits = props::conformable(copy); - if (!fits || !fits.template stride_compatible()) - return false; - copy_or_ref = std::move(copy); - loader_life_support::add_patient(copy_or_ref); - } - - ref.reset(); - map.reset(new MapType(data(copy_or_ref), fits.rows, fits.cols, make_stride(fits.stride.outer(), fits.stride.inner()))); - ref.reset(new Type(*map)); - - return true; - } - - operator Type*() { return ref.get(); } - operator Type&() { return *ref; } - template using cast_op_type = pybind11::detail::cast_op_type<_T>; - -private: - template ::value, int> = 0> - Scalar *data(Array &a) { return a.mutable_data(); } - - template ::value, int> = 0> - const Scalar *data(Array &a) { return a.data(); } - - // Attempt to figure out a constructor of `Stride` that will work. - // If both strides are fixed, use a default constructor: - template using stride_ctor_default = bool_constant< - S::InnerStrideAtCompileTime != Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic && - std::is_default_constructible::value>; - // Otherwise, if there is a two-index constructor, assume it is (outer,inner) like - // Eigen::Stride, and use it: - template using stride_ctor_dual = bool_constant< - !stride_ctor_default::value && std::is_constructible::value>; - // Otherwise, if there is a one-index constructor, and just one of the strides is dynamic, use - // it (passing whichever stride is dynamic). - template using stride_ctor_outer = bool_constant< - !any_of, stride_ctor_dual>::value && - S::OuterStrideAtCompileTime == Eigen::Dynamic && S::InnerStrideAtCompileTime != Eigen::Dynamic && - std::is_constructible::value>; - template using stride_ctor_inner = bool_constant< - !any_of, stride_ctor_dual>::value && - S::InnerStrideAtCompileTime == Eigen::Dynamic && S::OuterStrideAtCompileTime != Eigen::Dynamic && - std::is_constructible::value>; - - template ::value, int> = 0> - static S make_stride(EigenIndex, EigenIndex) { return S(); } - template ::value, int> = 0> - static S make_stride(EigenIndex outer, EigenIndex inner) { return S(outer, inner); } - template ::value, int> = 0> - static S make_stride(EigenIndex outer, EigenIndex) { return S(outer); } - template ::value, int> = 0> - static S make_stride(EigenIndex, EigenIndex inner) { return S(inner); } - -}; - -// type_caster for special matrix types (e.g. DiagonalMatrix), which are EigenBase, but not -// EigenDense (i.e. they don't have a data(), at least not with the usual matrix layout). -// load() is not supported, but we can cast them into the python domain by first copying to a -// regular Eigen::Matrix, then casting that. -template -struct type_caster::value>> { -protected: - using Matrix = Eigen::Matrix; - using props = EigenProps; -public: - static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { - handle h = eigen_encapsulate(new Matrix(src)); - return h; - } - static handle cast(const Type *src, return_value_policy policy, handle parent) { return cast(*src, policy, parent); } - - static constexpr auto name = props::descriptor; - - // Explicitly delete these: support python -> C++ conversion on these (i.e. these can be return - // types but not bound arguments). We still provide them (with an explicitly delete) so that - // you end up here if you try anyway. - bool load(handle, bool) = delete; - operator Type() = delete; - template using cast_op_type = Type; -}; - -template -struct type_caster::value>> { - typedef typename Type::Scalar Scalar; - typedef remove_reference_t().outerIndexPtr())> StorageIndex; - typedef typename Type::Index Index; - static constexpr bool rowMajor = Type::IsRowMajor; - - bool load(handle src, bool) { - if (!src) - return false; - - auto obj = reinterpret_borrow(src); - object sparse_module = module::import("scipy.sparse"); - object matrix_type = sparse_module.attr( - rowMajor ? "csr_matrix" : "csc_matrix"); - - if (!obj.get_type().is(matrix_type)) { - try { - obj = matrix_type(obj); - } catch (const error_already_set &) { - return false; - } - } - - auto values = array_t((object) obj.attr("data")); - auto innerIndices = array_t((object) obj.attr("indices")); - auto outerIndices = array_t((object) obj.attr("indptr")); - auto shape = pybind11::tuple((pybind11::object) obj.attr("shape")); - auto nnz = obj.attr("nnz").cast(); - - if (!values || !innerIndices || !outerIndices) - return false; - - value = Eigen::MappedSparseMatrix( - shape[0].cast(), shape[1].cast(), nnz, - outerIndices.mutable_data(), innerIndices.mutable_data(), values.mutable_data()); - - return true; - } - - static handle cast(const Type &src, return_value_policy /* policy */, handle /* parent */) { - const_cast(src).makeCompressed(); - - object matrix_type = module::import("scipy.sparse").attr( - rowMajor ? "csr_matrix" : "csc_matrix"); - - array data(src.nonZeros(), src.valuePtr()); - array outerIndices((rowMajor ? src.rows() : src.cols()) + 1, src.outerIndexPtr()); - array innerIndices(src.nonZeros(), src.innerIndexPtr()); - - return matrix_type( - std::make_tuple(data, innerIndices, outerIndices), - std::make_pair(src.rows(), src.cols()) - ).release(); - } - - PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[") - + npy_format_descriptor::name + _("]")); -}; - -NAMESPACE_END(detail) -NAMESPACE_END(PYBIND11_NAMESPACE) - -#if defined(__GNUG__) || defined(__clang__) -# pragma GCC diagnostic pop -#elif defined(_MSC_VER) -# pragma warning(pop) -#endif diff --git a/pybind11/include/pybind11/embed.h b/pybind11/include/pybind11/embed.h deleted file mode 100644 index 7265588..0000000 --- a/pybind11/include/pybind11/embed.h +++ /dev/null @@ -1,200 +0,0 @@ -/* - pybind11/embed.h: Support for embedding the interpreter - - Copyright (c) 2017 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "pybind11.h" -#include "eval.h" - -#if defined(PYPY_VERSION) -# error Embedding the interpreter is not supported with PyPy -#endif - -#if PY_MAJOR_VERSION >= 3 -# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ - extern "C" PyObject *pybind11_init_impl_##name() { \ - return pybind11_init_wrapper_##name(); \ - } -#else -# define PYBIND11_EMBEDDED_MODULE_IMPL(name) \ - extern "C" void pybind11_init_impl_##name() { \ - pybind11_init_wrapper_##name(); \ - } -#endif - -/** \rst - Add a new module to the table of builtins for the interpreter. Must be - defined in global scope. The first macro parameter is the name of the - module (without quotes). The second parameter is the variable which will - be used as the interface to add functions and classes to the module. - - .. code-block:: cpp - - PYBIND11_EMBEDDED_MODULE(example, m) { - // ... initialize functions and classes here - m.def("foo", []() { - return "Hello, World!"; - }); - } - \endrst */ -#define PYBIND11_EMBEDDED_MODULE(name, variable) \ - static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \ - static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \ - auto m = pybind11::module(PYBIND11_TOSTRING(name)); \ - try { \ - PYBIND11_CONCAT(pybind11_init_, name)(m); \ - return m.ptr(); \ - } catch (pybind11::error_already_set &e) { \ - PyErr_SetString(PyExc_ImportError, e.what()); \ - return nullptr; \ - } catch (const std::exception &e) { \ - PyErr_SetString(PyExc_ImportError, e.what()); \ - return nullptr; \ - } \ - } \ - PYBIND11_EMBEDDED_MODULE_IMPL(name) \ - pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \ - PYBIND11_CONCAT(pybind11_init_impl_, name)); \ - void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable) - - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -/// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks. -struct embedded_module { -#if PY_MAJOR_VERSION >= 3 - using init_t = PyObject *(*)(); -#else - using init_t = void (*)(); -#endif - embedded_module(const char *name, init_t init) { - if (Py_IsInitialized()) - pybind11_fail("Can't add new modules after the interpreter has been initialized"); - - auto result = PyImport_AppendInittab(name, init); - if (result == -1) - pybind11_fail("Insufficient memory to add a new module"); - } -}; - -NAMESPACE_END(detail) - -/** \rst - Initialize the Python interpreter. No other pybind11 or CPython API functions can be - called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The - optional parameter can be used to skip the registration of signal handlers (see the - `Python documentation`_ for details). Calling this function again after the interpreter - has already been initialized is a fatal error. - - If initializing the Python interpreter fails, then the program is terminated. (This - is controlled by the CPython runtime and is an exception to pybind11's normal behavior - of throwing exceptions on errors.) - - .. _Python documentation: https://docs.python.org/3/c-api/init.html#c.Py_InitializeEx - \endrst */ -inline void initialize_interpreter(bool init_signal_handlers = true) { - if (Py_IsInitialized()) - pybind11_fail("The interpreter is already running"); - - Py_InitializeEx(init_signal_handlers ? 1 : 0); - - // Make .py files in the working directory available by default - module::import("sys").attr("path").cast().append("."); -} - -/** \rst - Shut down the Python interpreter. No pybind11 or CPython API functions can be called - after this. In addition, pybind11 objects must not outlive the interpreter: - - .. code-block:: cpp - - { // BAD - py::initialize_interpreter(); - auto hello = py::str("Hello, World!"); - py::finalize_interpreter(); - } // <-- BOOM, hello's destructor is called after interpreter shutdown - - { // GOOD - py::initialize_interpreter(); - { // scoped - auto hello = py::str("Hello, World!"); - } // <-- OK, hello is cleaned up properly - py::finalize_interpreter(); - } - - { // BETTER - py::scoped_interpreter guard{}; - auto hello = py::str("Hello, World!"); - } - - .. warning:: - - The interpreter can be restarted by calling `initialize_interpreter` again. - Modules created using pybind11 can be safely re-initialized. However, Python - itself cannot completely unload binary extension modules and there are several - caveats with regard to interpreter restarting. All the details can be found - in the CPython documentation. In short, not all interpreter memory may be - freed, either due to reference cycles or user-created global data. - - \endrst */ -inline void finalize_interpreter() { - handle builtins(PyEval_GetBuiltins()); - const char *id = PYBIND11_INTERNALS_ID; - - // Get the internals pointer (without creating it if it doesn't exist). It's possible for the - // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()` - // during destruction), so we get the pointer-pointer here and check it after Py_Finalize(). - detail::internals **internals_ptr_ptr = detail::get_internals_pp(); - // It could also be stashed in builtins, so look there too: - if (builtins.contains(id) && isinstance(builtins[id])) - internals_ptr_ptr = capsule(builtins[id]); - - Py_Finalize(); - - if (internals_ptr_ptr) { - delete *internals_ptr_ptr; - *internals_ptr_ptr = nullptr; - } -} - -/** \rst - Scope guard version of `initialize_interpreter` and `finalize_interpreter`. - This a move-only guard and only a single instance can exist. - - .. code-block:: cpp - - #include - - int main() { - py::scoped_interpreter guard{}; - py::print(Hello, World!); - } // <-- interpreter shutdown - \endrst */ -class scoped_interpreter { -public: - scoped_interpreter(bool init_signal_handlers = true) { - initialize_interpreter(init_signal_handlers); - } - - scoped_interpreter(const scoped_interpreter &) = delete; - scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; } - scoped_interpreter &operator=(const scoped_interpreter &) = delete; - scoped_interpreter &operator=(scoped_interpreter &&) = delete; - - ~scoped_interpreter() { - if (is_valid) - finalize_interpreter(); - } - -private: - bool is_valid = true; -}; - -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/eval.h b/pybind11/include/pybind11/eval.h deleted file mode 100644 index ea85ba1..0000000 --- a/pybind11/include/pybind11/eval.h +++ /dev/null @@ -1,117 +0,0 @@ -/* - pybind11/exec.h: Support for evaluating Python expressions and statements - from strings and files - - Copyright (c) 2016 Klemens Morgenstern and - Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "pybind11.h" - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) - -enum eval_mode { - /// Evaluate a string containing an isolated expression - eval_expr, - - /// Evaluate a string containing a single statement. Returns \c none - eval_single_statement, - - /// Evaluate a string containing a sequence of statement. Returns \c none - eval_statements -}; - -template -object eval(str expr, object global = globals(), object local = object()) { - if (!local) - local = global; - - /* PyRun_String does not accept a PyObject / encoding specifier, - this seems to be the only alternative */ - std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr; - - int start; - switch (mode) { - case eval_expr: start = Py_eval_input; break; - case eval_single_statement: start = Py_single_input; break; - case eval_statements: start = Py_file_input; break; - default: pybind11_fail("invalid evaluation mode"); - } - - PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr()); - if (!result) - throw error_already_set(); - return reinterpret_steal(result); -} - -template -object eval(const char (&s)[N], object global = globals(), object local = object()) { - /* Support raw string literals by removing common leading whitespace */ - auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s)) - : str(s); - return eval(expr, global, local); -} - -inline void exec(str expr, object global = globals(), object local = object()) { - eval(expr, global, local); -} - -template -void exec(const char (&s)[N], object global = globals(), object local = object()) { - eval(s, global, local); -} - -template -object eval_file(str fname, object global = globals(), object local = object()) { - if (!local) - local = global; - - int start; - switch (mode) { - case eval_expr: start = Py_eval_input; break; - case eval_single_statement: start = Py_single_input; break; - case eval_statements: start = Py_file_input; break; - default: pybind11_fail("invalid evaluation mode"); - } - - int closeFile = 1; - std::string fname_str = (std::string) fname; -#if PY_VERSION_HEX >= 0x03040000 - FILE *f = _Py_fopen_obj(fname.ptr(), "r"); -#elif PY_VERSION_HEX >= 0x03000000 - FILE *f = _Py_fopen(fname.ptr(), "r"); -#else - /* No unicode support in open() :( */ - auto fobj = reinterpret_steal(PyFile_FromString( - const_cast(fname_str.c_str()), - const_cast("r"))); - FILE *f = nullptr; - if (fobj) - f = PyFile_AsFile(fobj.ptr()); - closeFile = 0; -#endif - if (!f) { - PyErr_Clear(); - pybind11_fail("File \"" + fname_str + "\" could not be opened!"); - } - -#if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION) - PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(), - local.ptr()); - (void) closeFile; -#else - PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(), - local.ptr(), closeFile); -#endif - - if (!result) - throw error_already_set(); - return reinterpret_steal(result); -} - -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/functional.h b/pybind11/include/pybind11/functional.h deleted file mode 100644 index f8bda64..0000000 --- a/pybind11/include/pybind11/functional.h +++ /dev/null @@ -1,101 +0,0 @@ -/* - pybind11/functional.h: std::function<> support - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "pybind11.h" -#include - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -template -struct type_caster> { - using type = std::function; - using retval_type = conditional_t::value, void_type, Return>; - using function_type = Return (*) (Args...); - -public: - bool load(handle src, bool convert) { - if (src.is_none()) { - // Defer accepting None to other overloads (if we aren't in convert mode): - if (!convert) return false; - return true; - } - - if (!isinstance(src)) - return false; - - auto func = reinterpret_borrow(src); - - /* - When passing a C++ function as an argument to another C++ - function via Python, every function call would normally involve - a full C++ -> Python -> C++ roundtrip, which can be prohibitive. - Here, we try to at least detect the case where the function is - stateless (i.e. function pointer or lambda function without - captured variables), in which case the roundtrip can be avoided. - */ - if (auto cfunc = func.cpp_function()) { - auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr())); - auto rec = (function_record *) c; - - if (rec && rec->is_stateless && - same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) { - struct capture { function_type f; }; - value = ((capture *) &rec->data)->f; - return true; - } - } - - // ensure GIL is held during functor destruction - struct func_handle { - function f; - func_handle(function&& f_) : f(std::move(f_)) {} - func_handle(const func_handle&) = default; - ~func_handle() { - gil_scoped_acquire acq; - function kill_f(std::move(f)); - } - }; - - // to emulate 'move initialization capture' in C++11 - struct func_wrapper { - func_handle hfunc; - func_wrapper(func_handle&& hf): hfunc(std::move(hf)) {} - Return operator()(Args... args) const { - gil_scoped_acquire acq; - object retval(hfunc.f(std::forward(args)...)); - /* Visual studio 2015 parser issue: need parentheses around this expression */ - return (retval.template cast()); - } - }; - - value = func_wrapper(func_handle(std::move(func))); - return true; - } - - template - static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) { - if (!f_) - return none().inc_ref(); - - auto result = f_.template target(); - if (result) - return cpp_function(*result, policy).release(); - else - return cpp_function(std::forward(f_), policy).release(); - } - - PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ") - + make_caster::name + _("]")); -}; - -NAMESPACE_END(detail) -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/iostream.h b/pybind11/include/pybind11/iostream.h deleted file mode 100644 index c43b7c9..0000000 --- a/pybind11/include/pybind11/iostream.h +++ /dev/null @@ -1,209 +0,0 @@ -/* - pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python - - Copyright (c) 2017 Henry F. Schreiner - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "pybind11.h" - -#include -#include -#include -#include -#include - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -// Buffer that writes to Python instead of C++ -class pythonbuf : public std::streambuf { -private: - using traits_type = std::streambuf::traits_type; - - const size_t buf_size; - std::unique_ptr d_buffer; - object pywrite; - object pyflush; - - int overflow(int c) { - if (!traits_type::eq_int_type(c, traits_type::eof())) { - *pptr() = traits_type::to_char_type(c); - pbump(1); - } - return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof(); - } - - int sync() { - if (pbase() != pptr()) { - // This subtraction cannot be negative, so dropping the sign - str line(pbase(), static_cast(pptr() - pbase())); - - { - gil_scoped_acquire tmp; - pywrite(line); - pyflush(); - } - - setp(pbase(), epptr()); - } - return 0; - } - -public: - - pythonbuf(object pyostream, size_t buffer_size = 1024) - : buf_size(buffer_size), - d_buffer(new char[buf_size]), - pywrite(pyostream.attr("write")), - pyflush(pyostream.attr("flush")) { - setp(d_buffer.get(), d_buffer.get() + buf_size - 1); - } - - pythonbuf(pythonbuf&&) = default; - - /// Sync before destroy - ~pythonbuf() { - sync(); - } -}; - -NAMESPACE_END(detail) - - -/** \rst - This a move-only guard that redirects output. - - .. code-block:: cpp - - #include - - ... - - { - py::scoped_ostream_redirect output; - std::cout << "Hello, World!"; // Python stdout - } // <-- return std::cout to normal - - You can explicitly pass the c++ stream and the python object, - for example to guard stderr instead. - - .. code-block:: cpp - - { - py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")}; - std::cerr << "Hello, World!"; - } - \endrst */ -class scoped_ostream_redirect { -protected: - std::streambuf *old; - std::ostream &costream; - detail::pythonbuf buffer; - -public: - scoped_ostream_redirect( - std::ostream &costream = std::cout, - object pyostream = module::import("sys").attr("stdout")) - : costream(costream), buffer(pyostream) { - old = costream.rdbuf(&buffer); - } - - ~scoped_ostream_redirect() { - costream.rdbuf(old); - } - - scoped_ostream_redirect(const scoped_ostream_redirect &) = delete; - scoped_ostream_redirect(scoped_ostream_redirect &&other) = default; - scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete; - scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete; -}; - - -/** \rst - Like `scoped_ostream_redirect`, but redirects cerr by default. This class - is provided primary to make ``py::call_guard`` easier to make. - - .. code-block:: cpp - - m.def("noisy_func", &noisy_func, - py::call_guard()); - -\endrst */ -class scoped_estream_redirect : public scoped_ostream_redirect { -public: - scoped_estream_redirect( - std::ostream &costream = std::cerr, - object pyostream = module::import("sys").attr("stderr")) - : scoped_ostream_redirect(costream,pyostream) {} -}; - - -NAMESPACE_BEGIN(detail) - -// Class to redirect output as a context manager. C++ backend. -class OstreamRedirect { - bool do_stdout_; - bool do_stderr_; - std::unique_ptr redirect_stdout; - std::unique_ptr redirect_stderr; - -public: - OstreamRedirect(bool do_stdout = true, bool do_stderr = true) - : do_stdout_(do_stdout), do_stderr_(do_stderr) {} - - void enter() { - if (do_stdout_) - redirect_stdout.reset(new scoped_ostream_redirect()); - if (do_stderr_) - redirect_stderr.reset(new scoped_estream_redirect()); - } - - void exit() { - redirect_stdout.reset(); - redirect_stderr.reset(); - } -}; - -NAMESPACE_END(detail) - -/** \rst - This is a helper function to add a C++ redirect context manager to Python - instead of using a C++ guard. To use it, add the following to your binding code: - - .. code-block:: cpp - - #include - - ... - - py::add_ostream_redirect(m, "ostream_redirect"); - - You now have a Python context manager that redirects your output: - - .. code-block:: python - - with m.ostream_redirect(): - m.print_to_cout_function() - - This manager can optionally be told which streams to operate on: - - .. code-block:: python - - with m.ostream_redirect(stdout=true, stderr=true): - m.noisy_function_with_error_printing() - - \endrst */ -inline class_ add_ostream_redirect(module m, std::string name = "ostream_redirect") { - return class_(m, name.c_str(), module_local()) - .def(init(), arg("stdout")=true, arg("stderr")=true) - .def("__enter__", &detail::OstreamRedirect::enter) - .def("__exit__", [](detail::OstreamRedirect &self_, args) { self_.exit(); }); -} - -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/numpy.h b/pybind11/include/pybind11/numpy.h deleted file mode 100644 index ba41a22..0000000 --- a/pybind11/include/pybind11/numpy.h +++ /dev/null @@ -1,1642 +0,0 @@ -/* - pybind11/numpy.h: Basic NumPy support, vectorize() wrapper - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "pybind11.h" -#include "complex.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(_MSC_VER) -# pragma warning(push) -# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant -#endif - -/* This will be true on all flat address space platforms and allows us to reduce the - whole npy_intp / ssize_t / Py_intptr_t business down to just ssize_t for all size - and dimension types (e.g. shape, strides, indexing), instead of inflicting this - upon the library user. */ -static_assert(sizeof(ssize_t) == sizeof(Py_intptr_t), "ssize_t != Py_intptr_t"); - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) - -class array; // Forward declaration - -NAMESPACE_BEGIN(detail) -template struct npy_format_descriptor; - -struct PyArrayDescr_Proxy { - PyObject_HEAD - PyObject *typeobj; - char kind; - char type; - char byteorder; - char flags; - int type_num; - int elsize; - int alignment; - char *subarray; - PyObject *fields; - PyObject *names; -}; - -struct PyArray_Proxy { - PyObject_HEAD - char *data; - int nd; - ssize_t *dimensions; - ssize_t *strides; - PyObject *base; - PyObject *descr; - int flags; -}; - -struct PyVoidScalarObject_Proxy { - PyObject_VAR_HEAD - char *obval; - PyArrayDescr_Proxy *descr; - int flags; - PyObject *base; -}; - -struct numpy_type_info { - PyObject* dtype_ptr; - std::string format_str; -}; - -struct numpy_internals { - std::unordered_map registered_dtypes; - - numpy_type_info *get_type_info(const std::type_info& tinfo, bool throw_if_missing = true) { - auto it = registered_dtypes.find(std::type_index(tinfo)); - if (it != registered_dtypes.end()) - return &(it->second); - if (throw_if_missing) - pybind11_fail(std::string("NumPy type info missing for ") + tinfo.name()); - return nullptr; - } - - template numpy_type_info *get_type_info(bool throw_if_missing = true) { - return get_type_info(typeid(typename std::remove_cv::type), throw_if_missing); - } -}; - -inline PYBIND11_NOINLINE void load_numpy_internals(numpy_internals* &ptr) { - ptr = &get_or_create_shared_data("_numpy_internals"); -} - -inline numpy_internals& get_numpy_internals() { - static numpy_internals* ptr = nullptr; - if (!ptr) - load_numpy_internals(ptr); - return *ptr; -} - -template struct same_size { - template using as = bool_constant; -}; - -template constexpr int platform_lookup() { return -1; } - -// Lookup a type according to its size, and return a value corresponding to the NumPy typenum. -template -constexpr int platform_lookup(int I, Ints... Is) { - return sizeof(Concrete) == sizeof(T) ? I : platform_lookup(Is...); -} - -struct npy_api { - enum constants { - NPY_ARRAY_C_CONTIGUOUS_ = 0x0001, - NPY_ARRAY_F_CONTIGUOUS_ = 0x0002, - NPY_ARRAY_OWNDATA_ = 0x0004, - NPY_ARRAY_FORCECAST_ = 0x0010, - NPY_ARRAY_ENSUREARRAY_ = 0x0040, - NPY_ARRAY_ALIGNED_ = 0x0100, - NPY_ARRAY_WRITEABLE_ = 0x0400, - NPY_BOOL_ = 0, - NPY_BYTE_, NPY_UBYTE_, - NPY_SHORT_, NPY_USHORT_, - NPY_INT_, NPY_UINT_, - NPY_LONG_, NPY_ULONG_, - NPY_LONGLONG_, NPY_ULONGLONG_, - NPY_FLOAT_, NPY_DOUBLE_, NPY_LONGDOUBLE_, - NPY_CFLOAT_, NPY_CDOUBLE_, NPY_CLONGDOUBLE_, - NPY_OBJECT_ = 17, - NPY_STRING_, NPY_UNICODE_, NPY_VOID_, - // Platform-dependent normalization - NPY_INT8_ = NPY_BYTE_, - NPY_UINT8_ = NPY_UBYTE_, - NPY_INT16_ = NPY_SHORT_, - NPY_UINT16_ = NPY_USHORT_, - // `npy_common.h` defines the integer aliases. In order, it checks: - // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR - // and assigns the alias to the first matching size, so we should check in this order. - NPY_INT32_ = platform_lookup( - NPY_LONG_, NPY_INT_, NPY_SHORT_), - NPY_UINT32_ = platform_lookup( - NPY_ULONG_, NPY_UINT_, NPY_USHORT_), - NPY_INT64_ = platform_lookup( - NPY_LONG_, NPY_LONGLONG_, NPY_INT_), - NPY_UINT64_ = platform_lookup( - NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_), - }; - - typedef struct { - Py_intptr_t *ptr; - int len; - } PyArray_Dims; - - static npy_api& get() { - static npy_api api = lookup(); - return api; - } - - bool PyArray_Check_(PyObject *obj) const { - return (bool) PyObject_TypeCheck(obj, PyArray_Type_); - } - bool PyArrayDescr_Check_(PyObject *obj) const { - return (bool) PyObject_TypeCheck(obj, PyArrayDescr_Type_); - } - - unsigned int (*PyArray_GetNDArrayCFeatureVersion_)(); - PyObject *(*PyArray_DescrFromType_)(int); - PyObject *(*PyArray_NewFromDescr_) - (PyTypeObject *, PyObject *, int, Py_intptr_t *, - Py_intptr_t *, void *, int, PyObject *); - PyObject *(*PyArray_DescrNewFromType_)(int); - int (*PyArray_CopyInto_)(PyObject *, PyObject *); - PyObject *(*PyArray_NewCopy_)(PyObject *, int); - PyTypeObject *PyArray_Type_; - PyTypeObject *PyVoidArrType_Type_; - PyTypeObject *PyArrayDescr_Type_; - PyObject *(*PyArray_DescrFromScalar_)(PyObject *); - PyObject *(*PyArray_FromAny_) (PyObject *, PyObject *, int, int, int, PyObject *); - int (*PyArray_DescrConverter_) (PyObject *, PyObject **); - bool (*PyArray_EquivTypes_) (PyObject *, PyObject *); - int (*PyArray_GetArrayParamsFromObject_)(PyObject *, PyObject *, char, PyObject **, int *, - Py_ssize_t *, PyObject **, PyObject *); - PyObject *(*PyArray_Squeeze_)(PyObject *); - int (*PyArray_SetBaseObject_)(PyObject *, PyObject *); - PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int); -private: - enum functions { - API_PyArray_GetNDArrayCFeatureVersion = 211, - API_PyArray_Type = 2, - API_PyArrayDescr_Type = 3, - API_PyVoidArrType_Type = 39, - API_PyArray_DescrFromType = 45, - API_PyArray_DescrFromScalar = 57, - API_PyArray_FromAny = 69, - API_PyArray_Resize = 80, - API_PyArray_CopyInto = 82, - API_PyArray_NewCopy = 85, - API_PyArray_NewFromDescr = 94, - API_PyArray_DescrNewFromType = 9, - API_PyArray_DescrConverter = 174, - API_PyArray_EquivTypes = 182, - API_PyArray_GetArrayParamsFromObject = 278, - API_PyArray_Squeeze = 136, - API_PyArray_SetBaseObject = 282 - }; - - static npy_api lookup() { - module m = module::import("numpy.core.multiarray"); - auto c = m.attr("_ARRAY_API"); -#if PY_MAJOR_VERSION >= 3 - void **api_ptr = (void **) PyCapsule_GetPointer(c.ptr(), NULL); -#else - void **api_ptr = (void **) PyCObject_AsVoidPtr(c.ptr()); -#endif - npy_api api; -#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func]; - DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion); - if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7) - pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0"); - DECL_NPY_API(PyArray_Type); - DECL_NPY_API(PyVoidArrType_Type); - DECL_NPY_API(PyArrayDescr_Type); - DECL_NPY_API(PyArray_DescrFromType); - DECL_NPY_API(PyArray_DescrFromScalar); - DECL_NPY_API(PyArray_FromAny); - DECL_NPY_API(PyArray_Resize); - DECL_NPY_API(PyArray_CopyInto); - DECL_NPY_API(PyArray_NewCopy); - DECL_NPY_API(PyArray_NewFromDescr); - DECL_NPY_API(PyArray_DescrNewFromType); - DECL_NPY_API(PyArray_DescrConverter); - DECL_NPY_API(PyArray_EquivTypes); - DECL_NPY_API(PyArray_GetArrayParamsFromObject); - DECL_NPY_API(PyArray_Squeeze); - DECL_NPY_API(PyArray_SetBaseObject); -#undef DECL_NPY_API - return api; - } -}; - -inline PyArray_Proxy* array_proxy(void* ptr) { - return reinterpret_cast(ptr); -} - -inline const PyArray_Proxy* array_proxy(const void* ptr) { - return reinterpret_cast(ptr); -} - -inline PyArrayDescr_Proxy* array_descriptor_proxy(PyObject* ptr) { - return reinterpret_cast(ptr); -} - -inline const PyArrayDescr_Proxy* array_descriptor_proxy(const PyObject* ptr) { - return reinterpret_cast(ptr); -} - -inline bool check_flags(const void* ptr, int flag) { - return (flag == (array_proxy(ptr)->flags & flag)); -} - -template struct is_std_array : std::false_type { }; -template struct is_std_array> : std::true_type { }; -template struct is_complex : std::false_type { }; -template struct is_complex> : std::true_type { }; - -template struct array_info_scalar { - typedef T type; - static constexpr bool is_array = false; - static constexpr bool is_empty = false; - static constexpr auto extents = _(""); - static void append_extents(list& /* shape */) { } -}; -// Computes underlying type and a comma-separated list of extents for array -// types (any mix of std::array and built-in arrays). An array of char is -// treated as scalar because it gets special handling. -template struct array_info : array_info_scalar { }; -template struct array_info> { - using type = typename array_info::type; - static constexpr bool is_array = true; - static constexpr bool is_empty = (N == 0) || array_info::is_empty; - static constexpr size_t extent = N; - - // appends the extents to shape - static void append_extents(list& shape) { - shape.append(N); - array_info::append_extents(shape); - } - - static constexpr auto extents = _::is_array>( - concat(_(), array_info::extents), _() - ); -}; -// For numpy we have special handling for arrays of characters, so we don't include -// the size in the array extents. -template struct array_info : array_info_scalar { }; -template struct array_info> : array_info_scalar> { }; -template struct array_info : array_info> { }; -template using remove_all_extents_t = typename array_info::type; - -template using is_pod_struct = all_of< - std::is_standard_layout, // since we're accessing directly in memory we need a standard layout type -#if !defined(__GNUG__) || defined(_LIBCPP_VERSION) || defined(_GLIBCXX_USE_CXX11_ABI) - // _GLIBCXX_USE_CXX11_ABI indicates that we're using libstdc++ from GCC 5 or newer, independent - // of the actual compiler (Clang can also use libstdc++, but it always defines __GNUC__ == 4). - std::is_trivially_copyable, -#else - // GCC 4 doesn't implement is_trivially_copyable, so approximate it - std::is_trivially_destructible, - satisfies_any_of, -#endif - satisfies_none_of ->; - -template ssize_t byte_offset_unsafe(const Strides &) { return 0; } -template -ssize_t byte_offset_unsafe(const Strides &strides, ssize_t i, Ix... index) { - return i * strides[Dim] + byte_offset_unsafe(strides, index...); -} - -/** - * Proxy class providing unsafe, unchecked const access to array data. This is constructed through - * the `unchecked()` method of `array` or the `unchecked()` method of `array_t`. `Dims` - * will be -1 for dimensions determined at runtime. - */ -template -class unchecked_reference { -protected: - static constexpr bool Dynamic = Dims < 0; - const unsigned char *data_; - // Storing the shape & strides in local variables (i.e. these arrays) allows the compiler to - // make large performance gains on big, nested loops, but requires compile-time dimensions - conditional_t> - shape_, strides_; - const ssize_t dims_; - - friend class pybind11::array; - // Constructor for compile-time dimensions: - template - unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t) - : data_{reinterpret_cast(data)}, dims_{Dims} { - for (size_t i = 0; i < (size_t) dims_; i++) { - shape_[i] = shape[i]; - strides_[i] = strides[i]; - } - } - // Constructor for runtime dimensions: - template - unchecked_reference(const void *data, const ssize_t *shape, const ssize_t *strides, enable_if_t dims) - : data_{reinterpret_cast(data)}, shape_{shape}, strides_{strides}, dims_{dims} {} - -public: - /** - * Unchecked const reference access to data at the given indices. For a compile-time known - * number of dimensions, this requires the correct number of arguments; for run-time - * dimensionality, this is not checked (and so is up to the caller to use safely). - */ - template const T &operator()(Ix... index) const { - static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic, - "Invalid number of indices for unchecked array reference"); - return *reinterpret_cast(data_ + byte_offset_unsafe(strides_, ssize_t(index)...)); - } - /** - * Unchecked const reference access to data; this operator only participates if the reference - * is to a 1-dimensional array. When present, this is exactly equivalent to `obj(index)`. - */ - template > - const T &operator[](ssize_t index) const { return operator()(index); } - - /// Pointer access to the data at the given indices. - template const T *data(Ix... ix) const { return &operator()(ssize_t(ix)...); } - - /// Returns the item size, i.e. sizeof(T) - constexpr static ssize_t itemsize() { return sizeof(T); } - - /// Returns the shape (i.e. size) of dimension `dim` - ssize_t shape(ssize_t dim) const { return shape_[(size_t) dim]; } - - /// Returns the number of dimensions of the array - ssize_t ndim() const { return dims_; } - - /// Returns the total number of elements in the referenced array, i.e. the product of the shapes - template - enable_if_t size() const { - return std::accumulate(shape_.begin(), shape_.end(), (ssize_t) 1, std::multiplies()); - } - template - enable_if_t size() const { - return std::accumulate(shape_, shape_ + ndim(), (ssize_t) 1, std::multiplies()); - } - - /// Returns the total number of bytes used by the referenced data. Note that the actual span in - /// memory may be larger if the referenced array has non-contiguous strides (e.g. for a slice). - ssize_t nbytes() const { - return size() * itemsize(); - } -}; - -template -class unchecked_mutable_reference : public unchecked_reference { - friend class pybind11::array; - using ConstBase = unchecked_reference; - using ConstBase::ConstBase; - using ConstBase::Dynamic; -public: - /// Mutable, unchecked access to data at the given indices. - template T& operator()(Ix... index) { - static_assert(ssize_t{sizeof...(Ix)} == Dims || Dynamic, - "Invalid number of indices for unchecked array reference"); - return const_cast(ConstBase::operator()(index...)); - } - /** - * Mutable, unchecked access data at the given index; this operator only participates if the - * reference is to a 1-dimensional array (or has runtime dimensions). When present, this is - * exactly equivalent to `obj(index)`. - */ - template > - T &operator[](ssize_t index) { return operator()(index); } - - /// Mutable pointer access to the data at the given indices. - template T *mutable_data(Ix... ix) { return &operator()(ssize_t(ix)...); } -}; - -template -struct type_caster> { - static_assert(Dim == 0 && Dim > 0 /* always fail */, "unchecked array proxy object is not castable"); -}; -template -struct type_caster> : type_caster> {}; - -NAMESPACE_END(detail) - -class dtype : public object { -public: - PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_); - - explicit dtype(const buffer_info &info) { - dtype descr(_dtype_from_pep3118()(PYBIND11_STR_TYPE(info.format))); - // If info.itemsize == 0, use the value calculated from the format string - m_ptr = descr.strip_padding(info.itemsize ? info.itemsize : descr.itemsize()).release().ptr(); - } - - explicit dtype(const std::string &format) { - m_ptr = from_args(pybind11::str(format)).release().ptr(); - } - - dtype(const char *format) : dtype(std::string(format)) { } - - dtype(list names, list formats, list offsets, ssize_t itemsize) { - dict args; - args["names"] = names; - args["formats"] = formats; - args["offsets"] = offsets; - args["itemsize"] = pybind11::int_(itemsize); - m_ptr = from_args(args).release().ptr(); - } - - /// This is essentially the same as calling numpy.dtype(args) in Python. - static dtype from_args(object args) { - PyObject *ptr = nullptr; - if (!detail::npy_api::get().PyArray_DescrConverter_(args.ptr(), &ptr) || !ptr) - throw error_already_set(); - return reinterpret_steal(ptr); - } - - /// Return dtype associated with a C++ type. - template static dtype of() { - return detail::npy_format_descriptor::type>::dtype(); - } - - /// Size of the data type in bytes. - ssize_t itemsize() const { - return detail::array_descriptor_proxy(m_ptr)->elsize; - } - - /// Returns true for structured data types. - bool has_fields() const { - return detail::array_descriptor_proxy(m_ptr)->names != nullptr; - } - - /// Single-character type code. - char kind() const { - return detail::array_descriptor_proxy(m_ptr)->kind; - } - -private: - static object _dtype_from_pep3118() { - static PyObject *obj = module::import("numpy.core._internal") - .attr("_dtype_from_pep3118").cast().release().ptr(); - return reinterpret_borrow(obj); - } - - dtype strip_padding(ssize_t itemsize) { - // Recursively strip all void fields with empty names that are generated for - // padding fields (as of NumPy v1.11). - if (!has_fields()) - return *this; - - struct field_descr { PYBIND11_STR_TYPE name; object format; pybind11::int_ offset; }; - std::vector field_descriptors; - - for (auto field : attr("fields").attr("items")()) { - auto spec = field.cast(); - auto name = spec[0].cast(); - auto format = spec[1].cast()[0].cast(); - auto offset = spec[1].cast()[1].cast(); - if (!len(name) && format.kind() == 'V') - continue; - field_descriptors.push_back({(PYBIND11_STR_TYPE) name, format.strip_padding(format.itemsize()), offset}); - } - - std::sort(field_descriptors.begin(), field_descriptors.end(), - [](const field_descr& a, const field_descr& b) { - return a.offset.cast() < b.offset.cast(); - }); - - list names, formats, offsets; - for (auto& descr : field_descriptors) { - names.append(descr.name); - formats.append(descr.format); - offsets.append(descr.offset); - } - return dtype(names, formats, offsets, itemsize); - } -}; - -class array : public buffer { -public: - PYBIND11_OBJECT_CVT(array, buffer, detail::npy_api::get().PyArray_Check_, raw_array) - - enum { - c_style = detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_, - f_style = detail::npy_api::NPY_ARRAY_F_CONTIGUOUS_, - forcecast = detail::npy_api::NPY_ARRAY_FORCECAST_ - }; - - array() : array({{0}}, static_cast(nullptr)) {} - - using ShapeContainer = detail::any_container; - using StridesContainer = detail::any_container; - - // Constructs an array taking shape/strides from arbitrary container types - array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides, - const void *ptr = nullptr, handle base = handle()) { - - if (strides->empty()) - *strides = c_strides(*shape, dt.itemsize()); - - auto ndim = shape->size(); - if (ndim != strides->size()) - pybind11_fail("NumPy: shape ndim doesn't match strides ndim"); - auto descr = dt; - - int flags = 0; - if (base && ptr) { - if (isinstance(base)) - /* Copy flags from base (except ownership bit) */ - flags = reinterpret_borrow(base).flags() & ~detail::npy_api::NPY_ARRAY_OWNDATA_; - else - /* Writable by default, easy to downgrade later on if needed */ - flags = detail::npy_api::NPY_ARRAY_WRITEABLE_; - } - - auto &api = detail::npy_api::get(); - auto tmp = reinterpret_steal(api.PyArray_NewFromDescr_( - api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(), - const_cast(ptr), flags, nullptr)); - if (!tmp) - throw error_already_set(); - if (ptr) { - if (base) { - api.PyArray_SetBaseObject_(tmp.ptr(), base.inc_ref().ptr()); - } else { - tmp = reinterpret_steal(api.PyArray_NewCopy_(tmp.ptr(), -1 /* any order */)); - } - } - m_ptr = tmp.release().ptr(); - } - - array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle()) - : array(dt, std::move(shape), {}, ptr, base) { } - - template ::value && !std::is_same::value>> - array(const pybind11::dtype &dt, T count, const void *ptr = nullptr, handle base = handle()) - : array(dt, {{count}}, ptr, base) { } - - template - array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle()) - : array(pybind11::dtype::of(), std::move(shape), std::move(strides), ptr, base) { } - - template - array(ShapeContainer shape, const T *ptr, handle base = handle()) - : array(std::move(shape), {}, ptr, base) { } - - template - explicit array(ssize_t count, const T *ptr, handle base = handle()) : array({count}, {}, ptr, base) { } - - explicit array(const buffer_info &info) - : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } - - /// Array descriptor (dtype) - pybind11::dtype dtype() const { - return reinterpret_borrow(detail::array_proxy(m_ptr)->descr); - } - - /// Total number of elements - ssize_t size() const { - return std::accumulate(shape(), shape() + ndim(), (ssize_t) 1, std::multiplies()); - } - - /// Byte size of a single element - ssize_t itemsize() const { - return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize; - } - - /// Total number of bytes - ssize_t nbytes() const { - return size() * itemsize(); - } - - /// Number of dimensions - ssize_t ndim() const { - return detail::array_proxy(m_ptr)->nd; - } - - /// Base object - object base() const { - return reinterpret_borrow(detail::array_proxy(m_ptr)->base); - } - - /// Dimensions of the array - const ssize_t* shape() const { - return detail::array_proxy(m_ptr)->dimensions; - } - - /// Dimension along a given axis - ssize_t shape(ssize_t dim) const { - if (dim >= ndim()) - fail_dim_check(dim, "invalid axis"); - return shape()[dim]; - } - - /// Strides of the array - const ssize_t* strides() const { - return detail::array_proxy(m_ptr)->strides; - } - - /// Stride along a given axis - ssize_t strides(ssize_t dim) const { - if (dim >= ndim()) - fail_dim_check(dim, "invalid axis"); - return strides()[dim]; - } - - /// Return the NumPy array flags - int flags() const { - return detail::array_proxy(m_ptr)->flags; - } - - /// If set, the array is writeable (otherwise the buffer is read-only) - bool writeable() const { - return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_WRITEABLE_); - } - - /// If set, the array owns the data (will be freed when the array is deleted) - bool owndata() const { - return detail::check_flags(m_ptr, detail::npy_api::NPY_ARRAY_OWNDATA_); - } - - /// Pointer to the contained data. If index is not provided, points to the - /// beginning of the buffer. May throw if the index would lead to out of bounds access. - template const void* data(Ix... index) const { - return static_cast(detail::array_proxy(m_ptr)->data + offset_at(index...)); - } - - /// Mutable pointer to the contained data. If index is not provided, points to the - /// beginning of the buffer. May throw if the index would lead to out of bounds access. - /// May throw if the array is not writeable. - template void* mutable_data(Ix... index) { - check_writeable(); - return static_cast(detail::array_proxy(m_ptr)->data + offset_at(index...)); - } - - /// Byte offset from beginning of the array to a given index (full or partial). - /// May throw if the index would lead to out of bounds access. - template ssize_t offset_at(Ix... index) const { - if ((ssize_t) sizeof...(index) > ndim()) - fail_dim_check(sizeof...(index), "too many indices for an array"); - return byte_offset(ssize_t(index)...); - } - - ssize_t offset_at() const { return 0; } - - /// Item count from beginning of the array to a given index (full or partial). - /// May throw if the index would lead to out of bounds access. - template ssize_t index_at(Ix... index) const { - return offset_at(index...) / itemsize(); - } - - /** - * Returns a proxy object that provides access to the array's data without bounds or - * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with - * care: the array must not be destroyed or reshaped for the duration of the returned object, - * and the caller must take care not to access invalid dimensions or dimension indices. - */ - template detail::unchecked_mutable_reference mutable_unchecked() & { - if (Dims >= 0 && ndim() != Dims) - throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + - "; expected " + std::to_string(Dims)); - return detail::unchecked_mutable_reference(mutable_data(), shape(), strides(), ndim()); - } - - /** - * Returns a proxy object that provides const access to the array's data without bounds or - * dimensionality checking. Unlike `mutable_unchecked()`, this does not require that the - * underlying array have the `writable` flag. Use with care: the array must not be destroyed or - * reshaped for the duration of the returned object, and the caller must take care not to access - * invalid dimensions or dimension indices. - */ - template detail::unchecked_reference unchecked() const & { - if (Dims >= 0 && ndim() != Dims) - throw std::domain_error("array has incorrect number of dimensions: " + std::to_string(ndim()) + - "; expected " + std::to_string(Dims)); - return detail::unchecked_reference(data(), shape(), strides(), ndim()); - } - - /// Return a new view with all of the dimensions of length 1 removed - array squeeze() { - auto& api = detail::npy_api::get(); - return reinterpret_steal(api.PyArray_Squeeze_(m_ptr)); - } - - /// Resize array to given shape - /// If refcheck is true and more that one reference exist to this array - /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change - void resize(ShapeContainer new_shape, bool refcheck = true) { - detail::npy_api::PyArray_Dims d = { - new_shape->data(), int(new_shape->size()) - }; - // try to resize, set ordering param to -1 cause it's not used anyway - object new_array = reinterpret_steal( - detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1) - ); - if (!new_array) throw error_already_set(); - if (isinstance(new_array)) { *this = std::move(new_array); } - } - - /// Ensure that the argument is a NumPy array - /// In case of an error, nullptr is returned and the Python error is cleared. - static array ensure(handle h, int ExtraFlags = 0) { - auto result = reinterpret_steal(raw_array(h.ptr(), ExtraFlags)); - if (!result) - PyErr_Clear(); - return result; - } - -protected: - template friend struct detail::npy_format_descriptor; - - void fail_dim_check(ssize_t dim, const std::string& msg) const { - throw index_error(msg + ": " + std::to_string(dim) + - " (ndim = " + std::to_string(ndim()) + ")"); - } - - template ssize_t byte_offset(Ix... index) const { - check_dimensions(index...); - return detail::byte_offset_unsafe(strides(), ssize_t(index)...); - } - - void check_writeable() const { - if (!writeable()) - throw std::domain_error("array is not writeable"); - } - - // Default, C-style strides - static std::vector c_strides(const std::vector &shape, ssize_t itemsize) { - auto ndim = shape.size(); - std::vector strides(ndim, itemsize); - if (ndim > 0) - for (size_t i = ndim - 1; i > 0; --i) - strides[i - 1] = strides[i] * shape[i]; - return strides; - } - - // F-style strides; default when constructing an array_t with `ExtraFlags & f_style` - static std::vector f_strides(const std::vector &shape, ssize_t itemsize) { - auto ndim = shape.size(); - std::vector strides(ndim, itemsize); - for (size_t i = 1; i < ndim; ++i) - strides[i] = strides[i - 1] * shape[i - 1]; - return strides; - } - - template void check_dimensions(Ix... index) const { - check_dimensions_impl(ssize_t(0), shape(), ssize_t(index)...); - } - - void check_dimensions_impl(ssize_t, const ssize_t*) const { } - - template void check_dimensions_impl(ssize_t axis, const ssize_t* shape, ssize_t i, Ix... index) const { - if (i >= *shape) { - throw index_error(std::string("index ") + std::to_string(i) + - " is out of bounds for axis " + std::to_string(axis) + - " with size " + std::to_string(*shape)); - } - check_dimensions_impl(axis + 1, shape + 1, index...); - } - - /// Create array from any object -- always returns a new reference - static PyObject *raw_array(PyObject *ptr, int ExtraFlags = 0) { - if (ptr == nullptr) { - PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array from a nullptr"); - return nullptr; - } - return detail::npy_api::get().PyArray_FromAny_( - ptr, nullptr, 0, 0, detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr); - } -}; - -template class array_t : public array { -private: - struct private_ctor {}; - // Delegating constructor needed when both moving and accessing in the same constructor - array_t(private_ctor, ShapeContainer &&shape, StridesContainer &&strides, const T *ptr, handle base) - : array(std::move(shape), std::move(strides), ptr, base) {} -public: - static_assert(!detail::array_info::is_array, "Array types cannot be used with array_t"); - - using value_type = T; - - array_t() : array(0, static_cast(nullptr)) {} - array_t(handle h, borrowed_t) : array(h, borrowed_t{}) { } - array_t(handle h, stolen_t) : array(h, stolen_t{}) { } - - PYBIND11_DEPRECATED("Use array_t::ensure() instead") - array_t(handle h, bool is_borrowed) : array(raw_array_t(h.ptr()), stolen_t{}) { - if (!m_ptr) PyErr_Clear(); - if (!is_borrowed) Py_XDECREF(h.ptr()); - } - - array_t(const object &o) : array(raw_array_t(o.ptr()), stolen_t{}) { - if (!m_ptr) throw error_already_set(); - } - - explicit array_t(const buffer_info& info) : array(info) { } - - array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle()) - : array(std::move(shape), std::move(strides), ptr, base) { } - - explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle()) - : array_t(private_ctor{}, std::move(shape), - ExtraFlags & f_style ? f_strides(*shape, itemsize()) : c_strides(*shape, itemsize()), - ptr, base) { } - - explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle()) - : array({count}, {}, ptr, base) { } - - constexpr ssize_t itemsize() const { - return sizeof(T); - } - - template ssize_t index_at(Ix... index) const { - return offset_at(index...) / itemsize(); - } - - template const T* data(Ix... index) const { - return static_cast(array::data(index...)); - } - - template T* mutable_data(Ix... index) { - return static_cast(array::mutable_data(index...)); - } - - // Reference to element at a given index - template const T& at(Ix... index) const { - if ((ssize_t) sizeof...(index) != ndim()) - fail_dim_check(sizeof...(index), "index dimension mismatch"); - return *(static_cast(array::data()) + byte_offset(ssize_t(index)...) / itemsize()); - } - - // Mutable reference to element at a given index - template T& mutable_at(Ix... index) { - if ((ssize_t) sizeof...(index) != ndim()) - fail_dim_check(sizeof...(index), "index dimension mismatch"); - return *(static_cast(array::mutable_data()) + byte_offset(ssize_t(index)...) / itemsize()); - } - - /** - * Returns a proxy object that provides access to the array's data without bounds or - * dimensionality checking. Will throw if the array is missing the `writeable` flag. Use with - * care: the array must not be destroyed or reshaped for the duration of the returned object, - * and the caller must take care not to access invalid dimensions or dimension indices. - */ - template detail::unchecked_mutable_reference mutable_unchecked() & { - return array::mutable_unchecked(); - } - - /** - * Returns a proxy object that provides const access to the array's data without bounds or - * dimensionality checking. Unlike `unchecked()`, this does not require that the underlying - * array have the `writable` flag. Use with care: the array must not be destroyed or reshaped - * for the duration of the returned object, and the caller must take care not to access invalid - * dimensions or dimension indices. - */ - template detail::unchecked_reference unchecked() const & { - return array::unchecked(); - } - - /// Ensure that the argument is a NumPy array of the correct dtype (and if not, try to convert - /// it). In case of an error, nullptr is returned and the Python error is cleared. - static array_t ensure(handle h) { - auto result = reinterpret_steal(raw_array_t(h.ptr())); - if (!result) - PyErr_Clear(); - return result; - } - - static bool check_(handle h) { - const auto &api = detail::npy_api::get(); - return api.PyArray_Check_(h.ptr()) - && api.PyArray_EquivTypes_(detail::array_proxy(h.ptr())->descr, dtype::of().ptr()); - } - -protected: - /// Create array from any object -- always returns a new reference - static PyObject *raw_array_t(PyObject *ptr) { - if (ptr == nullptr) { - PyErr_SetString(PyExc_ValueError, "cannot create a pybind11::array_t from a nullptr"); - return nullptr; - } - return detail::npy_api::get().PyArray_FromAny_( - ptr, dtype::of().release().ptr(), 0, 0, - detail::npy_api::NPY_ARRAY_ENSUREARRAY_ | ExtraFlags, nullptr); - } -}; - -template -struct format_descriptor::value>> { - static std::string format() { - return detail::npy_format_descriptor::type>::format(); - } -}; - -template struct format_descriptor { - static std::string format() { return std::to_string(N) + "s"; } -}; -template struct format_descriptor> { - static std::string format() { return std::to_string(N) + "s"; } -}; - -template -struct format_descriptor::value>> { - static std::string format() { - return format_descriptor< - typename std::remove_cv::type>::type>::format(); - } -}; - -template -struct format_descriptor::is_array>> { - static std::string format() { - using namespace detail; - static constexpr auto extents = _("(") + array_info::extents + _(")"); - return extents.text + format_descriptor>::format(); - } -}; - -NAMESPACE_BEGIN(detail) -template -struct pyobject_caster> { - using type = array_t; - - bool load(handle src, bool convert) { - if (!convert && !type::check_(src)) - return false; - value = type::ensure(src); - return static_cast(value); - } - - static handle cast(const handle &src, return_value_policy /* policy */, handle /* parent */) { - return src.inc_ref(); - } - PYBIND11_TYPE_CASTER(type, handle_type_name::name); -}; - -template -struct compare_buffer_info::value>> { - static bool compare(const buffer_info& b) { - return npy_api::get().PyArray_EquivTypes_(dtype::of().ptr(), dtype(b).ptr()); - } -}; - -template -struct npy_format_descriptor_name; - -template -struct npy_format_descriptor_name::value>> { - static constexpr auto name = _::value>( - _("bool"), _::value>("int", "uint") + _() - ); -}; - -template -struct npy_format_descriptor_name::value>> { - static constexpr auto name = _::value || std::is_same::value>( - _("float") + _(), _("longdouble") - ); -}; - -template -struct npy_format_descriptor_name::value>> { - static constexpr auto name = _::value - || std::is_same::value>( - _("complex") + _(), _("longcomplex") - ); -}; - -template -struct npy_format_descriptor::value>> - : npy_format_descriptor_name { -private: - // NB: the order here must match the one in common.h - constexpr static const int values[15] = { - npy_api::NPY_BOOL_, - npy_api::NPY_BYTE_, npy_api::NPY_UBYTE_, npy_api::NPY_INT16_, npy_api::NPY_UINT16_, - npy_api::NPY_INT32_, npy_api::NPY_UINT32_, npy_api::NPY_INT64_, npy_api::NPY_UINT64_, - npy_api::NPY_FLOAT_, npy_api::NPY_DOUBLE_, npy_api::NPY_LONGDOUBLE_, - npy_api::NPY_CFLOAT_, npy_api::NPY_CDOUBLE_, npy_api::NPY_CLONGDOUBLE_ - }; - -public: - static constexpr int value = values[detail::is_fmt_numeric::index]; - - static pybind11::dtype dtype() { - if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) - return reinterpret_steal(ptr); - pybind11_fail("Unsupported buffer format!"); - } -}; - -#define PYBIND11_DECL_CHAR_FMT \ - static constexpr auto name = _("S") + _(); \ - static pybind11::dtype dtype() { return pybind11::dtype(std::string("S") + std::to_string(N)); } -template struct npy_format_descriptor { PYBIND11_DECL_CHAR_FMT }; -template struct npy_format_descriptor> { PYBIND11_DECL_CHAR_FMT }; -#undef PYBIND11_DECL_CHAR_FMT - -template struct npy_format_descriptor::is_array>> { -private: - using base_descr = npy_format_descriptor::type>; -public: - static_assert(!array_info::is_empty, "Zero-sized arrays are not supported"); - - static constexpr auto name = _("(") + array_info::extents + _(")") + base_descr::name; - static pybind11::dtype dtype() { - list shape; - array_info::append_extents(shape); - return pybind11::dtype::from_args(pybind11::make_tuple(base_descr::dtype(), shape)); - } -}; - -template struct npy_format_descriptor::value>> { -private: - using base_descr = npy_format_descriptor::type>; -public: - static constexpr auto name = base_descr::name; - static pybind11::dtype dtype() { return base_descr::dtype(); } -}; - -struct field_descriptor { - const char *name; - ssize_t offset; - ssize_t size; - std::string format; - dtype descr; -}; - -inline PYBIND11_NOINLINE void register_structured_dtype( - any_container fields, - const std::type_info& tinfo, ssize_t itemsize, - bool (*direct_converter)(PyObject *, void *&)) { - - auto& numpy_internals = get_numpy_internals(); - if (numpy_internals.get_type_info(tinfo, false)) - pybind11_fail("NumPy: dtype is already registered"); - - // Use ordered fields because order matters as of NumPy 1.14: - // https://docs.scipy.org/doc/numpy/release.html#multiple-field-indexing-assignment-of-structured-arrays - std::vector ordered_fields(std::move(fields)); - std::sort(ordered_fields.begin(), ordered_fields.end(), - [](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; }); - - list names, formats, offsets; - for (auto& field : ordered_fields) { - if (!field.descr) - pybind11_fail(std::string("NumPy: unsupported field dtype: `") + - field.name + "` @ " + tinfo.name()); - names.append(PYBIND11_STR_TYPE(field.name)); - formats.append(field.descr); - offsets.append(pybind11::int_(field.offset)); - } - auto dtype_ptr = pybind11::dtype(names, formats, offsets, itemsize).release().ptr(); - - // There is an existing bug in NumPy (as of v1.11): trailing bytes are - // not encoded explicitly into the format string. This will supposedly - // get fixed in v1.12; for further details, see these: - // - https://github.com/numpy/numpy/issues/7797 - // - https://github.com/numpy/numpy/pull/7798 - // Because of this, we won't use numpy's logic to generate buffer format - // strings and will just do it ourselves. - ssize_t offset = 0; - std::ostringstream oss; - // mark the structure as unaligned with '^', because numpy and C++ don't - // always agree about alignment (particularly for complex), and we're - // explicitly listing all our padding. This depends on none of the fields - // overriding the endianness. Putting the ^ in front of individual fields - // isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049 - oss << "^T{"; - for (auto& field : ordered_fields) { - if (field.offset > offset) - oss << (field.offset - offset) << 'x'; - oss << field.format << ':' << field.name << ':'; - offset = field.offset + field.size; - } - if (itemsize > offset) - oss << (itemsize - offset) << 'x'; - oss << '}'; - auto format_str = oss.str(); - - // Sanity check: verify that NumPy properly parses our buffer format string - auto& api = npy_api::get(); - auto arr = array(buffer_info(nullptr, itemsize, format_str, 1)); - if (!api.PyArray_EquivTypes_(dtype_ptr, arr.dtype().ptr())) - pybind11_fail("NumPy: invalid buffer descriptor!"); - - auto tindex = std::type_index(tinfo); - numpy_internals.registered_dtypes[tindex] = { dtype_ptr, format_str }; - get_internals().direct_conversions[tindex].push_back(direct_converter); -} - -template struct npy_format_descriptor { - static_assert(is_pod_struct::value, "Attempt to use a non-POD or unimplemented POD type as a numpy dtype"); - - static constexpr auto name = make_caster::name; - - static pybind11::dtype dtype() { - return reinterpret_borrow(dtype_ptr()); - } - - static std::string format() { - static auto format_str = get_numpy_internals().get_type_info(true)->format_str; - return format_str; - } - - static void register_dtype(any_container fields) { - register_structured_dtype(std::move(fields), typeid(typename std::remove_cv::type), - sizeof(T), &direct_converter); - } - -private: - static PyObject* dtype_ptr() { - static PyObject* ptr = get_numpy_internals().get_type_info(true)->dtype_ptr; - return ptr; - } - - static bool direct_converter(PyObject *obj, void*& value) { - auto& api = npy_api::get(); - if (!PyObject_TypeCheck(obj, api.PyVoidArrType_Type_)) - return false; - if (auto descr = reinterpret_steal(api.PyArray_DescrFromScalar_(obj))) { - if (api.PyArray_EquivTypes_(dtype_ptr(), descr.ptr())) { - value = ((PyVoidScalarObject_Proxy *) obj)->obval; - return true; - } - } - return false; - } -}; - -#ifdef __CLION_IDE__ // replace heavy macro with dummy code for the IDE (doesn't affect code) -# define PYBIND11_NUMPY_DTYPE(Type, ...) ((void)0) -# define PYBIND11_NUMPY_DTYPE_EX(Type, ...) ((void)0) -#else - -#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \ - ::pybind11::detail::field_descriptor { \ - Name, offsetof(T, Field), sizeof(decltype(std::declval().Field)), \ - ::pybind11::format_descriptor().Field)>::format(), \ - ::pybind11::detail::npy_format_descriptor().Field)>::dtype() \ - } - -// Extract name, offset and format descriptor for a struct field -#define PYBIND11_FIELD_DESCRIPTOR(T, Field) PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, #Field) - -// The main idea of this macro is borrowed from https://github.com/swansontec/map-macro -// (C) William Swanson, Paul Fultz -#define PYBIND11_EVAL0(...) __VA_ARGS__ -#define PYBIND11_EVAL1(...) PYBIND11_EVAL0 (PYBIND11_EVAL0 (PYBIND11_EVAL0 (__VA_ARGS__))) -#define PYBIND11_EVAL2(...) PYBIND11_EVAL1 (PYBIND11_EVAL1 (PYBIND11_EVAL1 (__VA_ARGS__))) -#define PYBIND11_EVAL3(...) PYBIND11_EVAL2 (PYBIND11_EVAL2 (PYBIND11_EVAL2 (__VA_ARGS__))) -#define PYBIND11_EVAL4(...) PYBIND11_EVAL3 (PYBIND11_EVAL3 (PYBIND11_EVAL3 (__VA_ARGS__))) -#define PYBIND11_EVAL(...) PYBIND11_EVAL4 (PYBIND11_EVAL4 (PYBIND11_EVAL4 (__VA_ARGS__))) -#define PYBIND11_MAP_END(...) -#define PYBIND11_MAP_OUT -#define PYBIND11_MAP_COMMA , -#define PYBIND11_MAP_GET_END() 0, PYBIND11_MAP_END -#define PYBIND11_MAP_NEXT0(test, next, ...) next PYBIND11_MAP_OUT -#define PYBIND11_MAP_NEXT1(test, next) PYBIND11_MAP_NEXT0 (test, next, 0) -#define PYBIND11_MAP_NEXT(test, next) PYBIND11_MAP_NEXT1 (PYBIND11_MAP_GET_END test, next) -#ifdef _MSC_VER // MSVC is not as eager to expand macros, hence this workaround -#define PYBIND11_MAP_LIST_NEXT1(test, next) \ - PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)) -#else -#define PYBIND11_MAP_LIST_NEXT1(test, next) \ - PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0) -#endif -#define PYBIND11_MAP_LIST_NEXT(test, next) \ - PYBIND11_MAP_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next) -#define PYBIND11_MAP_LIST0(f, t, x, peek, ...) \ - f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST1) (f, t, peek, __VA_ARGS__) -#define PYBIND11_MAP_LIST1(f, t, x, peek, ...) \ - f(t, x) PYBIND11_MAP_LIST_NEXT (peek, PYBIND11_MAP_LIST0) (f, t, peek, __VA_ARGS__) -// PYBIND11_MAP_LIST(f, t, a1, a2, ...) expands to f(t, a1), f(t, a2), ... -#define PYBIND11_MAP_LIST(f, t, ...) \ - PYBIND11_EVAL (PYBIND11_MAP_LIST1 (f, t, __VA_ARGS__, (), 0)) - -#define PYBIND11_NUMPY_DTYPE(Type, ...) \ - ::pybind11::detail::npy_format_descriptor::register_dtype \ - (::std::vector<::pybind11::detail::field_descriptor> \ - {PYBIND11_MAP_LIST (PYBIND11_FIELD_DESCRIPTOR, Type, __VA_ARGS__)}) - -#ifdef _MSC_VER -#define PYBIND11_MAP2_LIST_NEXT1(test, next) \ - PYBIND11_EVAL0 (PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0)) -#else -#define PYBIND11_MAP2_LIST_NEXT1(test, next) \ - PYBIND11_MAP_NEXT0 (test, PYBIND11_MAP_COMMA next, 0) -#endif -#define PYBIND11_MAP2_LIST_NEXT(test, next) \ - PYBIND11_MAP2_LIST_NEXT1 (PYBIND11_MAP_GET_END test, next) -#define PYBIND11_MAP2_LIST0(f, t, x1, x2, peek, ...) \ - f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST1) (f, t, peek, __VA_ARGS__) -#define PYBIND11_MAP2_LIST1(f, t, x1, x2, peek, ...) \ - f(t, x1, x2) PYBIND11_MAP2_LIST_NEXT (peek, PYBIND11_MAP2_LIST0) (f, t, peek, __VA_ARGS__) -// PYBIND11_MAP2_LIST(f, t, a1, a2, ...) expands to f(t, a1, a2), f(t, a3, a4), ... -#define PYBIND11_MAP2_LIST(f, t, ...) \ - PYBIND11_EVAL (PYBIND11_MAP2_LIST1 (f, t, __VA_ARGS__, (), 0)) - -#define PYBIND11_NUMPY_DTYPE_EX(Type, ...) \ - ::pybind11::detail::npy_format_descriptor::register_dtype \ - (::std::vector<::pybind11::detail::field_descriptor> \ - {PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)}) - -#endif // __CLION_IDE__ - -template -using array_iterator = typename std::add_pointer::type; - -template -array_iterator array_begin(const buffer_info& buffer) { - return array_iterator(reinterpret_cast(buffer.ptr)); -} - -template -array_iterator array_end(const buffer_info& buffer) { - return array_iterator(reinterpret_cast(buffer.ptr) + buffer.size); -} - -class common_iterator { -public: - using container_type = std::vector; - using value_type = container_type::value_type; - using size_type = container_type::size_type; - - common_iterator() : p_ptr(0), m_strides() {} - - common_iterator(void* ptr, const container_type& strides, const container_type& shape) - : p_ptr(reinterpret_cast(ptr)), m_strides(strides.size()) { - m_strides.back() = static_cast(strides.back()); - for (size_type i = m_strides.size() - 1; i != 0; --i) { - size_type j = i - 1; - value_type s = static_cast(shape[i]); - m_strides[j] = strides[j] + m_strides[i] - strides[i] * s; - } - } - - void increment(size_type dim) { - p_ptr += m_strides[dim]; - } - - void* data() const { - return p_ptr; - } - -private: - char* p_ptr; - container_type m_strides; -}; - -template class multi_array_iterator { -public: - using container_type = std::vector; - - multi_array_iterator(const std::array &buffers, - const container_type &shape) - : m_shape(shape.size()), m_index(shape.size(), 0), - m_common_iterator() { - - // Manual copy to avoid conversion warning if using std::copy - for (size_t i = 0; i < shape.size(); ++i) - m_shape[i] = shape[i]; - - container_type strides(shape.size()); - for (size_t i = 0; i < N; ++i) - init_common_iterator(buffers[i], shape, m_common_iterator[i], strides); - } - - multi_array_iterator& operator++() { - for (size_t j = m_index.size(); j != 0; --j) { - size_t i = j - 1; - if (++m_index[i] != m_shape[i]) { - increment_common_iterator(i); - break; - } else { - m_index[i] = 0; - } - } - return *this; - } - - template T* data() const { - return reinterpret_cast(m_common_iterator[K].data()); - } - -private: - - using common_iter = common_iterator; - - void init_common_iterator(const buffer_info &buffer, - const container_type &shape, - common_iter &iterator, - container_type &strides) { - auto buffer_shape_iter = buffer.shape.rbegin(); - auto buffer_strides_iter = buffer.strides.rbegin(); - auto shape_iter = shape.rbegin(); - auto strides_iter = strides.rbegin(); - - while (buffer_shape_iter != buffer.shape.rend()) { - if (*shape_iter == *buffer_shape_iter) - *strides_iter = *buffer_strides_iter; - else - *strides_iter = 0; - - ++buffer_shape_iter; - ++buffer_strides_iter; - ++shape_iter; - ++strides_iter; - } - - std::fill(strides_iter, strides.rend(), 0); - iterator = common_iter(buffer.ptr, strides, shape); - } - - void increment_common_iterator(size_t dim) { - for (auto &iter : m_common_iterator) - iter.increment(dim); - } - - container_type m_shape; - container_type m_index; - std::array m_common_iterator; -}; - -enum class broadcast_trivial { non_trivial, c_trivial, f_trivial }; - -// Populates the shape and number of dimensions for the set of buffers. Returns a broadcast_trivial -// enum value indicating whether the broadcast is "trivial"--that is, has each buffer being either a -// singleton or a full-size, C-contiguous (`c_trivial`) or Fortran-contiguous (`f_trivial`) storage -// buffer; returns `non_trivial` otherwise. -template -broadcast_trivial broadcast(const std::array &buffers, ssize_t &ndim, std::vector &shape) { - ndim = std::accumulate(buffers.begin(), buffers.end(), ssize_t(0), [](ssize_t res, const buffer_info &buf) { - return std::max(res, buf.ndim); - }); - - shape.clear(); - shape.resize((size_t) ndim, 1); - - // Figure out the output size, and make sure all input arrays conform (i.e. are either size 1 or - // the full size). - for (size_t i = 0; i < N; ++i) { - auto res_iter = shape.rbegin(); - auto end = buffers[i].shape.rend(); - for (auto shape_iter = buffers[i].shape.rbegin(); shape_iter != end; ++shape_iter, ++res_iter) { - const auto &dim_size_in = *shape_iter; - auto &dim_size_out = *res_iter; - - // Each input dimension can either be 1 or `n`, but `n` values must match across buffers - if (dim_size_out == 1) - dim_size_out = dim_size_in; - else if (dim_size_in != 1 && dim_size_in != dim_size_out) - pybind11_fail("pybind11::vectorize: incompatible size/dimension of inputs!"); - } - } - - bool trivial_broadcast_c = true; - bool trivial_broadcast_f = true; - for (size_t i = 0; i < N && (trivial_broadcast_c || trivial_broadcast_f); ++i) { - if (buffers[i].size == 1) - continue; - - // Require the same number of dimensions: - if (buffers[i].ndim != ndim) - return broadcast_trivial::non_trivial; - - // Require all dimensions be full-size: - if (!std::equal(buffers[i].shape.cbegin(), buffers[i].shape.cend(), shape.cbegin())) - return broadcast_trivial::non_trivial; - - // Check for C contiguity (but only if previous inputs were also C contiguous) - if (trivial_broadcast_c) { - ssize_t expect_stride = buffers[i].itemsize; - auto end = buffers[i].shape.crend(); - for (auto shape_iter = buffers[i].shape.crbegin(), stride_iter = buffers[i].strides.crbegin(); - trivial_broadcast_c && shape_iter != end; ++shape_iter, ++stride_iter) { - if (expect_stride == *stride_iter) - expect_stride *= *shape_iter; - else - trivial_broadcast_c = false; - } - } - - // Check for Fortran contiguity (if previous inputs were also F contiguous) - if (trivial_broadcast_f) { - ssize_t expect_stride = buffers[i].itemsize; - auto end = buffers[i].shape.cend(); - for (auto shape_iter = buffers[i].shape.cbegin(), stride_iter = buffers[i].strides.cbegin(); - trivial_broadcast_f && shape_iter != end; ++shape_iter, ++stride_iter) { - if (expect_stride == *stride_iter) - expect_stride *= *shape_iter; - else - trivial_broadcast_f = false; - } - } - } - - return - trivial_broadcast_c ? broadcast_trivial::c_trivial : - trivial_broadcast_f ? broadcast_trivial::f_trivial : - broadcast_trivial::non_trivial; -} - -template -struct vectorize_arg { - static_assert(!std::is_rvalue_reference::value, "Functions with rvalue reference arguments cannot be vectorized"); - // The wrapped function gets called with this type: - using call_type = remove_reference_t; - // Is this a vectorized argument? - static constexpr bool vectorize = - satisfies_any_of::value && - satisfies_none_of::value && - (!std::is_reference::value || - (std::is_lvalue_reference::value && std::is_const::value)); - // Accept this type: an array for vectorized types, otherwise the type as-is: - using type = conditional_t, array::forcecast>, T>; -}; - -template -struct vectorize_helper { -private: - static constexpr size_t N = sizeof...(Args); - static constexpr size_t NVectorized = constexpr_sum(vectorize_arg::vectorize...); - static_assert(NVectorized >= 1, - "pybind11::vectorize(...) requires a function with at least one vectorizable argument"); - -public: - template - explicit vectorize_helper(T &&f) : f(std::forward(f)) { } - - object operator()(typename vectorize_arg::type... args) { - return run(args..., - make_index_sequence(), - select_indices::vectorize...>(), - make_index_sequence()); - } - -private: - remove_reference_t f; - - // Internal compiler error in MSVC 19.16.27025.1 (Visual Studio 2017 15.9.4), when compiling with "/permissive-" flag - // when arg_call_types is manually inlined. - using arg_call_types = std::tuple::call_type...>; - template using param_n_t = typename std::tuple_element::type; - - // Runs a vectorized function given arguments tuple and three index sequences: - // - Index is the full set of 0 ... (N-1) argument indices; - // - VIndex is the subset of argument indices with vectorized parameters, letting us access - // vectorized arguments (anything not in this sequence is passed through) - // - BIndex is a incremental sequence (beginning at 0) of the same size as VIndex, so that - // we can store vectorized buffer_infos in an array (argument VIndex has its buffer at - // index BIndex in the array). - template object run( - typename vectorize_arg::type &...args, - index_sequence i_seq, index_sequence vi_seq, index_sequence bi_seq) { - - // Pointers to values the function was called with; the vectorized ones set here will start - // out as array_t pointers, but they will be changed them to T pointers before we make - // call the wrapped function. Non-vectorized pointers are left as-is. - std::array params{{ &args... }}; - - // The array of `buffer_info`s of vectorized arguments: - std::array buffers{{ reinterpret_cast(params[VIndex])->request()... }}; - - /* Determine dimensions parameters of output array */ - ssize_t nd = 0; - std::vector shape(0); - auto trivial = broadcast(buffers, nd, shape); - size_t ndim = (size_t) nd; - - size_t size = std::accumulate(shape.begin(), shape.end(), (size_t) 1, std::multiplies()); - - // If all arguments are 0-dimension arrays (i.e. single values) return a plain value (i.e. - // not wrapped in an array). - if (size == 1 && ndim == 0) { - PYBIND11_EXPAND_SIDE_EFFECTS(params[VIndex] = buffers[BIndex].ptr); - return cast(f(*reinterpret_cast *>(params[Index])...)); - } - - array_t result; - if (trivial == broadcast_trivial::f_trivial) result = array_t(shape); - else result = array_t(shape); - - if (size == 0) return std::move(result); - - /* Call the function */ - if (trivial == broadcast_trivial::non_trivial) - apply_broadcast(buffers, params, result, i_seq, vi_seq, bi_seq); - else - apply_trivial(buffers, params, result.mutable_data(), size, i_seq, vi_seq, bi_seq); - - return std::move(result); - } - - template - void apply_trivial(std::array &buffers, - std::array ¶ms, - Return *out, - size_t size, - index_sequence, index_sequence, index_sequence) { - - // Initialize an array of mutable byte references and sizes with references set to the - // appropriate pointer in `params`; as we iterate, we'll increment each pointer by its size - // (except for singletons, which get an increment of 0). - std::array, NVectorized> vecparams{{ - std::pair( - reinterpret_cast(params[VIndex] = buffers[BIndex].ptr), - buffers[BIndex].size == 1 ? 0 : sizeof(param_n_t) - )... - }}; - - for (size_t i = 0; i < size; ++i) { - out[i] = f(*reinterpret_cast *>(params[Index])...); - for (auto &x : vecparams) x.first += x.second; - } - } - - template - void apply_broadcast(std::array &buffers, - std::array ¶ms, - array_t &output_array, - index_sequence, index_sequence, index_sequence) { - - buffer_info output = output_array.request(); - multi_array_iterator input_iter(buffers, output.shape); - - for (array_iterator iter = array_begin(output), end = array_end(output); - iter != end; - ++iter, ++input_iter) { - PYBIND11_EXPAND_SIDE_EFFECTS(( - params[VIndex] = input_iter.template data() - )); - *iter = f(*reinterpret_cast *>(std::get(params))...); - } - } -}; - -template -vectorize_helper -vectorize_extractor(const Func &f, Return (*) (Args ...)) { - return detail::vectorize_helper(f); -} - -template struct handle_type_name> { - static constexpr auto name = _("numpy.ndarray[") + npy_format_descriptor::name + _("]"); -}; - -NAMESPACE_END(detail) - -// Vanilla pointer vectorizer: -template -detail::vectorize_helper -vectorize(Return (*f) (Args ...)) { - return detail::vectorize_helper(f); -} - -// lambda vectorizer: -template ::value, int> = 0> -auto vectorize(Func &&f) -> decltype( - detail::vectorize_extractor(std::forward(f), (detail::function_signature_t *) nullptr)) { - return detail::vectorize_extractor(std::forward(f), (detail::function_signature_t *) nullptr); -} - -// Vectorize a class method (non-const): -template ())), Return, Class *, Args...>> -Helper vectorize(Return (Class::*f)(Args...)) { - return Helper(std::mem_fn(f)); -} - -// Vectorize a class method (const): -template ())), Return, const Class *, Args...>> -Helper vectorize(Return (Class::*f)(Args...) const) { - return Helper(std::mem_fn(f)); -} - -NAMESPACE_END(PYBIND11_NAMESPACE) - -#if defined(_MSC_VER) -#pragma warning(pop) -#endif diff --git a/pybind11/include/pybind11/operators.h b/pybind11/include/pybind11/operators.h deleted file mode 100644 index b3dd62c..0000000 --- a/pybind11/include/pybind11/operators.h +++ /dev/null @@ -1,168 +0,0 @@ -/* - pybind11/operator.h: Metatemplates for operator overloading - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "pybind11.h" - -#if defined(__clang__) && !defined(__INTEL_COMPILER) -# pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type())) -#elif defined(_MSC_VER) -# pragma warning(push) -# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant -#endif - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -/// Enumeration with all supported operator types -enum op_id : int { - op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift, - op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert, - op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le, - op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift, - op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero, - op_repr, op_truediv, op_itruediv, op_hash -}; - -enum op_type : int { - op_l, /* base type on left */ - op_r, /* base type on right */ - op_u /* unary operator */ -}; - -struct self_t { }; -static const self_t self = self_t(); - -/// Type for an unused type slot -struct undefined_t { }; - -/// Don't warn about an unused variable -inline self_t __self() { return self; } - -/// base template of operator implementations -template struct op_impl { }; - -/// Operator implementation generator -template struct op_ { - template void execute(Class &cl, const Extra&... extra) const { - using Base = typename Class::type; - using L_type = conditional_t::value, Base, L>; - using R_type = conditional_t::value, Base, R>; - using op = op_impl; - cl.def(op::name(), &op::execute, is_operator(), extra...); - #if PY_MAJOR_VERSION < 3 - if (id == op_truediv || id == op_itruediv) - cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", - &op::execute, is_operator(), extra...); - #endif - } - template void execute_cast(Class &cl, const Extra&... extra) const { - using Base = typename Class::type; - using L_type = conditional_t::value, Base, L>; - using R_type = conditional_t::value, Base, R>; - using op = op_impl; - cl.def(op::name(), &op::execute_cast, is_operator(), extra...); - #if PY_MAJOR_VERSION < 3 - if (id == op_truediv || id == op_itruediv) - cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__", - &op::execute, is_operator(), extra...); - #endif - } -}; - -#define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \ -template struct op_impl { \ - static char const* name() { return "__" #id "__"; } \ - static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \ - static B execute_cast(const L &l, const R &r) { return B(expr); } \ -}; \ -template struct op_impl { \ - static char const* name() { return "__" #rid "__"; } \ - static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \ - static B execute_cast(const R &r, const L &l) { return B(expr); } \ -}; \ -inline op_ op(const self_t &, const self_t &) { \ - return op_(); \ -} \ -template op_ op(const self_t &, const T &) { \ - return op_(); \ -} \ -template op_ op(const T &, const self_t &) { \ - return op_(); \ -} - -#define PYBIND11_INPLACE_OPERATOR(id, op, expr) \ -template struct op_impl { \ - static char const* name() { return "__" #id "__"; } \ - static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \ - static B execute_cast(L &l, const R &r) { return B(expr); } \ -}; \ -template op_ op(const self_t &, const T &) { \ - return op_(); \ -} - -#define PYBIND11_UNARY_OPERATOR(id, op, expr) \ -template struct op_impl { \ - static char const* name() { return "__" #id "__"; } \ - static auto execute(const L &l) -> decltype(expr) { return expr; } \ - static B execute_cast(const L &l) { return B(expr); } \ -}; \ -inline op_ op(const self_t &) { \ - return op_(); \ -} - -PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r) -PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r) -PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r) -PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r) -PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r) -PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r) -PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r) -PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r) -PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r) -PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r) -PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r) -PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r) -PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r) -PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r) -PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r) -PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r) -//PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r)) -PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r) -PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r) -PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r) -PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r) -PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r) -PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r) -PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r) -PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r) -PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r) -PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r) -PYBIND11_UNARY_OPERATOR(neg, operator-, -l) -PYBIND11_UNARY_OPERATOR(pos, operator+, +l) -PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l)) -PYBIND11_UNARY_OPERATOR(hash, hash, std::hash()(l)) -PYBIND11_UNARY_OPERATOR(invert, operator~, (~l)) -PYBIND11_UNARY_OPERATOR(bool, operator!, !!l) -PYBIND11_UNARY_OPERATOR(int, int_, (int) l) -PYBIND11_UNARY_OPERATOR(float, float_, (double) l) - -#undef PYBIND11_BINARY_OPERATOR -#undef PYBIND11_INPLACE_OPERATOR -#undef PYBIND11_UNARY_OPERATOR -NAMESPACE_END(detail) - -using detail::self; - -NAMESPACE_END(PYBIND11_NAMESPACE) - -#if defined(_MSC_VER) -# pragma warning(pop) -#endif diff --git a/pybind11/include/pybind11/options.h b/pybind11/include/pybind11/options.h deleted file mode 100644 index cc1e1f6..0000000 --- a/pybind11/include/pybind11/options.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - pybind11/options.h: global settings that are configurable at runtime. - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "detail/common.h" - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) - -class options { -public: - - // Default RAII constructor, which leaves settings as they currently are. - options() : previous_state(global_state()) {} - - // Class is non-copyable. - options(const options&) = delete; - options& operator=(const options&) = delete; - - // Destructor, which restores settings that were in effect before. - ~options() { - global_state() = previous_state; - } - - // Setter methods (affect the global state): - - options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; } - - options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; } - - options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; } - - options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; } - - // Getter methods (return the global state): - - static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; } - - static bool show_function_signatures() { return global_state().show_function_signatures; } - - // This type is not meant to be allocated on the heap. - void* operator new(size_t) = delete; - -private: - - struct state { - bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings. - bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings. - }; - - static state &global_state() { - static state instance; - return instance; - } - - state previous_state; -}; - -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/pybind11.h b/pybind11/include/pybind11/pybind11.h deleted file mode 100644 index c623705..0000000 --- a/pybind11/include/pybind11/pybind11.h +++ /dev/null @@ -1,2176 +0,0 @@ -/* - pybind11/pybind11.h: Main header file of the C++11 python - binding generator library - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#if defined(__INTEL_COMPILER) -# pragma warning push -# pragma warning disable 68 // integer conversion resulted in a change of sign -# pragma warning disable 186 // pointless comparison of unsigned integer with zero -# pragma warning disable 878 // incompatible exception specifications -# pragma warning disable 1334 // the "template" keyword used for syntactic disambiguation may only be used within a template -# pragma warning disable 1682 // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem) -# pragma warning disable 1786 // function "strdup" was declared deprecated -# pragma warning disable 1875 // offsetof applied to non-POD (Plain Old Data) types is nonstandard -# pragma warning disable 2196 // warning #2196: routine is both "inline" and "noinline" -#elif defined(_MSC_VER) -# pragma warning(push) -# pragma warning(disable: 4100) // warning C4100: Unreferenced formal parameter -# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant -# pragma warning(disable: 4512) // warning C4512: Assignment operator was implicitly defined as deleted -# pragma warning(disable: 4800) // warning C4800: 'int': forcing value to bool 'true' or 'false' (performance warning) -# pragma warning(disable: 4996) // warning C4996: The POSIX name for this item is deprecated. Instead, use the ISO C and C++ conformant name -# pragma warning(disable: 4702) // warning C4702: unreachable code -# pragma warning(disable: 4522) // warning C4522: multiple assignment operators specified -#elif defined(__GNUG__) && !defined(__clang__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wunused-but-set-parameter" -# pragma GCC diagnostic ignored "-Wunused-but-set-variable" -# pragma GCC diagnostic ignored "-Wmissing-field-initializers" -# pragma GCC diagnostic ignored "-Wstrict-aliasing" -# pragma GCC diagnostic ignored "-Wattributes" -# if __GNUC__ >= 7 -# pragma GCC diagnostic ignored "-Wnoexcept-type" -# endif -#endif - -#include "attr.h" -#include "options.h" -#include "detail/class.h" -#include "detail/init.h" - -#if defined(__GNUG__) && !defined(__clang__) -# include -#endif - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) - -/// Wraps an arbitrary C++ function/method/lambda function/.. into a callable Python object -class cpp_function : public function { -public: - cpp_function() { } - cpp_function(std::nullptr_t) { } - - /// Construct a cpp_function from a vanilla function pointer - template - cpp_function(Return (*f)(Args...), const Extra&... extra) { - initialize(f, f, extra...); - } - - /// Construct a cpp_function from a lambda function (possibly with internal state) - template ::value>> - cpp_function(Func &&f, const Extra&... extra) { - initialize(std::forward(f), - (detail::function_signature_t *) nullptr, extra...); - } - - /// Construct a cpp_function from a class method (non-const) - template - cpp_function(Return (Class::*f)(Arg...), const Extra&... extra) { - initialize([f](Class *c, Arg... args) -> Return { return (c->*f)(args...); }, - (Return (*) (Class *, Arg...)) nullptr, extra...); - } - - /// Construct a cpp_function from a class method (const) - template - cpp_function(Return (Class::*f)(Arg...) const, const Extra&... extra) { - initialize([f](const Class *c, Arg... args) -> Return { return (c->*f)(args...); }, - (Return (*)(const Class *, Arg ...)) nullptr, extra...); - } - - /// Return the function name - object name() const { return attr("__name__"); } - -protected: - /// Space optimization: don't inline this frequently instantiated fragment - PYBIND11_NOINLINE detail::function_record *make_function_record() { - return new detail::function_record(); - } - - /// Special internal constructor for functors, lambda functions, etc. - template - void initialize(Func &&f, Return (*)(Args...), const Extra&... extra) { - using namespace detail; - struct capture { remove_reference_t f; }; - - /* Store the function including any extra state it might have (e.g. a lambda capture object) */ - auto rec = make_function_record(); - - /* Store the capture object directly in the function record if there is enough space */ - if (sizeof(capture) <= sizeof(rec->data)) { - /* Without these pragmas, GCC warns that there might not be - enough space to use the placement new operator. However, the - 'if' statement above ensures that this is the case. */ -#if defined(__GNUG__) && !defined(__clang__) && __GNUC__ >= 6 -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wplacement-new" -#endif - new ((capture *) &rec->data) capture { std::forward(f) }; -#if defined(__GNUG__) && !defined(__clang__) && __GNUC__ >= 6 -# pragma GCC diagnostic pop -#endif - if (!std::is_trivially_destructible::value) - rec->free_data = [](function_record *r) { ((capture *) &r->data)->~capture(); }; - } else { - rec->data[0] = new capture { std::forward(f) }; - rec->free_data = [](function_record *r) { delete ((capture *) r->data[0]); }; - } - - /* Type casters for the function arguments and return value */ - using cast_in = argument_loader; - using cast_out = make_caster< - conditional_t::value, void_type, Return> - >; - - static_assert(expected_num_args(sizeof...(Args), cast_in::has_args, cast_in::has_kwargs), - "The number of argument annotations does not match the number of function arguments"); - - /* Dispatch code which converts function arguments and performs the actual function call */ - rec->impl = [](function_call &call) -> handle { - cast_in args_converter; - - /* Try to cast the function arguments into the C++ domain */ - if (!args_converter.load_args(call)) - return PYBIND11_TRY_NEXT_OVERLOAD; - - /* Invoke call policy pre-call hook */ - process_attributes::precall(call); - - /* Get a pointer to the capture object */ - auto data = (sizeof(capture) <= sizeof(call.func.data) - ? &call.func.data : call.func.data[0]); - capture *cap = const_cast(reinterpret_cast(data)); - - /* Override policy for rvalues -- usually to enforce rvp::move on an rvalue */ - return_value_policy policy = return_value_policy_override::policy(call.func.policy); - - /* Function scope guard -- defaults to the compile-to-nothing `void_type` */ - using Guard = extract_guard_t; - - /* Perform the function call */ - handle result = cast_out::cast( - std::move(args_converter).template call(cap->f), policy, call.parent); - - /* Invoke call policy post-call hook */ - process_attributes::postcall(call, result); - - return result; - }; - - /* Process any user-provided function attributes */ - process_attributes::init(extra..., rec); - - /* Generate a readable signature describing the function's arguments and return value types */ - static constexpr auto signature = _("(") + cast_in::arg_names + _(") -> ") + cast_out::name; - PYBIND11_DESCR_CONSTEXPR auto types = decltype(signature)::types(); - - /* Register the function with Python from generic (non-templated) code */ - initialize_generic(rec, signature.text, types.data(), sizeof...(Args)); - - if (cast_in::has_args) rec->has_args = true; - if (cast_in::has_kwargs) rec->has_kwargs = true; - - /* Stash some additional information used by an important optimization in 'functional.h' */ - using FunctionType = Return (*)(Args...); - constexpr bool is_function_ptr = - std::is_convertible::value && - sizeof(capture) == sizeof(void *); - if (is_function_ptr) { - rec->is_stateless = true; - rec->data[1] = const_cast(reinterpret_cast(&typeid(FunctionType))); - } - } - - /// Register a function call with Python (generic non-templated code goes here) - void initialize_generic(detail::function_record *rec, const char *text, - const std::type_info *const *types, size_t args) { - - /* Create copies of all referenced C-style strings */ - rec->name = strdup(rec->name ? rec->name : ""); - if (rec->doc) rec->doc = strdup(rec->doc); - for (auto &a: rec->args) { - if (a.name) - a.name = strdup(a.name); - if (a.descr) - a.descr = strdup(a.descr); - else if (a.value) - a.descr = strdup(a.value.attr("__repr__")().cast().c_str()); - } - - rec->is_constructor = !strcmp(rec->name, "__init__") || !strcmp(rec->name, "__setstate__"); - -#if !defined(NDEBUG) && !defined(PYBIND11_DISABLE_NEW_STYLE_INIT_WARNING) - if (rec->is_constructor && !rec->is_new_style_constructor) { - const auto class_name = std::string(((PyTypeObject *) rec->scope.ptr())->tp_name); - const auto func_name = std::string(rec->name); - PyErr_WarnEx( - PyExc_FutureWarning, - ("pybind11-bound class '" + class_name + "' is using an old-style " - "placement-new '" + func_name + "' which has been deprecated. See " - "the upgrade guide in pybind11's docs. This message is only visible " - "when compiled in debug mode.").c_str(), 0 - ); - } -#endif - - /* Generate a proper function signature */ - std::string signature; - size_t type_index = 0, arg_index = 0; - for (auto *pc = text; *pc != '\0'; ++pc) { - const auto c = *pc; - - if (c == '{') { - // Write arg name for everything except *args and **kwargs. - if (*(pc + 1) == '*') - continue; - - if (arg_index < rec->args.size() && rec->args[arg_index].name) { - signature += rec->args[arg_index].name; - } else if (arg_index == 0 && rec->is_method) { - signature += "self"; - } else { - signature += "arg" + std::to_string(arg_index - (rec->is_method ? 1 : 0)); - } - signature += ": "; - } else if (c == '}') { - // Write default value if available. - if (arg_index < rec->args.size() && rec->args[arg_index].descr) { - signature += " = "; - signature += rec->args[arg_index].descr; - } - arg_index++; - } else if (c == '%') { - const std::type_info *t = types[type_index++]; - if (!t) - pybind11_fail("Internal error while parsing type signature (1)"); - if (auto tinfo = detail::get_type_info(*t)) { - handle th((PyObject *) tinfo->type); - signature += - th.attr("__module__").cast() + "." + - th.attr("__qualname__").cast(); // Python 3.3+, but we backport it to earlier versions - } else if (rec->is_new_style_constructor && arg_index == 0) { - // A new-style `__init__` takes `self` as `value_and_holder`. - // Rewrite it to the proper class type. - signature += - rec->scope.attr("__module__").cast() + "." + - rec->scope.attr("__qualname__").cast(); - } else { - std::string tname(t->name()); - detail::clean_type_id(tname); - signature += tname; - } - } else { - signature += c; - } - } - if (arg_index != args || types[type_index] != nullptr) - pybind11_fail("Internal error while parsing type signature (2)"); - -#if PY_MAJOR_VERSION < 3 - if (strcmp(rec->name, "__next__") == 0) { - std::free(rec->name); - rec->name = strdup("next"); - } else if (strcmp(rec->name, "__bool__") == 0) { - std::free(rec->name); - rec->name = strdup("__nonzero__"); - } -#endif - rec->signature = strdup(signature.c_str()); - rec->args.shrink_to_fit(); - rec->nargs = (std::uint16_t) args; - - if (rec->sibling && PYBIND11_INSTANCE_METHOD_CHECK(rec->sibling.ptr())) - rec->sibling = PYBIND11_INSTANCE_METHOD_GET_FUNCTION(rec->sibling.ptr()); - - detail::function_record *chain = nullptr, *chain_start = rec; - if (rec->sibling) { - if (PyCFunction_Check(rec->sibling.ptr())) { - auto rec_capsule = reinterpret_borrow(PyCFunction_GET_SELF(rec->sibling.ptr())); - chain = (detail::function_record *) rec_capsule; - /* Never append a method to an overload chain of a parent class; - instead, hide the parent's overloads in this case */ - if (!chain->scope.is(rec->scope)) - chain = nullptr; - } - // Don't trigger for things like the default __init__, which are wrapper_descriptors that we are intentionally replacing - else if (!rec->sibling.is_none() && rec->name[0] != '_') - pybind11_fail("Cannot overload existing non-function object \"" + std::string(rec->name) + - "\" with a function of the same name"); - } - - if (!chain) { - /* No existing overload was found, create a new function object */ - rec->def = new PyMethodDef(); - std::memset(rec->def, 0, sizeof(PyMethodDef)); - rec->def->ml_name = rec->name; - rec->def->ml_meth = reinterpret_cast(reinterpret_cast(*dispatcher)); - rec->def->ml_flags = METH_VARARGS | METH_KEYWORDS; - - capsule rec_capsule(rec, [](void *ptr) { - destruct((detail::function_record *) ptr); - }); - - object scope_module; - if (rec->scope) { - if (hasattr(rec->scope, "__module__")) { - scope_module = rec->scope.attr("__module__"); - } else if (hasattr(rec->scope, "__name__")) { - scope_module = rec->scope.attr("__name__"); - } - } - - m_ptr = PyCFunction_NewEx(rec->def, rec_capsule.ptr(), scope_module.ptr()); - if (!m_ptr) - pybind11_fail("cpp_function::cpp_function(): Could not allocate function object"); - } else { - /* Append at the end of the overload chain */ - m_ptr = rec->sibling.ptr(); - inc_ref(); - chain_start = chain; - if (chain->is_method != rec->is_method) - pybind11_fail("overloading a method with both static and instance methods is not supported; " - #if defined(NDEBUG) - "compile in debug mode for more details" - #else - "error while attempting to bind " + std::string(rec->is_method ? "instance" : "static") + " method " + - std::string(pybind11::str(rec->scope.attr("__name__"))) + "." + std::string(rec->name) + signature - #endif - ); - while (chain->next) - chain = chain->next; - chain->next = rec; - } - - std::string signatures; - int index = 0; - /* Create a nice pydoc rec including all signatures and - docstrings of the functions in the overload chain */ - if (chain && options::show_function_signatures()) { - // First a generic signature - signatures += rec->name; - signatures += "(*args, **kwargs)\n"; - signatures += "Overloaded function.\n\n"; - } - // Then specific overload signatures - bool first_user_def = true; - for (auto it = chain_start; it != nullptr; it = it->next) { - if (options::show_function_signatures()) { - if (index > 0) signatures += "\n"; - if (chain) - signatures += std::to_string(++index) + ". "; - signatures += rec->name; - signatures += it->signature; - signatures += "\n"; - } - if (it->doc && strlen(it->doc) > 0 && options::show_user_defined_docstrings()) { - // If we're appending another docstring, and aren't printing function signatures, we - // need to append a newline first: - if (!options::show_function_signatures()) { - if (first_user_def) first_user_def = false; - else signatures += "\n"; - } - if (options::show_function_signatures()) signatures += "\n"; - signatures += it->doc; - if (options::show_function_signatures()) signatures += "\n"; - } - } - - /* Install docstring */ - PyCFunctionObject *func = (PyCFunctionObject *) m_ptr; - if (func->m_ml->ml_doc) - std::free(const_cast(func->m_ml->ml_doc)); - func->m_ml->ml_doc = strdup(signatures.c_str()); - - if (rec->is_method) { - m_ptr = PYBIND11_INSTANCE_METHOD_NEW(m_ptr, rec->scope.ptr()); - if (!m_ptr) - pybind11_fail("cpp_function::cpp_function(): Could not allocate instance method object"); - Py_DECREF(func); - } - } - - /// When a cpp_function is GCed, release any memory allocated by pybind11 - static void destruct(detail::function_record *rec) { - while (rec) { - detail::function_record *next = rec->next; - if (rec->free_data) - rec->free_data(rec); - std::free((char *) rec->name); - std::free((char *) rec->doc); - std::free((char *) rec->signature); - for (auto &arg: rec->args) { - std::free(const_cast(arg.name)); - std::free(const_cast(arg.descr)); - arg.value.dec_ref(); - } - if (rec->def) { - std::free(const_cast(rec->def->ml_doc)); - delete rec->def; - } - delete rec; - rec = next; - } - } - - /// Main dispatch logic for calls to functions bound using pybind11 - static PyObject *dispatcher(PyObject *self, PyObject *args_in, PyObject *kwargs_in) { - using namespace detail; - - /* Iterator over the list of potentially admissible overloads */ - const function_record *overloads = (function_record *) PyCapsule_GetPointer(self, nullptr), - *it = overloads; - - /* Need to know how many arguments + keyword arguments there are to pick the right overload */ - const size_t n_args_in = (size_t) PyTuple_GET_SIZE(args_in); - - handle parent = n_args_in > 0 ? PyTuple_GET_ITEM(args_in, 0) : nullptr, - result = PYBIND11_TRY_NEXT_OVERLOAD; - - auto self_value_and_holder = value_and_holder(); - if (overloads->is_constructor) { - const auto tinfo = get_type_info((PyTypeObject *) overloads->scope.ptr()); - const auto pi = reinterpret_cast(parent.ptr()); - self_value_and_holder = pi->get_value_and_holder(tinfo, false); - - if (!self_value_and_holder.type || !self_value_and_holder.inst) { - PyErr_SetString(PyExc_TypeError, "__init__(self, ...) called with invalid `self` argument"); - return nullptr; - } - - // If this value is already registered it must mean __init__ is invoked multiple times; - // we really can't support that in C++, so just ignore the second __init__. - if (self_value_and_holder.instance_registered()) - return none().release().ptr(); - } - - try { - // We do this in two passes: in the first pass, we load arguments with `convert=false`; - // in the second, we allow conversion (except for arguments with an explicit - // py::arg().noconvert()). This lets us prefer calls without conversion, with - // conversion as a fallback. - std::vector second_pass; - - // However, if there are no overloads, we can just skip the no-convert pass entirely - const bool overloaded = it != nullptr && it->next != nullptr; - - for (; it != nullptr; it = it->next) { - - /* For each overload: - 1. Copy all positional arguments we were given, also checking to make sure that - named positional arguments weren't *also* specified via kwarg. - 2. If we weren't given enough, try to make up the omitted ones by checking - whether they were provided by a kwarg matching the `py::arg("name")` name. If - so, use it (and remove it from kwargs; if not, see if the function binding - provided a default that we can use. - 3. Ensure that either all keyword arguments were "consumed", or that the function - takes a kwargs argument to accept unconsumed kwargs. - 4. Any positional arguments still left get put into a tuple (for args), and any - leftover kwargs get put into a dict. - 5. Pack everything into a vector; if we have py::args or py::kwargs, they are an - extra tuple or dict at the end of the positional arguments. - 6. Call the function call dispatcher (function_record::impl) - - If one of these fail, move on to the next overload and keep trying until we get a - result other than PYBIND11_TRY_NEXT_OVERLOAD. - */ - - const function_record &func = *it; - size_t pos_args = func.nargs; // Number of positional arguments that we need - if (func.has_args) --pos_args; // (but don't count py::args - if (func.has_kwargs) --pos_args; // or py::kwargs) - - if (!func.has_args && n_args_in > pos_args) - continue; // Too many arguments for this overload - - if (n_args_in < pos_args && func.args.size() < pos_args) - continue; // Not enough arguments given, and not enough defaults to fill in the blanks - - function_call call(func, parent); - - size_t args_to_copy = (std::min)(pos_args, n_args_in); // Protect std::min with parentheses - size_t args_copied = 0; - - // 0. Inject new-style `self` argument - if (func.is_new_style_constructor) { - // The `value` may have been preallocated by an old-style `__init__` - // if it was a preceding candidate for overload resolution. - if (self_value_and_holder) - self_value_and_holder.type->dealloc(self_value_and_holder); - - call.init_self = PyTuple_GET_ITEM(args_in, 0); - call.args.push_back(reinterpret_cast(&self_value_and_holder)); - call.args_convert.push_back(false); - ++args_copied; - } - - // 1. Copy any position arguments given. - bool bad_arg = false; - for (; args_copied < args_to_copy; ++args_copied) { - const argument_record *arg_rec = args_copied < func.args.size() ? &func.args[args_copied] : nullptr; - if (kwargs_in && arg_rec && arg_rec->name && PyDict_GetItemString(kwargs_in, arg_rec->name)) { - bad_arg = true; - break; - } - - handle arg(PyTuple_GET_ITEM(args_in, args_copied)); - if (arg_rec && !arg_rec->none && arg.is_none()) { - bad_arg = true; - break; - } - call.args.push_back(arg); - call.args_convert.push_back(arg_rec ? arg_rec->convert : true); - } - if (bad_arg) - continue; // Maybe it was meant for another overload (issue #688) - - // We'll need to copy this if we steal some kwargs for defaults - dict kwargs = reinterpret_borrow(kwargs_in); - - // 2. Check kwargs and, failing that, defaults that may help complete the list - if (args_copied < pos_args) { - bool copied_kwargs = false; - - for (; args_copied < pos_args; ++args_copied) { - const auto &arg = func.args[args_copied]; - - handle value; - if (kwargs_in && arg.name) - value = PyDict_GetItemString(kwargs.ptr(), arg.name); - - if (value) { - // Consume a kwargs value - if (!copied_kwargs) { - kwargs = reinterpret_steal(PyDict_Copy(kwargs.ptr())); - copied_kwargs = true; - } - PyDict_DelItemString(kwargs.ptr(), arg.name); - } else if (arg.value) { - value = arg.value; - } - - if (value) { - call.args.push_back(value); - call.args_convert.push_back(arg.convert); - } - else - break; - } - - if (args_copied < pos_args) - continue; // Not enough arguments, defaults, or kwargs to fill the positional arguments - } - - // 3. Check everything was consumed (unless we have a kwargs arg) - if (kwargs && kwargs.size() > 0 && !func.has_kwargs) - continue; // Unconsumed kwargs, but no py::kwargs argument to accept them - - // 4a. If we have a py::args argument, create a new tuple with leftovers - if (func.has_args) { - tuple extra_args; - if (args_to_copy == 0) { - // We didn't copy out any position arguments from the args_in tuple, so we - // can reuse it directly without copying: - extra_args = reinterpret_borrow(args_in); - } else if (args_copied >= n_args_in) { - extra_args = tuple(0); - } else { - size_t args_size = n_args_in - args_copied; - extra_args = tuple(args_size); - for (size_t i = 0; i < args_size; ++i) { - extra_args[i] = PyTuple_GET_ITEM(args_in, args_copied + i); - } - } - call.args.push_back(extra_args); - call.args_convert.push_back(false); - call.args_ref = std::move(extra_args); - } - - // 4b. If we have a py::kwargs, pass on any remaining kwargs - if (func.has_kwargs) { - if (!kwargs.ptr()) - kwargs = dict(); // If we didn't get one, send an empty one - call.args.push_back(kwargs); - call.args_convert.push_back(false); - call.kwargs_ref = std::move(kwargs); - } - - // 5. Put everything in a vector. Not technically step 5, we've been building it - // in `call.args` all along. - #if !defined(NDEBUG) - if (call.args.size() != func.nargs || call.args_convert.size() != func.nargs) - pybind11_fail("Internal error: function call dispatcher inserted wrong number of arguments!"); - #endif - - std::vector second_pass_convert; - if (overloaded) { - // We're in the first no-convert pass, so swap out the conversion flags for a - // set of all-false flags. If the call fails, we'll swap the flags back in for - // the conversion-allowed call below. - second_pass_convert.resize(func.nargs, false); - call.args_convert.swap(second_pass_convert); - } - - // 6. Call the function. - try { - loader_life_support guard{}; - result = func.impl(call); - } catch (reference_cast_error &) { - result = PYBIND11_TRY_NEXT_OVERLOAD; - } - - if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) - break; - - if (overloaded) { - // The (overloaded) call failed; if the call has at least one argument that - // permits conversion (i.e. it hasn't been explicitly specified `.noconvert()`) - // then add this call to the list of second pass overloads to try. - for (size_t i = func.is_method ? 1 : 0; i < pos_args; i++) { - if (second_pass_convert[i]) { - // Found one: swap the converting flags back in and store the call for - // the second pass. - call.args_convert.swap(second_pass_convert); - second_pass.push_back(std::move(call)); - break; - } - } - } - } - - if (overloaded && !second_pass.empty() && result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) { - // The no-conversion pass finished without success, try again with conversion allowed - for (auto &call : second_pass) { - try { - loader_life_support guard{}; - result = call.func.impl(call); - } catch (reference_cast_error &) { - result = PYBIND11_TRY_NEXT_OVERLOAD; - } - - if (result.ptr() != PYBIND11_TRY_NEXT_OVERLOAD) { - // The error reporting logic below expects 'it' to be valid, as it would be - // if we'd encountered this failure in the first-pass loop. - if (!result) - it = &call.func; - break; - } - } - } - } catch (error_already_set &e) { - e.restore(); - return nullptr; -#if defined(__GNUG__) && !defined(__clang__) - } catch ( abi::__forced_unwind& ) { - throw; -#endif - } catch (...) { - /* When an exception is caught, give each registered exception - translator a chance to translate it to a Python exception - in reverse order of registration. - - A translator may choose to do one of the following: - - - catch the exception and call PyErr_SetString or PyErr_SetObject - to set a standard (or custom) Python exception, or - - do nothing and let the exception fall through to the next translator, or - - delegate translation to the next translator by throwing a new type of exception. */ - - auto last_exception = std::current_exception(); - auto ®istered_exception_translators = get_internals().registered_exception_translators; - for (auto& translator : registered_exception_translators) { - try { - translator(last_exception); - } catch (...) { - last_exception = std::current_exception(); - continue; - } - return nullptr; - } - PyErr_SetString(PyExc_SystemError, "Exception escaped from default exception translator!"); - return nullptr; - } - - auto append_note_if_missing_header_is_suspected = [](std::string &msg) { - if (msg.find("std::") != std::string::npos) { - msg += "\n\n" - "Did you forget to `#include `? Or ,\n" - ", , etc. Some automatic\n" - "conversions are optional and require extra headers to be included\n" - "when compiling your pybind11 module."; - } - }; - - if (result.ptr() == PYBIND11_TRY_NEXT_OVERLOAD) { - if (overloads->is_operator) - return handle(Py_NotImplemented).inc_ref().ptr(); - - std::string msg = std::string(overloads->name) + "(): incompatible " + - std::string(overloads->is_constructor ? "constructor" : "function") + - " arguments. The following argument types are supported:\n"; - - int ctr = 0; - for (const function_record *it2 = overloads; it2 != nullptr; it2 = it2->next) { - msg += " "+ std::to_string(++ctr) + ". "; - - bool wrote_sig = false; - if (overloads->is_constructor) { - // For a constructor, rewrite `(self: Object, arg0, ...) -> NoneType` as `Object(arg0, ...)` - std::string sig = it2->signature; - size_t start = sig.find('(') + 7; // skip "(self: " - if (start < sig.size()) { - // End at the , for the next argument - size_t end = sig.find(", "), next = end + 2; - size_t ret = sig.rfind(" -> "); - // Or the ), if there is no comma: - if (end >= sig.size()) next = end = sig.find(')'); - if (start < end && next < sig.size()) { - msg.append(sig, start, end - start); - msg += '('; - msg.append(sig, next, ret - next); - wrote_sig = true; - } - } - } - if (!wrote_sig) msg += it2->signature; - - msg += "\n"; - } - msg += "\nInvoked with: "; - auto args_ = reinterpret_borrow(args_in); - bool some_args = false; - for (size_t ti = overloads->is_constructor ? 1 : 0; ti < args_.size(); ++ti) { - if (!some_args) some_args = true; - else msg += ", "; - msg += pybind11::repr(args_[ti]); - } - if (kwargs_in) { - auto kwargs = reinterpret_borrow(kwargs_in); - if (kwargs.size() > 0) { - if (some_args) msg += "; "; - msg += "kwargs: "; - bool first = true; - for (auto kwarg : kwargs) { - if (first) first = false; - else msg += ", "; - msg += pybind11::str("{}={!r}").format(kwarg.first, kwarg.second); - } - } - } - - append_note_if_missing_header_is_suspected(msg); - PyErr_SetString(PyExc_TypeError, msg.c_str()); - return nullptr; - } else if (!result) { - std::string msg = "Unable to convert function return value to a " - "Python type! The signature was\n\t"; - msg += it->signature; - append_note_if_missing_header_is_suspected(msg); - PyErr_SetString(PyExc_TypeError, msg.c_str()); - return nullptr; - } else { - if (overloads->is_constructor && !self_value_and_holder.holder_constructed()) { - auto *pi = reinterpret_cast(parent.ptr()); - self_value_and_holder.type->init_instance(pi, nullptr); - } - return result.ptr(); - } - } -}; - -/// Wrapper for Python extension modules -class module : public object { -public: - PYBIND11_OBJECT_DEFAULT(module, object, PyModule_Check) - - /// Create a new top-level Python module with the given name and docstring - explicit module(const char *name, const char *doc = nullptr) { - if (!options::show_user_defined_docstrings()) doc = nullptr; -#if PY_MAJOR_VERSION >= 3 - PyModuleDef *def = new PyModuleDef(); - std::memset(def, 0, sizeof(PyModuleDef)); - def->m_name = name; - def->m_doc = doc; - def->m_size = -1; - Py_INCREF(def); - m_ptr = PyModule_Create(def); -#else - m_ptr = Py_InitModule3(name, nullptr, doc); -#endif - if (m_ptr == nullptr) - pybind11_fail("Internal error in module::module()"); - inc_ref(); - } - - /** \rst - Create Python binding for a new function within the module scope. ``Func`` - can be a plain C++ function, a function pointer, or a lambda function. For - details on the ``Extra&& ... extra`` argument, see section :ref:`extras`. - \endrst */ - template - module &def(const char *name_, Func &&f, const Extra& ... extra) { - cpp_function func(std::forward(f), name(name_), scope(*this), - sibling(getattr(*this, name_, none())), extra...); - // NB: allow overwriting here because cpp_function sets up a chain with the intention of - // overwriting (and has already checked internally that it isn't overwriting non-functions). - add_object(name_, func, true /* overwrite */); - return *this; - } - - /** \rst - Create and return a new Python submodule with the given name and docstring. - This also works recursively, i.e. - - .. code-block:: cpp - - py::module m("example", "pybind11 example plugin"); - py::module m2 = m.def_submodule("sub", "A submodule of 'example'"); - py::module m3 = m2.def_submodule("subsub", "A submodule of 'example.sub'"); - \endrst */ - module def_submodule(const char *name, const char *doc = nullptr) { - std::string full_name = std::string(PyModule_GetName(m_ptr)) - + std::string(".") + std::string(name); - auto result = reinterpret_borrow(PyImport_AddModule(full_name.c_str())); - if (doc && options::show_user_defined_docstrings()) - result.attr("__doc__") = pybind11::str(doc); - attr(name) = result; - return result; - } - - /// Import and return a module or throws `error_already_set`. - static module import(const char *name) { - PyObject *obj = PyImport_ImportModule(name); - if (!obj) - throw error_already_set(); - return reinterpret_steal(obj); - } - - /// Reload the module or throws `error_already_set`. - void reload() { - PyObject *obj = PyImport_ReloadModule(ptr()); - if (!obj) - throw error_already_set(); - *this = reinterpret_steal(obj); - } - - // Adds an object to the module using the given name. Throws if an object with the given name - // already exists. - // - // overwrite should almost always be false: attempting to overwrite objects that pybind11 has - // established will, in most cases, break things. - PYBIND11_NOINLINE void add_object(const char *name, handle obj, bool overwrite = false) { - if (!overwrite && hasattr(*this, name)) - pybind11_fail("Error during initialization: multiple incompatible definitions with name \"" + - std::string(name) + "\""); - - PyModule_AddObject(ptr(), name, obj.inc_ref().ptr() /* steals a reference */); - } -}; - -/// \ingroup python_builtins -/// Return a dictionary representing the global variables in the current execution frame, -/// or ``__main__.__dict__`` if there is no frame (usually when the interpreter is embedded). -inline dict globals() { - PyObject *p = PyEval_GetGlobals(); - return reinterpret_borrow(p ? p : module::import("__main__").attr("__dict__").ptr()); -} - -NAMESPACE_BEGIN(detail) -/// Generic support for creating new Python heap types -class generic_type : public object { - template friend class class_; -public: - PYBIND11_OBJECT_DEFAULT(generic_type, object, PyType_Check) -protected: - void initialize(const type_record &rec) { - if (rec.scope && hasattr(rec.scope, rec.name)) - pybind11_fail("generic_type: cannot initialize type \"" + std::string(rec.name) + - "\": an object with that name is already defined"); - - if (rec.module_local ? get_local_type_info(*rec.type) : get_global_type_info(*rec.type)) - pybind11_fail("generic_type: type \"" + std::string(rec.name) + - "\" is already registered!"); - - m_ptr = make_new_python_type(rec); - - /* Register supplemental type information in C++ dict */ - auto *tinfo = new detail::type_info(); - tinfo->type = (PyTypeObject *) m_ptr; - tinfo->cpptype = rec.type; - tinfo->type_size = rec.type_size; - tinfo->type_align = rec.type_align; - tinfo->operator_new = rec.operator_new; - tinfo->holder_size_in_ptrs = size_in_ptrs(rec.holder_size); - tinfo->init_instance = rec.init_instance; - tinfo->dealloc = rec.dealloc; - tinfo->simple_type = true; - tinfo->simple_ancestors = true; - tinfo->default_holder = rec.default_holder; - tinfo->module_local = rec.module_local; - - auto &internals = get_internals(); - auto tindex = std::type_index(*rec.type); - tinfo->direct_conversions = &internals.direct_conversions[tindex]; - if (rec.module_local) - registered_local_types_cpp()[tindex] = tinfo; - else - internals.registered_types_cpp[tindex] = tinfo; - internals.registered_types_py[(PyTypeObject *) m_ptr] = { tinfo }; - - if (rec.bases.size() > 1 || rec.multiple_inheritance) { - mark_parents_nonsimple(tinfo->type); - tinfo->simple_ancestors = false; - } - else if (rec.bases.size() == 1) { - auto parent_tinfo = get_type_info((PyTypeObject *) rec.bases[0].ptr()); - tinfo->simple_ancestors = parent_tinfo->simple_ancestors; - } - - if (rec.module_local) { - // Stash the local typeinfo and loader so that external modules can access it. - tinfo->module_local_load = &type_caster_generic::local_load; - setattr(m_ptr, PYBIND11_MODULE_LOCAL_ID, capsule(tinfo)); - } - } - - /// Helper function which tags all parents of a type using mult. inheritance - void mark_parents_nonsimple(PyTypeObject *value) { - auto t = reinterpret_borrow(value->tp_bases); - for (handle h : t) { - auto tinfo2 = get_type_info((PyTypeObject *) h.ptr()); - if (tinfo2) - tinfo2->simple_type = false; - mark_parents_nonsimple((PyTypeObject *) h.ptr()); - } - } - - void install_buffer_funcs( - buffer_info *(*get_buffer)(PyObject *, void *), - void *get_buffer_data) { - PyHeapTypeObject *type = (PyHeapTypeObject*) m_ptr; - auto tinfo = detail::get_type_info(&type->ht_type); - - if (!type->ht_type.tp_as_buffer) - pybind11_fail( - "To be able to register buffer protocol support for the type '" + - std::string(tinfo->type->tp_name) + - "' the associated class<>(..) invocation must " - "include the pybind11::buffer_protocol() annotation!"); - - tinfo->get_buffer = get_buffer; - tinfo->get_buffer_data = get_buffer_data; - } - - // rec_func must be set for either fget or fset. - void def_property_static_impl(const char *name, - handle fget, handle fset, - detail::function_record *rec_func) { - const auto is_static = rec_func && !(rec_func->is_method && rec_func->scope); - const auto has_doc = rec_func && rec_func->doc && pybind11::options::show_user_defined_docstrings(); - auto property = handle((PyObject *) (is_static ? get_internals().static_property_type - : &PyProperty_Type)); - attr(name) = property(fget.ptr() ? fget : none(), - fset.ptr() ? fset : none(), - /*deleter*/none(), - pybind11::str(has_doc ? rec_func->doc : "")); - } -}; - -/// Set the pointer to operator new if it exists. The cast is needed because it can be overloaded. -template (T::operator new))>> -void set_operator_new(type_record *r) { r->operator_new = &T::operator new; } - -template void set_operator_new(...) { } - -template struct has_operator_delete : std::false_type { }; -template struct has_operator_delete(T::operator delete))>> - : std::true_type { }; -template struct has_operator_delete_size : std::false_type { }; -template struct has_operator_delete_size(T::operator delete))>> - : std::true_type { }; -/// Call class-specific delete if it exists or global otherwise. Can also be an overload set. -template ::value, int> = 0> -void call_operator_delete(T *p, size_t, size_t) { T::operator delete(p); } -template ::value && has_operator_delete_size::value, int> = 0> -void call_operator_delete(T *p, size_t s, size_t) { T::operator delete(p, s); } - -inline void call_operator_delete(void *p, size_t s, size_t a) { - (void)s; (void)a; -#if defined(PYBIND11_CPP17) - if (a > __STDCPP_DEFAULT_NEW_ALIGNMENT__) - ::operator delete(p, s, std::align_val_t(a)); - else - ::operator delete(p, s); -#else - ::operator delete(p); -#endif -} - -NAMESPACE_END(detail) - -/// Given a pointer to a member function, cast it to its `Derived` version. -/// Forward everything else unchanged. -template -auto method_adaptor(F &&f) -> decltype(std::forward(f)) { return std::forward(f); } - -template -auto method_adaptor(Return (Class::*pmf)(Args...)) -> Return (Derived::*)(Args...) { - static_assert(detail::is_accessible_base_of::value, - "Cannot bind an inaccessible base class method; use a lambda definition instead"); - return pmf; -} - -template -auto method_adaptor(Return (Class::*pmf)(Args...) const) -> Return (Derived::*)(Args...) const { - static_assert(detail::is_accessible_base_of::value, - "Cannot bind an inaccessible base class method; use a lambda definition instead"); - return pmf; -} - -template -class class_ : public detail::generic_type { - template using is_holder = detail::is_holder_type; - template using is_subtype = detail::is_strict_base_of; - template using is_base = detail::is_strict_base_of; - // struct instead of using here to help MSVC: - template struct is_valid_class_option : - detail::any_of, is_subtype, is_base> {}; - -public: - using type = type_; - using type_alias = detail::exactly_one_t; - constexpr static bool has_alias = !std::is_void::value; - using holder_type = detail::exactly_one_t, options...>; - - static_assert(detail::all_of...>::value, - "Unknown/invalid class_ template parameters provided"); - - static_assert(!has_alias || std::is_polymorphic::value, - "Cannot use an alias class with a non-polymorphic type"); - - PYBIND11_OBJECT(class_, generic_type, PyType_Check) - - template - class_(handle scope, const char *name, const Extra &... extra) { - using namespace detail; - - // MI can only be specified via class_ template options, not constructor parameters - static_assert( - none_of...>::value || // no base class arguments, or: - ( constexpr_sum(is_pyobject::value...) == 1 && // Exactly one base - constexpr_sum(is_base::value...) == 0 && // no template option bases - none_of...>::value), // no multiple_inheritance attr - "Error: multiple inheritance bases must be specified via class_ template options"); - - type_record record; - record.scope = scope; - record.name = name; - record.type = &typeid(type); - record.type_size = sizeof(conditional_t); - record.type_align = alignof(conditional_t&); - record.holder_size = sizeof(holder_type); - record.init_instance = init_instance; - record.dealloc = dealloc; - record.default_holder = detail::is_instantiation::value; - - set_operator_new(&record); - - /* Register base classes specified via template arguments to class_, if any */ - PYBIND11_EXPAND_SIDE_EFFECTS(add_base(record)); - - /* Process optional arguments, if any */ - process_attributes::init(extra..., &record); - - generic_type::initialize(record); - - if (has_alias) { - auto &instances = record.module_local ? registered_local_types_cpp() : get_internals().registered_types_cpp; - instances[std::type_index(typeid(type_alias))] = instances[std::type_index(typeid(type))]; - } - } - - template ::value, int> = 0> - static void add_base(detail::type_record &rec) { - rec.add_base(typeid(Base), [](void *src) -> void * { - return static_cast(reinterpret_cast(src)); - }); - } - - template ::value, int> = 0> - static void add_base(detail::type_record &) { } - - template - class_ &def(const char *name_, Func&& f, const Extra&... extra) { - cpp_function cf(method_adaptor(std::forward(f)), name(name_), is_method(*this), - sibling(getattr(*this, name_, none())), extra...); - attr(cf.name()) = cf; - return *this; - } - - template class_ & - def_static(const char *name_, Func &&f, const Extra&... extra) { - static_assert(!std::is_member_function_pointer::value, - "def_static(...) called with a non-static member function pointer"); - cpp_function cf(std::forward(f), name(name_), scope(*this), - sibling(getattr(*this, name_, none())), extra...); - attr(cf.name()) = staticmethod(cf); - return *this; - } - - template - class_ &def(const detail::op_ &op, const Extra&... extra) { - op.execute(*this, extra...); - return *this; - } - - template - class_ & def_cast(const detail::op_ &op, const Extra&... extra) { - op.execute_cast(*this, extra...); - return *this; - } - - template - class_ &def(const detail::initimpl::constructor &init, const Extra&... extra) { - init.execute(*this, extra...); - return *this; - } - - template - class_ &def(const detail::initimpl::alias_constructor &init, const Extra&... extra) { - init.execute(*this, extra...); - return *this; - } - - template - class_ &def(detail::initimpl::factory &&init, const Extra&... extra) { - std::move(init).execute(*this, extra...); - return *this; - } - - template - class_ &def(detail::initimpl::pickle_factory &&pf, const Extra &...extra) { - std::move(pf).execute(*this, extra...); - return *this; - } - - template class_& def_buffer(Func &&func) { - struct capture { Func func; }; - capture *ptr = new capture { std::forward(func) }; - install_buffer_funcs([](PyObject *obj, void *ptr) -> buffer_info* { - detail::make_caster caster; - if (!caster.load(obj, false)) - return nullptr; - return new buffer_info(((capture *) ptr)->func(caster)); - }, ptr); - return *this; - } - - template - class_ &def_buffer(Return (Class::*func)(Args...)) { - return def_buffer([func] (type &obj) { return (obj.*func)(); }); - } - - template - class_ &def_buffer(Return (Class::*func)(Args...) const) { - return def_buffer([func] (const type &obj) { return (obj.*func)(); }); - } - - template - class_ &def_readwrite(const char *name, D C::*pm, const Extra&... extra) { - static_assert(std::is_same::value || std::is_base_of::value, "def_readwrite() requires a class member (or base class member)"); - cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)), - fset([pm](type &c, const D &value) { c.*pm = value; }, is_method(*this)); - def_property(name, fget, fset, return_value_policy::reference_internal, extra...); - return *this; - } - - template - class_ &def_readonly(const char *name, const D C::*pm, const Extra& ...extra) { - static_assert(std::is_same::value || std::is_base_of::value, "def_readonly() requires a class member (or base class member)"); - cpp_function fget([pm](const type &c) -> const D &{ return c.*pm; }, is_method(*this)); - def_property_readonly(name, fget, return_value_policy::reference_internal, extra...); - return *this; - } - - template - class_ &def_readwrite_static(const char *name, D *pm, const Extra& ...extra) { - cpp_function fget([pm](object) -> const D &{ return *pm; }, scope(*this)), - fset([pm](object, const D &value) { *pm = value; }, scope(*this)); - def_property_static(name, fget, fset, return_value_policy::reference, extra...); - return *this; - } - - template - class_ &def_readonly_static(const char *name, const D *pm, const Extra& ...extra) { - cpp_function fget([pm](object) -> const D &{ return *pm; }, scope(*this)); - def_property_readonly_static(name, fget, return_value_policy::reference, extra...); - return *this; - } - - /// Uses return_value_policy::reference_internal by default - template - class_ &def_property_readonly(const char *name, const Getter &fget, const Extra& ...extra) { - return def_property_readonly(name, cpp_function(method_adaptor(fget)), - return_value_policy::reference_internal, extra...); - } - - /// Uses cpp_function's return_value_policy by default - template - class_ &def_property_readonly(const char *name, const cpp_function &fget, const Extra& ...extra) { - return def_property(name, fget, nullptr, extra...); - } - - /// Uses return_value_policy::reference by default - template - class_ &def_property_readonly_static(const char *name, const Getter &fget, const Extra& ...extra) { - return def_property_readonly_static(name, cpp_function(fget), return_value_policy::reference, extra...); - } - - /// Uses cpp_function's return_value_policy by default - template - class_ &def_property_readonly_static(const char *name, const cpp_function &fget, const Extra& ...extra) { - return def_property_static(name, fget, nullptr, extra...); - } - - /// Uses return_value_policy::reference_internal by default - template - class_ &def_property(const char *name, const Getter &fget, const Setter &fset, const Extra& ...extra) { - return def_property(name, fget, cpp_function(method_adaptor(fset)), extra...); - } - template - class_ &def_property(const char *name, const Getter &fget, const cpp_function &fset, const Extra& ...extra) { - return def_property(name, cpp_function(method_adaptor(fget)), fset, - return_value_policy::reference_internal, extra...); - } - - /// Uses cpp_function's return_value_policy by default - template - class_ &def_property(const char *name, const cpp_function &fget, const cpp_function &fset, const Extra& ...extra) { - return def_property_static(name, fget, fset, is_method(*this), extra...); - } - - /// Uses return_value_policy::reference by default - template - class_ &def_property_static(const char *name, const Getter &fget, const cpp_function &fset, const Extra& ...extra) { - return def_property_static(name, cpp_function(fget), fset, return_value_policy::reference, extra...); - } - - /// Uses cpp_function's return_value_policy by default - template - class_ &def_property_static(const char *name, const cpp_function &fget, const cpp_function &fset, const Extra& ...extra) { - static_assert( 0 == detail::constexpr_sum(std::is_base_of::value...), - "Argument annotations are not allowed for properties"); - auto rec_fget = get_function_record(fget), rec_fset = get_function_record(fset); - auto *rec_active = rec_fget; - if (rec_fget) { - char *doc_prev = rec_fget->doc; /* 'extra' field may include a property-specific documentation string */ - detail::process_attributes::init(extra..., rec_fget); - if (rec_fget->doc && rec_fget->doc != doc_prev) { - free(doc_prev); - rec_fget->doc = strdup(rec_fget->doc); - } - } - if (rec_fset) { - char *doc_prev = rec_fset->doc; - detail::process_attributes::init(extra..., rec_fset); - if (rec_fset->doc && rec_fset->doc != doc_prev) { - free(doc_prev); - rec_fset->doc = strdup(rec_fset->doc); - } - if (! rec_active) rec_active = rec_fset; - } - def_property_static_impl(name, fget, fset, rec_active); - return *this; - } - -private: - /// Initialize holder object, variant 1: object derives from enable_shared_from_this - template - static void init_holder(detail::instance *inst, detail::value_and_holder &v_h, - const holder_type * /* unused */, const std::enable_shared_from_this * /* dummy */) { - try { - auto sh = std::dynamic_pointer_cast( - v_h.value_ptr()->shared_from_this()); - if (sh) { - new (std::addressof(v_h.holder())) holder_type(std::move(sh)); - v_h.set_holder_constructed(); - } - } catch (const std::bad_weak_ptr &) {} - - if (!v_h.holder_constructed() && inst->owned) { - new (std::addressof(v_h.holder())) holder_type(v_h.value_ptr()); - v_h.set_holder_constructed(); - } - } - - static void init_holder_from_existing(const detail::value_and_holder &v_h, - const holder_type *holder_ptr, std::true_type /*is_copy_constructible*/) { - new (std::addressof(v_h.holder())) holder_type(*reinterpret_cast(holder_ptr)); - } - - static void init_holder_from_existing(const detail::value_and_holder &v_h, - const holder_type *holder_ptr, std::false_type /*is_copy_constructible*/) { - new (std::addressof(v_h.holder())) holder_type(std::move(*const_cast(holder_ptr))); - } - - /// Initialize holder object, variant 2: try to construct from existing holder object, if possible - static void init_holder(detail::instance *inst, detail::value_and_holder &v_h, - const holder_type *holder_ptr, const void * /* dummy -- not enable_shared_from_this) */) { - if (holder_ptr) { - init_holder_from_existing(v_h, holder_ptr, std::is_copy_constructible()); - v_h.set_holder_constructed(); - } else if (inst->owned || detail::always_construct_holder::value) { - new (std::addressof(v_h.holder())) holder_type(v_h.value_ptr()); - v_h.set_holder_constructed(); - } - } - - /// Performs instance initialization including constructing a holder and registering the known - /// instance. Should be called as soon as the `type` value_ptr is set for an instance. Takes an - /// optional pointer to an existing holder to use; if not specified and the instance is - /// `.owned`, a new holder will be constructed to manage the value pointer. - static void init_instance(detail::instance *inst, const void *holder_ptr) { - auto v_h = inst->get_value_and_holder(detail::get_type_info(typeid(type))); - if (!v_h.instance_registered()) { - register_instance(inst, v_h.value_ptr(), v_h.type); - v_h.set_instance_registered(); - } - init_holder(inst, v_h, (const holder_type *) holder_ptr, v_h.value_ptr()); - } - - /// Deallocates an instance; via holder, if constructed; otherwise via operator delete. - static void dealloc(detail::value_and_holder &v_h) { - if (v_h.holder_constructed()) { - v_h.holder().~holder_type(); - v_h.set_holder_constructed(false); - } - else { - detail::call_operator_delete(v_h.value_ptr(), - v_h.type->type_size, - v_h.type->type_align - ); - } - v_h.value_ptr() = nullptr; - } - - static detail::function_record *get_function_record(handle h) { - h = detail::get_function(h); - return h ? (detail::function_record *) reinterpret_borrow(PyCFunction_GET_SELF(h.ptr())) - : nullptr; - } -}; - -/// Binds an existing constructor taking arguments Args... -template detail::initimpl::constructor init() { return {}; } -/// Like `init()`, but the instance is always constructed through the alias class (even -/// when not inheriting on the Python side). -template detail::initimpl::alias_constructor init_alias() { return {}; } - -/// Binds a factory function as a constructor -template > -Ret init(Func &&f) { return {std::forward(f)}; } - -/// Dual-argument factory function: the first function is called when no alias is needed, the second -/// when an alias is needed (i.e. due to python-side inheritance). Arguments must be identical. -template > -Ret init(CFunc &&c, AFunc &&a) { - return {std::forward(c), std::forward(a)}; -} - -/// Binds pickling functions `__getstate__` and `__setstate__` and ensures that the type -/// returned by `__getstate__` is the same as the argument accepted by `__setstate__`. -template -detail::initimpl::pickle_factory pickle(GetState &&g, SetState &&s) { - return {std::forward(g), std::forward(s)}; -} - -NAMESPACE_BEGIN(detail) -struct enum_base { - enum_base(handle base, handle parent) : m_base(base), m_parent(parent) { } - - PYBIND11_NOINLINE void init(bool is_arithmetic, bool is_convertible) { - m_base.attr("__entries") = dict(); - auto property = handle((PyObject *) &PyProperty_Type); - auto static_property = handle((PyObject *) get_internals().static_property_type); - - m_base.attr("__repr__") = cpp_function( - [](handle arg) -> str { - handle type = arg.get_type(); - object type_name = type.attr("__name__"); - dict entries = type.attr("__entries"); - for (const auto &kv : entries) { - object other = kv.second[int_(0)]; - if (other.equal(arg)) - return pybind11::str("{}.{}").format(type_name, kv.first); - } - return pybind11::str("{}.???").format(type_name); - }, is_method(m_base) - ); - - m_base.attr("name") = property(cpp_function( - [](handle arg) -> str { - dict entries = arg.get_type().attr("__entries"); - for (const auto &kv : entries) { - if (handle(kv.second[int_(0)]).equal(arg)) - return pybind11::str(kv.first); - } - return "???"; - }, is_method(m_base) - )); - - m_base.attr("__doc__") = static_property(cpp_function( - [](handle arg) -> std::string { - std::string docstring; - dict entries = arg.attr("__entries"); - if (((PyTypeObject *) arg.ptr())->tp_doc) - docstring += std::string(((PyTypeObject *) arg.ptr())->tp_doc) + "\n\n"; - docstring += "Members:"; - for (const auto &kv : entries) { - auto key = std::string(pybind11::str(kv.first)); - auto comment = kv.second[int_(1)]; - docstring += "\n\n " + key; - if (!comment.is_none()) - docstring += " : " + (std::string) pybind11::str(comment); - } - return docstring; - } - ), none(), none(), ""); - - m_base.attr("__members__") = static_property(cpp_function( - [](handle arg) -> dict { - dict entries = arg.attr("__entries"), m; - for (const auto &kv : entries) - m[kv.first] = kv.second[int_(0)]; - return m; - }), none(), none(), "" - ); - - #define PYBIND11_ENUM_OP_STRICT(op, expr, strict_behavior) \ - m_base.attr(op) = cpp_function( \ - [](object a, object b) { \ - if (!a.get_type().is(b.get_type())) \ - strict_behavior; \ - return expr; \ - }, \ - is_method(m_base)) - - #define PYBIND11_ENUM_OP_CONV(op, expr) \ - m_base.attr(op) = cpp_function( \ - [](object a_, object b_) { \ - int_ a(a_), b(b_); \ - return expr; \ - }, \ - is_method(m_base)) - - #define PYBIND11_ENUM_OP_CONV_LHS(op, expr) \ - m_base.attr(op) = cpp_function( \ - [](object a_, object b) { \ - int_ a(a_); \ - return expr; \ - }, \ - is_method(m_base)) - - if (is_convertible) { - PYBIND11_ENUM_OP_CONV_LHS("__eq__", !b.is_none() && a.equal(b)); - PYBIND11_ENUM_OP_CONV_LHS("__ne__", b.is_none() || !a.equal(b)); - - if (is_arithmetic) { - PYBIND11_ENUM_OP_CONV("__lt__", a < b); - PYBIND11_ENUM_OP_CONV("__gt__", a > b); - PYBIND11_ENUM_OP_CONV("__le__", a <= b); - PYBIND11_ENUM_OP_CONV("__ge__", a >= b); - PYBIND11_ENUM_OP_CONV("__and__", a & b); - PYBIND11_ENUM_OP_CONV("__rand__", a & b); - PYBIND11_ENUM_OP_CONV("__or__", a | b); - PYBIND11_ENUM_OP_CONV("__ror__", a | b); - PYBIND11_ENUM_OP_CONV("__xor__", a ^ b); - PYBIND11_ENUM_OP_CONV("__rxor__", a ^ b); - m_base.attr("__invert__") = cpp_function( - [](object arg) { return ~(int_(arg)); }, is_method(m_base)); - } - } else { - PYBIND11_ENUM_OP_STRICT("__eq__", int_(a).equal(int_(b)), return false); - PYBIND11_ENUM_OP_STRICT("__ne__", !int_(a).equal(int_(b)), return true); - - if (is_arithmetic) { - #define PYBIND11_THROW throw type_error("Expected an enumeration of matching type!"); - PYBIND11_ENUM_OP_STRICT("__lt__", int_(a) < int_(b), PYBIND11_THROW); - PYBIND11_ENUM_OP_STRICT("__gt__", int_(a) > int_(b), PYBIND11_THROW); - PYBIND11_ENUM_OP_STRICT("__le__", int_(a) <= int_(b), PYBIND11_THROW); - PYBIND11_ENUM_OP_STRICT("__ge__", int_(a) >= int_(b), PYBIND11_THROW); - #undef PYBIND11_THROW - } - } - - #undef PYBIND11_ENUM_OP_CONV_LHS - #undef PYBIND11_ENUM_OP_CONV - #undef PYBIND11_ENUM_OP_STRICT - - object getstate = cpp_function( - [](object arg) { return int_(arg); }, is_method(m_base)); - - m_base.attr("__getstate__") = getstate; - m_base.attr("__hash__") = getstate; - } - - PYBIND11_NOINLINE void value(char const* name_, object value, const char *doc = nullptr) { - dict entries = m_base.attr("__entries"); - str name(name_); - if (entries.contains(name)) { - std::string type_name = (std::string) str(m_base.attr("__name__")); - throw value_error(type_name + ": element \"" + std::string(name_) + "\" already exists!"); - } - - entries[name] = std::make_pair(value, doc); - m_base.attr(name) = value; - } - - PYBIND11_NOINLINE void export_values() { - dict entries = m_base.attr("__entries"); - for (const auto &kv : entries) - m_parent.attr(kv.first) = kv.second[int_(0)]; - } - - handle m_base; - handle m_parent; -}; - -NAMESPACE_END(detail) - -/// Binds C++ enumerations and enumeration classes to Python -template class enum_ : public class_ { -public: - using Base = class_; - using Base::def; - using Base::attr; - using Base::def_property_readonly; - using Base::def_property_readonly_static; - using Scalar = typename std::underlying_type::type; - - template - enum_(const handle &scope, const char *name, const Extra&... extra) - : class_(scope, name, extra...), m_base(*this, scope) { - constexpr bool is_arithmetic = detail::any_of...>::value; - constexpr bool is_convertible = std::is_convertible::value; - m_base.init(is_arithmetic, is_convertible); - - def(init([](Scalar i) { return static_cast(i); })); - def("__int__", [](Type value) { return (Scalar) value; }); - #if PY_MAJOR_VERSION < 3 - def("__long__", [](Type value) { return (Scalar) value; }); - #endif - #if PY_MAJOR_VERSION > 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 8) - def("__index__", [](Type value) { return (Scalar) value; }); - #endif - - cpp_function setstate( - [](Type &value, Scalar arg) { value = static_cast(arg); }, - is_method(*this)); - attr("__setstate__") = setstate; - } - - /// Export enumeration entries into the parent scope - enum_& export_values() { - m_base.export_values(); - return *this; - } - - /// Add an enumeration entry - enum_& value(char const* name, Type value, const char *doc = nullptr) { - m_base.value(name, pybind11::cast(value, return_value_policy::copy), doc); - return *this; - } - -private: - detail::enum_base m_base; -}; - -NAMESPACE_BEGIN(detail) - - -inline void keep_alive_impl(handle nurse, handle patient) { - if (!nurse || !patient) - pybind11_fail("Could not activate keep_alive!"); - - if (patient.is_none() || nurse.is_none()) - return; /* Nothing to keep alive or nothing to be kept alive by */ - - auto tinfo = all_type_info(Py_TYPE(nurse.ptr())); - if (!tinfo.empty()) { - /* It's a pybind-registered type, so we can store the patient in the - * internal list. */ - add_patient(nurse.ptr(), patient.ptr()); - } - else { - /* Fall back to clever approach based on weak references taken from - * Boost.Python. This is not used for pybind-registered types because - * the objects can be destroyed out-of-order in a GC pass. */ - cpp_function disable_lifesupport( - [patient](handle weakref) { patient.dec_ref(); weakref.dec_ref(); }); - - weakref wr(nurse, disable_lifesupport); - - patient.inc_ref(); /* reference patient and leak the weak reference */ - (void) wr.release(); - } -} - -PYBIND11_NOINLINE inline void keep_alive_impl(size_t Nurse, size_t Patient, function_call &call, handle ret) { - auto get_arg = [&](size_t n) { - if (n == 0) - return ret; - else if (n == 1 && call.init_self) - return call.init_self; - else if (n <= call.args.size()) - return call.args[n - 1]; - return handle(); - }; - - keep_alive_impl(get_arg(Nurse), get_arg(Patient)); -} - -inline std::pair all_type_info_get_cache(PyTypeObject *type) { - auto res = get_internals().registered_types_py -#ifdef __cpp_lib_unordered_map_try_emplace - .try_emplace(type); -#else - .emplace(type, std::vector()); -#endif - if (res.second) { - // New cache entry created; set up a weak reference to automatically remove it if the type - // gets destroyed: - weakref((PyObject *) type, cpp_function([type](handle wr) { - get_internals().registered_types_py.erase(type); - wr.dec_ref(); - })).release(); - } - - return res; -} - -template -struct iterator_state { - Iterator it; - Sentinel end; - bool first_or_done; -}; - -NAMESPACE_END(detail) - -/// Makes a python iterator from a first and past-the-end C++ InputIterator. -template ()), - typename... Extra> -iterator make_iterator(Iterator first, Sentinel last, Extra &&... extra) { - typedef detail::iterator_state state; - - if (!detail::get_type_info(typeid(state), false)) { - class_(handle(), "iterator", pybind11::module_local()) - .def("__iter__", [](state &s) -> state& { return s; }) - .def("__next__", [](state &s) -> ValueType { - if (!s.first_or_done) - ++s.it; - else - s.first_or_done = false; - if (s.it == s.end) { - s.first_or_done = true; - throw stop_iteration(); - } - return *s.it; - }, std::forward(extra)..., Policy); - } - - return cast(state{first, last, true}); -} - -/// Makes an python iterator over the keys (`.first`) of a iterator over pairs from a -/// first and past-the-end InputIterator. -template ()).first), - typename... Extra> -iterator make_key_iterator(Iterator first, Sentinel last, Extra &&... extra) { - typedef detail::iterator_state state; - - if (!detail::get_type_info(typeid(state), false)) { - class_(handle(), "iterator", pybind11::module_local()) - .def("__iter__", [](state &s) -> state& { return s; }) - .def("__next__", [](state &s) -> KeyType { - if (!s.first_or_done) - ++s.it; - else - s.first_or_done = false; - if (s.it == s.end) { - s.first_or_done = true; - throw stop_iteration(); - } - return (*s.it).first; - }, std::forward(extra)..., Policy); - } - - return cast(state{first, last, true}); -} - -/// Makes an iterator over values of an stl container or other container supporting -/// `std::begin()`/`std::end()` -template iterator make_iterator(Type &value, Extra&&... extra) { - return make_iterator(std::begin(value), std::end(value), extra...); -} - -/// Makes an iterator over the keys (`.first`) of a stl map-like container supporting -/// `std::begin()`/`std::end()` -template iterator make_key_iterator(Type &value, Extra&&... extra) { - return make_key_iterator(std::begin(value), std::end(value), extra...); -} - -template void implicitly_convertible() { - struct set_flag { - bool &flag; - set_flag(bool &flag) : flag(flag) { flag = true; } - ~set_flag() { flag = false; } - }; - auto implicit_caster = [](PyObject *obj, PyTypeObject *type) -> PyObject * { - static bool currently_used = false; - if (currently_used) // implicit conversions are non-reentrant - return nullptr; - set_flag flag_helper(currently_used); - if (!detail::make_caster().load(obj, false)) - return nullptr; - tuple args(1); - args[0] = obj; - PyObject *result = PyObject_Call((PyObject *) type, args.ptr(), nullptr); - if (result == nullptr) - PyErr_Clear(); - return result; - }; - - if (auto tinfo = detail::get_type_info(typeid(OutputType))) - tinfo->implicit_conversions.push_back(implicit_caster); - else - pybind11_fail("implicitly_convertible: Unable to find type " + type_id()); -} - -template -void register_exception_translator(ExceptionTranslator&& translator) { - detail::get_internals().registered_exception_translators.push_front( - std::forward(translator)); -} - -/** - * Wrapper to generate a new Python exception type. - * - * This should only be used with PyErr_SetString for now. - * It is not (yet) possible to use as a py::base. - * Template type argument is reserved for future use. - */ -template -class exception : public object { -public: - exception() = default; - exception(handle scope, const char *name, PyObject *base = PyExc_Exception) { - std::string full_name = scope.attr("__name__").cast() + - std::string(".") + name; - m_ptr = PyErr_NewException(const_cast(full_name.c_str()), base, NULL); - if (hasattr(scope, name)) - pybind11_fail("Error during initialization: multiple incompatible " - "definitions with name \"" + std::string(name) + "\""); - scope.attr(name) = *this; - } - - // Sets the current python exception to this exception object with the given message - void operator()(const char *message) { - PyErr_SetString(m_ptr, message); - } -}; - -NAMESPACE_BEGIN(detail) -// Returns a reference to a function-local static exception object used in the simple -// register_exception approach below. (It would be simpler to have the static local variable -// directly in register_exception, but that makes clang <3.5 segfault - issue #1349). -template -exception &get_exception_object() { static exception ex; return ex; } -NAMESPACE_END(detail) - -/** - * Registers a Python exception in `m` of the given `name` and installs an exception translator to - * translate the C++ exception to the created Python exception using the exceptions what() method. - * This is intended for simple exception translations; for more complex translation, register the - * exception object and translator directly. - */ -template -exception ®ister_exception(handle scope, - const char *name, - PyObject *base = PyExc_Exception) { - auto &ex = detail::get_exception_object(); - if (!ex) ex = exception(scope, name, base); - - register_exception_translator([](std::exception_ptr p) { - if (!p) return; - try { - std::rethrow_exception(p); - } catch (const CppException &e) { - detail::get_exception_object()(e.what()); - } - }); - return ex; -} - -NAMESPACE_BEGIN(detail) -PYBIND11_NOINLINE inline void print(tuple args, dict kwargs) { - auto strings = tuple(args.size()); - for (size_t i = 0; i < args.size(); ++i) { - strings[i] = str(args[i]); - } - auto sep = kwargs.contains("sep") ? kwargs["sep"] : cast(" "); - auto line = sep.attr("join")(strings); - - object file; - if (kwargs.contains("file")) { - file = kwargs["file"].cast(); - } else { - try { - file = module::import("sys").attr("stdout"); - } catch (const error_already_set &) { - /* If print() is called from code that is executed as - part of garbage collection during interpreter shutdown, - importing 'sys' can fail. Give up rather than crashing the - interpreter in this case. */ - return; - } - } - - auto write = file.attr("write"); - write(line); - write(kwargs.contains("end") ? kwargs["end"] : cast("\n")); - - if (kwargs.contains("flush") && kwargs["flush"].cast()) - file.attr("flush")(); -} -NAMESPACE_END(detail) - -template -void print(Args &&...args) { - auto c = detail::collect_arguments(std::forward(args)...); - detail::print(c.args(), c.kwargs()); -} - -#if defined(WITH_THREAD) && !defined(PYPY_VERSION) - -/* The functions below essentially reproduce the PyGILState_* API using a RAII - * pattern, but there are a few important differences: - * - * 1. When acquiring the GIL from an non-main thread during the finalization - * phase, the GILState API blindly terminates the calling thread, which - * is often not what is wanted. This API does not do this. - * - * 2. The gil_scoped_release function can optionally cut the relationship - * of a PyThreadState and its associated thread, which allows moving it to - * another thread (this is a fairly rare/advanced use case). - * - * 3. The reference count of an acquired thread state can be controlled. This - * can be handy to prevent cases where callbacks issued from an external - * thread would otherwise constantly construct and destroy thread state data - * structures. - * - * See the Python bindings of NanoGUI (http://github.com/wjakob/nanogui) for an - * example which uses features 2 and 3 to migrate the Python thread of - * execution to another thread (to run the event loop on the original thread, - * in this case). - */ - -class gil_scoped_acquire { -public: - PYBIND11_NOINLINE gil_scoped_acquire() { - auto const &internals = detail::get_internals(); - tstate = (PyThreadState *) PYBIND11_TLS_GET_VALUE(internals.tstate); - - if (!tstate) { - /* Check if the GIL was acquired using the PyGILState_* API instead (e.g. if - calling from a Python thread). Since we use a different key, this ensures - we don't create a new thread state and deadlock in PyEval_AcquireThread - below. Note we don't save this state with internals.tstate, since we don't - create it we would fail to clear it (its reference count should be > 0). */ - tstate = PyGILState_GetThisThreadState(); - } - - if (!tstate) { - tstate = PyThreadState_New(internals.istate); - #if !defined(NDEBUG) - if (!tstate) - pybind11_fail("scoped_acquire: could not create thread state!"); - #endif - tstate->gilstate_counter = 0; - PYBIND11_TLS_REPLACE_VALUE(internals.tstate, tstate); - } else { - release = detail::get_thread_state_unchecked() != tstate; - } - - if (release) { - /* Work around an annoying assertion in PyThreadState_Swap */ - #if defined(Py_DEBUG) - PyInterpreterState *interp = tstate->interp; - tstate->interp = nullptr; - #endif - PyEval_AcquireThread(tstate); - #if defined(Py_DEBUG) - tstate->interp = interp; - #endif - } - - inc_ref(); - } - - void inc_ref() { - ++tstate->gilstate_counter; - } - - PYBIND11_NOINLINE void dec_ref() { - --tstate->gilstate_counter; - #if !defined(NDEBUG) - if (detail::get_thread_state_unchecked() != tstate) - pybind11_fail("scoped_acquire::dec_ref(): thread state must be current!"); - if (tstate->gilstate_counter < 0) - pybind11_fail("scoped_acquire::dec_ref(): reference count underflow!"); - #endif - if (tstate->gilstate_counter == 0) { - #if !defined(NDEBUG) - if (!release) - pybind11_fail("scoped_acquire::dec_ref(): internal error!"); - #endif - PyThreadState_Clear(tstate); - PyThreadState_DeleteCurrent(); - PYBIND11_TLS_DELETE_VALUE(detail::get_internals().tstate); - release = false; - } - } - - PYBIND11_NOINLINE ~gil_scoped_acquire() { - dec_ref(); - if (release) - PyEval_SaveThread(); - } -private: - PyThreadState *tstate = nullptr; - bool release = true; -}; - -class gil_scoped_release { -public: - explicit gil_scoped_release(bool disassoc = false) : disassoc(disassoc) { - // `get_internals()` must be called here unconditionally in order to initialize - // `internals.tstate` for subsequent `gil_scoped_acquire` calls. Otherwise, an - // initialization race could occur as multiple threads try `gil_scoped_acquire`. - const auto &internals = detail::get_internals(); - tstate = PyEval_SaveThread(); - if (disassoc) { - auto key = internals.tstate; - PYBIND11_TLS_DELETE_VALUE(key); - } - } - ~gil_scoped_release() { - if (!tstate) - return; - PyEval_RestoreThread(tstate); - if (disassoc) { - auto key = detail::get_internals().tstate; - PYBIND11_TLS_REPLACE_VALUE(key, tstate); - } - } -private: - PyThreadState *tstate; - bool disassoc; -}; -#elif defined(PYPY_VERSION) -class gil_scoped_acquire { - PyGILState_STATE state; -public: - gil_scoped_acquire() { state = PyGILState_Ensure(); } - ~gil_scoped_acquire() { PyGILState_Release(state); } -}; - -class gil_scoped_release { - PyThreadState *state; -public: - gil_scoped_release() { state = PyEval_SaveThread(); } - ~gil_scoped_release() { PyEval_RestoreThread(state); } -}; -#else -class gil_scoped_acquire { }; -class gil_scoped_release { }; -#endif - -error_already_set::~error_already_set() { - if (m_type) { - gil_scoped_acquire gil; - error_scope scope; - m_type.release().dec_ref(); - m_value.release().dec_ref(); - m_trace.release().dec_ref(); - } -} - -inline function get_type_overload(const void *this_ptr, const detail::type_info *this_type, const char *name) { - handle self = detail::get_object_handle(this_ptr, this_type); - if (!self) - return function(); - handle type = self.get_type(); - auto key = std::make_pair(type.ptr(), name); - - /* Cache functions that aren't overloaded in Python to avoid - many costly Python dictionary lookups below */ - auto &cache = detail::get_internals().inactive_overload_cache; - if (cache.find(key) != cache.end()) - return function(); - - function overload = getattr(self, name, function()); - if (overload.is_cpp_function()) { - cache.insert(key); - return function(); - } - - /* Don't call dispatch code if invoked from overridden function. - Unfortunately this doesn't work on PyPy. */ -#if !defined(PYPY_VERSION) - PyFrameObject *frame = PyThreadState_Get()->frame; - if (frame && (std::string) str(frame->f_code->co_name) == name && - frame->f_code->co_argcount > 0) { - PyFrame_FastToLocals(frame); - PyObject *self_caller = PyDict_GetItem( - frame->f_locals, PyTuple_GET_ITEM(frame->f_code->co_varnames, 0)); - if (self_caller == self.ptr()) - return function(); - } -#else - /* PyPy currently doesn't provide a detailed cpyext emulation of - frame objects, so we have to emulate this using Python. This - is going to be slow..*/ - dict d; d["self"] = self; d["name"] = pybind11::str(name); - PyObject *result = PyRun_String( - "import inspect\n" - "frame = inspect.currentframe()\n" - "if frame is not None:\n" - " frame = frame.f_back\n" - " if frame is not None and str(frame.f_code.co_name) == name and " - "frame.f_code.co_argcount > 0:\n" - " self_caller = frame.f_locals[frame.f_code.co_varnames[0]]\n" - " if self_caller == self:\n" - " self = None\n", - Py_file_input, d.ptr(), d.ptr()); - if (result == nullptr) - throw error_already_set(); - if (d["self"].is_none()) - return function(); - Py_DECREF(result); -#endif - - return overload; -} - -/** \rst - Try to retrieve a python method by the provided name from the instance pointed to by the this_ptr. - - :this_ptr: The pointer to the object the overload should be retrieved for. This should be the first - non-trampoline class encountered in the inheritance chain. - :name: The name of the overloaded Python method to retrieve. - :return: The Python method by this name from the object or an empty function wrapper. - \endrst */ -template function get_overload(const T *this_ptr, const char *name) { - auto tinfo = detail::get_type_info(typeid(T)); - return tinfo ? get_type_overload(this_ptr, tinfo, name) : function(); -} - -#define PYBIND11_OVERLOAD_INT(ret_type, cname, name, ...) { \ - pybind11::gil_scoped_acquire gil; \ - pybind11::function overload = pybind11::get_overload(static_cast(this), name); \ - if (overload) { \ - auto o = overload(__VA_ARGS__); \ - if (pybind11::detail::cast_is_temporary_value_reference::value) { \ - static pybind11::detail::overload_caster_t caster; \ - return pybind11::detail::cast_ref(std::move(o), caster); \ - } \ - else return pybind11::detail::cast_safe(std::move(o)); \ - } \ - } - -/** \rst - Macro to populate the virtual method in the trampoline class. This macro tries to look up a method named 'fn' - from the Python side, deals with the :ref:`gil` and necessary argument conversions to call this method and return - the appropriate type. See :ref:`overriding_virtuals` for more information. This macro should be used when the method - name in C is not the same as the method name in Python. For example with `__str__`. - - .. code-block:: cpp - - std::string toString() override { - PYBIND11_OVERLOAD_NAME( - std::string, // Return type (ret_type) - Animal, // Parent class (cname) - toString, // Name of function in C++ (name) - "__str__", // Name of method in Python (fn) - ); - } -\endrst */ -#define PYBIND11_OVERLOAD_NAME(ret_type, cname, name, fn, ...) \ - PYBIND11_OVERLOAD_INT(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__) \ - return cname::fn(__VA_ARGS__) - -/** \rst - Macro for pure virtual functions, this function is identical to :c:macro:`PYBIND11_OVERLOAD_NAME`, except that it - throws if no overload can be found. -\endrst */ -#define PYBIND11_OVERLOAD_PURE_NAME(ret_type, cname, name, fn, ...) \ - PYBIND11_OVERLOAD_INT(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), name, __VA_ARGS__) \ - pybind11::pybind11_fail("Tried to call pure virtual function \"" PYBIND11_STRINGIFY(cname) "::" name "\""); - -/** \rst - Macro to populate the virtual method in the trampoline class. This macro tries to look up the method - from the Python side, deals with the :ref:`gil` and necessary argument conversions to call this method and return - the appropriate type. This macro should be used if the method name in C and in Python are identical. - See :ref:`overriding_virtuals` for more information. - - .. code-block:: cpp - - class PyAnimal : public Animal { - public: - // Inherit the constructors - using Animal::Animal; - - // Trampoline (need one for each virtual function) - std::string go(int n_times) override { - PYBIND11_OVERLOAD_PURE( - std::string, // Return type (ret_type) - Animal, // Parent class (cname) - go, // Name of function in C++ (must match Python name) (fn) - n_times // Argument(s) (...) - ); - } - }; -\endrst */ -#define PYBIND11_OVERLOAD(ret_type, cname, fn, ...) \ - PYBIND11_OVERLOAD_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__) - -/** \rst - Macro for pure virtual functions, this function is identical to :c:macro:`PYBIND11_OVERLOAD`, except that it throws - if no overload can be found. -\endrst */ -#define PYBIND11_OVERLOAD_PURE(ret_type, cname, fn, ...) \ - PYBIND11_OVERLOAD_PURE_NAME(PYBIND11_TYPE(ret_type), PYBIND11_TYPE(cname), #fn, fn, __VA_ARGS__) - -NAMESPACE_END(PYBIND11_NAMESPACE) - -#if defined(_MSC_VER) && !defined(__INTEL_COMPILER) -# pragma warning(pop) -#elif defined(__GNUG__) && !defined(__clang__) -# pragma GCC diagnostic pop -#endif diff --git a/pybind11/include/pybind11/pytypes.h b/pybind11/include/pybind11/pytypes.h deleted file mode 100644 index 96eab96..0000000 --- a/pybind11/include/pybind11/pytypes.h +++ /dev/null @@ -1,1484 +0,0 @@ -/* - pybind11/pytypes.h: Convenience wrapper classes for basic Python types - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "detail/common.h" -#include "buffer_info.h" -#include -#include - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) - -/* A few forward declarations */ -class handle; class object; -class str; class iterator; -struct arg; struct arg_v; - -NAMESPACE_BEGIN(detail) -class args_proxy; -inline bool isinstance_generic(handle obj, const std::type_info &tp); - -// Accessor forward declarations -template class accessor; -namespace accessor_policies { - struct obj_attr; - struct str_attr; - struct generic_item; - struct sequence_item; - struct list_item; - struct tuple_item; -} -using obj_attr_accessor = accessor; -using str_attr_accessor = accessor; -using item_accessor = accessor; -using sequence_accessor = accessor; -using list_accessor = accessor; -using tuple_accessor = accessor; - -/// Tag and check to identify a class which implements the Python object API -class pyobject_tag { }; -template using is_pyobject = std::is_base_of>; - -/** \rst - A mixin class which adds common functions to `handle`, `object` and various accessors. - The only requirement for `Derived` is to implement ``PyObject *Derived::ptr() const``. -\endrst */ -template -class object_api : public pyobject_tag { - const Derived &derived() const { return static_cast(*this); } - -public: - /** \rst - Return an iterator equivalent to calling ``iter()`` in Python. The object - must be a collection which supports the iteration protocol. - \endrst */ - iterator begin() const; - /// Return a sentinel which ends iteration. - iterator end() const; - - /** \rst - Return an internal functor to invoke the object's sequence protocol. Casting - the returned ``detail::item_accessor`` instance to a `handle` or `object` - subclass causes a corresponding call to ``__getitem__``. Assigning a `handle` - or `object` subclass causes a call to ``__setitem__``. - \endrst */ - item_accessor operator[](handle key) const; - /// See above (the only difference is that they key is provided as a string literal) - item_accessor operator[](const char *key) const; - - /** \rst - Return an internal functor to access the object's attributes. Casting the - returned ``detail::obj_attr_accessor`` instance to a `handle` or `object` - subclass causes a corresponding call to ``getattr``. Assigning a `handle` - or `object` subclass causes a call to ``setattr``. - \endrst */ - obj_attr_accessor attr(handle key) const; - /// See above (the only difference is that they key is provided as a string literal) - str_attr_accessor attr(const char *key) const; - - /** \rst - Matches * unpacking in Python, e.g. to unpack arguments out of a ``tuple`` - or ``list`` for a function call. Applying another * to the result yields - ** unpacking, e.g. to unpack a dict as function keyword arguments. - See :ref:`calling_python_functions`. - \endrst */ - args_proxy operator*() const; - - /// Check if the given item is contained within this object, i.e. ``item in obj``. - template bool contains(T &&item) const; - - /** \rst - Assuming the Python object is a function or implements the ``__call__`` - protocol, ``operator()`` invokes the underlying function, passing an - arbitrary set of parameters. The result is returned as a `object` and - may need to be converted back into a Python object using `handle::cast()`. - - When some of the arguments cannot be converted to Python objects, the - function will throw a `cast_error` exception. When the Python function - call fails, a `error_already_set` exception is thrown. - \endrst */ - template - object operator()(Args &&...args) const; - template - PYBIND11_DEPRECATED("call(...) was deprecated in favor of operator()(...)") - object call(Args&&... args) const; - - /// Equivalent to ``obj is other`` in Python. - bool is(object_api const& other) const { return derived().ptr() == other.derived().ptr(); } - /// Equivalent to ``obj is None`` in Python. - bool is_none() const { return derived().ptr() == Py_None; } - /// Equivalent to obj == other in Python - bool equal(object_api const &other) const { return rich_compare(other, Py_EQ); } - bool not_equal(object_api const &other) const { return rich_compare(other, Py_NE); } - bool operator<(object_api const &other) const { return rich_compare(other, Py_LT); } - bool operator<=(object_api const &other) const { return rich_compare(other, Py_LE); } - bool operator>(object_api const &other) const { return rich_compare(other, Py_GT); } - bool operator>=(object_api const &other) const { return rich_compare(other, Py_GE); } - - object operator-() const; - object operator~() const; - object operator+(object_api const &other) const; - object operator+=(object_api const &other) const; - object operator-(object_api const &other) const; - object operator-=(object_api const &other) const; - object operator*(object_api const &other) const; - object operator*=(object_api const &other) const; - object operator/(object_api const &other) const; - object operator/=(object_api const &other) const; - object operator|(object_api const &other) const; - object operator|=(object_api const &other) const; - object operator&(object_api const &other) const; - object operator&=(object_api const &other) const; - object operator^(object_api const &other) const; - object operator^=(object_api const &other) const; - object operator<<(object_api const &other) const; - object operator<<=(object_api const &other) const; - object operator>>(object_api const &other) const; - object operator>>=(object_api const &other) const; - - PYBIND11_DEPRECATED("Use py::str(obj) instead") - pybind11::str str() const; - - /// Get or set the object's docstring, i.e. ``obj.__doc__``. - str_attr_accessor doc() const; - - /// Return the object's current reference count - int ref_count() const { return static_cast(Py_REFCNT(derived().ptr())); } - /// Return a handle to the Python type object underlying the instance - handle get_type() const; - -private: - bool rich_compare(object_api const &other, int value) const; -}; - -NAMESPACE_END(detail) - -/** \rst - Holds a reference to a Python object (no reference counting) - - The `handle` class is a thin wrapper around an arbitrary Python object (i.e. a - ``PyObject *`` in Python's C API). It does not perform any automatic reference - counting and merely provides a basic C++ interface to various Python API functions. - - .. seealso:: - The `object` class inherits from `handle` and adds automatic reference - counting features. -\endrst */ -class handle : public detail::object_api { -public: - /// The default constructor creates a handle with a ``nullptr``-valued pointer - handle() = default; - /// Creates a ``handle`` from the given raw Python object pointer - handle(PyObject *ptr) : m_ptr(ptr) { } // Allow implicit conversion from PyObject* - - /// Return the underlying ``PyObject *`` pointer - PyObject *ptr() const { return m_ptr; } - PyObject *&ptr() { return m_ptr; } - - /** \rst - Manually increase the reference count of the Python object. Usually, it is - preferable to use the `object` class which derives from `handle` and calls - this function automatically. Returns a reference to itself. - \endrst */ - const handle& inc_ref() const & { Py_XINCREF(m_ptr); return *this; } - - /** \rst - Manually decrease the reference count of the Python object. Usually, it is - preferable to use the `object` class which derives from `handle` and calls - this function automatically. Returns a reference to itself. - \endrst */ - const handle& dec_ref() const & { Py_XDECREF(m_ptr); return *this; } - - /** \rst - Attempt to cast the Python object into the given C++ type. A `cast_error` - will be throw upon failure. - \endrst */ - template T cast() const; - /// Return ``true`` when the `handle` wraps a valid Python object - explicit operator bool() const { return m_ptr != nullptr; } - /** \rst - Deprecated: Check that the underlying pointers are the same. - Equivalent to ``obj1 is obj2`` in Python. - \endrst */ - PYBIND11_DEPRECATED("Use obj1.is(obj2) instead") - bool operator==(const handle &h) const { return m_ptr == h.m_ptr; } - PYBIND11_DEPRECATED("Use !obj1.is(obj2) instead") - bool operator!=(const handle &h) const { return m_ptr != h.m_ptr; } - PYBIND11_DEPRECATED("Use handle::operator bool() instead") - bool check() const { return m_ptr != nullptr; } -protected: - PyObject *m_ptr = nullptr; -}; - -/** \rst - Holds a reference to a Python object (with reference counting) - - Like `handle`, the `object` class is a thin wrapper around an arbitrary Python - object (i.e. a ``PyObject *`` in Python's C API). In contrast to `handle`, it - optionally increases the object's reference count upon construction, and it - *always* decreases the reference count when the `object` instance goes out of - scope and is destructed. When using `object` instances consistently, it is much - easier to get reference counting right at the first attempt. -\endrst */ -class object : public handle { -public: - object() = default; - PYBIND11_DEPRECATED("Use reinterpret_borrow() or reinterpret_steal()") - object(handle h, bool is_borrowed) : handle(h) { if (is_borrowed) inc_ref(); } - /// Copy constructor; always increases the reference count - object(const object &o) : handle(o) { inc_ref(); } - /// Move constructor; steals the object from ``other`` and preserves its reference count - object(object &&other) noexcept { m_ptr = other.m_ptr; other.m_ptr = nullptr; } - /// Destructor; automatically calls `handle::dec_ref()` - ~object() { dec_ref(); } - - /** \rst - Resets the internal pointer to ``nullptr`` without without decreasing the - object's reference count. The function returns a raw handle to the original - Python object. - \endrst */ - handle release() { - PyObject *tmp = m_ptr; - m_ptr = nullptr; - return handle(tmp); - } - - object& operator=(const object &other) { - other.inc_ref(); - dec_ref(); - m_ptr = other.m_ptr; - return *this; - } - - object& operator=(object &&other) noexcept { - if (this != &other) { - handle temp(m_ptr); - m_ptr = other.m_ptr; - other.m_ptr = nullptr; - temp.dec_ref(); - } - return *this; - } - - // Calling cast() on an object lvalue just copies (via handle::cast) - template T cast() const &; - // Calling on an object rvalue does a move, if needed and/or possible - template T cast() &&; - -protected: - // Tags for choosing constructors from raw PyObject * - struct borrowed_t { }; - struct stolen_t { }; - - template friend T reinterpret_borrow(handle); - template friend T reinterpret_steal(handle); - -public: - // Only accessible from derived classes and the reinterpret_* functions - object(handle h, borrowed_t) : handle(h) { inc_ref(); } - object(handle h, stolen_t) : handle(h) { } -}; - -/** \rst - Declare that a `handle` or ``PyObject *`` is a certain type and borrow the reference. - The target type ``T`` must be `object` or one of its derived classes. The function - doesn't do any conversions or checks. It's up to the user to make sure that the - target type is correct. - - .. code-block:: cpp - - PyObject *p = PyList_GetItem(obj, index); - py::object o = reinterpret_borrow(p); - // or - py::tuple t = reinterpret_borrow(p); // <-- `p` must be already be a `tuple` -\endrst */ -template T reinterpret_borrow(handle h) { return {h, object::borrowed_t{}}; } - -/** \rst - Like `reinterpret_borrow`, but steals the reference. - - .. code-block:: cpp - - PyObject *p = PyObject_Str(obj); - py::str s = reinterpret_steal(p); // <-- `p` must be already be a `str` -\endrst */ -template T reinterpret_steal(handle h) { return {h, object::stolen_t{}}; } - -NAMESPACE_BEGIN(detail) -inline std::string error_string(); -NAMESPACE_END(detail) - -/// Fetch and hold an error which was already set in Python. An instance of this is typically -/// thrown to propagate python-side errors back through C++ which can either be caught manually or -/// else falls back to the function dispatcher (which then raises the captured error back to -/// python). -class error_already_set : public std::runtime_error { -public: - /// Constructs a new exception from the current Python error indicator, if any. The current - /// Python error indicator will be cleared. - error_already_set() : std::runtime_error(detail::error_string()) { - PyErr_Fetch(&m_type.ptr(), &m_value.ptr(), &m_trace.ptr()); - } - - error_already_set(const error_already_set &) = default; - error_already_set(error_already_set &&) = default; - - inline ~error_already_set(); - - /// Give the currently-held error back to Python, if any. If there is currently a Python error - /// already set it is cleared first. After this call, the current object no longer stores the - /// error variables (but the `.what()` string is still available). - void restore() { PyErr_Restore(m_type.release().ptr(), m_value.release().ptr(), m_trace.release().ptr()); } - - // Does nothing; provided for backwards compatibility. - PYBIND11_DEPRECATED("Use of error_already_set.clear() is deprecated") - void clear() {} - - /// Check if the currently trapped error type matches the given Python exception class (or a - /// subclass thereof). May also be passed a tuple to search for any exception class matches in - /// the given tuple. - bool matches(handle exc) const { return PyErr_GivenExceptionMatches(m_type.ptr(), exc.ptr()); } - - const object& type() const { return m_type; } - const object& value() const { return m_value; } - const object& trace() const { return m_trace; } - -private: - object m_type, m_value, m_trace; -}; - -/** \defgroup python_builtins _ - Unless stated otherwise, the following C++ functions behave the same - as their Python counterparts. - */ - -/** \ingroup python_builtins - \rst - Return true if ``obj`` is an instance of ``T``. Type ``T`` must be a subclass of - `object` or a class which was exposed to Python as ``py::class_``. -\endrst */ -template ::value, int> = 0> -bool isinstance(handle obj) { return T::check_(obj); } - -template ::value, int> = 0> -bool isinstance(handle obj) { return detail::isinstance_generic(obj, typeid(T)); } - -template <> inline bool isinstance(handle obj) = delete; -template <> inline bool isinstance(handle obj) { return obj.ptr() != nullptr; } - -/// \ingroup python_builtins -/// Return true if ``obj`` is an instance of the ``type``. -inline bool isinstance(handle obj, handle type) { - const auto result = PyObject_IsInstance(obj.ptr(), type.ptr()); - if (result == -1) - throw error_already_set(); - return result != 0; -} - -/// \addtogroup python_builtins -/// @{ -inline bool hasattr(handle obj, handle name) { - return PyObject_HasAttr(obj.ptr(), name.ptr()) == 1; -} - -inline bool hasattr(handle obj, const char *name) { - return PyObject_HasAttrString(obj.ptr(), name) == 1; -} - -inline void delattr(handle obj, handle name) { - if (PyObject_DelAttr(obj.ptr(), name.ptr()) != 0) { throw error_already_set(); } -} - -inline void delattr(handle obj, const char *name) { - if (PyObject_DelAttrString(obj.ptr(), name) != 0) { throw error_already_set(); } -} - -inline object getattr(handle obj, handle name) { - PyObject *result = PyObject_GetAttr(obj.ptr(), name.ptr()); - if (!result) { throw error_already_set(); } - return reinterpret_steal(result); -} - -inline object getattr(handle obj, const char *name) { - PyObject *result = PyObject_GetAttrString(obj.ptr(), name); - if (!result) { throw error_already_set(); } - return reinterpret_steal(result); -} - -inline object getattr(handle obj, handle name, handle default_) { - if (PyObject *result = PyObject_GetAttr(obj.ptr(), name.ptr())) { - return reinterpret_steal(result); - } else { - PyErr_Clear(); - return reinterpret_borrow(default_); - } -} - -inline object getattr(handle obj, const char *name, handle default_) { - if (PyObject *result = PyObject_GetAttrString(obj.ptr(), name)) { - return reinterpret_steal(result); - } else { - PyErr_Clear(); - return reinterpret_borrow(default_); - } -} - -inline void setattr(handle obj, handle name, handle value) { - if (PyObject_SetAttr(obj.ptr(), name.ptr(), value.ptr()) != 0) { throw error_already_set(); } -} - -inline void setattr(handle obj, const char *name, handle value) { - if (PyObject_SetAttrString(obj.ptr(), name, value.ptr()) != 0) { throw error_already_set(); } -} - -inline ssize_t hash(handle obj) { - auto h = PyObject_Hash(obj.ptr()); - if (h == -1) { throw error_already_set(); } - return h; -} - -/// @} python_builtins - -NAMESPACE_BEGIN(detail) -inline handle get_function(handle value) { - if (value) { -#if PY_MAJOR_VERSION >= 3 - if (PyInstanceMethod_Check(value.ptr())) - value = PyInstanceMethod_GET_FUNCTION(value.ptr()); - else -#endif - if (PyMethod_Check(value.ptr())) - value = PyMethod_GET_FUNCTION(value.ptr()); - } - return value; -} - -// Helper aliases/functions to support implicit casting of values given to python accessors/methods. -// When given a pyobject, this simply returns the pyobject as-is; for other C++ type, the value goes -// through pybind11::cast(obj) to convert it to an `object`. -template ::value, int> = 0> -auto object_or_cast(T &&o) -> decltype(std::forward(o)) { return std::forward(o); } -// The following casting version is implemented in cast.h: -template ::value, int> = 0> -object object_or_cast(T &&o); -// Match a PyObject*, which we want to convert directly to handle via its converting constructor -inline handle object_or_cast(PyObject *ptr) { return ptr; } - -template -class accessor : public object_api> { - using key_type = typename Policy::key_type; - -public: - accessor(handle obj, key_type key) : obj(obj), key(std::move(key)) { } - accessor(const accessor &) = default; - accessor(accessor &&) = default; - - // accessor overload required to override default assignment operator (templates are not allowed - // to replace default compiler-generated assignments). - void operator=(const accessor &a) && { std::move(*this).operator=(handle(a)); } - void operator=(const accessor &a) & { operator=(handle(a)); } - - template void operator=(T &&value) && { - Policy::set(obj, key, object_or_cast(std::forward(value))); - } - template void operator=(T &&value) & { - get_cache() = reinterpret_borrow(object_or_cast(std::forward(value))); - } - - template - PYBIND11_DEPRECATED("Use of obj.attr(...) as bool is deprecated in favor of pybind11::hasattr(obj, ...)") - explicit operator enable_if_t::value || - std::is_same::value, bool>() const { - return hasattr(obj, key); - } - template - PYBIND11_DEPRECATED("Use of obj[key] as bool is deprecated in favor of obj.contains(key)") - explicit operator enable_if_t::value, bool>() const { - return obj.contains(key); - } - - operator object() const { return get_cache(); } - PyObject *ptr() const { return get_cache().ptr(); } - template T cast() const { return get_cache().template cast(); } - -private: - object &get_cache() const { - if (!cache) { cache = Policy::get(obj, key); } - return cache; - } - -private: - handle obj; - key_type key; - mutable object cache; -}; - -NAMESPACE_BEGIN(accessor_policies) -struct obj_attr { - using key_type = object; - static object get(handle obj, handle key) { return getattr(obj, key); } - static void set(handle obj, handle key, handle val) { setattr(obj, key, val); } -}; - -struct str_attr { - using key_type = const char *; - static object get(handle obj, const char *key) { return getattr(obj, key); } - static void set(handle obj, const char *key, handle val) { setattr(obj, key, val); } -}; - -struct generic_item { - using key_type = object; - - static object get(handle obj, handle key) { - PyObject *result = PyObject_GetItem(obj.ptr(), key.ptr()); - if (!result) { throw error_already_set(); } - return reinterpret_steal(result); - } - - static void set(handle obj, handle key, handle val) { - if (PyObject_SetItem(obj.ptr(), key.ptr(), val.ptr()) != 0) { throw error_already_set(); } - } -}; - -struct sequence_item { - using key_type = size_t; - - static object get(handle obj, size_t index) { - PyObject *result = PySequence_GetItem(obj.ptr(), static_cast(index)); - if (!result) { throw error_already_set(); } - return reinterpret_steal(result); - } - - static void set(handle obj, size_t index, handle val) { - // PySequence_SetItem does not steal a reference to 'val' - if (PySequence_SetItem(obj.ptr(), static_cast(index), val.ptr()) != 0) { - throw error_already_set(); - } - } -}; - -struct list_item { - using key_type = size_t; - - static object get(handle obj, size_t index) { - PyObject *result = PyList_GetItem(obj.ptr(), static_cast(index)); - if (!result) { throw error_already_set(); } - return reinterpret_borrow(result); - } - - static void set(handle obj, size_t index, handle val) { - // PyList_SetItem steals a reference to 'val' - if (PyList_SetItem(obj.ptr(), static_cast(index), val.inc_ref().ptr()) != 0) { - throw error_already_set(); - } - } -}; - -struct tuple_item { - using key_type = size_t; - - static object get(handle obj, size_t index) { - PyObject *result = PyTuple_GetItem(obj.ptr(), static_cast(index)); - if (!result) { throw error_already_set(); } - return reinterpret_borrow(result); - } - - static void set(handle obj, size_t index, handle val) { - // PyTuple_SetItem steals a reference to 'val' - if (PyTuple_SetItem(obj.ptr(), static_cast(index), val.inc_ref().ptr()) != 0) { - throw error_already_set(); - } - } -}; -NAMESPACE_END(accessor_policies) - -/// STL iterator template used for tuple, list, sequence and dict -template -class generic_iterator : public Policy { - using It = generic_iterator; - -public: - using difference_type = ssize_t; - using iterator_category = typename Policy::iterator_category; - using value_type = typename Policy::value_type; - using reference = typename Policy::reference; - using pointer = typename Policy::pointer; - - generic_iterator() = default; - generic_iterator(handle seq, ssize_t index) : Policy(seq, index) { } - - reference operator*() const { return Policy::dereference(); } - reference operator[](difference_type n) const { return *(*this + n); } - pointer operator->() const { return **this; } - - It &operator++() { Policy::increment(); return *this; } - It operator++(int) { auto copy = *this; Policy::increment(); return copy; } - It &operator--() { Policy::decrement(); return *this; } - It operator--(int) { auto copy = *this; Policy::decrement(); return copy; } - It &operator+=(difference_type n) { Policy::advance(n); return *this; } - It &operator-=(difference_type n) { Policy::advance(-n); return *this; } - - friend It operator+(const It &a, difference_type n) { auto copy = a; return copy += n; } - friend It operator+(difference_type n, const It &b) { return b + n; } - friend It operator-(const It &a, difference_type n) { auto copy = a; return copy -= n; } - friend difference_type operator-(const It &a, const It &b) { return a.distance_to(b); } - - friend bool operator==(const It &a, const It &b) { return a.equal(b); } - friend bool operator!=(const It &a, const It &b) { return !(a == b); } - friend bool operator< (const It &a, const It &b) { return b - a > 0; } - friend bool operator> (const It &a, const It &b) { return b < a; } - friend bool operator>=(const It &a, const It &b) { return !(a < b); } - friend bool operator<=(const It &a, const It &b) { return !(a > b); } -}; - -NAMESPACE_BEGIN(iterator_policies) -/// Quick proxy class needed to implement ``operator->`` for iterators which can't return pointers -template -struct arrow_proxy { - T value; - - arrow_proxy(T &&value) : value(std::move(value)) { } - T *operator->() const { return &value; } -}; - -/// Lightweight iterator policy using just a simple pointer: see ``PySequence_Fast_ITEMS`` -class sequence_fast_readonly { -protected: - using iterator_category = std::random_access_iterator_tag; - using value_type = handle; - using reference = const handle; - using pointer = arrow_proxy; - - sequence_fast_readonly(handle obj, ssize_t n) : ptr(PySequence_Fast_ITEMS(obj.ptr()) + n) { } - - reference dereference() const { return *ptr; } - void increment() { ++ptr; } - void decrement() { --ptr; } - void advance(ssize_t n) { ptr += n; } - bool equal(const sequence_fast_readonly &b) const { return ptr == b.ptr; } - ssize_t distance_to(const sequence_fast_readonly &b) const { return ptr - b.ptr; } - -private: - PyObject **ptr; -}; - -/// Full read and write access using the sequence protocol: see ``detail::sequence_accessor`` -class sequence_slow_readwrite { -protected: - using iterator_category = std::random_access_iterator_tag; - using value_type = object; - using reference = sequence_accessor; - using pointer = arrow_proxy; - - sequence_slow_readwrite(handle obj, ssize_t index) : obj(obj), index(index) { } - - reference dereference() const { return {obj, static_cast(index)}; } - void increment() { ++index; } - void decrement() { --index; } - void advance(ssize_t n) { index += n; } - bool equal(const sequence_slow_readwrite &b) const { return index == b.index; } - ssize_t distance_to(const sequence_slow_readwrite &b) const { return index - b.index; } - -private: - handle obj; - ssize_t index; -}; - -/// Python's dictionary protocol permits this to be a forward iterator -class dict_readonly { -protected: - using iterator_category = std::forward_iterator_tag; - using value_type = std::pair; - using reference = const value_type; - using pointer = arrow_proxy; - - dict_readonly() = default; - dict_readonly(handle obj, ssize_t pos) : obj(obj), pos(pos) { increment(); } - - reference dereference() const { return {key, value}; } - void increment() { if (!PyDict_Next(obj.ptr(), &pos, &key, &value)) { pos = -1; } } - bool equal(const dict_readonly &b) const { return pos == b.pos; } - -private: - handle obj; - PyObject *key = nullptr, *value = nullptr; - ssize_t pos = -1; -}; -NAMESPACE_END(iterator_policies) - -#if !defined(PYPY_VERSION) -using tuple_iterator = generic_iterator; -using list_iterator = generic_iterator; -#else -using tuple_iterator = generic_iterator; -using list_iterator = generic_iterator; -#endif - -using sequence_iterator = generic_iterator; -using dict_iterator = generic_iterator; - -inline bool PyIterable_Check(PyObject *obj) { - PyObject *iter = PyObject_GetIter(obj); - if (iter) { - Py_DECREF(iter); - return true; - } else { - PyErr_Clear(); - return false; - } -} - -inline bool PyNone_Check(PyObject *o) { return o == Py_None; } -#if PY_MAJOR_VERSION >= 3 -inline bool PyEllipsis_Check(PyObject *o) { return o == Py_Ellipsis; } -#endif - -inline bool PyUnicode_Check_Permissive(PyObject *o) { return PyUnicode_Check(o) || PYBIND11_BYTES_CHECK(o); } - -inline bool PyStaticMethod_Check(PyObject *o) { return o->ob_type == &PyStaticMethod_Type; } - -class kwargs_proxy : public handle { -public: - explicit kwargs_proxy(handle h) : handle(h) { } -}; - -class args_proxy : public handle { -public: - explicit args_proxy(handle h) : handle(h) { } - kwargs_proxy operator*() const { return kwargs_proxy(*this); } -}; - -/// Python argument categories (using PEP 448 terms) -template using is_keyword = std::is_base_of; -template using is_s_unpacking = std::is_same; // * unpacking -template using is_ds_unpacking = std::is_same; // ** unpacking -template using is_positional = satisfies_none_of; -template using is_keyword_or_ds = satisfies_any_of; - -// Call argument collector forward declarations -template -class simple_collector; -template -class unpacking_collector; - -NAMESPACE_END(detail) - -// TODO: After the deprecated constructors are removed, this macro can be simplified by -// inheriting ctors: `using Parent::Parent`. It's not an option right now because -// the `using` statement triggers the parent deprecation warning even if the ctor -// isn't even used. -#define PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \ - public: \ - PYBIND11_DEPRECATED("Use reinterpret_borrow<"#Name">() or reinterpret_steal<"#Name">()") \ - Name(handle h, bool is_borrowed) : Parent(is_borrowed ? Parent(h, borrowed_t{}) : Parent(h, stolen_t{})) { } \ - Name(handle h, borrowed_t) : Parent(h, borrowed_t{}) { } \ - Name(handle h, stolen_t) : Parent(h, stolen_t{}) { } \ - PYBIND11_DEPRECATED("Use py::isinstance(obj) instead") \ - bool check() const { return m_ptr != nullptr && (bool) CheckFun(m_ptr); } \ - static bool check_(handle h) { return h.ptr() != nullptr && CheckFun(h.ptr()); } - -#define PYBIND11_OBJECT_CVT(Name, Parent, CheckFun, ConvertFun) \ - PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \ - /* This is deliberately not 'explicit' to allow implicit conversion from object: */ \ - Name(const object &o) \ - : Parent(check_(o) ? o.inc_ref().ptr() : ConvertFun(o.ptr()), stolen_t{}) \ - { if (!m_ptr) throw error_already_set(); } \ - Name(object &&o) \ - : Parent(check_(o) ? o.release().ptr() : ConvertFun(o.ptr()), stolen_t{}) \ - { if (!m_ptr) throw error_already_set(); } \ - template \ - Name(const ::pybind11::detail::accessor &a) : Name(object(a)) { } - -#define PYBIND11_OBJECT(Name, Parent, CheckFun) \ - PYBIND11_OBJECT_COMMON(Name, Parent, CheckFun) \ - /* This is deliberately not 'explicit' to allow implicit conversion from object: */ \ - Name(const object &o) : Parent(o) { } \ - Name(object &&o) : Parent(std::move(o)) { } - -#define PYBIND11_OBJECT_DEFAULT(Name, Parent, CheckFun) \ - PYBIND11_OBJECT(Name, Parent, CheckFun) \ - Name() : Parent() { } - -/// \addtogroup pytypes -/// @{ - -/** \rst - Wraps a Python iterator so that it can also be used as a C++ input iterator - - Caveat: copying an iterator does not (and cannot) clone the internal - state of the Python iterable. This also applies to the post-increment - operator. This iterator should only be used to retrieve the current - value using ``operator*()``. -\endrst */ -class iterator : public object { -public: - using iterator_category = std::input_iterator_tag; - using difference_type = ssize_t; - using value_type = handle; - using reference = const handle; - using pointer = const handle *; - - PYBIND11_OBJECT_DEFAULT(iterator, object, PyIter_Check) - - iterator& operator++() { - advance(); - return *this; - } - - iterator operator++(int) { - auto rv = *this; - advance(); - return rv; - } - - reference operator*() const { - if (m_ptr && !value.ptr()) { - auto& self = const_cast(*this); - self.advance(); - } - return value; - } - - pointer operator->() const { operator*(); return &value; } - - /** \rst - The value which marks the end of the iteration. ``it == iterator::sentinel()`` - is equivalent to catching ``StopIteration`` in Python. - - .. code-block:: cpp - - void foo(py::iterator it) { - while (it != py::iterator::sentinel()) { - // use `*it` - ++it; - } - } - \endrst */ - static iterator sentinel() { return {}; } - - friend bool operator==(const iterator &a, const iterator &b) { return a->ptr() == b->ptr(); } - friend bool operator!=(const iterator &a, const iterator &b) { return a->ptr() != b->ptr(); } - -private: - void advance() { - value = reinterpret_steal(PyIter_Next(m_ptr)); - if (PyErr_Occurred()) { throw error_already_set(); } - } - -private: - object value = {}; -}; - -class iterable : public object { -public: - PYBIND11_OBJECT_DEFAULT(iterable, object, detail::PyIterable_Check) -}; - -class bytes; - -class str : public object { -public: - PYBIND11_OBJECT_CVT(str, object, detail::PyUnicode_Check_Permissive, raw_str) - - str(const char *c, size_t n) - : object(PyUnicode_FromStringAndSize(c, (ssize_t) n), stolen_t{}) { - if (!m_ptr) pybind11_fail("Could not allocate string object!"); - } - - // 'explicit' is explicitly omitted from the following constructors to allow implicit conversion to py::str from C++ string-like objects - str(const char *c = "") - : object(PyUnicode_FromString(c), stolen_t{}) { - if (!m_ptr) pybind11_fail("Could not allocate string object!"); - } - - str(const std::string &s) : str(s.data(), s.size()) { } - - explicit str(const bytes &b); - - /** \rst - Return a string representation of the object. This is analogous to - the ``str()`` function in Python. - \endrst */ - explicit str(handle h) : object(raw_str(h.ptr()), stolen_t{}) { } - - operator std::string() const { - object temp = *this; - if (PyUnicode_Check(m_ptr)) { - temp = reinterpret_steal(PyUnicode_AsUTF8String(m_ptr)); - if (!temp) - pybind11_fail("Unable to extract string contents! (encoding issue)"); - } - char *buffer; - ssize_t length; - if (PYBIND11_BYTES_AS_STRING_AND_SIZE(temp.ptr(), &buffer, &length)) - pybind11_fail("Unable to extract string contents! (invalid type)"); - return std::string(buffer, (size_t) length); - } - - template - str format(Args &&...args) const { - return attr("format")(std::forward(args)...); - } - -private: - /// Return string representation -- always returns a new reference, even if already a str - static PyObject *raw_str(PyObject *op) { - PyObject *str_value = PyObject_Str(op); -#if PY_MAJOR_VERSION < 3 - if (!str_value) throw error_already_set(); - PyObject *unicode = PyUnicode_FromEncodedObject(str_value, "utf-8", nullptr); - Py_XDECREF(str_value); str_value = unicode; -#endif - return str_value; - } -}; -/// @} pytypes - -inline namespace literals { -/** \rst - String literal version of `str` - \endrst */ -inline str operator"" _s(const char *s, size_t size) { return {s, size}; } -} - -/// \addtogroup pytypes -/// @{ -class bytes : public object { -public: - PYBIND11_OBJECT(bytes, object, PYBIND11_BYTES_CHECK) - - // Allow implicit conversion: - bytes(const char *c = "") - : object(PYBIND11_BYTES_FROM_STRING(c), stolen_t{}) { - if (!m_ptr) pybind11_fail("Could not allocate bytes object!"); - } - - bytes(const char *c, size_t n) - : object(PYBIND11_BYTES_FROM_STRING_AND_SIZE(c, (ssize_t) n), stolen_t{}) { - if (!m_ptr) pybind11_fail("Could not allocate bytes object!"); - } - - // Allow implicit conversion: - bytes(const std::string &s) : bytes(s.data(), s.size()) { } - - explicit bytes(const pybind11::str &s); - - operator std::string() const { - char *buffer; - ssize_t length; - if (PYBIND11_BYTES_AS_STRING_AND_SIZE(m_ptr, &buffer, &length)) - pybind11_fail("Unable to extract bytes contents!"); - return std::string(buffer, (size_t) length); - } -}; - -inline bytes::bytes(const pybind11::str &s) { - object temp = s; - if (PyUnicode_Check(s.ptr())) { - temp = reinterpret_steal(PyUnicode_AsUTF8String(s.ptr())); - if (!temp) - pybind11_fail("Unable to extract string contents! (encoding issue)"); - } - char *buffer; - ssize_t length; - if (PYBIND11_BYTES_AS_STRING_AND_SIZE(temp.ptr(), &buffer, &length)) - pybind11_fail("Unable to extract string contents! (invalid type)"); - auto obj = reinterpret_steal(PYBIND11_BYTES_FROM_STRING_AND_SIZE(buffer, length)); - if (!obj) - pybind11_fail("Could not allocate bytes object!"); - m_ptr = obj.release().ptr(); -} - -inline str::str(const bytes& b) { - char *buffer; - ssize_t length; - if (PYBIND11_BYTES_AS_STRING_AND_SIZE(b.ptr(), &buffer, &length)) - pybind11_fail("Unable to extract bytes contents!"); - auto obj = reinterpret_steal(PyUnicode_FromStringAndSize(buffer, (ssize_t) length)); - if (!obj) - pybind11_fail("Could not allocate string object!"); - m_ptr = obj.release().ptr(); -} - -class none : public object { -public: - PYBIND11_OBJECT(none, object, detail::PyNone_Check) - none() : object(Py_None, borrowed_t{}) { } -}; - -#if PY_MAJOR_VERSION >= 3 -class ellipsis : public object { -public: - PYBIND11_OBJECT(ellipsis, object, detail::PyEllipsis_Check) - ellipsis() : object(Py_Ellipsis, borrowed_t{}) { } -}; -#endif - -class bool_ : public object { -public: - PYBIND11_OBJECT_CVT(bool_, object, PyBool_Check, raw_bool) - bool_() : object(Py_False, borrowed_t{}) { } - // Allow implicit conversion from and to `bool`: - bool_(bool value) : object(value ? Py_True : Py_False, borrowed_t{}) { } - operator bool() const { return m_ptr && PyLong_AsLong(m_ptr) != 0; } - -private: - /// Return the truth value of an object -- always returns a new reference - static PyObject *raw_bool(PyObject *op) { - const auto value = PyObject_IsTrue(op); - if (value == -1) return nullptr; - return handle(value ? Py_True : Py_False).inc_ref().ptr(); - } -}; - -NAMESPACE_BEGIN(detail) -// Converts a value to the given unsigned type. If an error occurs, you get back (Unsigned) -1; -// otherwise you get back the unsigned long or unsigned long long value cast to (Unsigned). -// (The distinction is critically important when casting a returned -1 error value to some other -// unsigned type: (A)-1 != (B)-1 when A and B are unsigned types of different sizes). -template -Unsigned as_unsigned(PyObject *o) { - if (sizeof(Unsigned) <= sizeof(unsigned long) -#if PY_VERSION_HEX < 0x03000000 - || PyInt_Check(o) -#endif - ) { - unsigned long v = PyLong_AsUnsignedLong(o); - return v == (unsigned long) -1 && PyErr_Occurred() ? (Unsigned) -1 : (Unsigned) v; - } - else { - unsigned long long v = PyLong_AsUnsignedLongLong(o); - return v == (unsigned long long) -1 && PyErr_Occurred() ? (Unsigned) -1 : (Unsigned) v; - } -} -NAMESPACE_END(detail) - -class int_ : public object { -public: - PYBIND11_OBJECT_CVT(int_, object, PYBIND11_LONG_CHECK, PyNumber_Long) - int_() : object(PyLong_FromLong(0), stolen_t{}) { } - // Allow implicit conversion from C++ integral types: - template ::value, int> = 0> - int_(T value) { - if (sizeof(T) <= sizeof(long)) { - if (std::is_signed::value) - m_ptr = PyLong_FromLong((long) value); - else - m_ptr = PyLong_FromUnsignedLong((unsigned long) value); - } else { - if (std::is_signed::value) - m_ptr = PyLong_FromLongLong((long long) value); - else - m_ptr = PyLong_FromUnsignedLongLong((unsigned long long) value); - } - if (!m_ptr) pybind11_fail("Could not allocate int object!"); - } - - template ::value, int> = 0> - operator T() const { - return std::is_unsigned::value - ? detail::as_unsigned(m_ptr) - : sizeof(T) <= sizeof(long) - ? (T) PyLong_AsLong(m_ptr) - : (T) PYBIND11_LONG_AS_LONGLONG(m_ptr); - } -}; - -class float_ : public object { -public: - PYBIND11_OBJECT_CVT(float_, object, PyFloat_Check, PyNumber_Float) - // Allow implicit conversion from float/double: - float_(float value) : object(PyFloat_FromDouble((double) value), stolen_t{}) { - if (!m_ptr) pybind11_fail("Could not allocate float object!"); - } - float_(double value = .0) : object(PyFloat_FromDouble((double) value), stolen_t{}) { - if (!m_ptr) pybind11_fail("Could not allocate float object!"); - } - operator float() const { return (float) PyFloat_AsDouble(m_ptr); } - operator double() const { return (double) PyFloat_AsDouble(m_ptr); } -}; - -class weakref : public object { -public: - PYBIND11_OBJECT_DEFAULT(weakref, object, PyWeakref_Check) - explicit weakref(handle obj, handle callback = {}) - : object(PyWeakref_NewRef(obj.ptr(), callback.ptr()), stolen_t{}) { - if (!m_ptr) pybind11_fail("Could not allocate weak reference!"); - } -}; - -class slice : public object { -public: - PYBIND11_OBJECT_DEFAULT(slice, object, PySlice_Check) - slice(ssize_t start_, ssize_t stop_, ssize_t step_) { - int_ start(start_), stop(stop_), step(step_); - m_ptr = PySlice_New(start.ptr(), stop.ptr(), step.ptr()); - if (!m_ptr) pybind11_fail("Could not allocate slice object!"); - } - bool compute(size_t length, size_t *start, size_t *stop, size_t *step, - size_t *slicelength) const { - return PySlice_GetIndicesEx((PYBIND11_SLICE_OBJECT *) m_ptr, - (ssize_t) length, (ssize_t *) start, - (ssize_t *) stop, (ssize_t *) step, - (ssize_t *) slicelength) == 0; - } - bool compute(ssize_t length, ssize_t *start, ssize_t *stop, ssize_t *step, - ssize_t *slicelength) const { - return PySlice_GetIndicesEx((PYBIND11_SLICE_OBJECT *) m_ptr, - length, start, - stop, step, - slicelength) == 0; - } -}; - -class capsule : public object { -public: - PYBIND11_OBJECT_DEFAULT(capsule, object, PyCapsule_CheckExact) - PYBIND11_DEPRECATED("Use reinterpret_borrow() or reinterpret_steal()") - capsule(PyObject *ptr, bool is_borrowed) : object(is_borrowed ? object(ptr, borrowed_t{}) : object(ptr, stolen_t{})) { } - - explicit capsule(const void *value, const char *name = nullptr, void (*destructor)(PyObject *) = nullptr) - : object(PyCapsule_New(const_cast(value), name, destructor), stolen_t{}) { - if (!m_ptr) - pybind11_fail("Could not allocate capsule object!"); - } - - PYBIND11_DEPRECATED("Please pass a destructor that takes a void pointer as input") - capsule(const void *value, void (*destruct)(PyObject *)) - : object(PyCapsule_New(const_cast(value), nullptr, destruct), stolen_t{}) { - if (!m_ptr) - pybind11_fail("Could not allocate capsule object!"); - } - - capsule(const void *value, void (*destructor)(void *)) { - m_ptr = PyCapsule_New(const_cast(value), nullptr, [](PyObject *o) { - auto destructor = reinterpret_cast(PyCapsule_GetContext(o)); - void *ptr = PyCapsule_GetPointer(o, nullptr); - destructor(ptr); - }); - - if (!m_ptr) - pybind11_fail("Could not allocate capsule object!"); - - if (PyCapsule_SetContext(m_ptr, (void *) destructor) != 0) - pybind11_fail("Could not set capsule context!"); - } - - capsule(void (*destructor)()) { - m_ptr = PyCapsule_New(reinterpret_cast(destructor), nullptr, [](PyObject *o) { - auto destructor = reinterpret_cast(PyCapsule_GetPointer(o, nullptr)); - destructor(); - }); - - if (!m_ptr) - pybind11_fail("Could not allocate capsule object!"); - } - - template operator T *() const { - auto name = this->name(); - T * result = static_cast(PyCapsule_GetPointer(m_ptr, name)); - if (!result) pybind11_fail("Unable to extract capsule contents!"); - return result; - } - - const char *name() const { return PyCapsule_GetName(m_ptr); } -}; - -class tuple : public object { -public: - PYBIND11_OBJECT_CVT(tuple, object, PyTuple_Check, PySequence_Tuple) - explicit tuple(size_t size = 0) : object(PyTuple_New((ssize_t) size), stolen_t{}) { - if (!m_ptr) pybind11_fail("Could not allocate tuple object!"); - } - size_t size() const { return (size_t) PyTuple_Size(m_ptr); } - bool empty() const { return size() == 0; } - detail::tuple_accessor operator[](size_t index) const { return {*this, index}; } - detail::item_accessor operator[](handle h) const { return object::operator[](h); } - detail::tuple_iterator begin() const { return {*this, 0}; } - detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; } -}; - -class dict : public object { -public: - PYBIND11_OBJECT_CVT(dict, object, PyDict_Check, raw_dict) - dict() : object(PyDict_New(), stolen_t{}) { - if (!m_ptr) pybind11_fail("Could not allocate dict object!"); - } - template ...>::value>, - // MSVC workaround: it can't compile an out-of-line definition, so defer the collector - typename collector = detail::deferred_t, Args...>> - explicit dict(Args &&...args) : dict(collector(std::forward(args)...).kwargs()) { } - - size_t size() const { return (size_t) PyDict_Size(m_ptr); } - bool empty() const { return size() == 0; } - detail::dict_iterator begin() const { return {*this, 0}; } - detail::dict_iterator end() const { return {}; } - void clear() const { PyDict_Clear(ptr()); } - template bool contains(T &&key) const { - return PyDict_Contains(m_ptr, detail::object_or_cast(std::forward(key)).ptr()) == 1; - } - -private: - /// Call the `dict` Python type -- always returns a new reference - static PyObject *raw_dict(PyObject *op) { - if (PyDict_Check(op)) - return handle(op).inc_ref().ptr(); - return PyObject_CallFunctionObjArgs((PyObject *) &PyDict_Type, op, nullptr); - } -}; - -class sequence : public object { -public: - PYBIND11_OBJECT_DEFAULT(sequence, object, PySequence_Check) - size_t size() const { return (size_t) PySequence_Size(m_ptr); } - bool empty() const { return size() == 0; } - detail::sequence_accessor operator[](size_t index) const { return {*this, index}; } - detail::item_accessor operator[](handle h) const { return object::operator[](h); } - detail::sequence_iterator begin() const { return {*this, 0}; } - detail::sequence_iterator end() const { return {*this, PySequence_Size(m_ptr)}; } -}; - -class list : public object { -public: - PYBIND11_OBJECT_CVT(list, object, PyList_Check, PySequence_List) - explicit list(size_t size = 0) : object(PyList_New((ssize_t) size), stolen_t{}) { - if (!m_ptr) pybind11_fail("Could not allocate list object!"); - } - size_t size() const { return (size_t) PyList_Size(m_ptr); } - bool empty() const { return size() == 0; } - detail::list_accessor operator[](size_t index) const { return {*this, index}; } - detail::item_accessor operator[](handle h) const { return object::operator[](h); } - detail::list_iterator begin() const { return {*this, 0}; } - detail::list_iterator end() const { return {*this, PyList_GET_SIZE(m_ptr)}; } - template void append(T &&val) const { - PyList_Append(m_ptr, detail::object_or_cast(std::forward(val)).ptr()); - } - template void insert(size_t index, T &&val) const { - PyList_Insert(m_ptr, static_cast(index), - detail::object_or_cast(std::forward(val)).ptr()); - } -}; - -class args : public tuple { PYBIND11_OBJECT_DEFAULT(args, tuple, PyTuple_Check) }; -class kwargs : public dict { PYBIND11_OBJECT_DEFAULT(kwargs, dict, PyDict_Check) }; - -class set : public object { -public: - PYBIND11_OBJECT_CVT(set, object, PySet_Check, PySet_New) - set() : object(PySet_New(nullptr), stolen_t{}) { - if (!m_ptr) pybind11_fail("Could not allocate set object!"); - } - size_t size() const { return (size_t) PySet_Size(m_ptr); } - bool empty() const { return size() == 0; } - template bool add(T &&val) const { - return PySet_Add(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 0; - } - void clear() const { PySet_Clear(m_ptr); } - template bool contains(T &&val) const { - return PySet_Contains(m_ptr, detail::object_or_cast(std::forward(val)).ptr()) == 1; - } -}; - -class function : public object { -public: - PYBIND11_OBJECT_DEFAULT(function, object, PyCallable_Check) - handle cpp_function() const { - handle fun = detail::get_function(m_ptr); - if (fun && PyCFunction_Check(fun.ptr())) - return fun; - return handle(); - } - bool is_cpp_function() const { return (bool) cpp_function(); } -}; - -class staticmethod : public object { -public: - PYBIND11_OBJECT_CVT(staticmethod, object, detail::PyStaticMethod_Check, PyStaticMethod_New) -}; - -class buffer : public object { -public: - PYBIND11_OBJECT_DEFAULT(buffer, object, PyObject_CheckBuffer) - - buffer_info request(bool writable = false) const { - int flags = PyBUF_STRIDES | PyBUF_FORMAT; - if (writable) flags |= PyBUF_WRITABLE; - Py_buffer *view = new Py_buffer(); - if (PyObject_GetBuffer(m_ptr, view, flags) != 0) { - delete view; - throw error_already_set(); - } - return buffer_info(view); - } -}; - -class memoryview : public object { -public: - explicit memoryview(const buffer_info& info) { - static Py_buffer buf { }; - // Py_buffer uses signed sizes, strides and shape!.. - static std::vector py_strides { }; - static std::vector py_shape { }; - buf.buf = info.ptr; - buf.itemsize = info.itemsize; - buf.format = const_cast(info.format.c_str()); - buf.ndim = (int) info.ndim; - buf.len = info.size; - py_strides.clear(); - py_shape.clear(); - for (size_t i = 0; i < (size_t) info.ndim; ++i) { - py_strides.push_back(info.strides[i]); - py_shape.push_back(info.shape[i]); - } - buf.strides = py_strides.data(); - buf.shape = py_shape.data(); - buf.suboffsets = nullptr; - buf.readonly = false; - buf.internal = nullptr; - - m_ptr = PyMemoryView_FromBuffer(&buf); - if (!m_ptr) - pybind11_fail("Unable to create memoryview from buffer descriptor"); - } - - PYBIND11_OBJECT_CVT(memoryview, object, PyMemoryView_Check, PyMemoryView_FromObject) -}; -/// @} pytypes - -/// \addtogroup python_builtins -/// @{ -inline size_t len(handle h) { - ssize_t result = PyObject_Length(h.ptr()); - if (result < 0) - pybind11_fail("Unable to compute length of object"); - return (size_t) result; -} - -inline size_t len_hint(handle h) { -#if PY_VERSION_HEX >= 0x03040000 - ssize_t result = PyObject_LengthHint(h.ptr(), 0); -#else - ssize_t result = PyObject_Length(h.ptr()); -#endif - if (result < 0) { - // Sometimes a length can't be determined at all (eg generators) - // In which case simply return 0 - PyErr_Clear(); - return 0; - } - return (size_t) result; -} - -inline str repr(handle h) { - PyObject *str_value = PyObject_Repr(h.ptr()); - if (!str_value) throw error_already_set(); -#if PY_MAJOR_VERSION < 3 - PyObject *unicode = PyUnicode_FromEncodedObject(str_value, "utf-8", nullptr); - Py_XDECREF(str_value); str_value = unicode; - if (!str_value) throw error_already_set(); -#endif - return reinterpret_steal(str_value); -} - -inline iterator iter(handle obj) { - PyObject *result = PyObject_GetIter(obj.ptr()); - if (!result) { throw error_already_set(); } - return reinterpret_steal(result); -} -/// @} python_builtins - -NAMESPACE_BEGIN(detail) -template iterator object_api::begin() const { return iter(derived()); } -template iterator object_api::end() const { return iterator::sentinel(); } -template item_accessor object_api::operator[](handle key) const { - return {derived(), reinterpret_borrow(key)}; -} -template item_accessor object_api::operator[](const char *key) const { - return {derived(), pybind11::str(key)}; -} -template obj_attr_accessor object_api::attr(handle key) const { - return {derived(), reinterpret_borrow(key)}; -} -template str_attr_accessor object_api::attr(const char *key) const { - return {derived(), key}; -} -template args_proxy object_api::operator*() const { - return args_proxy(derived().ptr()); -} -template template bool object_api::contains(T &&item) const { - return attr("__contains__")(std::forward(item)).template cast(); -} - -template -pybind11::str object_api::str() const { return pybind11::str(derived()); } - -template -str_attr_accessor object_api::doc() const { return attr("__doc__"); } - -template -handle object_api::get_type() const { return (PyObject *) Py_TYPE(derived().ptr()); } - -template -bool object_api::rich_compare(object_api const &other, int value) const { - int rv = PyObject_RichCompareBool(derived().ptr(), other.derived().ptr(), value); - if (rv == -1) - throw error_already_set(); - return rv == 1; -} - -#define PYBIND11_MATH_OPERATOR_UNARY(op, fn) \ - template object object_api::op() const { \ - object result = reinterpret_steal(fn(derived().ptr())); \ - if (!result.ptr()) \ - throw error_already_set(); \ - return result; \ - } - -#define PYBIND11_MATH_OPERATOR_BINARY(op, fn) \ - template \ - object object_api::op(object_api const &other) const { \ - object result = reinterpret_steal( \ - fn(derived().ptr(), other.derived().ptr())); \ - if (!result.ptr()) \ - throw error_already_set(); \ - return result; \ - } - -PYBIND11_MATH_OPERATOR_UNARY (operator~, PyNumber_Invert) -PYBIND11_MATH_OPERATOR_UNARY (operator-, PyNumber_Negative) -PYBIND11_MATH_OPERATOR_BINARY(operator+, PyNumber_Add) -PYBIND11_MATH_OPERATOR_BINARY(operator+=, PyNumber_InPlaceAdd) -PYBIND11_MATH_OPERATOR_BINARY(operator-, PyNumber_Subtract) -PYBIND11_MATH_OPERATOR_BINARY(operator-=, PyNumber_InPlaceSubtract) -PYBIND11_MATH_OPERATOR_BINARY(operator*, PyNumber_Multiply) -PYBIND11_MATH_OPERATOR_BINARY(operator*=, PyNumber_InPlaceMultiply) -PYBIND11_MATH_OPERATOR_BINARY(operator/, PyNumber_TrueDivide) -PYBIND11_MATH_OPERATOR_BINARY(operator/=, PyNumber_InPlaceTrueDivide) -PYBIND11_MATH_OPERATOR_BINARY(operator|, PyNumber_Or) -PYBIND11_MATH_OPERATOR_BINARY(operator|=, PyNumber_InPlaceOr) -PYBIND11_MATH_OPERATOR_BINARY(operator&, PyNumber_And) -PYBIND11_MATH_OPERATOR_BINARY(operator&=, PyNumber_InPlaceAnd) -PYBIND11_MATH_OPERATOR_BINARY(operator^, PyNumber_Xor) -PYBIND11_MATH_OPERATOR_BINARY(operator^=, PyNumber_InPlaceXor) -PYBIND11_MATH_OPERATOR_BINARY(operator<<, PyNumber_Lshift) -PYBIND11_MATH_OPERATOR_BINARY(operator<<=, PyNumber_InPlaceLshift) -PYBIND11_MATH_OPERATOR_BINARY(operator>>, PyNumber_Rshift) -PYBIND11_MATH_OPERATOR_BINARY(operator>>=, PyNumber_InPlaceRshift) - -#undef PYBIND11_MATH_OPERATOR_UNARY -#undef PYBIND11_MATH_OPERATOR_BINARY - -NAMESPACE_END(detail) -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/include/pybind11/stl.h b/pybind11/include/pybind11/stl.h deleted file mode 100644 index 32f8d29..0000000 --- a/pybind11/include/pybind11/stl.h +++ /dev/null @@ -1,386 +0,0 @@ -/* - pybind11/stl.h: Transparent conversion for STL data types - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "pybind11.h" -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable: 4127) // warning C4127: Conditional expression is constant -#endif - -#ifdef __has_include -// std::optional (but including it in c++14 mode isn't allowed) -# if defined(PYBIND11_CPP17) && __has_include() -# include -# define PYBIND11_HAS_OPTIONAL 1 -# endif -// std::experimental::optional (but not allowed in c++11 mode) -# if defined(PYBIND11_CPP14) && (__has_include() && \ - !__has_include()) -# include -# define PYBIND11_HAS_EXP_OPTIONAL 1 -# endif -// std::variant -# if defined(PYBIND11_CPP17) && __has_include() -# include -# define PYBIND11_HAS_VARIANT 1 -# endif -#elif defined(_MSC_VER) && defined(PYBIND11_CPP17) -# include -# include -# define PYBIND11_HAS_OPTIONAL 1 -# define PYBIND11_HAS_VARIANT 1 -#endif - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -/// Extracts an const lvalue reference or rvalue reference for U based on the type of T (e.g. for -/// forwarding a container element). Typically used indirect via forwarded_type(), below. -template -using forwarded_type = conditional_t< - std::is_lvalue_reference::value, remove_reference_t &, remove_reference_t &&>; - -/// Forwards a value U as rvalue or lvalue according to whether T is rvalue or lvalue; typically -/// used for forwarding a container's elements. -template -forwarded_type forward_like(U &&u) { - return std::forward>(std::forward(u)); -} - -template struct set_caster { - using type = Type; - using key_conv = make_caster; - - bool load(handle src, bool convert) { - if (!isinstance(src)) - return false; - auto s = reinterpret_borrow(src); - value.clear(); - for (auto entry : s) { - key_conv conv; - if (!conv.load(entry, convert)) - return false; - value.insert(cast_op(std::move(conv))); - } - return true; - } - - template - static handle cast(T &&src, return_value_policy policy, handle parent) { - if (!std::is_lvalue_reference::value) - policy = return_value_policy_override::policy(policy); - pybind11::set s; - for (auto &&value : src) { - auto value_ = reinterpret_steal(key_conv::cast(forward_like(value), policy, parent)); - if (!value_ || !s.add(value_)) - return handle(); - } - return s.release(); - } - - PYBIND11_TYPE_CASTER(type, _("Set[") + key_conv::name + _("]")); -}; - -template struct map_caster { - using key_conv = make_caster; - using value_conv = make_caster; - - bool load(handle src, bool convert) { - if (!isinstance(src)) - return false; - auto d = reinterpret_borrow(src); - value.clear(); - for (auto it : d) { - key_conv kconv; - value_conv vconv; - if (!kconv.load(it.first.ptr(), convert) || - !vconv.load(it.second.ptr(), convert)) - return false; - value.emplace(cast_op(std::move(kconv)), cast_op(std::move(vconv))); - } - return true; - } - - template - static handle cast(T &&src, return_value_policy policy, handle parent) { - dict d; - return_value_policy policy_key = policy; - return_value_policy policy_value = policy; - if (!std::is_lvalue_reference::value) { - policy_key = return_value_policy_override::policy(policy_key); - policy_value = return_value_policy_override::policy(policy_value); - } - for (auto &&kv : src) { - auto key = reinterpret_steal(key_conv::cast(forward_like(kv.first), policy_key, parent)); - auto value = reinterpret_steal(value_conv::cast(forward_like(kv.second), policy_value, parent)); - if (!key || !value) - return handle(); - d[key] = value; - } - return d.release(); - } - - PYBIND11_TYPE_CASTER(Type, _("Dict[") + key_conv::name + _(", ") + value_conv::name + _("]")); -}; - -template struct list_caster { - using value_conv = make_caster; - - bool load(handle src, bool convert) { - if (!isinstance(src) || isinstance(src)) - return false; - auto s = reinterpret_borrow(src); - value.clear(); - reserve_maybe(s, &value); - for (auto it : s) { - value_conv conv; - if (!conv.load(it, convert)) - return false; - value.push_back(cast_op(std::move(conv))); - } - return true; - } - -private: - template ().reserve(0)), void>::value, int> = 0> - void reserve_maybe(sequence s, Type *) { value.reserve(s.size()); } - void reserve_maybe(sequence, void *) { } - -public: - template - static handle cast(T &&src, return_value_policy policy, handle parent) { - if (!std::is_lvalue_reference::value) - policy = return_value_policy_override::policy(policy); - list l(src.size()); - size_t index = 0; - for (auto &&value : src) { - auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); - if (!value_) - return handle(); - PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference - } - return l.release(); - } - - PYBIND11_TYPE_CASTER(Type, _("List[") + value_conv::name + _("]")); -}; - -template struct type_caster> - : list_caster, Type> { }; - -template struct type_caster> - : list_caster, Type> { }; - -template struct type_caster> - : list_caster, Type> { }; - -template struct array_caster { - using value_conv = make_caster; - -private: - template - bool require_size(enable_if_t size) { - if (value.size() != size) - value.resize(size); - return true; - } - template - bool require_size(enable_if_t size) { - return size == Size; - } - -public: - bool load(handle src, bool convert) { - if (!isinstance(src)) - return false; - auto l = reinterpret_borrow(src); - if (!require_size(l.size())) - return false; - size_t ctr = 0; - for (auto it : l) { - value_conv conv; - if (!conv.load(it, convert)) - return false; - value[ctr++] = cast_op(std::move(conv)); - } - return true; - } - - template - static handle cast(T &&src, return_value_policy policy, handle parent) { - list l(src.size()); - size_t index = 0; - for (auto &&value : src) { - auto value_ = reinterpret_steal(value_conv::cast(forward_like(value), policy, parent)); - if (!value_) - return handle(); - PyList_SET_ITEM(l.ptr(), (ssize_t) index++, value_.release().ptr()); // steals a reference - } - return l.release(); - } - - PYBIND11_TYPE_CASTER(ArrayType, _("List[") + value_conv::name + _(_(""), _("[") + _() + _("]")) + _("]")); -}; - -template struct type_caster> - : array_caster, Type, false, Size> { }; - -template struct type_caster> - : array_caster, Type, true> { }; - -template struct type_caster> - : set_caster, Key> { }; - -template struct type_caster> - : set_caster, Key> { }; - -template struct type_caster> - : map_caster, Key, Value> { }; - -template struct type_caster> - : map_caster, Key, Value> { }; - -// This type caster is intended to be used for std::optional and std::experimental::optional -template struct optional_caster { - using value_conv = make_caster; - - template - static handle cast(T_ &&src, return_value_policy policy, handle parent) { - if (!src) - return none().inc_ref(); - policy = return_value_policy_override::policy(policy); - return value_conv::cast(*std::forward(src), policy, parent); - } - - bool load(handle src, bool convert) { - if (!src) { - return false; - } else if (src.is_none()) { - return true; // default-constructed value is already empty - } - value_conv inner_caster; - if (!inner_caster.load(src, convert)) - return false; - - value.emplace(cast_op(std::move(inner_caster))); - return true; - } - - PYBIND11_TYPE_CASTER(T, _("Optional[") + value_conv::name + _("]")); -}; - -#if PYBIND11_HAS_OPTIONAL -template struct type_caster> - : public optional_caster> {}; - -template<> struct type_caster - : public void_caster {}; -#endif - -#if PYBIND11_HAS_EXP_OPTIONAL -template struct type_caster> - : public optional_caster> {}; - -template<> struct type_caster - : public void_caster {}; -#endif - -/// Visit a variant and cast any found type to Python -struct variant_caster_visitor { - return_value_policy policy; - handle parent; - - using result_type = handle; // required by boost::variant in C++11 - - template - result_type operator()(T &&src) const { - return make_caster::cast(std::forward(src), policy, parent); - } -}; - -/// Helper class which abstracts away variant's `visit` function. `std::variant` and similar -/// `namespace::variant` types which provide a `namespace::visit()` function are handled here -/// automatically using argument-dependent lookup. Users can provide specializations for other -/// variant-like classes, e.g. `boost::variant` and `boost::apply_visitor`. -template class Variant> -struct visit_helper { - template - static auto call(Args &&...args) -> decltype(visit(std::forward(args)...)) { - return visit(std::forward(args)...); - } -}; - -/// Generic variant caster -template struct variant_caster; - -template class V, typename... Ts> -struct variant_caster> { - static_assert(sizeof...(Ts) > 0, "Variant must consist of at least one alternative."); - - template - bool load_alternative(handle src, bool convert, type_list) { - auto caster = make_caster(); - if (caster.load(src, convert)) { - value = cast_op(caster); - return true; - } - return load_alternative(src, convert, type_list{}); - } - - bool load_alternative(handle, bool, type_list<>) { return false; } - - bool load(handle src, bool convert) { - // Do a first pass without conversions to improve constructor resolution. - // E.g. `py::int_(1).cast>()` needs to fill the `int` - // slot of the variant. Without two-pass loading `double` would be filled - // because it appears first and a conversion is possible. - if (convert && load_alternative(src, false, type_list{})) - return true; - return load_alternative(src, convert, type_list{}); - } - - template - static handle cast(Variant &&src, return_value_policy policy, handle parent) { - return visit_helper::call(variant_caster_visitor{policy, parent}, - std::forward(src)); - } - - using Type = V; - PYBIND11_TYPE_CASTER(Type, _("Union[") + detail::concat(make_caster::name...) + _("]")); -}; - -#if PYBIND11_HAS_VARIANT -template -struct type_caster> : variant_caster> { }; -#endif - -NAMESPACE_END(detail) - -inline std::ostream &operator<<(std::ostream &os, const handle &obj) { - os << (std::string) str(obj); - return os; -} - -NAMESPACE_END(PYBIND11_NAMESPACE) - -#if defined(_MSC_VER) -#pragma warning(pop) -#endif diff --git a/pybind11/include/pybind11/stl_bind.h b/pybind11/include/pybind11/stl_bind.h deleted file mode 100644 index d3adaed..0000000 --- a/pybind11/include/pybind11/stl_bind.h +++ /dev/null @@ -1,649 +0,0 @@ -/* - pybind11/std_bind.h: Binding generators for STL data types - - Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#pragma once - -#include "detail/common.h" -#include "operators.h" - -#include -#include - -NAMESPACE_BEGIN(PYBIND11_NAMESPACE) -NAMESPACE_BEGIN(detail) - -/* SFINAE helper class used by 'is_comparable */ -template struct container_traits { - template static std::true_type test_comparable(decltype(std::declval() == std::declval())*); - template static std::false_type test_comparable(...); - template static std::true_type test_value(typename T2::value_type *); - template static std::false_type test_value(...); - template static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *); - template static std::false_type test_pair(...); - - static constexpr const bool is_comparable = std::is_same(nullptr))>::value; - static constexpr const bool is_pair = std::is_same(nullptr, nullptr))>::value; - static constexpr const bool is_vector = std::is_same(nullptr))>::value; - static constexpr const bool is_element = !is_pair && !is_vector; -}; - -/* Default: is_comparable -> std::false_type */ -template -struct is_comparable : std::false_type { }; - -/* For non-map data structures, check whether operator== can be instantiated */ -template -struct is_comparable< - T, enable_if_t::is_element && - container_traits::is_comparable>> - : std::true_type { }; - -/* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */ -template -struct is_comparable::is_vector>> { - static constexpr const bool value = - is_comparable::value; -}; - -/* For pairs, recursively check the two data types */ -template -struct is_comparable::is_pair>> { - static constexpr const bool value = - is_comparable::value && - is_comparable::value; -}; - -/* Fallback functions */ -template void vector_if_copy_constructible(const Args &...) { } -template void vector_if_equal_operator(const Args &...) { } -template void vector_if_insertion_operator(const Args &...) { } -template void vector_modifiers(const Args &...) { } - -template -void vector_if_copy_constructible(enable_if_t::value, Class_> &cl) { - cl.def(init(), "Copy constructor"); -} - -template -void vector_if_equal_operator(enable_if_t::value, Class_> &cl) { - using T = typename Vector::value_type; - - cl.def(self == self); - cl.def(self != self); - - cl.def("count", - [](const Vector &v, const T &x) { - return std::count(v.begin(), v.end(), x); - }, - arg("x"), - "Return the number of times ``x`` appears in the list" - ); - - cl.def("remove", [](Vector &v, const T &x) { - auto p = std::find(v.begin(), v.end(), x); - if (p != v.end()) - v.erase(p); - else - throw value_error(); - }, - arg("x"), - "Remove the first item from the list whose value is x. " - "It is an error if there is no such item." - ); - - cl.def("__contains__", - [](const Vector &v, const T &x) { - return std::find(v.begin(), v.end(), x) != v.end(); - }, - arg("x"), - "Return true the container contains ``x``" - ); -} - -// Vector modifiers -- requires a copyable vector_type: -// (Technically, some of these (pop and __delitem__) don't actually require copyability, but it seems -// silly to allow deletion but not insertion, so include them here too.) -template -void vector_modifiers(enable_if_t::value, Class_> &cl) { - using T = typename Vector::value_type; - using SizeType = typename Vector::size_type; - using DiffType = typename Vector::difference_type; - - auto wrap_i = [](DiffType i, SizeType n) { - if (i < 0) - i += n; - if (i < 0 || (SizeType)i >= n) - throw index_error(); - return i; - }; - - cl.def("append", - [](Vector &v, const T &value) { v.push_back(value); }, - arg("x"), - "Add an item to the end of the list"); - - cl.def(init([](iterable it) { - auto v = std::unique_ptr(new Vector()); - v->reserve(len_hint(it)); - for (handle h : it) - v->push_back(h.cast()); - return v.release(); - })); - - cl.def("extend", - [](Vector &v, const Vector &src) { - v.insert(v.end(), src.begin(), src.end()); - }, - arg("L"), - "Extend the list by appending all the items in the given list" - ); - - cl.def("extend", - [](Vector &v, iterable it) { - const size_t old_size = v.size(); - v.reserve(old_size + len_hint(it)); - try { - for (handle h : it) { - v.push_back(h.cast()); - } - } catch (const cast_error &) { - v.erase(v.begin() + static_cast(old_size), v.end()); - try { - v.shrink_to_fit(); - } catch (const std::exception &) { - // Do nothing - } - throw; - } - }, - arg("L"), - "Extend the list by appending all the items in the given list" - ); - - cl.def("insert", - [](Vector &v, DiffType i, const T &x) { - // Can't use wrap_i; i == v.size() is OK - if (i < 0) - i += v.size(); - if (i < 0 || (SizeType)i > v.size()) - throw index_error(); - v.insert(v.begin() + i, x); - }, - arg("i") , arg("x"), - "Insert an item at a given position." - ); - - cl.def("pop", - [](Vector &v) { - if (v.empty()) - throw index_error(); - T t = v.back(); - v.pop_back(); - return t; - }, - "Remove and return the last item" - ); - - cl.def("pop", - [wrap_i](Vector &v, DiffType i) { - i = wrap_i(i, v.size()); - T t = v[(SizeType) i]; - v.erase(v.begin() + i); - return t; - }, - arg("i"), - "Remove and return the item at index ``i``" - ); - - cl.def("__setitem__", - [wrap_i](Vector &v, DiffType i, const T &t) { - i = wrap_i(i, v.size()); - v[(SizeType)i] = t; - } - ); - - /// Slicing protocol - cl.def("__getitem__", - [](const Vector &v, slice slice) -> Vector * { - size_t start, stop, step, slicelength; - - if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) - throw error_already_set(); - - Vector *seq = new Vector(); - seq->reserve((size_t) slicelength); - - for (size_t i=0; ipush_back(v[start]); - start += step; - } - return seq; - }, - arg("s"), - "Retrieve list elements using a slice object" - ); - - cl.def("__setitem__", - [](Vector &v, slice slice, const Vector &value) { - size_t start, stop, step, slicelength; - if (!slice.compute(v.size(), &start, &stop, &step, &slicelength)) - throw error_already_set(); - - if (slicelength != value.size()) - throw std::runtime_error("Left and right hand size of slice assignment have different sizes!"); - - for (size_t i=0; i), -// we have to access by copying; otherwise we return by reference. -template using vector_needs_copy = negation< - std::is_same()[typename Vector::size_type()]), typename Vector::value_type &>>; - -// The usual case: access and iterate by reference -template -void vector_accessor(enable_if_t::value, Class_> &cl) { - using T = typename Vector::value_type; - using SizeType = typename Vector::size_type; - using DiffType = typename Vector::difference_type; - using ItType = typename Vector::iterator; - - auto wrap_i = [](DiffType i, SizeType n) { - if (i < 0) - i += n; - if (i < 0 || (SizeType)i >= n) - throw index_error(); - return i; - }; - - cl.def("__getitem__", - [wrap_i](Vector &v, DiffType i) -> T & { - i = wrap_i(i, v.size()); - return v[(SizeType)i]; - }, - return_value_policy::reference_internal // ref + keepalive - ); - - cl.def("__iter__", - [](Vector &v) { - return make_iterator< - return_value_policy::reference_internal, ItType, ItType, T&>( - v.begin(), v.end()); - }, - keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ - ); -} - -// The case for special objects, like std::vector, that have to be returned-by-copy: -template -void vector_accessor(enable_if_t::value, Class_> &cl) { - using T = typename Vector::value_type; - using SizeType = typename Vector::size_type; - using DiffType = typename Vector::difference_type; - using ItType = typename Vector::iterator; - cl.def("__getitem__", - [](const Vector &v, DiffType i) -> T { - if (i < 0 && (i += v.size()) < 0) - throw index_error(); - if ((SizeType)i >= v.size()) - throw index_error(); - return v[(SizeType)i]; - } - ); - - cl.def("__iter__", - [](Vector &v) { - return make_iterator< - return_value_policy::copy, ItType, ItType, T>( - v.begin(), v.end()); - }, - keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ - ); -} - -template auto vector_if_insertion_operator(Class_ &cl, std::string const &name) - -> decltype(std::declval() << std::declval(), void()) { - using size_type = typename Vector::size_type; - - cl.def("__repr__", - [name](Vector &v) { - std::ostringstream s; - s << name << '['; - for (size_type i=0; i < v.size(); ++i) { - s << v[i]; - if (i != v.size() - 1) - s << ", "; - } - s << ']'; - return s.str(); - }, - "Return the canonical string representation of this list." - ); -} - -// Provide the buffer interface for vectors if we have data() and we have a format for it -// GCC seems to have "void std::vector::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer -template -struct vector_has_data_and_format : std::false_type {}; -template -struct vector_has_data_and_format::format(), std::declval().data()), typename Vector::value_type*>::value>> : std::true_type {}; - -// Add the buffer interface to a vector -template -enable_if_t...>::value> -vector_buffer(Class_& cl) { - using T = typename Vector::value_type; - - static_assert(vector_has_data_and_format::value, "There is not an appropriate format descriptor for this vector"); - - // numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here - format_descriptor::format(); - - cl.def_buffer([](Vector& v) -> buffer_info { - return buffer_info(v.data(), static_cast(sizeof(T)), format_descriptor::format(), 1, {v.size()}, {sizeof(T)}); - }); - - cl.def(init([](buffer buf) { - auto info = buf.request(); - if (info.ndim != 1 || info.strides[0] % static_cast(sizeof(T))) - throw type_error("Only valid 1D buffers can be copied to a vector"); - if (!detail::compare_buffer_info::compare(info) || (ssize_t) sizeof(T) != info.itemsize) - throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor::format() + ")"); - - auto vec = std::unique_ptr(new Vector()); - vec->reserve((size_t) info.shape[0]); - T *p = static_cast(info.ptr); - ssize_t step = info.strides[0] / static_cast(sizeof(T)); - T *end = p + info.shape[0] * step; - for (; p != end; p += step) - vec->push_back(*p); - return vec.release(); - })); - - return; -} - -template -enable_if_t...>::value> vector_buffer(Class_&) {} - -NAMESPACE_END(detail) - -// -// std::vector -// -template , typename... Args> -class_ bind_vector(handle scope, std::string const &name, Args&&... args) { - using Class_ = class_; - - // If the value_type is unregistered (e.g. a converting type) or is itself registered - // module-local then make the vector binding module-local as well: - using vtype = typename Vector::value_type; - auto vtype_info = detail::get_type_info(typeid(vtype)); - bool local = !vtype_info || vtype_info->module_local; - - Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); - - // Declare the buffer interface if a buffer_protocol() is passed in - detail::vector_buffer(cl); - - cl.def(init<>()); - - // Register copy constructor (if possible) - detail::vector_if_copy_constructible(cl); - - // Register comparison-related operators and functions (if possible) - detail::vector_if_equal_operator(cl); - - // Register stream insertion operator (if possible) - detail::vector_if_insertion_operator(cl, name); - - // Modifiers require copyable vector value type - detail::vector_modifiers(cl); - - // Accessor and iterator; return by value if copyable, otherwise we return by ref + keep-alive - detail::vector_accessor(cl); - - cl.def("__bool__", - [](const Vector &v) -> bool { - return !v.empty(); - }, - "Check whether the list is nonempty" - ); - - cl.def("__len__", &Vector::size); - - - - -#if 0 - // C++ style functions deprecated, leaving it here as an example - cl.def(init()); - - cl.def("resize", - (void (Vector::*) (size_type count)) & Vector::resize, - "changes the number of elements stored"); - - cl.def("erase", - [](Vector &v, SizeType i) { - if (i >= v.size()) - throw index_error(); - v.erase(v.begin() + i); - }, "erases element at index ``i``"); - - cl.def("empty", &Vector::empty, "checks whether the container is empty"); - cl.def("size", &Vector::size, "returns the number of elements"); - cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end"); - cl.def("pop_back", &Vector::pop_back, "removes the last element"); - - cl.def("max_size", &Vector::max_size, "returns the maximum possible number of elements"); - cl.def("reserve", &Vector::reserve, "reserves storage"); - cl.def("capacity", &Vector::capacity, "returns the number of elements that can be held in currently allocated storage"); - cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory"); - - cl.def("clear", &Vector::clear, "clears the contents"); - cl.def("swap", &Vector::swap, "swaps the contents"); - - cl.def("front", [](Vector &v) { - if (v.size()) return v.front(); - else throw index_error(); - }, "access the first element"); - - cl.def("back", [](Vector &v) { - if (v.size()) return v.back(); - else throw index_error(); - }, "access the last element "); - -#endif - - return cl; -} - - - -// -// std::map, std::unordered_map -// - -NAMESPACE_BEGIN(detail) - -/* Fallback functions */ -template void map_if_insertion_operator(const Args &...) { } -template void map_assignment(const Args &...) { } - -// Map assignment when copy-assignable: just copy the value -template -void map_assignment(enable_if_t::value, Class_> &cl) { - using KeyType = typename Map::key_type; - using MappedType = typename Map::mapped_type; - - cl.def("__setitem__", - [](Map &m, const KeyType &k, const MappedType &v) { - auto it = m.find(k); - if (it != m.end()) it->second = v; - else m.emplace(k, v); - } - ); -} - -// Not copy-assignable, but still copy-constructible: we can update the value by erasing and reinserting -template -void map_assignment(enable_if_t< - !std::is_copy_assignable::value && - is_copy_constructible::value, - Class_> &cl) { - using KeyType = typename Map::key_type; - using MappedType = typename Map::mapped_type; - - cl.def("__setitem__", - [](Map &m, const KeyType &k, const MappedType &v) { - // We can't use m[k] = v; because value type might not be default constructable - auto r = m.emplace(k, v); - if (!r.second) { - // value type is not copy assignable so the only way to insert it is to erase it first... - m.erase(r.first); - m.emplace(k, v); - } - } - ); -} - - -template auto map_if_insertion_operator(Class_ &cl, std::string const &name) --> decltype(std::declval() << std::declval() << std::declval(), void()) { - - cl.def("__repr__", - [name](Map &m) { - std::ostringstream s; - s << name << '{'; - bool f = false; - for (auto const &kv : m) { - if (f) - s << ", "; - s << kv.first << ": " << kv.second; - f = true; - } - s << '}'; - return s.str(); - }, - "Return the canonical string representation of this map." - ); -} - - -NAMESPACE_END(detail) - -template , typename... Args> -class_ bind_map(handle scope, const std::string &name, Args&&... args) { - using KeyType = typename Map::key_type; - using MappedType = typename Map::mapped_type; - using Class_ = class_; - - // If either type is a non-module-local bound type then make the map binding non-local as well; - // otherwise (e.g. both types are either module-local or converting) the map will be - // module-local. - auto tinfo = detail::get_type_info(typeid(MappedType)); - bool local = !tinfo || tinfo->module_local; - if (local) { - tinfo = detail::get_type_info(typeid(KeyType)); - local = !tinfo || tinfo->module_local; - } - - Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward(args)...); - - cl.def(init<>()); - - // Register stream insertion operator (if possible) - detail::map_if_insertion_operator(cl, name); - - cl.def("__bool__", - [](const Map &m) -> bool { return !m.empty(); }, - "Check whether the map is nonempty" - ); - - cl.def("__iter__", - [](Map &m) { return make_key_iterator(m.begin(), m.end()); }, - keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ - ); - - cl.def("items", - [](Map &m) { return make_iterator(m.begin(), m.end()); }, - keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */ - ); - - cl.def("__getitem__", - [](Map &m, const KeyType &k) -> MappedType & { - auto it = m.find(k); - if (it == m.end()) - throw key_error(); - return it->second; - }, - return_value_policy::reference_internal // ref + keepalive - ); - - cl.def("__contains__", - [](Map &m, const KeyType &k) -> bool { - auto it = m.find(k); - if (it == m.end()) - return false; - return true; - } - ); - - // Assignment provided only if the type is copyable - detail::map_assignment(cl); - - cl.def("__delitem__", - [](Map &m, const KeyType &k) { - auto it = m.find(k); - if (it == m.end()) - throw key_error(); - m.erase(it); - } - ); - - cl.def("__len__", &Map::size); - - return cl; -} - -NAMESPACE_END(PYBIND11_NAMESPACE) diff --git a/pybind11/pybind11/__init__.py b/pybind11/pybind11/__init__.py deleted file mode 100644 index c625e8c..0000000 --- a/pybind11/pybind11/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -from ._version import version_info, __version__ # noqa: F401 imported but unused - - -def get_include(user=False): - from distutils.dist import Distribution - import os - import sys - - # Are we running in a virtual environment? - virtualenv = hasattr(sys, 'real_prefix') or \ - sys.prefix != getattr(sys, "base_prefix", sys.prefix) - - # Are we running in a conda environment? - conda = os.path.exists(os.path.join(sys.prefix, 'conda-meta')) - - if virtualenv: - return os.path.join(sys.prefix, 'include', 'site', - 'python' + sys.version[:3]) - elif conda: - if os.name == 'nt': - return os.path.join(sys.prefix, 'Library', 'include') - else: - return os.path.join(sys.prefix, 'include') - else: - dist = Distribution({'name': 'pybind11'}) - dist.parse_config_files() - - dist_cobj = dist.get_command_obj('install', create=True) - - # Search for packages in user's home directory? - if user: - dist_cobj.user = user - dist_cobj.prefix = "" - dist_cobj.finalize_options() - - return os.path.dirname(dist_cobj.install_headers) diff --git a/pybind11/pybind11/__main__.py b/pybind11/pybind11/__main__.py deleted file mode 100644 index 9ef8378..0000000 --- a/pybind11/pybind11/__main__.py +++ /dev/null @@ -1,37 +0,0 @@ -from __future__ import print_function - -import argparse -import sys -import sysconfig - -from . import get_include - - -def print_includes(): - dirs = [sysconfig.get_path('include'), - sysconfig.get_path('platinclude'), - get_include(), - get_include(True)] - - # Make unique but preserve order - unique_dirs = [] - for d in dirs: - if d not in unique_dirs: - unique_dirs.append(d) - - print(' '.join('-I' + d for d in unique_dirs)) - - -def main(): - parser = argparse.ArgumentParser(prog='python -m pybind11') - parser.add_argument('--includes', action='store_true', - help='Include flags for both pybind11 and Python headers.') - args = parser.parse_args() - if not sys.argv[1:]: - parser.print_help() - if args.includes: - print_includes() - - -if __name__ == '__main__': - main() diff --git a/pybind11/pybind11/_version.py b/pybind11/pybind11/_version.py deleted file mode 100644 index 2709cc5..0000000 --- a/pybind11/pybind11/_version.py +++ /dev/null @@ -1,2 +0,0 @@ -version_info = (2, 4, 3) -__version__ = '.'.join(map(str, version_info)) diff --git a/pybind11/setup.py b/pybind11/setup.py deleted file mode 100644 index f677f2a..0000000 --- a/pybind11/setup.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python - -# Setup script for PyPI; use CMakeFile.txt to build extension modules - -from setuptools import setup -from distutils.command.install_headers import install_headers -from pybind11 import __version__ -import os - -# Prevent installation of pybind11 headers by setting -# PYBIND11_USE_CMAKE. -if os.environ.get('PYBIND11_USE_CMAKE'): - headers = [] -else: - headers = [ - 'include/pybind11/detail/class.h', - 'include/pybind11/detail/common.h', - 'include/pybind11/detail/descr.h', - 'include/pybind11/detail/init.h', - 'include/pybind11/detail/internals.h', - 'include/pybind11/detail/typeid.h', - 'include/pybind11/attr.h', - 'include/pybind11/buffer_info.h', - 'include/pybind11/cast.h', - 'include/pybind11/chrono.h', - 'include/pybind11/common.h', - 'include/pybind11/complex.h', - 'include/pybind11/eigen.h', - 'include/pybind11/embed.h', - 'include/pybind11/eval.h', - 'include/pybind11/functional.h', - 'include/pybind11/iostream.h', - 'include/pybind11/numpy.h', - 'include/pybind11/operators.h', - 'include/pybind11/options.h', - 'include/pybind11/pybind11.h', - 'include/pybind11/pytypes.h', - 'include/pybind11/stl.h', - 'include/pybind11/stl_bind.h', - ] - - -class InstallHeaders(install_headers): - """Use custom header installer because the default one flattens subdirectories""" - def run(self): - if not self.distribution.headers: - return - - for header in self.distribution.headers: - subdir = os.path.dirname(os.path.relpath(header, 'include/pybind11')) - install_dir = os.path.join(self.install_dir, subdir) - self.mkpath(install_dir) - - (out, _) = self.copy_file(header, install_dir) - self.outfiles.append(out) - - -setup( - name='pybind11', - version=__version__, - description='Seamless operability between C++11 and Python', - author='Wenzel Jakob', - author_email='wenzel.jakob@epfl.ch', - url='https://github.com/pybind/pybind11', - download_url='https://github.com/pybind/pybind11/tarball/v' + __version__, - packages=['pybind11'], - license='BSD', - headers=headers, - cmdclass=dict(install_headers=InstallHeaders), - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Libraries :: Python Modules', - 'Topic :: Utilities', - 'Programming Language :: C++', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.2', - 'Programming Language :: Python :: 3.3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'License :: OSI Approved :: BSD License' - ], - keywords='C++11, Python bindings', - long_description="""pybind11 is a lightweight header-only library that -exposes C++ types in Python and vice versa, mainly to create Python bindings of -existing C++ code. Its goals and syntax are similar to the excellent -Boost.Python by David Abrahams: to minimize boilerplate code in traditional -extension modules by inferring type information using compile-time -introspection. - -The main issue with Boost.Python-and the reason for creating such a similar -project-is Boost. Boost is an enormously large and complex suite of utility -libraries that works with almost every C++ compiler in existence. This -compatibility has its cost: arcane template tricks and workarounds are -necessary to support the oldest and buggiest of compiler specimens. Now that -C++11-compatible compilers are widely available, this heavy machinery has -become an excessively large and unnecessary dependency. - -Think of this library as a tiny self-contained version of Boost.Python with -everything stripped away that isn't relevant for binding generation. Without -comments, the core header files only require ~4K lines of code and depend on -Python (2.7 or 3.x, or PyPy2.7 >= 5.7) and the C++ standard library. This -compact implementation was possible thanks to some of the new C++11 language -features (specifically: tuples, lambda functions and variadic templates). Since -its creation, this library has grown beyond Boost.Python in many ways, leading -to dramatically simpler binding code in many common situations.""") diff --git a/pybind11/tests/CMakeLists.txt b/pybind11/tests/CMakeLists.txt deleted file mode 100644 index 765c47a..0000000 --- a/pybind11/tests/CMakeLists.txt +++ /dev/null @@ -1,259 +0,0 @@ -# CMakeLists.txt -- Build system for the pybind11 test suite -# -# Copyright (c) 2015 Wenzel Jakob -# -# All rights reserved. Use of this source code is governed by a -# BSD-style license that can be found in the LICENSE file. - -cmake_minimum_required(VERSION 2.8.12) - -option(PYBIND11_WERROR "Report all warnings as errors" OFF) - -if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) - # We're being loaded directly, i.e. not via add_subdirectory, so make this - # work as its own project and load the pybind11Config to get the tools we need - project(pybind11_tests CXX) - - find_package(pybind11 REQUIRED CONFIG) -endif() - -if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) - message(STATUS "Setting tests build type to MinSizeRel as none was specified") - set(CMAKE_BUILD_TYPE MinSizeRel CACHE STRING "Choose the type of build." FORCE) - set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" - "MinSizeRel" "RelWithDebInfo") -endif() - -# Full set of test files (you can override these; see below) -set(PYBIND11_TEST_FILES - test_async.cpp - test_buffers.cpp - test_builtin_casters.cpp - test_call_policies.cpp - test_callbacks.cpp - test_chrono.cpp - test_class.cpp - test_constants_and_functions.cpp - test_copy_move.cpp - test_docstring_options.cpp - test_eigen.cpp - test_enum.cpp - test_eval.cpp - test_exceptions.cpp - test_factory_constructors.cpp - test_gil_scoped.cpp - test_iostream.cpp - test_kwargs_and_defaults.cpp - test_local_bindings.cpp - test_methods_and_attributes.cpp - test_modules.cpp - test_multiple_inheritance.cpp - test_numpy_array.cpp - test_numpy_dtypes.cpp - test_numpy_vectorize.cpp - test_opaque_types.cpp - test_operator_overloading.cpp - test_pickling.cpp - test_pytypes.cpp - test_sequences_and_iterators.cpp - test_smart_ptr.cpp - test_stl.cpp - test_stl_binders.cpp - test_tagbased_polymorphic.cpp - test_union.cpp - test_virtual_functions.cpp -) - -# Invoking cmake with something like: -# cmake -DPYBIND11_TEST_OVERRIDE="test_callbacks.cpp;test_picking.cpp" .. -# lets you override the tests that get compiled and run. You can restore to all tests with: -# cmake -DPYBIND11_TEST_OVERRIDE= .. -if (PYBIND11_TEST_OVERRIDE) - set(PYBIND11_TEST_FILES ${PYBIND11_TEST_OVERRIDE}) -endif() - -# Skip test_async for Python < 3.5 -list(FIND PYBIND11_TEST_FILES test_async.cpp PYBIND11_TEST_FILES_ASYNC_I) -if((PYBIND11_TEST_FILES_ASYNC_I GREATER -1) AND ("${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}" VERSION_LESS 3.5)) - message(STATUS "Skipping test_async because Python version ${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR} < 3.5") - list(REMOVE_AT PYBIND11_TEST_FILES ${PYBIND11_TEST_FILES_ASYNC_I}) -endif() - -string(REPLACE ".cpp" ".py" PYBIND11_PYTEST_FILES "${PYBIND11_TEST_FILES}") - -# Contains the set of test files that require pybind11_cross_module_tests to be -# built; if none of these are built (i.e. because TEST_OVERRIDE is used and -# doesn't include them) the second module doesn't get built. -set(PYBIND11_CROSS_MODULE_TESTS - test_exceptions.py - test_local_bindings.py - test_stl.py - test_stl_binders.py -) - -set(PYBIND11_CROSS_MODULE_GIL_TESTS - test_gil_scoped.py -) - -# Check if Eigen is available; if not, remove from PYBIND11_TEST_FILES (but -# keep it in PYBIND11_PYTEST_FILES, so that we get the "eigen is not installed" -# skip message). -list(FIND PYBIND11_TEST_FILES test_eigen.cpp PYBIND11_TEST_FILES_EIGEN_I) -if(PYBIND11_TEST_FILES_EIGEN_I GREATER -1) - # Try loading via newer Eigen's Eigen3Config first (bypassing tools/FindEigen3.cmake). - # Eigen 3.3.1+ exports a cmake 3.0+ target for handling dependency requirements, but also - # produces a fatal error if loaded from a pre-3.0 cmake. - if (NOT CMAKE_VERSION VERSION_LESS 3.0) - find_package(Eigen3 3.2.7 QUIET CONFIG) - if (EIGEN3_FOUND) - if (EIGEN3_VERSION_STRING AND NOT EIGEN3_VERSION_STRING VERSION_LESS 3.3.1) - set(PYBIND11_EIGEN_VIA_TARGET 1) - endif() - endif() - endif() - if (NOT EIGEN3_FOUND) - # Couldn't load via target, so fall back to allowing module mode finding, which will pick up - # tools/FindEigen3.cmake - find_package(Eigen3 3.2.7 QUIET) - endif() - - if(EIGEN3_FOUND) - # Eigen 3.3.1+ cmake sets EIGEN3_VERSION_STRING (and hard codes the version when installed - # rather than looking it up in the cmake script); older versions, and the - # tools/FindEigen3.cmake, set EIGEN3_VERSION instead. - if(NOT EIGEN3_VERSION AND EIGEN3_VERSION_STRING) - set(EIGEN3_VERSION ${EIGEN3_VERSION_STRING}) - endif() - message(STATUS "Building tests with Eigen v${EIGEN3_VERSION}") - else() - list(REMOVE_AT PYBIND11_TEST_FILES ${PYBIND11_TEST_FILES_EIGEN_I}) - message(STATUS "Building tests WITHOUT Eigen") - endif() -endif() - -# Optional dependency for some tests (boost::variant is only supported with version >= 1.56) -find_package(Boost 1.56) - -# Compile with compiler warnings turned on -function(pybind11_enable_warnings target_name) - if(MSVC) - target_compile_options(${target_name} PRIVATE /W4) - elseif(CMAKE_CXX_COMPILER_ID MATCHES "(GNU|Intel|Clang)") - target_compile_options(${target_name} PRIVATE -Wall -Wextra -Wconversion -Wcast-qual -Wdeprecated) - endif() - - if(PYBIND11_WERROR) - if(MSVC) - target_compile_options(${target_name} PRIVATE /WX) - elseif(CMAKE_CXX_COMPILER_ID MATCHES "(GNU|Intel|Clang)") - target_compile_options(${target_name} PRIVATE -Werror) - endif() - endif() -endfunction() - -set(test_targets pybind11_tests) - -# Build pybind11_cross_module_tests if any test_whatever.py are being built that require it -foreach(t ${PYBIND11_CROSS_MODULE_TESTS}) - list(FIND PYBIND11_PYTEST_FILES ${t} i) - if (i GREATER -1) - list(APPEND test_targets pybind11_cross_module_tests) - break() - endif() -endforeach() - -foreach(t ${PYBIND11_CROSS_MODULE_GIL_TESTS}) - list(FIND PYBIND11_PYTEST_FILES ${t} i) - if (i GREATER -1) - list(APPEND test_targets cross_module_gil_utils) - break() - endif() -endforeach() - -set(testdir ${CMAKE_CURRENT_SOURCE_DIR}) -foreach(target ${test_targets}) - set(test_files ${PYBIND11_TEST_FILES}) - if(NOT target STREQUAL "pybind11_tests") - set(test_files "") - endif() - - # Create the binding library - pybind11_add_module(${target} THIN_LTO ${target}.cpp ${test_files} ${PYBIND11_HEADERS}) - pybind11_enable_warnings(${target}) - - if(MSVC) - target_compile_options(${target} PRIVATE /utf-8) - endif() - - if(EIGEN3_FOUND) - if (PYBIND11_EIGEN_VIA_TARGET) - target_link_libraries(${target} PRIVATE Eigen3::Eigen) - else() - target_include_directories(${target} PRIVATE ${EIGEN3_INCLUDE_DIR}) - endif() - target_compile_definitions(${target} PRIVATE -DPYBIND11_TEST_EIGEN) - endif() - - if(Boost_FOUND) - target_include_directories(${target} PRIVATE ${Boost_INCLUDE_DIRS}) - target_compile_definitions(${target} PRIVATE -DPYBIND11_TEST_BOOST) - endif() - - # Always write the output file directly into the 'tests' directory (even on MSVC) - if(NOT CMAKE_LIBRARY_OUTPUT_DIRECTORY) - set_target_properties(${target} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${testdir}) - foreach(config ${CMAKE_CONFIGURATION_TYPES}) - string(TOUPPER ${config} config) - set_target_properties(${target} PROPERTIES LIBRARY_OUTPUT_DIRECTORY_${config} ${testdir}) - endforeach() - endif() -endforeach() - -# Make sure pytest is found or produce a fatal error -if(NOT PYBIND11_PYTEST_FOUND) - execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import pytest; print(pytest.__version__)" - RESULT_VARIABLE pytest_not_found OUTPUT_VARIABLE pytest_version ERROR_QUIET) - if(pytest_not_found) - message(FATAL_ERROR "Running the tests requires pytest. Please install it manually" - " (try: ${PYTHON_EXECUTABLE} -m pip install pytest)") - elseif(pytest_version VERSION_LESS 3.0) - message(FATAL_ERROR "Running the tests requires pytest >= 3.0. Found: ${pytest_version}" - "Please update it (try: ${PYTHON_EXECUTABLE} -m pip install -U pytest)") - endif() - set(PYBIND11_PYTEST_FOUND TRUE CACHE INTERNAL "") -endif() - -if(CMAKE_VERSION VERSION_LESS 3.2) - set(PYBIND11_USES_TERMINAL "") -else() - set(PYBIND11_USES_TERMINAL "USES_TERMINAL") -endif() - -# A single command to compile and run the tests -add_custom_target(pytest COMMAND ${PYTHON_EXECUTABLE} -m pytest ${PYBIND11_PYTEST_FILES} - DEPENDS ${test_targets} WORKING_DIRECTORY ${testdir} ${PYBIND11_USES_TERMINAL}) - -if(PYBIND11_TEST_OVERRIDE) - add_custom_command(TARGET pytest POST_BUILD - COMMAND ${CMAKE_COMMAND} -E echo "Note: not all tests run: -DPYBIND11_TEST_OVERRIDE is in effect") -endif() - -# Add a check target to run all the tests, starting with pytest (we add dependencies to this below) -add_custom_target(check DEPENDS pytest) - -# The remaining tests only apply when being built as part of the pybind11 project, but not if the -# tests are being built independently. -if (NOT PROJECT_NAME STREQUAL "pybind11") - return() -endif() - -# Add a post-build comment to show the primary test suite .so size and, if a previous size, compare it: -add_custom_command(TARGET pybind11_tests POST_BUILD - COMMAND ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/tools/libsize.py - $ ${CMAKE_CURRENT_BINARY_DIR}/sosize-$.txt) - -# Test embedding the interpreter. Provides the `cpptest` target. -add_subdirectory(test_embed) - -# Test CMake build using functions and targets from subdirectory or installed location -add_subdirectory(test_cmake_build) diff --git a/pybind11/tests/conftest.py b/pybind11/tests/conftest.py deleted file mode 100644 index 57f681c..0000000 --- a/pybind11/tests/conftest.py +++ /dev/null @@ -1,244 +0,0 @@ -"""pytest configuration - -Extends output capture as needed by pybind11: ignore constructors, optional unordered lines. -Adds docstring and exceptions message sanitizers: ignore Python 2 vs 3 differences. -""" - -import pytest -import textwrap -import difflib -import re -import sys -import contextlib -import platform -import gc - -_unicode_marker = re.compile(r'u(\'[^\']*\')') -_long_marker = re.compile(r'([0-9])L') -_hexadecimal = re.compile(r'0x[0-9a-fA-F]+') - -# test_async.py requires support for async and await -collect_ignore = [] -if sys.version_info[:2] < (3, 5): - collect_ignore.append("test_async.py") - - -def _strip_and_dedent(s): - """For triple-quote strings""" - return textwrap.dedent(s.lstrip('\n').rstrip()) - - -def _split_and_sort(s): - """For output which does not require specific line order""" - return sorted(_strip_and_dedent(s).splitlines()) - - -def _make_explanation(a, b): - """Explanation for a failed assert -- the a and b arguments are List[str]""" - return ["--- actual / +++ expected"] + [line.strip('\n') for line in difflib.ndiff(a, b)] - - -class Output(object): - """Basic output post-processing and comparison""" - def __init__(self, string): - self.string = string - self.explanation = [] - - def __str__(self): - return self.string - - def __eq__(self, other): - # Ignore constructor/destructor output which is prefixed with "###" - a = [line for line in self.string.strip().splitlines() if not line.startswith("###")] - b = _strip_and_dedent(other).splitlines() - if a == b: - return True - else: - self.explanation = _make_explanation(a, b) - return False - - -class Unordered(Output): - """Custom comparison for output without strict line ordering""" - def __eq__(self, other): - a = _split_and_sort(self.string) - b = _split_and_sort(other) - if a == b: - return True - else: - self.explanation = _make_explanation(a, b) - return False - - -class Capture(object): - def __init__(self, capfd): - self.capfd = capfd - self.out = "" - self.err = "" - - def __enter__(self): - self.capfd.readouterr() - return self - - def __exit__(self, *args): - self.out, self.err = self.capfd.readouterr() - - def __eq__(self, other): - a = Output(self.out) - b = other - if a == b: - return True - else: - self.explanation = a.explanation - return False - - def __str__(self): - return self.out - - def __contains__(self, item): - return item in self.out - - @property - def unordered(self): - return Unordered(self.out) - - @property - def stderr(self): - return Output(self.err) - - -@pytest.fixture -def capture(capsys): - """Extended `capsys` with context manager and custom equality operators""" - return Capture(capsys) - - -class SanitizedString(object): - def __init__(self, sanitizer): - self.sanitizer = sanitizer - self.string = "" - self.explanation = [] - - def __call__(self, thing): - self.string = self.sanitizer(thing) - return self - - def __eq__(self, other): - a = self.string - b = _strip_and_dedent(other) - if a == b: - return True - else: - self.explanation = _make_explanation(a.splitlines(), b.splitlines()) - return False - - -def _sanitize_general(s): - s = s.strip() - s = s.replace("pybind11_tests.", "m.") - s = s.replace("unicode", "str") - s = _long_marker.sub(r"\1", s) - s = _unicode_marker.sub(r"\1", s) - return s - - -def _sanitize_docstring(thing): - s = thing.__doc__ - s = _sanitize_general(s) - return s - - -@pytest.fixture -def doc(): - """Sanitize docstrings and add custom failure explanation""" - return SanitizedString(_sanitize_docstring) - - -def _sanitize_message(thing): - s = str(thing) - s = _sanitize_general(s) - s = _hexadecimal.sub("0", s) - return s - - -@pytest.fixture -def msg(): - """Sanitize messages and add custom failure explanation""" - return SanitizedString(_sanitize_message) - - -# noinspection PyUnusedLocal -def pytest_assertrepr_compare(op, left, right): - """Hook to insert custom failure explanation""" - if hasattr(left, 'explanation'): - return left.explanation - - -@contextlib.contextmanager -def suppress(exception): - """Suppress the desired exception""" - try: - yield - except exception: - pass - - -def gc_collect(): - ''' Run the garbage collector twice (needed when running - reference counting tests with PyPy) ''' - gc.collect() - gc.collect() - - -def pytest_configure(): - """Add import suppression and test requirements to `pytest` namespace""" - try: - import numpy as np - except ImportError: - np = None - try: - import scipy - except ImportError: - scipy = None - try: - from pybind11_tests.eigen import have_eigen - except ImportError: - have_eigen = False - pypy = platform.python_implementation() == "PyPy" - - skipif = pytest.mark.skipif - pytest.suppress = suppress - pytest.requires_numpy = skipif(not np, reason="numpy is not installed") - pytest.requires_scipy = skipif(not np, reason="scipy is not installed") - pytest.requires_eigen_and_numpy = skipif(not have_eigen or not np, - reason="eigen and/or numpy are not installed") - pytest.requires_eigen_and_scipy = skipif( - not have_eigen or not scipy, reason="eigen and/or scipy are not installed") - pytest.unsupported_on_pypy = skipif(pypy, reason="unsupported on PyPy") - pytest.unsupported_on_py2 = skipif(sys.version_info.major < 3, - reason="unsupported on Python 2.x") - pytest.gc_collect = gc_collect - - -def _test_import_pybind11(): - """Early diagnostic for test module initialization errors - - When there is an error during initialization, the first import will report the - real error while all subsequent imports will report nonsense. This import test - is done early (in the pytest configuration file, before any tests) in order to - avoid the noise of having all tests fail with identical error messages. - - Any possible exception is caught here and reported manually *without* the stack - trace. This further reduces noise since the trace would only show pytest internals - which are not useful for debugging pybind11 module issues. - """ - # noinspection PyBroadException - try: - import pybind11_tests # noqa: F401 imported but unused - except Exception as e: - print("Failed to import pybind11_tests from pytest:") - print(" {}: {}".format(type(e).__name__, e)) - sys.exit(1) - - -_test_import_pybind11() diff --git a/pybind11/tests/constructor_stats.h b/pybind11/tests/constructor_stats.h deleted file mode 100644 index f026e70..0000000 --- a/pybind11/tests/constructor_stats.h +++ /dev/null @@ -1,276 +0,0 @@ -#pragma once -/* - tests/constructor_stats.h -- framework for printing and tracking object - instance lifetimes in example/test code. - - Copyright (c) 2016 Jason Rhinelander - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. - -This header provides a few useful tools for writing examples or tests that want to check and/or -display object instance lifetimes. It requires that you include this header and add the following -function calls to constructors: - - class MyClass { - MyClass() { ...; print_default_created(this); } - ~MyClass() { ...; print_destroyed(this); } - MyClass(const MyClass &c) { ...; print_copy_created(this); } - MyClass(MyClass &&c) { ...; print_move_created(this); } - MyClass(int a, int b) { ...; print_created(this, a, b); } - MyClass &operator=(const MyClass &c) { ...; print_copy_assigned(this); } - MyClass &operator=(MyClass &&c) { ...; print_move_assigned(this); } - - ... - } - -You can find various examples of these in several of the existing testing .cpp files. (Of course -you don't need to add any of the above constructors/operators that you don't actually have, except -for the destructor). - -Each of these will print an appropriate message such as: - - ### MyClass @ 0x2801910 created via default constructor - ### MyClass @ 0x27fa780 created 100 200 - ### MyClass @ 0x2801910 destroyed - ### MyClass @ 0x27fa780 destroyed - -You can also include extra arguments (such as the 100, 200 in the output above, coming from the -value constructor) for all of the above methods which will be included in the output. - -For testing, each of these also keeps track the created instances and allows you to check how many -of the various constructors have been invoked from the Python side via code such as: - - from pybind11_tests import ConstructorStats - cstats = ConstructorStats.get(MyClass) - print(cstats.alive()) - print(cstats.default_constructions) - -Note that `.alive()` should usually be the first thing you call as it invokes Python's garbage -collector to actually destroy objects that aren't yet referenced. - -For everything except copy and move constructors and destructors, any extra values given to the -print_...() function is stored in a class-specific values list which you can retrieve and inspect -from the ConstructorStats instance `.values()` method. - -In some cases, when you need to track instances of a C++ class not registered with pybind11, you -need to add a function returning the ConstructorStats for the C++ class; this can be done with: - - m.def("get_special_cstats", &ConstructorStats::get, py::return_value_policy::reference) - -Finally, you can suppress the output messages, but keep the constructor tracking (for -inspection/testing in python) by using the functions with `print_` replaced with `track_` (e.g. -`track_copy_created(this)`). - -*/ - -#include "pybind11_tests.h" -#include -#include -#include -#include - -class ConstructorStats { -protected: - std::unordered_map _instances; // Need a map rather than set because members can shared address with parents - std::list _values; // Used to track values (e.g. of value constructors) -public: - int default_constructions = 0; - int copy_constructions = 0; - int move_constructions = 0; - int copy_assignments = 0; - int move_assignments = 0; - - void copy_created(void *inst) { - created(inst); - copy_constructions++; - } - - void move_created(void *inst) { - created(inst); - move_constructions++; - } - - void default_created(void *inst) { - created(inst); - default_constructions++; - } - - void created(void *inst) { - ++_instances[inst]; - } - - void destroyed(void *inst) { - if (--_instances[inst] < 0) - throw std::runtime_error("cstats.destroyed() called with unknown " - "instance; potential double-destruction " - "or a missing cstats.created()"); - } - - static void gc() { - // Force garbage collection to ensure any pending destructors are invoked: -#if defined(PYPY_VERSION) - PyObject *globals = PyEval_GetGlobals(); - PyObject *result = PyRun_String( - "import gc\n" - "for i in range(2):" - " gc.collect()\n", - Py_file_input, globals, globals); - if (result == nullptr) - throw py::error_already_set(); - Py_DECREF(result); -#else - py::module::import("gc").attr("collect")(); -#endif - } - - int alive() { - gc(); - int total = 0; - for (const auto &p : _instances) - if (p.second > 0) - total += p.second; - return total; - } - - void value() {} // Recursion terminator - // Takes one or more values, converts them to strings, then stores them. - template void value(const T &v, Tmore &&...args) { - std::ostringstream oss; - oss << v; - _values.push_back(oss.str()); - value(std::forward(args)...); - } - - // Move out stored values - py::list values() { - py::list l; - for (const auto &v : _values) l.append(py::cast(v)); - _values.clear(); - return l; - } - - // Gets constructor stats from a C++ type index - static ConstructorStats& get(std::type_index type) { - static std::unordered_map all_cstats; - return all_cstats[type]; - } - - // Gets constructor stats from a C++ type - template static ConstructorStats& get() { -#if defined(PYPY_VERSION) - gc(); -#endif - return get(typeid(T)); - } - - // Gets constructor stats from a Python class - static ConstructorStats& get(py::object class_) { - auto &internals = py::detail::get_internals(); - const std::type_index *t1 = nullptr, *t2 = nullptr; - try { - auto *type_info = internals.registered_types_py.at((PyTypeObject *) class_.ptr()).at(0); - for (auto &p : internals.registered_types_cpp) { - if (p.second == type_info) { - if (t1) { - t2 = &p.first; - break; - } - t1 = &p.first; - } - } - } - catch (const std::out_of_range &) {} - if (!t1) throw std::runtime_error("Unknown class passed to ConstructorStats::get()"); - auto &cs1 = get(*t1); - // If we have both a t1 and t2 match, one is probably the trampoline class; return whichever - // has more constructions (typically one or the other will be 0) - if (t2) { - auto &cs2 = get(*t2); - int cs1_total = cs1.default_constructions + cs1.copy_constructions + cs1.move_constructions + (int) cs1._values.size(); - int cs2_total = cs2.default_constructions + cs2.copy_constructions + cs2.move_constructions + (int) cs2._values.size(); - if (cs2_total > cs1_total) return cs2; - } - return cs1; - } -}; - -// To track construction/destruction, you need to call these methods from the various -// constructors/operators. The ones that take extra values record the given values in the -// constructor stats values for later inspection. -template void track_copy_created(T *inst) { ConstructorStats::get().copy_created(inst); } -template void track_move_created(T *inst) { ConstructorStats::get().move_created(inst); } -template void track_copy_assigned(T *, Values &&...values) { - auto &cst = ConstructorStats::get(); - cst.copy_assignments++; - cst.value(std::forward(values)...); -} -template void track_move_assigned(T *, Values &&...values) { - auto &cst = ConstructorStats::get(); - cst.move_assignments++; - cst.value(std::forward(values)...); -} -template void track_default_created(T *inst, Values &&...values) { - auto &cst = ConstructorStats::get(); - cst.default_created(inst); - cst.value(std::forward(values)...); -} -template void track_created(T *inst, Values &&...values) { - auto &cst = ConstructorStats::get(); - cst.created(inst); - cst.value(std::forward(values)...); -} -template void track_destroyed(T *inst) { - ConstructorStats::get().destroyed(inst); -} -template void track_values(T *, Values &&...values) { - ConstructorStats::get().value(std::forward(values)...); -} - -/// Don't cast pointers to Python, print them as strings -inline const char *format_ptrs(const char *p) { return p; } -template -py::str format_ptrs(T *p) { return "{:#x}"_s.format(reinterpret_cast(p)); } -template -auto format_ptrs(T &&x) -> decltype(std::forward(x)) { return std::forward(x); } - -template -void print_constr_details(T *inst, const std::string &action, Output &&...output) { - py::print("###", py::type_id(), "@", format_ptrs(inst), action, - format_ptrs(std::forward(output))...); -} - -// Verbose versions of the above: -template void print_copy_created(T *inst, Values &&...values) { // NB: this prints, but doesn't store, given values - print_constr_details(inst, "created via copy constructor", values...); - track_copy_created(inst); -} -template void print_move_created(T *inst, Values &&...values) { // NB: this prints, but doesn't store, given values - print_constr_details(inst, "created via move constructor", values...); - track_move_created(inst); -} -template void print_copy_assigned(T *inst, Values &&...values) { - print_constr_details(inst, "assigned via copy assignment", values...); - track_copy_assigned(inst, values...); -} -template void print_move_assigned(T *inst, Values &&...values) { - print_constr_details(inst, "assigned via move assignment", values...); - track_move_assigned(inst, values...); -} -template void print_default_created(T *inst, Values &&...values) { - print_constr_details(inst, "created via default constructor", values...); - track_default_created(inst, values...); -} -template void print_created(T *inst, Values &&...values) { - print_constr_details(inst, "created", values...); - track_created(inst, values...); -} -template void print_destroyed(T *inst, Values &&...values) { // Prints but doesn't store given values - print_constr_details(inst, "destroyed", values...); - track_destroyed(inst); -} -template void print_values(T *inst, Values &&...values) { - print_constr_details(inst, ":", values...); - track_values(inst, values...); -} - diff --git a/pybind11/tests/cross_module_gil_utils.cpp b/pybind11/tests/cross_module_gil_utils.cpp deleted file mode 100644 index 07db9f6..0000000 --- a/pybind11/tests/cross_module_gil_utils.cpp +++ /dev/null @@ -1,73 +0,0 @@ -/* - tests/cross_module_gil_utils.cpp -- tools for acquiring GIL from a different module - - Copyright (c) 2019 Google LLC - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ -#include -#include - -// This file mimics a DSO that makes pybind11 calls but does not define a -// PYBIND11_MODULE. The purpose is to test that such a DSO can create a -// py::gil_scoped_acquire when the running thread is in a GIL-released state. -// -// Note that we define a Python module here for convenience, but in general -// this need not be the case. The typical scenario would be a DSO that implements -// shared logic used internally by multiple pybind11 modules. - -namespace { - -namespace py = pybind11; -void gil_acquire() { py::gil_scoped_acquire gil; } - -constexpr char kModuleName[] = "cross_module_gil_utils"; - -#if PY_MAJOR_VERSION >= 3 -struct PyModuleDef moduledef = { - PyModuleDef_HEAD_INIT, - kModuleName, - NULL, - 0, - NULL, - NULL, - NULL, - NULL, - NULL -}; -#else -PyMethodDef module_methods[] = { - {NULL, NULL, 0, NULL} -}; -#endif - -} // namespace - -extern "C" PYBIND11_EXPORT -#if PY_MAJOR_VERSION >= 3 -PyObject* PyInit_cross_module_gil_utils() -#else -void initcross_module_gil_utils() -#endif -{ - - PyObject* m = -#if PY_MAJOR_VERSION >= 3 - PyModule_Create(&moduledef); -#else - Py_InitModule(kModuleName, module_methods); -#endif - - if (m != NULL) { - static_assert( - sizeof(&gil_acquire) == sizeof(void*), - "Function pointer must have the same size as void*"); - PyModule_AddObject(m, "gil_acquire_funcaddr", - PyLong_FromVoidPtr(reinterpret_cast(&gil_acquire))); - } - -#if PY_MAJOR_VERSION >= 3 - return m; -#endif -} diff --git a/pybind11/tests/local_bindings.h b/pybind11/tests/local_bindings.h deleted file mode 100644 index b6afb80..0000000 --- a/pybind11/tests/local_bindings.h +++ /dev/null @@ -1,64 +0,0 @@ -#pragma once -#include "pybind11_tests.h" - -/// Simple class used to test py::local: -template class LocalBase { -public: - LocalBase(int i) : i(i) { } - int i = -1; -}; - -/// Registered with py::module_local in both main and secondary modules: -using LocalType = LocalBase<0>; -/// Registered without py::module_local in both modules: -using NonLocalType = LocalBase<1>; -/// A second non-local type (for stl_bind tests): -using NonLocal2 = LocalBase<2>; -/// Tests within-module, different-compilation-unit local definition conflict: -using LocalExternal = LocalBase<3>; -/// Mixed: registered local first, then global -using MixedLocalGlobal = LocalBase<4>; -/// Mixed: global first, then local -using MixedGlobalLocal = LocalBase<5>; - -/// Registered with py::module_local only in the secondary module: -using ExternalType1 = LocalBase<6>; -using ExternalType2 = LocalBase<7>; - -using LocalVec = std::vector; -using LocalVec2 = std::vector; -using LocalMap = std::unordered_map; -using NonLocalVec = std::vector; -using NonLocalVec2 = std::vector; -using NonLocalMap = std::unordered_map; -using NonLocalMap2 = std::unordered_map; - -PYBIND11_MAKE_OPAQUE(LocalVec); -PYBIND11_MAKE_OPAQUE(LocalVec2); -PYBIND11_MAKE_OPAQUE(LocalMap); -PYBIND11_MAKE_OPAQUE(NonLocalVec); -//PYBIND11_MAKE_OPAQUE(NonLocalVec2); // same type as LocalVec2 -PYBIND11_MAKE_OPAQUE(NonLocalMap); -PYBIND11_MAKE_OPAQUE(NonLocalMap2); - - -// Simple bindings (used with the above): -template -py::class_ bind_local(Args && ...args) { - return py::class_(std::forward(args)...) - .def(py::init()) - .def("get", [](T &i) { return i.i + Adjust; }); -}; - -// Simulate a foreign library base class (to match the example in the docs): -namespace pets { -class Pet { -public: - Pet(std::string name) : name_(name) {} - std::string name_; - const std::string &name() { return name_; } -}; -} - -struct MixGL { int i; MixGL(int i) : i{i} {} }; -struct MixGL2 { int i; MixGL2(int i) : i{i} {} }; diff --git a/pybind11/tests/object.h b/pybind11/tests/object.h deleted file mode 100644 index 9235f19..0000000 --- a/pybind11/tests/object.h +++ /dev/null @@ -1,175 +0,0 @@ -#if !defined(__OBJECT_H) -#define __OBJECT_H - -#include -#include "constructor_stats.h" - -/// Reference counted object base class -class Object { -public: - /// Default constructor - Object() { print_default_created(this); } - - /// Copy constructor - Object(const Object &) : m_refCount(0) { print_copy_created(this); } - - /// Return the current reference count - int getRefCount() const { return m_refCount; }; - - /// Increase the object's reference count by one - void incRef() const { ++m_refCount; } - - /** \brief Decrease the reference count of - * the object and possibly deallocate it. - * - * The object will automatically be deallocated once - * the reference count reaches zero. - */ - void decRef(bool dealloc = true) const { - --m_refCount; - if (m_refCount == 0 && dealloc) - delete this; - else if (m_refCount < 0) - throw std::runtime_error("Internal error: reference count < 0!"); - } - - virtual std::string toString() const = 0; -protected: - /** \brief Virtual protected deconstructor. - * (Will only be called by \ref ref) - */ - virtual ~Object() { print_destroyed(this); } -private: - mutable std::atomic m_refCount { 0 }; -}; - -// Tag class used to track constructions of ref objects. When we track constructors, below, we -// track and print out the actual class (e.g. ref), and *also* add a fake tracker for -// ref_tag. This lets us check that the total number of ref constructors/destructors is -// correct without having to check each individual ref type individually. -class ref_tag {}; - -/** - * \brief Reference counting helper - * - * The \a ref refeference template is a simple wrapper to store a - * pointer to an object. It takes care of increasing and decreasing - * the reference count of the object. When the last reference goes - * out of scope, the associated object will be deallocated. - * - * \ingroup libcore - */ -template class ref { -public: - /// Create a nullptr reference - ref() : m_ptr(nullptr) { print_default_created(this); track_default_created((ref_tag*) this); } - - /// Construct a reference from a pointer - ref(T *ptr) : m_ptr(ptr) { - if (m_ptr) ((Object *) m_ptr)->incRef(); - - print_created(this, "from pointer", m_ptr); track_created((ref_tag*) this, "from pointer"); - - } - - /// Copy constructor - ref(const ref &r) : m_ptr(r.m_ptr) { - if (m_ptr) - ((Object *) m_ptr)->incRef(); - - print_copy_created(this, "with pointer", m_ptr); track_copy_created((ref_tag*) this); - } - - /// Move constructor - ref(ref &&r) : m_ptr(r.m_ptr) { - r.m_ptr = nullptr; - - print_move_created(this, "with pointer", m_ptr); track_move_created((ref_tag*) this); - } - - /// Destroy this reference - ~ref() { - if (m_ptr) - ((Object *) m_ptr)->decRef(); - - print_destroyed(this); track_destroyed((ref_tag*) this); - } - - /// Move another reference into the current one - ref& operator=(ref&& r) { - print_move_assigned(this, "pointer", r.m_ptr); track_move_assigned((ref_tag*) this); - - if (*this == r) - return *this; - if (m_ptr) - ((Object *) m_ptr)->decRef(); - m_ptr = r.m_ptr; - r.m_ptr = nullptr; - return *this; - } - - /// Overwrite this reference with another reference - ref& operator=(const ref& r) { - print_copy_assigned(this, "pointer", r.m_ptr); track_copy_assigned((ref_tag*) this); - - if (m_ptr == r.m_ptr) - return *this; - if (m_ptr) - ((Object *) m_ptr)->decRef(); - m_ptr = r.m_ptr; - if (m_ptr) - ((Object *) m_ptr)->incRef(); - return *this; - } - - /// Overwrite this reference with a pointer to another object - ref& operator=(T *ptr) { - print_values(this, "assigned pointer"); track_values((ref_tag*) this, "assigned pointer"); - - if (m_ptr == ptr) - return *this; - if (m_ptr) - ((Object *) m_ptr)->decRef(); - m_ptr = ptr; - if (m_ptr) - ((Object *) m_ptr)->incRef(); - return *this; - } - - /// Compare this reference with another reference - bool operator==(const ref &r) const { return m_ptr == r.m_ptr; } - - /// Compare this reference with another reference - bool operator!=(const ref &r) const { return m_ptr != r.m_ptr; } - - /// Compare this reference with a pointer - bool operator==(const T* ptr) const { return m_ptr == ptr; } - - /// Compare this reference with a pointer - bool operator!=(const T* ptr) const { return m_ptr != ptr; } - - /// Access the object referenced by this reference - T* operator->() { return m_ptr; } - - /// Access the object referenced by this reference - const T* operator->() const { return m_ptr; } - - /// Return a C++ reference to the referenced object - T& operator*() { return *m_ptr; } - - /// Return a const C++ reference to the referenced object - const T& operator*() const { return *m_ptr; } - - /// Return a pointer to the referenced object - operator T* () { return m_ptr; } - - /// Return a const pointer to the referenced object - T* get_ptr() { return m_ptr; } - - /// Return a pointer to the referenced object - const T* get_ptr() const { return m_ptr; } -private: - T *m_ptr; -}; - -#endif /* __OBJECT_H */ diff --git a/pybind11/tests/pybind11_cross_module_tests.cpp b/pybind11/tests/pybind11_cross_module_tests.cpp deleted file mode 100644 index f705e31..0000000 --- a/pybind11/tests/pybind11_cross_module_tests.cpp +++ /dev/null @@ -1,123 +0,0 @@ -/* - tests/pybind11_cross_module_tests.cpp -- contains tests that require multiple modules - - Copyright (c) 2017 Jason Rhinelander - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "local_bindings.h" -#include -#include - -PYBIND11_MODULE(pybind11_cross_module_tests, m) { - m.doc() = "pybind11 cross-module test module"; - - // test_local_bindings.py tests: - // - // Definitions here are tested by importing both this module and the - // relevant pybind11_tests submodule from a test_whatever.py - - // test_load_external - bind_local(m, "ExternalType1", py::module_local()); - bind_local(m, "ExternalType2", py::module_local()); - - // test_exceptions.py - m.def("raise_runtime_error", []() { PyErr_SetString(PyExc_RuntimeError, "My runtime error"); throw py::error_already_set(); }); - m.def("raise_value_error", []() { PyErr_SetString(PyExc_ValueError, "My value error"); throw py::error_already_set(); }); - m.def("throw_pybind_value_error", []() { throw py::value_error("pybind11 value error"); }); - m.def("throw_pybind_type_error", []() { throw py::type_error("pybind11 type error"); }); - m.def("throw_stop_iteration", []() { throw py::stop_iteration(); }); - - // test_local_bindings.py - // Local to both: - bind_local(m, "LocalType", py::module_local()) - .def("get2", [](LocalType &t) { return t.i + 2; }) - ; - - // Can only be called with our python type: - m.def("local_value", [](LocalType &l) { return l.i; }); - - // test_nonlocal_failure - // This registration will fail (global registration when LocalFail is already registered - // globally in the main test module): - m.def("register_nonlocal", [m]() { - bind_local(m, "NonLocalType"); - }); - - // test_stl_bind_local - // stl_bind.h binders defaults to py::module_local if the types are local or converting: - py::bind_vector(m, "LocalVec"); - py::bind_map(m, "LocalMap"); - - // test_stl_bind_global - // and global if the type (or one of the types, for the map) is global (so these will fail, - // assuming pybind11_tests is already loaded): - m.def("register_nonlocal_vec", [m]() { - py::bind_vector(m, "NonLocalVec"); - }); - m.def("register_nonlocal_map", [m]() { - py::bind_map(m, "NonLocalMap"); - }); - // The default can, however, be overridden to global using `py::module_local()` or - // `py::module_local(false)`. - // Explicitly made local: - py::bind_vector(m, "NonLocalVec2", py::module_local()); - // Explicitly made global (and so will fail to bind): - m.def("register_nonlocal_map2", [m]() { - py::bind_map(m, "NonLocalMap2", py::module_local(false)); - }); - - // test_mixed_local_global - // We try this both with the global type registered first and vice versa (the order shouldn't - // matter). - m.def("register_mixed_global_local", [m]() { - bind_local(m, "MixedGlobalLocal", py::module_local()); - }); - m.def("register_mixed_local_global", [m]() { - bind_local(m, "MixedLocalGlobal", py::module_local(false)); - }); - m.def("get_mixed_gl", [](int i) { return MixedGlobalLocal(i); }); - m.def("get_mixed_lg", [](int i) { return MixedLocalGlobal(i); }); - - // test_internal_locals_differ - m.def("local_cpp_types_addr", []() { return (uintptr_t) &py::detail::registered_local_types_cpp(); }); - - // test_stl_caster_vs_stl_bind - py::bind_vector>(m, "VectorInt"); - - m.def("load_vector_via_binding", [](std::vector &v) { - return std::accumulate(v.begin(), v.end(), 0); - }); - - // test_cross_module_calls - m.def("return_self", [](LocalVec *v) { return v; }); - m.def("return_copy", [](const LocalVec &v) { return LocalVec(v); }); - - class Dog : public pets::Pet { public: Dog(std::string name) : Pet(name) {}; }; - py::class_(m, "Pet", py::module_local()) - .def("name", &pets::Pet::name); - // Binding for local extending class: - py::class_(m, "Dog") - .def(py::init()); - m.def("pet_name", [](pets::Pet &p) { return p.name(); }); - - py::class_(m, "MixGL", py::module_local()).def(py::init()); - m.def("get_gl_value", [](MixGL &o) { return o.i + 100; }); - - py::class_(m, "MixGL2", py::module_local()).def(py::init()); - - // test_vector_bool - // We can't test both stl.h and stl_bind.h conversions of `std::vector` within - // the same module (it would be an ODR violation). Therefore `bind_vector` of `bool` - // is defined here and tested in `test_stl_binders.py`. - py::bind_vector>(m, "VectorBool"); - - // test_missing_header_message - // The main module already includes stl.h, but we need to test the error message - // which appears when this header is missing. - m.def("missing_header_arg", [](std::vector) { }); - m.def("missing_header_return", []() { return std::vector(); }); -} diff --git a/pybind11/tests/pybind11_tests.cpp b/pybind11/tests/pybind11_tests.cpp deleted file mode 100644 index bc7d2c3..0000000 --- a/pybind11/tests/pybind11_tests.cpp +++ /dev/null @@ -1,93 +0,0 @@ -/* - tests/pybind11_tests.cpp -- pybind example plugin - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" - -#include -#include - -/* -For testing purposes, we define a static global variable here in a function that each individual -test .cpp calls with its initialization lambda. It's convenient here because we can just not -compile some test files to disable/ignore some of the test code. - -It is NOT recommended as a way to use pybind11 in practice, however: the initialization order will -be essentially random, which is okay for our test scripts (there are no dependencies between the -individual pybind11 test .cpp files), but most likely not what you want when using pybind11 -productively. - -Instead, see the "How can I reduce the build time?" question in the "Frequently asked questions" -section of the documentation for good practice on splitting binding code over multiple files. -*/ -std::list> &initializers() { - static std::list> inits; - return inits; -} - -test_initializer::test_initializer(Initializer init) { - initializers().push_back(init); -} - -test_initializer::test_initializer(const char *submodule_name, Initializer init) { - initializers().push_back([=](py::module &parent) { - auto m = parent.def_submodule(submodule_name); - init(m); - }); -} - -void bind_ConstructorStats(py::module &m) { - py::class_(m, "ConstructorStats") - .def("alive", &ConstructorStats::alive) - .def("values", &ConstructorStats::values) - .def_readwrite("default_constructions", &ConstructorStats::default_constructions) - .def_readwrite("copy_assignments", &ConstructorStats::copy_assignments) - .def_readwrite("move_assignments", &ConstructorStats::move_assignments) - .def_readwrite("copy_constructions", &ConstructorStats::copy_constructions) - .def_readwrite("move_constructions", &ConstructorStats::move_constructions) - .def_static("get", (ConstructorStats &(*)(py::object)) &ConstructorStats::get, py::return_value_policy::reference_internal) - - // Not exactly ConstructorStats, but related: expose the internal pybind number of registered instances - // to allow instance cleanup checks (invokes a GC first) - .def_static("detail_reg_inst", []() { - ConstructorStats::gc(); - return py::detail::get_internals().registered_instances.size(); - }) - ; -} - -PYBIND11_MODULE(pybind11_tests, m) { - m.doc() = "pybind11 test module"; - - bind_ConstructorStats(m); - -#if !defined(NDEBUG) - m.attr("debug_enabled") = true; -#else - m.attr("debug_enabled") = false; -#endif - - py::class_(m, "UserType", "A `py::class_` type for testing") - .def(py::init<>()) - .def(py::init()) - .def("get_value", &UserType::value, "Get value using a method") - .def("set_value", &UserType::set, "Set value using a method") - .def_property("value", &UserType::value, &UserType::set, "Get/set value using a property") - .def("__repr__", [](const UserType& u) { return "UserType({})"_s.format(u.value()); }); - - py::class_(m, "IncType") - .def(py::init<>()) - .def(py::init()) - .def("__repr__", [](const IncType& u) { return "IncType({})"_s.format(u.value()); }); - - for (const auto &initializer : initializers()) - initializer(m); - - if (!py::hasattr(m, "have_eigen")) m.attr("have_eigen") = false; -} diff --git a/pybind11/tests/pybind11_tests.h b/pybind11/tests/pybind11_tests.h deleted file mode 100644 index 90963a5..0000000 --- a/pybind11/tests/pybind11_tests.h +++ /dev/null @@ -1,65 +0,0 @@ -#pragma once -#include - -#if defined(_MSC_VER) && _MSC_VER < 1910 -// We get some really long type names here which causes MSVC 2015 to emit warnings -# pragma warning(disable: 4503) // warning C4503: decorated name length exceeded, name was truncated -#endif - -namespace py = pybind11; -using namespace pybind11::literals; - -class test_initializer { - using Initializer = void (*)(py::module &); - -public: - test_initializer(Initializer init); - test_initializer(const char *submodule_name, Initializer init); -}; - -#define TEST_SUBMODULE(name, variable) \ - void test_submodule_##name(py::module &); \ - test_initializer name(#name, test_submodule_##name); \ - void test_submodule_##name(py::module &variable) - - -/// Dummy type which is not exported anywhere -- something to trigger a conversion error -struct UnregisteredType { }; - -/// A user-defined type which is exported and can be used by any test -class UserType { -public: - UserType() = default; - UserType(int i) : i(i) { } - - int value() const { return i; } - void set(int set) { i = set; } - -private: - int i = -1; -}; - -/// Like UserType, but increments `value` on copy for quick reference vs. copy tests -class IncType : public UserType { -public: - using UserType::UserType; - IncType() = default; - IncType(const IncType &other) : IncType(other.value() + 1) { } - IncType(IncType &&) = delete; - IncType &operator=(const IncType &) = delete; - IncType &operator=(IncType &&) = delete; -}; - -/// Custom cast-only type that casts to a string "rvalue" or "lvalue" depending on the cast context. -/// Used to test recursive casters (e.g. std::tuple, stl containers). -struct RValueCaster {}; -NAMESPACE_BEGIN(pybind11) -NAMESPACE_BEGIN(detail) -template<> class type_caster { -public: - PYBIND11_TYPE_CASTER(RValueCaster, _("RValueCaster")); - static handle cast(RValueCaster &&, return_value_policy, handle) { return py::str("rvalue").release(); } - static handle cast(const RValueCaster &, return_value_policy, handle) { return py::str("lvalue").release(); } -}; -NAMESPACE_END(detail) -NAMESPACE_END(pybind11) diff --git a/pybind11/tests/pytest.ini b/pybind11/tests/pytest.ini deleted file mode 100644 index f209964..0000000 --- a/pybind11/tests/pytest.ini +++ /dev/null @@ -1,16 +0,0 @@ -[pytest] -minversion = 3.0 -norecursedirs = test_cmake_build test_embed -addopts = - # show summary of skipped tests - -rs - # capture only Python print and C++ py::print, but not C output (low-level Python errors) - --capture=sys -filterwarnings = - # make warnings into errors but ignore certain third-party extension issues - error - # importing scipy submodules on some version of Python - ignore::ImportWarning - # bogus numpy ABI warning (see numpy/#432) - ignore:.*numpy.dtype size changed.*:RuntimeWarning - ignore:.*numpy.ufunc size changed.*:RuntimeWarning diff --git a/pybind11/tests/test_async.cpp b/pybind11/tests/test_async.cpp deleted file mode 100644 index f0ad0d5..0000000 --- a/pybind11/tests/test_async.cpp +++ /dev/null @@ -1,26 +0,0 @@ -/* - tests/test_async.cpp -- __await__ support - - Copyright (c) 2019 Google Inc. - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" - -TEST_SUBMODULE(async_module, m) { - struct DoesNotSupportAsync {}; - py::class_(m, "DoesNotSupportAsync") - .def(py::init<>()); - struct SupportsAsync {}; - py::class_(m, "SupportsAsync") - .def(py::init<>()) - .def("__await__", [](const SupportsAsync& self) -> py::object { - static_cast(self); - py::object loop = py::module::import("asyncio.events").attr("get_event_loop")(); - py::object f = loop.attr("create_future")(); - f.attr("set_result")(5); - return f.attr("__await__")(); - }); -} diff --git a/pybind11/tests/test_async.py b/pybind11/tests/test_async.py deleted file mode 100644 index e1c959d..0000000 --- a/pybind11/tests/test_async.py +++ /dev/null @@ -1,23 +0,0 @@ -import asyncio -import pytest -from pybind11_tests import async_module as m - - -@pytest.fixture -def event_loop(): - loop = asyncio.new_event_loop() - yield loop - loop.close() - - -async def get_await_result(x): - return await x - - -def test_await(event_loop): - assert 5 == event_loop.run_until_complete(get_await_result(m.SupportsAsync())) - - -def test_await_missing(event_loop): - with pytest.raises(TypeError): - event_loop.run_until_complete(get_await_result(m.DoesNotSupportAsync())) diff --git a/pybind11/tests/test_buffers.cpp b/pybind11/tests/test_buffers.cpp deleted file mode 100644 index 433dfee..0000000 --- a/pybind11/tests/test_buffers.cpp +++ /dev/null @@ -1,169 +0,0 @@ -/* - tests/test_buffers.cpp -- supporting Pythons' buffer protocol - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" - -TEST_SUBMODULE(buffers, m) { - // test_from_python / test_to_python: - class Matrix { - public: - Matrix(ssize_t rows, ssize_t cols) : m_rows(rows), m_cols(cols) { - print_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); - m_data = new float[(size_t) (rows*cols)]; - memset(m_data, 0, sizeof(float) * (size_t) (rows * cols)); - } - - Matrix(const Matrix &s) : m_rows(s.m_rows), m_cols(s.m_cols) { - print_copy_created(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); - m_data = new float[(size_t) (m_rows * m_cols)]; - memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols)); - } - - Matrix(Matrix &&s) : m_rows(s.m_rows), m_cols(s.m_cols), m_data(s.m_data) { - print_move_created(this); - s.m_rows = 0; - s.m_cols = 0; - s.m_data = nullptr; - } - - ~Matrix() { - print_destroyed(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); - delete[] m_data; - } - - Matrix &operator=(const Matrix &s) { - print_copy_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); - delete[] m_data; - m_rows = s.m_rows; - m_cols = s.m_cols; - m_data = new float[(size_t) (m_rows * m_cols)]; - memcpy(m_data, s.m_data, sizeof(float) * (size_t) (m_rows * m_cols)); - return *this; - } - - Matrix &operator=(Matrix &&s) { - print_move_assigned(this, std::to_string(m_rows) + "x" + std::to_string(m_cols) + " matrix"); - if (&s != this) { - delete[] m_data; - m_rows = s.m_rows; m_cols = s.m_cols; m_data = s.m_data; - s.m_rows = 0; s.m_cols = 0; s.m_data = nullptr; - } - return *this; - } - - float operator()(ssize_t i, ssize_t j) const { - return m_data[(size_t) (i*m_cols + j)]; - } - - float &operator()(ssize_t i, ssize_t j) { - return m_data[(size_t) (i*m_cols + j)]; - } - - float *data() { return m_data; } - - ssize_t rows() const { return m_rows; } - ssize_t cols() const { return m_cols; } - private: - ssize_t m_rows; - ssize_t m_cols; - float *m_data; - }; - py::class_(m, "Matrix", py::buffer_protocol()) - .def(py::init()) - /// Construct from a buffer - .def(py::init([](py::buffer const b) { - py::buffer_info info = b.request(); - if (info.format != py::format_descriptor::format() || info.ndim != 2) - throw std::runtime_error("Incompatible buffer format!"); - - auto v = new Matrix(info.shape[0], info.shape[1]); - memcpy(v->data(), info.ptr, sizeof(float) * (size_t) (v->rows() * v->cols())); - return v; - })) - - .def("rows", &Matrix::rows) - .def("cols", &Matrix::cols) - - /// Bare bones interface - .def("__getitem__", [](const Matrix &m, std::pair i) { - if (i.first >= m.rows() || i.second >= m.cols()) - throw py::index_error(); - return m(i.first, i.second); - }) - .def("__setitem__", [](Matrix &m, std::pair i, float v) { - if (i.first >= m.rows() || i.second >= m.cols()) - throw py::index_error(); - m(i.first, i.second) = v; - }) - /// Provide buffer access - .def_buffer([](Matrix &m) -> py::buffer_info { - return py::buffer_info( - m.data(), /* Pointer to buffer */ - { m.rows(), m.cols() }, /* Buffer dimensions */ - { sizeof(float) * size_t(m.cols()), /* Strides (in bytes) for each index */ - sizeof(float) } - ); - }) - ; - - - // test_inherited_protocol - class SquareMatrix : public Matrix { - public: - SquareMatrix(ssize_t n) : Matrix(n, n) { } - }; - // Derived classes inherit the buffer protocol and the buffer access function - py::class_(m, "SquareMatrix") - .def(py::init()); - - - // test_pointer_to_member_fn - // Tests that passing a pointer to member to the base class works in - // the derived class. - struct Buffer { - int32_t value = 0; - - py::buffer_info get_buffer_info() { - return py::buffer_info(&value, sizeof(value), - py::format_descriptor::format(), 1); - } - }; - py::class_(m, "Buffer", py::buffer_protocol()) - .def(py::init<>()) - .def_readwrite("value", &Buffer::value) - .def_buffer(&Buffer::get_buffer_info); - - - class ConstBuffer { - std::unique_ptr value; - - public: - int32_t get_value() const { return *value; } - void set_value(int32_t v) { *value = v; } - - py::buffer_info get_buffer_info() const { - return py::buffer_info(value.get(), sizeof(*value), - py::format_descriptor::format(), 1); - } - - ConstBuffer() : value(new int32_t{0}) { }; - }; - py::class_(m, "ConstBuffer", py::buffer_protocol()) - .def(py::init<>()) - .def_property("value", &ConstBuffer::get_value, &ConstBuffer::set_value) - .def_buffer(&ConstBuffer::get_buffer_info); - - struct DerivedBuffer : public Buffer { }; - py::class_(m, "DerivedBuffer", py::buffer_protocol()) - .def(py::init<>()) - .def_readwrite("value", (int32_t DerivedBuffer::*) &DerivedBuffer::value) - .def_buffer(&DerivedBuffer::get_buffer_info); - -} diff --git a/pybind11/tests/test_buffers.py b/pybind11/tests/test_buffers.py deleted file mode 100644 index f006552..0000000 --- a/pybind11/tests/test_buffers.py +++ /dev/null @@ -1,87 +0,0 @@ -import struct -import pytest -from pybind11_tests import buffers as m -from pybind11_tests import ConstructorStats - -pytestmark = pytest.requires_numpy - -with pytest.suppress(ImportError): - import numpy as np - - -def test_from_python(): - with pytest.raises(RuntimeError) as excinfo: - m.Matrix(np.array([1, 2, 3])) # trying to assign a 1D array - assert str(excinfo.value) == "Incompatible buffer format!" - - m3 = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) - m4 = m.Matrix(m3) - - for i in range(m4.rows()): - for j in range(m4.cols()): - assert m3[i, j] == m4[i, j] - - cstats = ConstructorStats.get(m.Matrix) - assert cstats.alive() == 1 - del m3, m4 - assert cstats.alive() == 0 - assert cstats.values() == ["2x3 matrix"] - assert cstats.copy_constructions == 0 - # assert cstats.move_constructions >= 0 # Don't invoke any - assert cstats.copy_assignments == 0 - assert cstats.move_assignments == 0 - - -# PyPy: Memory leak in the "np.array(m, copy=False)" call -# https://bitbucket.org/pypy/pypy/issues/2444 -@pytest.unsupported_on_pypy -def test_to_python(): - mat = m.Matrix(5, 4) - assert memoryview(mat).shape == (5, 4) - - assert mat[2, 3] == 0 - mat[2, 3] = 4.0 - mat[3, 2] = 7.0 - assert mat[2, 3] == 4 - assert mat[3, 2] == 7 - assert struct.unpack_from('f', mat, (3 * 4 + 2) * 4) == (7, ) - assert struct.unpack_from('f', mat, (2 * 4 + 3) * 4) == (4, ) - - mat2 = np.array(mat, copy=False) - assert mat2.shape == (5, 4) - assert abs(mat2).sum() == 11 - assert mat2[2, 3] == 4 and mat2[3, 2] == 7 - mat2[2, 3] = 5 - assert mat2[2, 3] == 5 - - cstats = ConstructorStats.get(m.Matrix) - assert cstats.alive() == 1 - del mat - pytest.gc_collect() - assert cstats.alive() == 1 - del mat2 # holds a mat reference - pytest.gc_collect() - assert cstats.alive() == 0 - assert cstats.values() == ["5x4 matrix"] - assert cstats.copy_constructions == 0 - # assert cstats.move_constructions >= 0 # Don't invoke any - assert cstats.copy_assignments == 0 - assert cstats.move_assignments == 0 - - -@pytest.unsupported_on_pypy -def test_inherited_protocol(): - """SquareMatrix is derived from Matrix and inherits the buffer protocol""" - - matrix = m.SquareMatrix(5) - assert memoryview(matrix).shape == (5, 5) - assert np.asarray(matrix).shape == (5, 5) - - -@pytest.unsupported_on_pypy -def test_pointer_to_member_fn(): - for cls in [m.Buffer, m.ConstBuffer, m.DerivedBuffer]: - buf = cls() - buf.value = 0x12345678 - value = struct.unpack('i', bytearray(buf))[0] - assert value == 0x12345678 diff --git a/pybind11/tests/test_builtin_casters.cpp b/pybind11/tests/test_builtin_casters.cpp deleted file mode 100644 index e026127..0000000 --- a/pybind11/tests/test_builtin_casters.cpp +++ /dev/null @@ -1,170 +0,0 @@ -/* - tests/test_builtin_casters.cpp -- Casters available without any additional headers - - Copyright (c) 2017 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include - -#if defined(_MSC_VER) -# pragma warning(push) -# pragma warning(disable: 4127) // warning C4127: Conditional expression is constant -#endif - -TEST_SUBMODULE(builtin_casters, m) { - // test_simple_string - m.def("string_roundtrip", [](const char *s) { return s; }); - - // test_unicode_conversion - // Some test characters in utf16 and utf32 encodings. The last one (the 𝐀) contains a null byte - char32_t a32 = 0x61 /*a*/, z32 = 0x7a /*z*/, ib32 = 0x203d /*‽*/, cake32 = 0x1f382 /*🎂*/, mathbfA32 = 0x1d400 /*𝐀*/; - char16_t b16 = 0x62 /*b*/, z16 = 0x7a, ib16 = 0x203d, cake16_1 = 0xd83c, cake16_2 = 0xdf82, mathbfA16_1 = 0xd835, mathbfA16_2 = 0xdc00; - std::wstring wstr; - wstr.push_back(0x61); // a - wstr.push_back(0x2e18); // ⸘ - if (sizeof(wchar_t) == 2) { wstr.push_back(mathbfA16_1); wstr.push_back(mathbfA16_2); } // 𝐀, utf16 - else { wstr.push_back((wchar_t) mathbfA32); } // 𝐀, utf32 - wstr.push_back(0x7a); // z - - m.def("good_utf8_string", []() { return std::string(u8"Say utf8\u203d \U0001f382 \U0001d400"); }); // Say utf8‽ 🎂 𝐀 - m.def("good_utf16_string", [=]() { return std::u16string({ b16, ib16, cake16_1, cake16_2, mathbfA16_1, mathbfA16_2, z16 }); }); // b‽🎂𝐀z - m.def("good_utf32_string", [=]() { return std::u32string({ a32, mathbfA32, cake32, ib32, z32 }); }); // a𝐀🎂‽z - m.def("good_wchar_string", [=]() { return wstr; }); // a‽𝐀z - m.def("bad_utf8_string", []() { return std::string("abc\xd0" "def"); }); - m.def("bad_utf16_string", [=]() { return std::u16string({ b16, char16_t(0xd800), z16 }); }); - // Under Python 2.7, invalid unicode UTF-32 characters don't appear to trigger UnicodeDecodeError - if (PY_MAJOR_VERSION >= 3) - m.def("bad_utf32_string", [=]() { return std::u32string({ a32, char32_t(0xd800), z32 }); }); - if (PY_MAJOR_VERSION >= 3 || sizeof(wchar_t) == 2) - m.def("bad_wchar_string", [=]() { return std::wstring({ wchar_t(0x61), wchar_t(0xd800) }); }); - m.def("u8_Z", []() -> char { return 'Z'; }); - m.def("u8_eacute", []() -> char { return '\xe9'; }); - m.def("u16_ibang", [=]() -> char16_t { return ib16; }); - m.def("u32_mathbfA", [=]() -> char32_t { return mathbfA32; }); - m.def("wchar_heart", []() -> wchar_t { return 0x2665; }); - - // test_single_char_arguments - m.attr("wchar_size") = py::cast(sizeof(wchar_t)); - m.def("ord_char", [](char c) -> int { return static_cast(c); }); - m.def("ord_char_lv", [](char &c) -> int { return static_cast(c); }); - m.def("ord_char16", [](char16_t c) -> uint16_t { return c; }); - m.def("ord_char16_lv", [](char16_t &c) -> uint16_t { return c; }); - m.def("ord_char32", [](char32_t c) -> uint32_t { return c; }); - m.def("ord_wchar", [](wchar_t c) -> int { return c; }); - - // test_bytes_to_string - m.def("strlen", [](char *s) { return strlen(s); }); - m.def("string_length", [](std::string s) { return s.length(); }); - - // test_string_view -#ifdef PYBIND11_HAS_STRING_VIEW - m.attr("has_string_view") = true; - m.def("string_view_print", [](std::string_view s) { py::print(s, s.size()); }); - m.def("string_view16_print", [](std::u16string_view s) { py::print(s, s.size()); }); - m.def("string_view32_print", [](std::u32string_view s) { py::print(s, s.size()); }); - m.def("string_view_chars", [](std::string_view s) { py::list l; for (auto c : s) l.append((std::uint8_t) c); return l; }); - m.def("string_view16_chars", [](std::u16string_view s) { py::list l; for (auto c : s) l.append((int) c); return l; }); - m.def("string_view32_chars", [](std::u32string_view s) { py::list l; for (auto c : s) l.append((int) c); return l; }); - m.def("string_view_return", []() { return std::string_view(u8"utf8 secret \U0001f382"); }); - m.def("string_view16_return", []() { return std::u16string_view(u"utf16 secret \U0001f382"); }); - m.def("string_view32_return", []() { return std::u32string_view(U"utf32 secret \U0001f382"); }); -#endif - - // test_integer_casting - m.def("i32_str", [](std::int32_t v) { return std::to_string(v); }); - m.def("u32_str", [](std::uint32_t v) { return std::to_string(v); }); - m.def("i64_str", [](std::int64_t v) { return std::to_string(v); }); - m.def("u64_str", [](std::uint64_t v) { return std::to_string(v); }); - - // test_tuple - m.def("pair_passthrough", [](std::pair input) { - return std::make_pair(input.second, input.first); - }, "Return a pair in reversed order"); - m.def("tuple_passthrough", [](std::tuple input) { - return std::make_tuple(std::get<2>(input), std::get<1>(input), std::get<0>(input)); - }, "Return a triple in reversed order"); - m.def("empty_tuple", []() { return std::tuple<>(); }); - static std::pair lvpair; - static std::tuple lvtuple; - static std::pair>> lvnested; - m.def("rvalue_pair", []() { return std::make_pair(RValueCaster{}, RValueCaster{}); }); - m.def("lvalue_pair", []() -> const decltype(lvpair) & { return lvpair; }); - m.def("rvalue_tuple", []() { return std::make_tuple(RValueCaster{}, RValueCaster{}, RValueCaster{}); }); - m.def("lvalue_tuple", []() -> const decltype(lvtuple) & { return lvtuple; }); - m.def("rvalue_nested", []() { - return std::make_pair(RValueCaster{}, std::make_tuple(RValueCaster{}, std::make_pair(RValueCaster{}, RValueCaster{}))); }); - m.def("lvalue_nested", []() -> const decltype(lvnested) & { return lvnested; }); - - // test_builtins_cast_return_none - m.def("return_none_string", []() -> std::string * { return nullptr; }); - m.def("return_none_char", []() -> const char * { return nullptr; }); - m.def("return_none_bool", []() -> bool * { return nullptr; }); - m.def("return_none_int", []() -> int * { return nullptr; }); - m.def("return_none_float", []() -> float * { return nullptr; }); - - // test_none_deferred - m.def("defer_none_cstring", [](char *) { return false; }); - m.def("defer_none_cstring", [](py::none) { return true; }); - m.def("defer_none_custom", [](UserType *) { return false; }); - m.def("defer_none_custom", [](py::none) { return true; }); - m.def("nodefer_none_void", [](void *) { return true; }); - m.def("nodefer_none_void", [](py::none) { return false; }); - - // test_void_caster - m.def("load_nullptr_t", [](std::nullptr_t) {}); // not useful, but it should still compile - m.def("cast_nullptr_t", []() { return std::nullptr_t{}; }); - - // test_bool_caster - m.def("bool_passthrough", [](bool arg) { return arg; }); - m.def("bool_passthrough_noconvert", [](bool arg) { return arg; }, py::arg().noconvert()); - - // test_reference_wrapper - m.def("refwrap_builtin", [](std::reference_wrapper p) { return 10 * p.get(); }); - m.def("refwrap_usertype", [](std::reference_wrapper p) { return p.get().value(); }); - // Not currently supported (std::pair caster has return-by-value cast operator); - // triggers static_assert failure. - //m.def("refwrap_pair", [](std::reference_wrapper>) { }); - - m.def("refwrap_list", [](bool copy) { - static IncType x1(1), x2(2); - py::list l; - for (auto &f : {std::ref(x1), std::ref(x2)}) { - l.append(py::cast(f, copy ? py::return_value_policy::copy - : py::return_value_policy::reference)); - } - return l; - }, "copy"_a); - - m.def("refwrap_iiw", [](const IncType &w) { return w.value(); }); - m.def("refwrap_call_iiw", [](IncType &w, py::function f) { - py::list l; - l.append(f(std::ref(w))); - l.append(f(std::cref(w))); - IncType x(w.value()); - l.append(f(std::ref(x))); - IncType y(w.value()); - auto r3 = std::ref(y); - l.append(f(r3)); - return l; - }); - - // test_complex - m.def("complex_cast", [](float x) { return "{}"_s.format(x); }); - m.def("complex_cast", [](std::complex x) { return "({}, {})"_s.format(x.real(), x.imag()); }); - - // test int vs. long (Python 2) - m.def("int_cast", []() {return (int) 42;}); - m.def("long_cast", []() {return (long) 42;}); - m.def("longlong_cast", []() {return ULLONG_MAX;}); - - /// test void* cast operator - m.def("test_void_caster", []() -> bool { - void *v = (void *) 0xabcd; - py::object o = py::cast(v); - return py::cast(o) == v; - }); -} diff --git a/pybind11/tests/test_builtin_casters.py b/pybind11/tests/test_builtin_casters.py deleted file mode 100644 index 73cc465..0000000 --- a/pybind11/tests/test_builtin_casters.py +++ /dev/null @@ -1,342 +0,0 @@ -# Python < 3 needs this: coding=utf-8 -import pytest - -from pybind11_tests import builtin_casters as m -from pybind11_tests import UserType, IncType - - -def test_simple_string(): - assert m.string_roundtrip("const char *") == "const char *" - - -def test_unicode_conversion(): - """Tests unicode conversion and error reporting.""" - assert m.good_utf8_string() == u"Say utf8‽ 🎂 𝐀" - assert m.good_utf16_string() == u"b‽🎂𝐀z" - assert m.good_utf32_string() == u"a𝐀🎂‽z" - assert m.good_wchar_string() == u"a⸘𝐀z" - - with pytest.raises(UnicodeDecodeError): - m.bad_utf8_string() - - with pytest.raises(UnicodeDecodeError): - m.bad_utf16_string() - - # These are provided only if they actually fail (they don't when 32-bit and under Python 2.7) - if hasattr(m, "bad_utf32_string"): - with pytest.raises(UnicodeDecodeError): - m.bad_utf32_string() - if hasattr(m, "bad_wchar_string"): - with pytest.raises(UnicodeDecodeError): - m.bad_wchar_string() - - assert m.u8_Z() == 'Z' - assert m.u8_eacute() == u'é' - assert m.u16_ibang() == u'‽' - assert m.u32_mathbfA() == u'𝐀' - assert m.wchar_heart() == u'♥' - - -def test_single_char_arguments(): - """Tests failures for passing invalid inputs to char-accepting functions""" - def toobig_message(r): - return "Character code point not in range({0:#x})".format(r) - toolong_message = "Expected a character, but multi-character string found" - - assert m.ord_char(u'a') == 0x61 # simple ASCII - assert m.ord_char_lv(u'b') == 0x62 - assert m.ord_char(u'é') == 0xE9 # requires 2 bytes in utf-8, but can be stuffed in a char - with pytest.raises(ValueError) as excinfo: - assert m.ord_char(u'Ā') == 0x100 # requires 2 bytes, doesn't fit in a char - assert str(excinfo.value) == toobig_message(0x100) - with pytest.raises(ValueError) as excinfo: - assert m.ord_char(u'ab') - assert str(excinfo.value) == toolong_message - - assert m.ord_char16(u'a') == 0x61 - assert m.ord_char16(u'é') == 0xE9 - assert m.ord_char16_lv(u'ê') == 0xEA - assert m.ord_char16(u'Ā') == 0x100 - assert m.ord_char16(u'‽') == 0x203d - assert m.ord_char16(u'♥') == 0x2665 - assert m.ord_char16_lv(u'♡') == 0x2661 - with pytest.raises(ValueError) as excinfo: - assert m.ord_char16(u'🎂') == 0x1F382 # requires surrogate pair - assert str(excinfo.value) == toobig_message(0x10000) - with pytest.raises(ValueError) as excinfo: - assert m.ord_char16(u'aa') - assert str(excinfo.value) == toolong_message - - assert m.ord_char32(u'a') == 0x61 - assert m.ord_char32(u'é') == 0xE9 - assert m.ord_char32(u'Ā') == 0x100 - assert m.ord_char32(u'‽') == 0x203d - assert m.ord_char32(u'♥') == 0x2665 - assert m.ord_char32(u'🎂') == 0x1F382 - with pytest.raises(ValueError) as excinfo: - assert m.ord_char32(u'aa') - assert str(excinfo.value) == toolong_message - - assert m.ord_wchar(u'a') == 0x61 - assert m.ord_wchar(u'é') == 0xE9 - assert m.ord_wchar(u'Ā') == 0x100 - assert m.ord_wchar(u'‽') == 0x203d - assert m.ord_wchar(u'♥') == 0x2665 - if m.wchar_size == 2: - with pytest.raises(ValueError) as excinfo: - assert m.ord_wchar(u'🎂') == 0x1F382 # requires surrogate pair - assert str(excinfo.value) == toobig_message(0x10000) - else: - assert m.ord_wchar(u'🎂') == 0x1F382 - with pytest.raises(ValueError) as excinfo: - assert m.ord_wchar(u'aa') - assert str(excinfo.value) == toolong_message - - -def test_bytes_to_string(): - """Tests the ability to pass bytes to C++ string-accepting functions. Note that this is - one-way: the only way to return bytes to Python is via the pybind11::bytes class.""" - # Issue #816 - import sys - byte = bytes if sys.version_info[0] < 3 else str - - assert m.strlen(byte("hi")) == 2 - assert m.string_length(byte("world")) == 5 - assert m.string_length(byte("a\x00b")) == 3 - assert m.strlen(byte("a\x00b")) == 1 # C-string limitation - - # passing in a utf8 encoded string should work - assert m.string_length(u'💩'.encode("utf8")) == 4 - - -@pytest.mark.skipif(not hasattr(m, "has_string_view"), reason="no ") -def test_string_view(capture): - """Tests support for C++17 string_view arguments and return values""" - assert m.string_view_chars("Hi") == [72, 105] - assert m.string_view_chars("Hi 🎂") == [72, 105, 32, 0xf0, 0x9f, 0x8e, 0x82] - assert m.string_view16_chars("Hi 🎂") == [72, 105, 32, 0xd83c, 0xdf82] - assert m.string_view32_chars("Hi 🎂") == [72, 105, 32, 127874] - - assert m.string_view_return() == "utf8 secret 🎂" - assert m.string_view16_return() == "utf16 secret 🎂" - assert m.string_view32_return() == "utf32 secret 🎂" - - with capture: - m.string_view_print("Hi") - m.string_view_print("utf8 🎂") - m.string_view16_print("utf16 🎂") - m.string_view32_print("utf32 🎂") - assert capture == """ - Hi 2 - utf8 🎂 9 - utf16 🎂 8 - utf32 🎂 7 - """ - - with capture: - m.string_view_print("Hi, ascii") - m.string_view_print("Hi, utf8 🎂") - m.string_view16_print("Hi, utf16 🎂") - m.string_view32_print("Hi, utf32 🎂") - assert capture == """ - Hi, ascii 9 - Hi, utf8 🎂 13 - Hi, utf16 🎂 12 - Hi, utf32 🎂 11 - """ - - -def test_integer_casting(): - """Issue #929 - out-of-range integer values shouldn't be accepted""" - import sys - assert m.i32_str(-1) == "-1" - assert m.i64_str(-1) == "-1" - assert m.i32_str(2000000000) == "2000000000" - assert m.u32_str(2000000000) == "2000000000" - if sys.version_info < (3,): - assert m.i32_str(long(-1)) == "-1" # noqa: F821 undefined name 'long' - assert m.i64_str(long(-1)) == "-1" # noqa: F821 undefined name 'long' - assert m.i64_str(long(-999999999999)) == "-999999999999" # noqa: F821 undefined name - assert m.u64_str(long(999999999999)) == "999999999999" # noqa: F821 undefined name 'long' - else: - assert m.i64_str(-999999999999) == "-999999999999" - assert m.u64_str(999999999999) == "999999999999" - - with pytest.raises(TypeError) as excinfo: - m.u32_str(-1) - assert "incompatible function arguments" in str(excinfo.value) - with pytest.raises(TypeError) as excinfo: - m.u64_str(-1) - assert "incompatible function arguments" in str(excinfo.value) - with pytest.raises(TypeError) as excinfo: - m.i32_str(-3000000000) - assert "incompatible function arguments" in str(excinfo.value) - with pytest.raises(TypeError) as excinfo: - m.i32_str(3000000000) - assert "incompatible function arguments" in str(excinfo.value) - - if sys.version_info < (3,): - with pytest.raises(TypeError) as excinfo: - m.u32_str(long(-1)) # noqa: F821 undefined name 'long' - assert "incompatible function arguments" in str(excinfo.value) - with pytest.raises(TypeError) as excinfo: - m.u64_str(long(-1)) # noqa: F821 undefined name 'long' - assert "incompatible function arguments" in str(excinfo.value) - - -def test_tuple(doc): - """std::pair <-> tuple & std::tuple <-> tuple""" - assert m.pair_passthrough((True, "test")) == ("test", True) - assert m.tuple_passthrough((True, "test", 5)) == (5, "test", True) - # Any sequence can be cast to a std::pair or std::tuple - assert m.pair_passthrough([True, "test"]) == ("test", True) - assert m.tuple_passthrough([True, "test", 5]) == (5, "test", True) - assert m.empty_tuple() == () - - assert doc(m.pair_passthrough) == """ - pair_passthrough(arg0: Tuple[bool, str]) -> Tuple[str, bool] - - Return a pair in reversed order - """ - assert doc(m.tuple_passthrough) == """ - tuple_passthrough(arg0: Tuple[bool, str, int]) -> Tuple[int, str, bool] - - Return a triple in reversed order - """ - - assert m.rvalue_pair() == ("rvalue", "rvalue") - assert m.lvalue_pair() == ("lvalue", "lvalue") - assert m.rvalue_tuple() == ("rvalue", "rvalue", "rvalue") - assert m.lvalue_tuple() == ("lvalue", "lvalue", "lvalue") - assert m.rvalue_nested() == ("rvalue", ("rvalue", ("rvalue", "rvalue"))) - assert m.lvalue_nested() == ("lvalue", ("lvalue", ("lvalue", "lvalue"))) - - -def test_builtins_cast_return_none(): - """Casters produced with PYBIND11_TYPE_CASTER() should convert nullptr to None""" - assert m.return_none_string() is None - assert m.return_none_char() is None - assert m.return_none_bool() is None - assert m.return_none_int() is None - assert m.return_none_float() is None - - -def test_none_deferred(): - """None passed as various argument types should defer to other overloads""" - assert not m.defer_none_cstring("abc") - assert m.defer_none_cstring(None) - assert not m.defer_none_custom(UserType()) - assert m.defer_none_custom(None) - assert m.nodefer_none_void(None) - - -def test_void_caster(): - assert m.load_nullptr_t(None) is None - assert m.cast_nullptr_t() is None - - -def test_reference_wrapper(): - """std::reference_wrapper for builtin and user types""" - assert m.refwrap_builtin(42) == 420 - assert m.refwrap_usertype(UserType(42)) == 42 - - with pytest.raises(TypeError) as excinfo: - m.refwrap_builtin(None) - assert "incompatible function arguments" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - m.refwrap_usertype(None) - assert "incompatible function arguments" in str(excinfo.value) - - a1 = m.refwrap_list(copy=True) - a2 = m.refwrap_list(copy=True) - assert [x.value for x in a1] == [2, 3] - assert [x.value for x in a2] == [2, 3] - assert not a1[0] is a2[0] and not a1[1] is a2[1] - - b1 = m.refwrap_list(copy=False) - b2 = m.refwrap_list(copy=False) - assert [x.value for x in b1] == [1, 2] - assert [x.value for x in b2] == [1, 2] - assert b1[0] is b2[0] and b1[1] is b2[1] - - assert m.refwrap_iiw(IncType(5)) == 5 - assert m.refwrap_call_iiw(IncType(10), m.refwrap_iiw) == [10, 10, 10, 10] - - -def test_complex_cast(): - """std::complex casts""" - assert m.complex_cast(1) == "1.0" - assert m.complex_cast(2j) == "(0.0, 2.0)" - - -def test_bool_caster(): - """Test bool caster implicit conversions.""" - convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert - - def require_implicit(v): - pytest.raises(TypeError, noconvert, v) - - def cant_convert(v): - pytest.raises(TypeError, convert, v) - - # straight up bool - assert convert(True) is True - assert convert(False) is False - assert noconvert(True) is True - assert noconvert(False) is False - - # None requires implicit conversion - require_implicit(None) - assert convert(None) is False - - class A(object): - def __init__(self, x): - self.x = x - - def __nonzero__(self): - return self.x - - def __bool__(self): - return self.x - - class B(object): - pass - - # Arbitrary objects are not accepted - cant_convert(object()) - cant_convert(B()) - - # Objects with __nonzero__ / __bool__ defined can be converted - require_implicit(A(True)) - assert convert(A(True)) is True - assert convert(A(False)) is False - - -@pytest.requires_numpy -def test_numpy_bool(): - import numpy as np - convert, noconvert = m.bool_passthrough, m.bool_passthrough_noconvert - - # np.bool_ is not considered implicit - assert convert(np.bool_(True)) is True - assert convert(np.bool_(False)) is False - assert noconvert(np.bool_(True)) is True - assert noconvert(np.bool_(False)) is False - - -def test_int_long(): - """In Python 2, a C++ int should return a Python int rather than long - if possible: longs are not always accepted where ints are used (such - as the argument to sys.exit()). A C++ long long is always a Python - long.""" - - import sys - must_be_long = type(getattr(sys, 'maxint', 1) + 1) - assert isinstance(m.int_cast(), int) - assert isinstance(m.long_cast(), int) - assert isinstance(m.longlong_cast(), must_be_long) - - -def test_void_caster_2(): - assert m.test_void_caster() diff --git a/pybind11/tests/test_call_policies.cpp b/pybind11/tests/test_call_policies.cpp deleted file mode 100644 index fd24557..0000000 --- a/pybind11/tests/test_call_policies.cpp +++ /dev/null @@ -1,100 +0,0 @@ -/* - tests/test_call_policies.cpp -- keep_alive and call_guard - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" - -struct CustomGuard { - static bool enabled; - - CustomGuard() { enabled = true; } - ~CustomGuard() { enabled = false; } - - static const char *report_status() { return enabled ? "guarded" : "unguarded"; } -}; -bool CustomGuard::enabled = false; - -struct DependentGuard { - static bool enabled; - - DependentGuard() { enabled = CustomGuard::enabled; } - ~DependentGuard() { enabled = false; } - - static const char *report_status() { return enabled ? "guarded" : "unguarded"; } -}; -bool DependentGuard::enabled = false; - -TEST_SUBMODULE(call_policies, m) { - // Parent/Child are used in: - // test_keep_alive_argument, test_keep_alive_return_value, test_alive_gc_derived, - // test_alive_gc_multi_derived, test_return_none, test_keep_alive_constructor - class Child { - public: - Child() { py::print("Allocating child."); } - Child(const Child &) = default; - Child(Child &&) = default; - ~Child() { py::print("Releasing child."); } - }; - py::class_(m, "Child") - .def(py::init<>()); - - class Parent { - public: - Parent() { py::print("Allocating parent."); } - ~Parent() { py::print("Releasing parent."); } - void addChild(Child *) { } - Child *returnChild() { return new Child(); } - Child *returnNullChild() { return nullptr; } - }; - py::class_(m, "Parent") - .def(py::init<>()) - .def(py::init([](Child *) { return new Parent(); }), py::keep_alive<1, 2>()) - .def("addChild", &Parent::addChild) - .def("addChildKeepAlive", &Parent::addChild, py::keep_alive<1, 2>()) - .def("returnChild", &Parent::returnChild) - .def("returnChildKeepAlive", &Parent::returnChild, py::keep_alive<1, 0>()) - .def("returnNullChildKeepAliveChild", &Parent::returnNullChild, py::keep_alive<1, 0>()) - .def("returnNullChildKeepAliveParent", &Parent::returnNullChild, py::keep_alive<0, 1>()); - -#if !defined(PYPY_VERSION) - // test_alive_gc - class ParentGC : public Parent { - public: - using Parent::Parent; - }; - py::class_(m, "ParentGC", py::dynamic_attr()) - .def(py::init<>()); -#endif - - // test_call_guard - m.def("unguarded_call", &CustomGuard::report_status); - m.def("guarded_call", &CustomGuard::report_status, py::call_guard()); - - m.def("multiple_guards_correct_order", []() { - return CustomGuard::report_status() + std::string(" & ") + DependentGuard::report_status(); - }, py::call_guard()); - - m.def("multiple_guards_wrong_order", []() { - return DependentGuard::report_status() + std::string(" & ") + CustomGuard::report_status(); - }, py::call_guard()); - -#if defined(WITH_THREAD) && !defined(PYPY_VERSION) - // `py::call_guard()` should work in PyPy as well, - // but it's unclear how to test it without `PyGILState_GetThisThreadState`. - auto report_gil_status = []() { - auto is_gil_held = false; - if (auto tstate = py::detail::get_thread_state_unchecked()) - is_gil_held = (tstate == PyGILState_GetThisThreadState()); - - return is_gil_held ? "GIL held" : "GIL released"; - }; - - m.def("with_gil", report_gil_status); - m.def("without_gil", report_gil_status, py::call_guard()); -#endif -} diff --git a/pybind11/tests/test_call_policies.py b/pybind11/tests/test_call_policies.py deleted file mode 100644 index 7c83559..0000000 --- a/pybind11/tests/test_call_policies.py +++ /dev/null @@ -1,187 +0,0 @@ -import pytest -from pybind11_tests import call_policies as m -from pybind11_tests import ConstructorStats - - -def test_keep_alive_argument(capture): - n_inst = ConstructorStats.detail_reg_inst() - with capture: - p = m.Parent() - assert capture == "Allocating parent." - with capture: - p.addChild(m.Child()) - assert ConstructorStats.detail_reg_inst() == n_inst + 1 - assert capture == """ - Allocating child. - Releasing child. - """ - with capture: - del p - assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == "Releasing parent." - - with capture: - p = m.Parent() - assert capture == "Allocating parent." - with capture: - p.addChildKeepAlive(m.Child()) - assert ConstructorStats.detail_reg_inst() == n_inst + 2 - assert capture == "Allocating child." - with capture: - del p - assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ - Releasing parent. - Releasing child. - """ - - -def test_keep_alive_return_value(capture): - n_inst = ConstructorStats.detail_reg_inst() - with capture: - p = m.Parent() - assert capture == "Allocating parent." - with capture: - p.returnChild() - assert ConstructorStats.detail_reg_inst() == n_inst + 1 - assert capture == """ - Allocating child. - Releasing child. - """ - with capture: - del p - assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == "Releasing parent." - - with capture: - p = m.Parent() - assert capture == "Allocating parent." - with capture: - p.returnChildKeepAlive() - assert ConstructorStats.detail_reg_inst() == n_inst + 2 - assert capture == "Allocating child." - with capture: - del p - assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ - Releasing parent. - Releasing child. - """ - - -# https://bitbucket.org/pypy/pypy/issues/2447 -@pytest.unsupported_on_pypy -def test_alive_gc(capture): - n_inst = ConstructorStats.detail_reg_inst() - p = m.ParentGC() - p.addChildKeepAlive(m.Child()) - assert ConstructorStats.detail_reg_inst() == n_inst + 2 - lst = [p] - lst.append(lst) # creates a circular reference - with capture: - del p, lst - assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ - Releasing parent. - Releasing child. - """ - - -def test_alive_gc_derived(capture): - class Derived(m.Parent): - pass - - n_inst = ConstructorStats.detail_reg_inst() - p = Derived() - p.addChildKeepAlive(m.Child()) - assert ConstructorStats.detail_reg_inst() == n_inst + 2 - lst = [p] - lst.append(lst) # creates a circular reference - with capture: - del p, lst - assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ - Releasing parent. - Releasing child. - """ - - -def test_alive_gc_multi_derived(capture): - class Derived(m.Parent, m.Child): - def __init__(self): - m.Parent.__init__(self) - m.Child.__init__(self) - - n_inst = ConstructorStats.detail_reg_inst() - p = Derived() - p.addChildKeepAlive(m.Child()) - # +3 rather than +2 because Derived corresponds to two registered instances - assert ConstructorStats.detail_reg_inst() == n_inst + 3 - lst = [p] - lst.append(lst) # creates a circular reference - with capture: - del p, lst - assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ - Releasing parent. - Releasing child. - Releasing child. - """ - - -def test_return_none(capture): - n_inst = ConstructorStats.detail_reg_inst() - with capture: - p = m.Parent() - assert capture == "Allocating parent." - with capture: - p.returnNullChildKeepAliveChild() - assert ConstructorStats.detail_reg_inst() == n_inst + 1 - assert capture == "" - with capture: - del p - assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == "Releasing parent." - - with capture: - p = m.Parent() - assert capture == "Allocating parent." - with capture: - p.returnNullChildKeepAliveParent() - assert ConstructorStats.detail_reg_inst() == n_inst + 1 - assert capture == "" - with capture: - del p - assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == "Releasing parent." - - -def test_keep_alive_constructor(capture): - n_inst = ConstructorStats.detail_reg_inst() - - with capture: - p = m.Parent(m.Child()) - assert ConstructorStats.detail_reg_inst() == n_inst + 2 - assert capture == """ - Allocating child. - Allocating parent. - """ - with capture: - del p - assert ConstructorStats.detail_reg_inst() == n_inst - assert capture == """ - Releasing parent. - Releasing child. - """ - - -def test_call_guard(): - assert m.unguarded_call() == "unguarded" - assert m.guarded_call() == "guarded" - - assert m.multiple_guards_correct_order() == "guarded & guarded" - assert m.multiple_guards_wrong_order() == "unguarded & guarded" - - if hasattr(m, "with_gil"): - assert m.with_gil() == "GIL held" - assert m.without_gil() == "GIL released" diff --git a/pybind11/tests/test_callbacks.cpp b/pybind11/tests/test_callbacks.cpp deleted file mode 100644 index 71b88c4..0000000 --- a/pybind11/tests/test_callbacks.cpp +++ /dev/null @@ -1,168 +0,0 @@ -/* - tests/test_callbacks.cpp -- callbacks - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" -#include -#include - - -int dummy_function(int i) { return i + 1; } - -TEST_SUBMODULE(callbacks, m) { - // test_callbacks, test_function_signatures - m.def("test_callback1", [](py::object func) { return func(); }); - m.def("test_callback2", [](py::object func) { return func("Hello", 'x', true, 5); }); - m.def("test_callback3", [](const std::function &func) { - return "func(43) = " + std::to_string(func(43)); }); - m.def("test_callback4", []() -> std::function { return [](int i) { return i+1; }; }); - m.def("test_callback5", []() { - return py::cpp_function([](int i) { return i+1; }, py::arg("number")); - }); - - // test_keyword_args_and_generalized_unpacking - m.def("test_tuple_unpacking", [](py::function f) { - auto t1 = py::make_tuple(2, 3); - auto t2 = py::make_tuple(5, 6); - return f("positional", 1, *t1, 4, *t2); - }); - - m.def("test_dict_unpacking", [](py::function f) { - auto d1 = py::dict("key"_a="value", "a"_a=1); - auto d2 = py::dict(); - auto d3 = py::dict("b"_a=2); - return f("positional", 1, **d1, **d2, **d3); - }); - - m.def("test_keyword_args", [](py::function f) { - return f("x"_a=10, "y"_a=20); - }); - - m.def("test_unpacking_and_keywords1", [](py::function f) { - auto args = py::make_tuple(2); - auto kwargs = py::dict("d"_a=4); - return f(1, *args, "c"_a=3, **kwargs); - }); - - m.def("test_unpacking_and_keywords2", [](py::function f) { - auto kwargs1 = py::dict("a"_a=1); - auto kwargs2 = py::dict("c"_a=3, "d"_a=4); - return f("positional", *py::make_tuple(1), 2, *py::make_tuple(3, 4), 5, - "key"_a="value", **kwargs1, "b"_a=2, **kwargs2, "e"_a=5); - }); - - m.def("test_unpacking_error1", [](py::function f) { - auto kwargs = py::dict("x"_a=3); - return f("x"_a=1, "y"_a=2, **kwargs); // duplicate ** after keyword - }); - - m.def("test_unpacking_error2", [](py::function f) { - auto kwargs = py::dict("x"_a=3); - return f(**kwargs, "x"_a=1); // duplicate keyword after ** - }); - - m.def("test_arg_conversion_error1", [](py::function f) { - f(234, UnregisteredType(), "kw"_a=567); - }); - - m.def("test_arg_conversion_error2", [](py::function f) { - f(234, "expected_name"_a=UnregisteredType(), "kw"_a=567); - }); - - // test_lambda_closure_cleanup - struct Payload { - Payload() { print_default_created(this); } - ~Payload() { print_destroyed(this); } - Payload(const Payload &) { print_copy_created(this); } - Payload(Payload &&) { print_move_created(this); } - }; - // Export the payload constructor statistics for testing purposes: - m.def("payload_cstats", &ConstructorStats::get); - /* Test cleanup of lambda closure */ - m.def("test_cleanup", []() -> std::function { - Payload p; - - return [p]() { - /* p should be cleaned up when the returned function is garbage collected */ - (void) p; - }; - }); - - // test_cpp_function_roundtrip - /* Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer */ - m.def("dummy_function", &dummy_function); - m.def("dummy_function2", [](int i, int j) { return i + j; }); - m.def("roundtrip", [](std::function f, bool expect_none = false) { - if (expect_none && f) - throw std::runtime_error("Expected None to be converted to empty std::function"); - return f; - }, py::arg("f"), py::arg("expect_none")=false); - m.def("test_dummy_function", [](const std::function &f) -> std::string { - using fn_type = int (*)(int); - auto result = f.target(); - if (!result) { - auto r = f(1); - return "can't convert to function pointer: eval(1) = " + std::to_string(r); - } else if (*result == dummy_function) { - auto r = (*result)(1); - return "matches dummy_function: eval(1) = " + std::to_string(r); - } else { - return "argument does NOT match dummy_function. This should never happen!"; - } - }); - - class AbstractBase { public: virtual unsigned int func() = 0; }; - m.def("func_accepting_func_accepting_base", [](std::function) { }); - - struct MovableObject { - bool valid = true; - - MovableObject() = default; - MovableObject(const MovableObject &) = default; - MovableObject &operator=(const MovableObject &) = default; - MovableObject(MovableObject &&o) : valid(o.valid) { o.valid = false; } - MovableObject &operator=(MovableObject &&o) { - valid = o.valid; - o.valid = false; - return *this; - } - }; - py::class_(m, "MovableObject"); - - // test_movable_object - m.def("callback_with_movable", [](std::function f) { - auto x = MovableObject(); - f(x); // lvalue reference shouldn't move out object - return x.valid; // must still return `true` - }); - - // test_bound_method_callback - struct CppBoundMethodTest {}; - py::class_(m, "CppBoundMethodTest") - .def(py::init<>()) - .def("triple", [](CppBoundMethodTest &, int val) { return 3 * val; }); - - // test async Python callbacks - using callback_f = std::function; - m.def("test_async_callback", [](callback_f f, py::list work) { - // make detached thread that calls `f` with piece of work after a little delay - auto start_f = [f](int j) { - auto invoke_f = [f, j] { - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - f(j); - }; - auto t = std::thread(std::move(invoke_f)); - t.detach(); - }; - - // spawn worker threads - for (auto i : work) - start_f(py::cast(i)); - }); -} diff --git a/pybind11/tests/test_callbacks.py b/pybind11/tests/test_callbacks.py deleted file mode 100644 index 6439c8e..0000000 --- a/pybind11/tests/test_callbacks.py +++ /dev/null @@ -1,136 +0,0 @@ -import pytest -from pybind11_tests import callbacks as m -from threading import Thread - - -def test_callbacks(): - from functools import partial - - def func1(): - return "func1" - - def func2(a, b, c, d): - return "func2", a, b, c, d - - def func3(a): - return "func3({})".format(a) - - assert m.test_callback1(func1) == "func1" - assert m.test_callback2(func2) == ("func2", "Hello", "x", True, 5) - assert m.test_callback1(partial(func2, 1, 2, 3, 4)) == ("func2", 1, 2, 3, 4) - assert m.test_callback1(partial(func3, "partial")) == "func3(partial)" - assert m.test_callback3(lambda i: i + 1) == "func(43) = 44" - - f = m.test_callback4() - assert f(43) == 44 - f = m.test_callback5() - assert f(number=43) == 44 - - -def test_bound_method_callback(): - # Bound Python method: - class MyClass: - def double(self, val): - return 2 * val - - z = MyClass() - assert m.test_callback3(z.double) == "func(43) = 86" - - z = m.CppBoundMethodTest() - assert m.test_callback3(z.triple) == "func(43) = 129" - - -def test_keyword_args_and_generalized_unpacking(): - - def f(*args, **kwargs): - return args, kwargs - - assert m.test_tuple_unpacking(f) == (("positional", 1, 2, 3, 4, 5, 6), {}) - assert m.test_dict_unpacking(f) == (("positional", 1), {"key": "value", "a": 1, "b": 2}) - assert m.test_keyword_args(f) == ((), {"x": 10, "y": 20}) - assert m.test_unpacking_and_keywords1(f) == ((1, 2), {"c": 3, "d": 4}) - assert m.test_unpacking_and_keywords2(f) == ( - ("positional", 1, 2, 3, 4, 5), - {"key": "value", "a": 1, "b": 2, "c": 3, "d": 4, "e": 5} - ) - - with pytest.raises(TypeError) as excinfo: - m.test_unpacking_error1(f) - assert "Got multiple values for keyword argument" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - m.test_unpacking_error2(f) - assert "Got multiple values for keyword argument" in str(excinfo.value) - - with pytest.raises(RuntimeError) as excinfo: - m.test_arg_conversion_error1(f) - assert "Unable to convert call argument" in str(excinfo.value) - - with pytest.raises(RuntimeError) as excinfo: - m.test_arg_conversion_error2(f) - assert "Unable to convert call argument" in str(excinfo.value) - - -def test_lambda_closure_cleanup(): - m.test_cleanup() - cstats = m.payload_cstats() - assert cstats.alive() == 0 - assert cstats.copy_constructions == 1 - assert cstats.move_constructions >= 1 - - -def test_cpp_function_roundtrip(): - """Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer""" - - assert m.test_dummy_function(m.dummy_function) == "matches dummy_function: eval(1) = 2" - assert (m.test_dummy_function(m.roundtrip(m.dummy_function)) == - "matches dummy_function: eval(1) = 2") - assert m.roundtrip(None, expect_none=True) is None - assert (m.test_dummy_function(lambda x: x + 2) == - "can't convert to function pointer: eval(1) = 3") - - with pytest.raises(TypeError) as excinfo: - m.test_dummy_function(m.dummy_function2) - assert "incompatible function arguments" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - m.test_dummy_function(lambda x, y: x + y) - assert any(s in str(excinfo.value) for s in ("missing 1 required positional argument", - "takes exactly 2 arguments")) - - -def test_function_signatures(doc): - assert doc(m.test_callback3) == "test_callback3(arg0: Callable[[int], int]) -> str" - assert doc(m.test_callback4) == "test_callback4() -> Callable[[int], int]" - - -def test_movable_object(): - assert m.callback_with_movable(lambda _: None) is True - - -def test_async_callbacks(): - # serves as state for async callback - class Item: - def __init__(self, value): - self.value = value - - res = [] - - # generate stateful lambda that will store result in `res` - def gen_f(): - s = Item(3) - return lambda j: res.append(s.value + j) - - # do some work async - work = [1, 2, 3, 4] - m.test_async_callback(gen_f(), work) - # wait until work is done - from time import sleep - sleep(0.5) - assert sum(res) == sum([x + 3 for x in work]) - - -def test_async_async_callbacks(): - t = Thread(target=test_async_callbacks) - t.start() - t.join() diff --git a/pybind11/tests/test_chrono.cpp b/pybind11/tests/test_chrono.cpp deleted file mode 100644 index 899d08d..0000000 --- a/pybind11/tests/test_chrono.cpp +++ /dev/null @@ -1,55 +0,0 @@ -/* - tests/test_chrono.cpp -- test conversions to/from std::chrono types - - Copyright (c) 2016 Trent Houliston and - Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include - -TEST_SUBMODULE(chrono, m) { - using system_time = std::chrono::system_clock::time_point; - using steady_time = std::chrono::steady_clock::time_point; - - using timespan = std::chrono::duration; - using timestamp = std::chrono::time_point; - - // test_chrono_system_clock - // Return the current time off the wall clock - m.def("test_chrono1", []() { return std::chrono::system_clock::now(); }); - - // test_chrono_system_clock_roundtrip - // Round trip the passed in system clock time - m.def("test_chrono2", [](system_time t) { return t; }); - - // test_chrono_duration_roundtrip - // Round trip the passed in duration - m.def("test_chrono3", [](std::chrono::system_clock::duration d) { return d; }); - - // test_chrono_duration_subtraction_equivalence - // Difference between two passed in time_points - m.def("test_chrono4", [](system_time a, system_time b) { return a - b; }); - - // test_chrono_steady_clock - // Return the current time off the steady_clock - m.def("test_chrono5", []() { return std::chrono::steady_clock::now(); }); - - // test_chrono_steady_clock_roundtrip - // Round trip a steady clock timepoint - m.def("test_chrono6", [](steady_time t) { return t; }); - - // test_floating_point_duration - // Roundtrip a duration in microseconds from a float argument - m.def("test_chrono7", [](std::chrono::microseconds t) { return t; }); - // Float durations (issue #719) - m.def("test_chrono_float_diff", [](std::chrono::duration a, std::chrono::duration b) { - return a - b; }); - - m.def("test_nano_timepoint", [](timestamp start, timespan delta) -> timestamp { - return start + delta; - }); -} diff --git a/pybind11/tests/test_chrono.py b/pybind11/tests/test_chrono.py deleted file mode 100644 index 55c9544..0000000 --- a/pybind11/tests/test_chrono.py +++ /dev/null @@ -1,176 +0,0 @@ -from pybind11_tests import chrono as m -import datetime - - -def test_chrono_system_clock(): - - # Get the time from both c++ and datetime - date1 = m.test_chrono1() - date2 = datetime.datetime.today() - - # The returned value should be a datetime - assert isinstance(date1, datetime.datetime) - - # The numbers should vary by a very small amount (time it took to execute) - diff = abs(date1 - date2) - - # There should never be a days/seconds difference - assert diff.days == 0 - assert diff.seconds == 0 - - # We test that no more than about 0.5 seconds passes here - # This makes sure that the dates created are very close to the same - # but if the testing system is incredibly overloaded this should still pass - assert diff.microseconds < 500000 - - -def test_chrono_system_clock_roundtrip(): - date1 = datetime.datetime.today() - - # Roundtrip the time - date2 = m.test_chrono2(date1) - - # The returned value should be a datetime - assert isinstance(date2, datetime.datetime) - - # They should be identical (no information lost on roundtrip) - diff = abs(date1 - date2) - assert diff.days == 0 - assert diff.seconds == 0 - assert diff.microseconds == 0 - - -def test_chrono_system_clock_roundtrip_date(): - date1 = datetime.date.today() - - # Roundtrip the time - datetime2 = m.test_chrono2(date1) - date2 = datetime2.date() - time2 = datetime2.time() - - # The returned value should be a datetime - assert isinstance(datetime2, datetime.datetime) - assert isinstance(date2, datetime.date) - assert isinstance(time2, datetime.time) - - # They should be identical (no information lost on roundtrip) - diff = abs(date1 - date2) - assert diff.days == 0 - assert diff.seconds == 0 - assert diff.microseconds == 0 - - # Year, Month & Day should be the same after the round trip - assert date1.year == date2.year - assert date1.month == date2.month - assert date1.day == date2.day - - # There should be no time information - assert time2.hour == 0 - assert time2.minute == 0 - assert time2.second == 0 - assert time2.microsecond == 0 - - -def test_chrono_system_clock_roundtrip_time(): - time1 = datetime.datetime.today().time() - - # Roundtrip the time - datetime2 = m.test_chrono2(time1) - date2 = datetime2.date() - time2 = datetime2.time() - - # The returned value should be a datetime - assert isinstance(datetime2, datetime.datetime) - assert isinstance(date2, datetime.date) - assert isinstance(time2, datetime.time) - - # Hour, Minute, Second & Microsecond should be the same after the round trip - assert time1.hour == time2.hour - assert time1.minute == time2.minute - assert time1.second == time2.second - assert time1.microsecond == time2.microsecond - - # There should be no date information (i.e. date = python base date) - assert date2.year == 1970 - assert date2.month == 1 - assert date2.day == 1 - - -def test_chrono_duration_roundtrip(): - - # Get the difference between two times (a timedelta) - date1 = datetime.datetime.today() - date2 = datetime.datetime.today() - diff = date2 - date1 - - # Make sure this is a timedelta - assert isinstance(diff, datetime.timedelta) - - cpp_diff = m.test_chrono3(diff) - - assert cpp_diff.days == diff.days - assert cpp_diff.seconds == diff.seconds - assert cpp_diff.microseconds == diff.microseconds - - -def test_chrono_duration_subtraction_equivalence(): - - date1 = datetime.datetime.today() - date2 = datetime.datetime.today() - - diff = date2 - date1 - cpp_diff = m.test_chrono4(date2, date1) - - assert cpp_diff.days == diff.days - assert cpp_diff.seconds == diff.seconds - assert cpp_diff.microseconds == diff.microseconds - - -def test_chrono_duration_subtraction_equivalence_date(): - - date1 = datetime.date.today() - date2 = datetime.date.today() - - diff = date2 - date1 - cpp_diff = m.test_chrono4(date2, date1) - - assert cpp_diff.days == diff.days - assert cpp_diff.seconds == diff.seconds - assert cpp_diff.microseconds == diff.microseconds - - -def test_chrono_steady_clock(): - time1 = m.test_chrono5() - assert isinstance(time1, datetime.timedelta) - - -def test_chrono_steady_clock_roundtrip(): - time1 = datetime.timedelta(days=10, seconds=10, microseconds=100) - time2 = m.test_chrono6(time1) - - assert isinstance(time2, datetime.timedelta) - - # They should be identical (no information lost on roundtrip) - assert time1.days == time2.days - assert time1.seconds == time2.seconds - assert time1.microseconds == time2.microseconds - - -def test_floating_point_duration(): - # Test using a floating point number in seconds - time = m.test_chrono7(35.525123) - - assert isinstance(time, datetime.timedelta) - - assert time.seconds == 35 - assert 525122 <= time.microseconds <= 525123 - - diff = m.test_chrono_float_diff(43.789012, 1.123456) - assert diff.seconds == 42 - assert 665556 <= diff.microseconds <= 665557 - - -def test_nano_timepoint(): - time = datetime.datetime.now() - time1 = m.test_nano_timepoint(time, datetime.timedelta(seconds=60)) - assert(time1 == time + datetime.timedelta(seconds=60)) diff --git a/pybind11/tests/test_class.cpp b/pybind11/tests/test_class.cpp deleted file mode 100644 index 499d0cc..0000000 --- a/pybind11/tests/test_class.cpp +++ /dev/null @@ -1,422 +0,0 @@ -/* - tests/test_class.cpp -- test py::class_ definitions and basic functionality - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" -#include "local_bindings.h" -#include - -#if defined(_MSC_VER) -# pragma warning(disable: 4324) // warning C4324: structure was padded due to alignment specifier -#endif - -// test_brace_initialization -struct NoBraceInitialization { - NoBraceInitialization(std::vector v) : vec{std::move(v)} {} - template - NoBraceInitialization(std::initializer_list l) : vec(l) {} - - std::vector vec; -}; - -TEST_SUBMODULE(class_, m) { - // test_instance - struct NoConstructor { - NoConstructor() = default; - NoConstructor(const NoConstructor &) = default; - NoConstructor(NoConstructor &&) = default; - static NoConstructor *new_instance() { - auto *ptr = new NoConstructor(); - print_created(ptr, "via new_instance"); - return ptr; - } - ~NoConstructor() { print_destroyed(this); } - }; - - py::class_(m, "NoConstructor") - .def_static("new_instance", &NoConstructor::new_instance, "Return an instance"); - - // test_inheritance - class Pet { - public: - Pet(const std::string &name, const std::string &species) - : m_name(name), m_species(species) {} - std::string name() const { return m_name; } - std::string species() const { return m_species; } - private: - std::string m_name; - std::string m_species; - }; - - class Dog : public Pet { - public: - Dog(const std::string &name) : Pet(name, "dog") {} - std::string bark() const { return "Woof!"; } - }; - - class Rabbit : public Pet { - public: - Rabbit(const std::string &name) : Pet(name, "parrot") {} - }; - - class Hamster : public Pet { - public: - Hamster(const std::string &name) : Pet(name, "rodent") {} - }; - - class Chimera : public Pet { - Chimera() : Pet("Kimmy", "chimera") {} - }; - - py::class_ pet_class(m, "Pet"); - pet_class - .def(py::init()) - .def("name", &Pet::name) - .def("species", &Pet::species); - - /* One way of declaring a subclass relationship: reference parent's class_ object */ - py::class_(m, "Dog", pet_class) - .def(py::init()); - - /* Another way of declaring a subclass relationship: reference parent's C++ type */ - py::class_(m, "Rabbit") - .def(py::init()); - - /* And another: list parent in class template arguments */ - py::class_(m, "Hamster") - .def(py::init()); - - /* Constructors are not inherited by default */ - py::class_(m, "Chimera"); - - m.def("pet_name_species", [](const Pet &pet) { return pet.name() + " is a " + pet.species(); }); - m.def("dog_bark", [](const Dog &dog) { return dog.bark(); }); - - // test_automatic_upcasting - struct BaseClass { - BaseClass() = default; - BaseClass(const BaseClass &) = default; - BaseClass(BaseClass &&) = default; - virtual ~BaseClass() {} - }; - struct DerivedClass1 : BaseClass { }; - struct DerivedClass2 : BaseClass { }; - - py::class_(m, "BaseClass").def(py::init<>()); - py::class_(m, "DerivedClass1").def(py::init<>()); - py::class_(m, "DerivedClass2").def(py::init<>()); - - m.def("return_class_1", []() -> BaseClass* { return new DerivedClass1(); }); - m.def("return_class_2", []() -> BaseClass* { return new DerivedClass2(); }); - m.def("return_class_n", [](int n) -> BaseClass* { - if (n == 1) return new DerivedClass1(); - if (n == 2) return new DerivedClass2(); - return new BaseClass(); - }); - m.def("return_none", []() -> BaseClass* { return nullptr; }); - - // test_isinstance - m.def("check_instances", [](py::list l) { - return py::make_tuple( - py::isinstance(l[0]), - py::isinstance(l[1]), - py::isinstance(l[2]), - py::isinstance(l[3]), - py::isinstance(l[4]), - py::isinstance(l[5]), - py::isinstance(l[6]) - ); - }); - - // test_mismatched_holder - struct MismatchBase1 { }; - struct MismatchDerived1 : MismatchBase1 { }; - - struct MismatchBase2 { }; - struct MismatchDerived2 : MismatchBase2 { }; - - m.def("mismatched_holder_1", []() { - auto mod = py::module::import("__main__"); - py::class_>(mod, "MismatchBase1"); - py::class_(mod, "MismatchDerived1"); - }); - m.def("mismatched_holder_2", []() { - auto mod = py::module::import("__main__"); - py::class_(mod, "MismatchBase2"); - py::class_, - MismatchBase2>(mod, "MismatchDerived2"); - }); - - // test_override_static - // #511: problem with inheritance + overwritten def_static - struct MyBase { - static std::unique_ptr make() { - return std::unique_ptr(new MyBase()); - } - }; - - struct MyDerived : MyBase { - static std::unique_ptr make() { - return std::unique_ptr(new MyDerived()); - } - }; - - py::class_(m, "MyBase") - .def_static("make", &MyBase::make); - - py::class_(m, "MyDerived") - .def_static("make", &MyDerived::make) - .def_static("make2", &MyDerived::make); - - // test_implicit_conversion_life_support - struct ConvertibleFromUserType { - int i; - - ConvertibleFromUserType(UserType u) : i(u.value()) { } - }; - - py::class_(m, "AcceptsUserType") - .def(py::init()); - py::implicitly_convertible(); - - m.def("implicitly_convert_argument", [](const ConvertibleFromUserType &r) { return r.i; }); - m.def("implicitly_convert_variable", [](py::object o) { - // `o` is `UserType` and `r` is a reference to a temporary created by implicit - // conversion. This is valid when called inside a bound function because the temp - // object is attached to the same life support system as the arguments. - const auto &r = o.cast(); - return r.i; - }); - m.add_object("implicitly_convert_variable_fail", [&] { - auto f = [](PyObject *, PyObject *args) -> PyObject * { - auto o = py::reinterpret_borrow(args)[0]; - try { // It should fail here because there is no life support. - o.cast(); - } catch (const py::cast_error &e) { - return py::str(e.what()).release().ptr(); - } - return py::str().release().ptr(); - }; - - auto def = new PyMethodDef{"f", f, METH_VARARGS, nullptr}; - return py::reinterpret_steal(PyCFunction_NewEx(def, nullptr, m.ptr())); - }()); - - // test_operator_new_delete - struct HasOpNewDel { - std::uint64_t i; - static void *operator new(size_t s) { py::print("A new", s); return ::operator new(s); } - static void *operator new(size_t s, void *ptr) { py::print("A placement-new", s); return ptr; } - static void operator delete(void *p) { py::print("A delete"); return ::operator delete(p); } - }; - struct HasOpNewDelSize { - std::uint32_t i; - static void *operator new(size_t s) { py::print("B new", s); return ::operator new(s); } - static void *operator new(size_t s, void *ptr) { py::print("B placement-new", s); return ptr; } - static void operator delete(void *p, size_t s) { py::print("B delete", s); return ::operator delete(p); } - }; - struct AliasedHasOpNewDelSize { - std::uint64_t i; - static void *operator new(size_t s) { py::print("C new", s); return ::operator new(s); } - static void *operator new(size_t s, void *ptr) { py::print("C placement-new", s); return ptr; } - static void operator delete(void *p, size_t s) { py::print("C delete", s); return ::operator delete(p); } - virtual ~AliasedHasOpNewDelSize() = default; - }; - struct PyAliasedHasOpNewDelSize : AliasedHasOpNewDelSize { - PyAliasedHasOpNewDelSize() = default; - PyAliasedHasOpNewDelSize(int) { } - std::uint64_t j; - }; - struct HasOpNewDelBoth { - std::uint32_t i[8]; - static void *operator new(size_t s) { py::print("D new", s); return ::operator new(s); } - static void *operator new(size_t s, void *ptr) { py::print("D placement-new", s); return ptr; } - static void operator delete(void *p) { py::print("D delete"); return ::operator delete(p); } - static void operator delete(void *p, size_t s) { py::print("D wrong delete", s); return ::operator delete(p); } - }; - py::class_(m, "HasOpNewDel").def(py::init<>()); - py::class_(m, "HasOpNewDelSize").def(py::init<>()); - py::class_(m, "HasOpNewDelBoth").def(py::init<>()); - py::class_ aliased(m, "AliasedHasOpNewDelSize"); - aliased.def(py::init<>()); - aliased.attr("size_noalias") = py::int_(sizeof(AliasedHasOpNewDelSize)); - aliased.attr("size_alias") = py::int_(sizeof(PyAliasedHasOpNewDelSize)); - - // This test is actually part of test_local_bindings (test_duplicate_local), but we need a - // definition in a different compilation unit within the same module: - bind_local(m, "LocalExternal", py::module_local()); - - // test_bind_protected_functions - class ProtectedA { - protected: - int foo() const { return value; } - - private: - int value = 42; - }; - - class PublicistA : public ProtectedA { - public: - using ProtectedA::foo; - }; - - py::class_(m, "ProtectedA") - .def(py::init<>()) -#if !defined(_MSC_VER) || _MSC_VER >= 1910 - .def("foo", &PublicistA::foo); -#else - .def("foo", static_cast(&PublicistA::foo)); -#endif - - class ProtectedB { - public: - virtual ~ProtectedB() = default; - - protected: - virtual int foo() const { return value; } - - private: - int value = 42; - }; - - class TrampolineB : public ProtectedB { - public: - int foo() const override { PYBIND11_OVERLOAD(int, ProtectedB, foo, ); } - }; - - class PublicistB : public ProtectedB { - public: - using ProtectedB::foo; - }; - - py::class_(m, "ProtectedB") - .def(py::init<>()) -#if !defined(_MSC_VER) || _MSC_VER >= 1910 - .def("foo", &PublicistB::foo); -#else - .def("foo", static_cast(&PublicistB::foo)); -#endif - - // test_brace_initialization - struct BraceInitialization { - int field1; - std::string field2; - }; - - py::class_(m, "BraceInitialization") - .def(py::init()) - .def_readwrite("field1", &BraceInitialization::field1) - .def_readwrite("field2", &BraceInitialization::field2); - // We *don't* want to construct using braces when the given constructor argument maps to a - // constructor, because brace initialization could go to the wrong place (in particular when - // there is also an `initializer_list`-accept constructor): - py::class_(m, "NoBraceInitialization") - .def(py::init>()) - .def_readonly("vec", &NoBraceInitialization::vec); - - // test_reentrant_implicit_conversion_failure - // #1035: issue with runaway reentrant implicit conversion - struct BogusImplicitConversion { - BogusImplicitConversion(const BogusImplicitConversion &) { } - }; - - py::class_(m, "BogusImplicitConversion") - .def(py::init()); - - py::implicitly_convertible(); - - // test_qualname - // #1166: nested class docstring doesn't show nested name - // Also related: tests that __qualname__ is set properly - struct NestBase {}; - struct Nested {}; - py::class_ base(m, "NestBase"); - base.def(py::init<>()); - py::class_(base, "Nested") - .def(py::init<>()) - .def("fn", [](Nested &, int, NestBase &, Nested &) {}) - .def("fa", [](Nested &, int, NestBase &, Nested &) {}, - "a"_a, "b"_a, "c"_a); - base.def("g", [](NestBase &, Nested &) {}); - base.def("h", []() { return NestBase(); }); - - // test_error_after_conversion - // The second-pass path through dispatcher() previously didn't - // remember which overload was used, and would crash trying to - // generate a useful error message - - struct NotRegistered {}; - struct StringWrapper { std::string str; }; - m.def("test_error_after_conversions", [](int) {}); - m.def("test_error_after_conversions", - [](StringWrapper) -> NotRegistered { return {}; }); - py::class_(m, "StringWrapper").def(py::init()); - py::implicitly_convertible(); - - #if defined(PYBIND11_CPP17) - struct alignas(1024) Aligned { - std::uintptr_t ptr() const { return (uintptr_t) this; } - }; - py::class_(m, "Aligned") - .def(py::init<>()) - .def("ptr", &Aligned::ptr); - #endif -} - -template class BreaksBase { public: virtual ~BreaksBase() = default; }; -template class BreaksTramp : public BreaksBase {}; -// These should all compile just fine: -typedef py::class_, std::unique_ptr>, BreaksTramp<1>> DoesntBreak1; -typedef py::class_, BreaksTramp<2>, std::unique_ptr>> DoesntBreak2; -typedef py::class_, std::unique_ptr>> DoesntBreak3; -typedef py::class_, BreaksTramp<4>> DoesntBreak4; -typedef py::class_> DoesntBreak5; -typedef py::class_, std::shared_ptr>, BreaksTramp<6>> DoesntBreak6; -typedef py::class_, BreaksTramp<7>, std::shared_ptr>> DoesntBreak7; -typedef py::class_, std::shared_ptr>> DoesntBreak8; -#define CHECK_BASE(N) static_assert(std::is_same>::value, \ - "DoesntBreak" #N " has wrong type!") -CHECK_BASE(1); CHECK_BASE(2); CHECK_BASE(3); CHECK_BASE(4); CHECK_BASE(5); CHECK_BASE(6); CHECK_BASE(7); CHECK_BASE(8); -#define CHECK_ALIAS(N) static_assert(DoesntBreak##N::has_alias && std::is_same>::value, \ - "DoesntBreak" #N " has wrong type_alias!") -#define CHECK_NOALIAS(N) static_assert(!DoesntBreak##N::has_alias && std::is_void::value, \ - "DoesntBreak" #N " has type alias, but shouldn't!") -CHECK_ALIAS(1); CHECK_ALIAS(2); CHECK_NOALIAS(3); CHECK_ALIAS(4); CHECK_NOALIAS(5); CHECK_ALIAS(6); CHECK_ALIAS(7); CHECK_NOALIAS(8); -#define CHECK_HOLDER(N, TYPE) static_assert(std::is_same>>::value, \ - "DoesntBreak" #N " has wrong holder_type!") -CHECK_HOLDER(1, unique); CHECK_HOLDER(2, unique); CHECK_HOLDER(3, unique); CHECK_HOLDER(4, unique); CHECK_HOLDER(5, unique); -CHECK_HOLDER(6, shared); CHECK_HOLDER(7, shared); CHECK_HOLDER(8, shared); - -// There's no nice way to test that these fail because they fail to compile; leave them here, -// though, so that they can be manually tested by uncommenting them (and seeing that compilation -// failures occurs). - -// We have to actually look into the type: the typedef alone isn't enough to instantiate the type: -#define CHECK_BROKEN(N) static_assert(std::is_same>::value, \ - "Breaks1 has wrong type!"); - -//// Two holder classes: -//typedef py::class_, std::unique_ptr>, std::unique_ptr>> Breaks1; -//CHECK_BROKEN(1); -//// Two aliases: -//typedef py::class_, BreaksTramp<-2>, BreaksTramp<-2>> Breaks2; -//CHECK_BROKEN(2); -//// Holder + 2 aliases -//typedef py::class_, std::unique_ptr>, BreaksTramp<-3>, BreaksTramp<-3>> Breaks3; -//CHECK_BROKEN(3); -//// Alias + 2 holders -//typedef py::class_, std::unique_ptr>, BreaksTramp<-4>, std::shared_ptr>> Breaks4; -//CHECK_BROKEN(4); -//// Invalid option (not a subclass or holder) -//typedef py::class_, BreaksTramp<-4>> Breaks5; -//CHECK_BROKEN(5); -//// Invalid option: multiple inheritance not supported: -//template <> struct BreaksBase<-8> : BreaksBase<-6>, BreaksBase<-7> {}; -//typedef py::class_, BreaksBase<-6>, BreaksBase<-7>> Breaks8; -//CHECK_BROKEN(8); diff --git a/pybind11/tests/test_class.py b/pybind11/tests/test_class.py deleted file mode 100644 index ed63ca8..0000000 --- a/pybind11/tests/test_class.py +++ /dev/null @@ -1,281 +0,0 @@ -import pytest - -from pybind11_tests import class_ as m -from pybind11_tests import UserType, ConstructorStats - - -def test_repr(): - # In Python 3.3+, repr() accesses __qualname__ - assert "pybind11_type" in repr(type(UserType)) - assert "UserType" in repr(UserType) - - -def test_instance(msg): - with pytest.raises(TypeError) as excinfo: - m.NoConstructor() - assert msg(excinfo.value) == "m.class_.NoConstructor: No constructor defined!" - - instance = m.NoConstructor.new_instance() - - cstats = ConstructorStats.get(m.NoConstructor) - assert cstats.alive() == 1 - del instance - assert cstats.alive() == 0 - - -def test_docstrings(doc): - assert doc(UserType) == "A `py::class_` type for testing" - assert UserType.__name__ == "UserType" - assert UserType.__module__ == "pybind11_tests" - assert UserType.get_value.__name__ == "get_value" - assert UserType.get_value.__module__ == "pybind11_tests" - - assert doc(UserType.get_value) == """ - get_value(self: m.UserType) -> int - - Get value using a method - """ - assert doc(UserType.value) == "Get/set value using a property" - - assert doc(m.NoConstructor.new_instance) == """ - new_instance() -> m.class_.NoConstructor - - Return an instance - """ - - -def test_qualname(doc): - """Tests that a properly qualified name is set in __qualname__ (even in pre-3.3, where we - backport the attribute) and that generated docstrings properly use it and the module name""" - assert m.NestBase.__qualname__ == "NestBase" - assert m.NestBase.Nested.__qualname__ == "NestBase.Nested" - - assert doc(m.NestBase.__init__) == """ - __init__(self: m.class_.NestBase) -> None - """ - assert doc(m.NestBase.g) == """ - g(self: m.class_.NestBase, arg0: m.class_.NestBase.Nested) -> None - """ - assert doc(m.NestBase.Nested.__init__) == """ - __init__(self: m.class_.NestBase.Nested) -> None - """ - assert doc(m.NestBase.Nested.fn) == """ - fn(self: m.class_.NestBase.Nested, arg0: int, arg1: m.class_.NestBase, arg2: m.class_.NestBase.Nested) -> None - """ # noqa: E501 line too long - assert doc(m.NestBase.Nested.fa) == """ - fa(self: m.class_.NestBase.Nested, a: int, b: m.class_.NestBase, c: m.class_.NestBase.Nested) -> None - """ # noqa: E501 line too long - assert m.NestBase.__module__ == "pybind11_tests.class_" - assert m.NestBase.Nested.__module__ == "pybind11_tests.class_" - - -def test_inheritance(msg): - roger = m.Rabbit('Rabbit') - assert roger.name() + " is a " + roger.species() == "Rabbit is a parrot" - assert m.pet_name_species(roger) == "Rabbit is a parrot" - - polly = m.Pet('Polly', 'parrot') - assert polly.name() + " is a " + polly.species() == "Polly is a parrot" - assert m.pet_name_species(polly) == "Polly is a parrot" - - molly = m.Dog('Molly') - assert molly.name() + " is a " + molly.species() == "Molly is a dog" - assert m.pet_name_species(molly) == "Molly is a dog" - - fred = m.Hamster('Fred') - assert fred.name() + " is a " + fred.species() == "Fred is a rodent" - - assert m.dog_bark(molly) == "Woof!" - - with pytest.raises(TypeError) as excinfo: - m.dog_bark(polly) - assert msg(excinfo.value) == """ - dog_bark(): incompatible function arguments. The following argument types are supported: - 1. (arg0: m.class_.Dog) -> str - - Invoked with: - """ - - with pytest.raises(TypeError) as excinfo: - m.Chimera("lion", "goat") - assert "No constructor defined!" in str(excinfo.value) - - -def test_automatic_upcasting(): - assert type(m.return_class_1()).__name__ == "DerivedClass1" - assert type(m.return_class_2()).__name__ == "DerivedClass2" - assert type(m.return_none()).__name__ == "NoneType" - # Repeat these a few times in a random order to ensure no invalid caching is applied - assert type(m.return_class_n(1)).__name__ == "DerivedClass1" - assert type(m.return_class_n(2)).__name__ == "DerivedClass2" - assert type(m.return_class_n(0)).__name__ == "BaseClass" - assert type(m.return_class_n(2)).__name__ == "DerivedClass2" - assert type(m.return_class_n(2)).__name__ == "DerivedClass2" - assert type(m.return_class_n(0)).__name__ == "BaseClass" - assert type(m.return_class_n(1)).__name__ == "DerivedClass1" - - -def test_isinstance(): - objects = [tuple(), dict(), m.Pet("Polly", "parrot")] + [m.Dog("Molly")] * 4 - expected = (True, True, True, True, True, False, False) - assert m.check_instances(objects) == expected - - -def test_mismatched_holder(): - import re - - with pytest.raises(RuntimeError) as excinfo: - m.mismatched_holder_1() - assert re.match('generic_type: type ".*MismatchDerived1" does not have a non-default ' - 'holder type while its base ".*MismatchBase1" does', str(excinfo.value)) - - with pytest.raises(RuntimeError) as excinfo: - m.mismatched_holder_2() - assert re.match('generic_type: type ".*MismatchDerived2" has a non-default holder type ' - 'while its base ".*MismatchBase2" does not', str(excinfo.value)) - - -def test_override_static(): - """#511: problem with inheritance + overwritten def_static""" - b = m.MyBase.make() - d1 = m.MyDerived.make2() - d2 = m.MyDerived.make() - - assert isinstance(b, m.MyBase) - assert isinstance(d1, m.MyDerived) - assert isinstance(d2, m.MyDerived) - - -def test_implicit_conversion_life_support(): - """Ensure the lifetime of temporary objects created for implicit conversions""" - assert m.implicitly_convert_argument(UserType(5)) == 5 - assert m.implicitly_convert_variable(UserType(5)) == 5 - - assert "outside a bound function" in m.implicitly_convert_variable_fail(UserType(5)) - - -def test_operator_new_delete(capture): - """Tests that class-specific operator new/delete functions are invoked""" - - class SubAliased(m.AliasedHasOpNewDelSize): - pass - - with capture: - a = m.HasOpNewDel() - b = m.HasOpNewDelSize() - d = m.HasOpNewDelBoth() - assert capture == """ - A new 8 - B new 4 - D new 32 - """ - sz_alias = str(m.AliasedHasOpNewDelSize.size_alias) - sz_noalias = str(m.AliasedHasOpNewDelSize.size_noalias) - with capture: - c = m.AliasedHasOpNewDelSize() - c2 = SubAliased() - assert capture == ( - "C new " + sz_noalias + "\n" + - "C new " + sz_alias + "\n" - ) - - with capture: - del a - pytest.gc_collect() - del b - pytest.gc_collect() - del d - pytest.gc_collect() - assert capture == """ - A delete - B delete 4 - D delete - """ - - with capture: - del c - pytest.gc_collect() - del c2 - pytest.gc_collect() - assert capture == ( - "C delete " + sz_noalias + "\n" + - "C delete " + sz_alias + "\n" - ) - - -def test_bind_protected_functions(): - """Expose protected member functions to Python using a helper class""" - a = m.ProtectedA() - assert a.foo() == 42 - - b = m.ProtectedB() - assert b.foo() == 42 - - class C(m.ProtectedB): - def __init__(self): - m.ProtectedB.__init__(self) - - def foo(self): - return 0 - - c = C() - assert c.foo() == 0 - - -def test_brace_initialization(): - """ Tests that simple POD classes can be constructed using C++11 brace initialization """ - a = m.BraceInitialization(123, "test") - assert a.field1 == 123 - assert a.field2 == "test" - - # Tests that a non-simple class doesn't get brace initialization (if the - # class defines an initializer_list constructor, in particular, it would - # win over the expected constructor). - b = m.NoBraceInitialization([123, 456]) - assert b.vec == [123, 456] - - -@pytest.unsupported_on_pypy -def test_class_refcount(): - """Instances must correctly increase/decrease the reference count of their types (#1029)""" - from sys import getrefcount - - class PyDog(m.Dog): - pass - - for cls in m.Dog, PyDog: - refcount_1 = getrefcount(cls) - molly = [cls("Molly") for _ in range(10)] - refcount_2 = getrefcount(cls) - - del molly - pytest.gc_collect() - refcount_3 = getrefcount(cls) - - assert refcount_1 == refcount_3 - assert refcount_2 > refcount_1 - - -def test_reentrant_implicit_conversion_failure(msg): - # ensure that there is no runaway reentrant implicit conversion (#1035) - with pytest.raises(TypeError) as excinfo: - m.BogusImplicitConversion(0) - assert msg(excinfo.value) == ''' - __init__(): incompatible constructor arguments. The following argument types are supported: - 1. m.class_.BogusImplicitConversion(arg0: m.class_.BogusImplicitConversion) - - Invoked with: 0 - ''' - - -def test_error_after_conversions(): - with pytest.raises(TypeError) as exc_info: - m.test_error_after_conversions("hello") - assert str(exc_info.value).startswith( - "Unable to convert function return value to a Python type!") - - -def test_aligned(): - if hasattr(m, "Aligned"): - p = m.Aligned().ptr() - assert p % 1024 == 0 diff --git a/pybind11/tests/test_cmake_build/CMakeLists.txt b/pybind11/tests/test_cmake_build/CMakeLists.txt deleted file mode 100644 index c9b5fcb..0000000 --- a/pybind11/tests/test_cmake_build/CMakeLists.txt +++ /dev/null @@ -1,58 +0,0 @@ -add_custom_target(test_cmake_build) - -if(CMAKE_VERSION VERSION_LESS 3.1) - # 3.0 needed for interface library for subdirectory_target/installed_target - # 3.1 needed for cmake -E env for testing - return() -endif() - -include(CMakeParseArguments) -function(pybind11_add_build_test name) - cmake_parse_arguments(ARG "INSTALL" "" "" ${ARGN}) - - set(build_options "-DCMAKE_PREFIX_PATH=${PROJECT_BINARY_DIR}/mock_install" - "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" - "-DPYTHON_EXECUTABLE:FILEPATH=${PYTHON_EXECUTABLE}" - "-DPYBIND11_CPP_STANDARD=${PYBIND11_CPP_STANDARD}") - if(NOT ARG_INSTALL) - list(APPEND build_options "-DPYBIND11_PROJECT_DIR=${PROJECT_SOURCE_DIR}") - endif() - - add_custom_target(test_${name} ${CMAKE_CTEST_COMMAND} - --quiet --output-log ${name}.log - --build-and-test "${CMAKE_CURRENT_SOURCE_DIR}/${name}" - "${CMAKE_CURRENT_BINARY_DIR}/${name}" - --build-config Release - --build-noclean - --build-generator ${CMAKE_GENERATOR} - $<$:--build-generator-platform> ${CMAKE_GENERATOR_PLATFORM} - --build-makeprogram ${CMAKE_MAKE_PROGRAM} - --build-target check - --build-options ${build_options} - ) - if(ARG_INSTALL) - add_dependencies(test_${name} mock_install) - endif() - add_dependencies(test_cmake_build test_${name}) -endfunction() - -pybind11_add_build_test(subdirectory_function) -pybind11_add_build_test(subdirectory_target) -if(NOT ${PYTHON_MODULE_EXTENSION} MATCHES "pypy") - pybind11_add_build_test(subdirectory_embed) -endif() - -if(PYBIND11_INSTALL) - add_custom_target(mock_install ${CMAKE_COMMAND} - "-DCMAKE_INSTALL_PREFIX=${PROJECT_BINARY_DIR}/mock_install" - -P "${PROJECT_BINARY_DIR}/cmake_install.cmake" - ) - - pybind11_add_build_test(installed_function INSTALL) - pybind11_add_build_test(installed_target INSTALL) - if(NOT ${PYTHON_MODULE_EXTENSION} MATCHES "pypy") - pybind11_add_build_test(installed_embed INSTALL) - endif() -endif() - -add_dependencies(check test_cmake_build) diff --git a/pybind11/tests/test_cmake_build/embed.cpp b/pybind11/tests/test_cmake_build/embed.cpp deleted file mode 100644 index b9581d2..0000000 --- a/pybind11/tests/test_cmake_build/embed.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include -namespace py = pybind11; - -PYBIND11_EMBEDDED_MODULE(test_cmake_build, m) { - m.def("add", [](int i, int j) { return i + j; }); -} - -int main(int argc, char *argv[]) { - if (argc != 2) - throw std::runtime_error("Expected test.py file as the first argument"); - auto test_py_file = argv[1]; - - py::scoped_interpreter guard{}; - - auto m = py::module::import("test_cmake_build"); - if (m.attr("add")(1, 2).cast() != 3) - throw std::runtime_error("embed.cpp failed"); - - py::module::import("sys").attr("argv") = py::make_tuple("test.py", "embed.cpp"); - py::eval_file(test_py_file, py::globals()); -} diff --git a/pybind11/tests/test_cmake_build/installed_embed/CMakeLists.txt b/pybind11/tests/test_cmake_build/installed_embed/CMakeLists.txt deleted file mode 100644 index f7fc09c..0000000 --- a/pybind11/tests/test_cmake_build/installed_embed/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -cmake_minimum_required(VERSION 3.0) -project(test_installed_embed CXX) - -set(CMAKE_MODULE_PATH "") -find_package(pybind11 CONFIG REQUIRED) -message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIRS}") - -add_executable(test_cmake_build ../embed.cpp) -target_link_libraries(test_cmake_build PRIVATE pybind11::embed) - -# Do not treat includes from IMPORTED target as SYSTEM (Python headers in pybind11::embed). -# This may be needed to resolve header conflicts, e.g. between Python release and debug headers. -set_target_properties(test_cmake_build PROPERTIES NO_SYSTEM_FROM_IMPORTED ON) - -add_custom_target(check $ ${PROJECT_SOURCE_DIR}/../test.py) diff --git a/pybind11/tests/test_cmake_build/installed_function/CMakeLists.txt b/pybind11/tests/test_cmake_build/installed_function/CMakeLists.txt deleted file mode 100644 index e0c20a8..0000000 --- a/pybind11/tests/test_cmake_build/installed_function/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -cmake_minimum_required(VERSION 2.8.12) -project(test_installed_module CXX) - -set(CMAKE_MODULE_PATH "") - -find_package(pybind11 CONFIG REQUIRED) -message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIRS}") - -pybind11_add_module(test_cmake_build SHARED NO_EXTRAS ../main.cpp) - -add_custom_target(check ${CMAKE_COMMAND} -E env PYTHONPATH=$ - ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/../test.py ${PROJECT_NAME}) diff --git a/pybind11/tests/test_cmake_build/installed_target/CMakeLists.txt b/pybind11/tests/test_cmake_build/installed_target/CMakeLists.txt deleted file mode 100644 index cd3ae6f..0000000 --- a/pybind11/tests/test_cmake_build/installed_target/CMakeLists.txt +++ /dev/null @@ -1,22 +0,0 @@ -cmake_minimum_required(VERSION 3.0) -project(test_installed_target CXX) - -set(CMAKE_MODULE_PATH "") - -find_package(pybind11 CONFIG REQUIRED) -message(STATUS "Found pybind11 v${pybind11_VERSION}: ${pybind11_INCLUDE_DIRS}") - -add_library(test_cmake_build MODULE ../main.cpp) - -target_link_libraries(test_cmake_build PRIVATE pybind11::module) - -# make sure result is, for example, test_installed_target.so, not libtest_installed_target.dylib -set_target_properties(test_cmake_build PROPERTIES PREFIX "${PYTHON_MODULE_PREFIX}" - SUFFIX "${PYTHON_MODULE_EXTENSION}") - -# Do not treat includes from IMPORTED target as SYSTEM (Python headers in pybind11::module). -# This may be needed to resolve header conflicts, e.g. between Python release and debug headers. -set_target_properties(test_cmake_build PROPERTIES NO_SYSTEM_FROM_IMPORTED ON) - -add_custom_target(check ${CMAKE_COMMAND} -E env PYTHONPATH=$ - ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/../test.py ${PROJECT_NAME}) diff --git a/pybind11/tests/test_cmake_build/main.cpp b/pybind11/tests/test_cmake_build/main.cpp deleted file mode 100644 index e30f2c4..0000000 --- a/pybind11/tests/test_cmake_build/main.cpp +++ /dev/null @@ -1,6 +0,0 @@ -#include -namespace py = pybind11; - -PYBIND11_MODULE(test_cmake_build, m) { - m.def("add", [](int i, int j) { return i + j; }); -} diff --git a/pybind11/tests/test_cmake_build/subdirectory_embed/CMakeLists.txt b/pybind11/tests/test_cmake_build/subdirectory_embed/CMakeLists.txt deleted file mode 100644 index 88ba60d..0000000 --- a/pybind11/tests/test_cmake_build/subdirectory_embed/CMakeLists.txt +++ /dev/null @@ -1,25 +0,0 @@ -cmake_minimum_required(VERSION 3.0) -project(test_subdirectory_embed CXX) - -set(PYBIND11_INSTALL ON CACHE BOOL "") -set(PYBIND11_EXPORT_NAME test_export) - -add_subdirectory(${PYBIND11_PROJECT_DIR} pybind11) - -# Test basic target functionality -add_executable(test_cmake_build ../embed.cpp) -target_link_libraries(test_cmake_build PRIVATE pybind11::embed) - -add_custom_target(check $ ${PROJECT_SOURCE_DIR}/../test.py) - -# Test custom export group -- PYBIND11_EXPORT_NAME -add_library(test_embed_lib ../embed.cpp) -target_link_libraries(test_embed_lib PRIVATE pybind11::embed) - -install(TARGETS test_embed_lib - EXPORT test_export - ARCHIVE DESTINATION bin - LIBRARY DESTINATION lib - RUNTIME DESTINATION lib) -install(EXPORT test_export - DESTINATION lib/cmake/test_export/test_export-Targets.cmake) diff --git a/pybind11/tests/test_cmake_build/subdirectory_function/CMakeLists.txt b/pybind11/tests/test_cmake_build/subdirectory_function/CMakeLists.txt deleted file mode 100644 index 278007a..0000000 --- a/pybind11/tests/test_cmake_build/subdirectory_function/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -cmake_minimum_required(VERSION 2.8.12) -project(test_subdirectory_module CXX) - -add_subdirectory(${PYBIND11_PROJECT_DIR} pybind11) -pybind11_add_module(test_cmake_build THIN_LTO ../main.cpp) - -add_custom_target(check ${CMAKE_COMMAND} -E env PYTHONPATH=$ - ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/../test.py ${PROJECT_NAME}) diff --git a/pybind11/tests/test_cmake_build/subdirectory_target/CMakeLists.txt b/pybind11/tests/test_cmake_build/subdirectory_target/CMakeLists.txt deleted file mode 100644 index 6b142d6..0000000 --- a/pybind11/tests/test_cmake_build/subdirectory_target/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -cmake_minimum_required(VERSION 3.0) -project(test_subdirectory_target CXX) - -add_subdirectory(${PYBIND11_PROJECT_DIR} pybind11) - -add_library(test_cmake_build MODULE ../main.cpp) - -target_link_libraries(test_cmake_build PRIVATE pybind11::module) - -# make sure result is, for example, test_installed_target.so, not libtest_installed_target.dylib -set_target_properties(test_cmake_build PROPERTIES PREFIX "${PYTHON_MODULE_PREFIX}" - SUFFIX "${PYTHON_MODULE_EXTENSION}") - -add_custom_target(check ${CMAKE_COMMAND} -E env PYTHONPATH=$ - ${PYTHON_EXECUTABLE} ${PROJECT_SOURCE_DIR}/../test.py ${PROJECT_NAME}) diff --git a/pybind11/tests/test_cmake_build/test.py b/pybind11/tests/test_cmake_build/test.py deleted file mode 100644 index 1467a61..0000000 --- a/pybind11/tests/test_cmake_build/test.py +++ /dev/null @@ -1,5 +0,0 @@ -import sys -import test_cmake_build - -assert test_cmake_build.add(1, 2) == 3 -print("{} imports, runs, and adds: 1 + 2 = 3".format(sys.argv[1])) diff --git a/pybind11/tests/test_constants_and_functions.cpp b/pybind11/tests/test_constants_and_functions.cpp deleted file mode 100644 index e8ec74b..0000000 --- a/pybind11/tests/test_constants_and_functions.cpp +++ /dev/null @@ -1,127 +0,0 @@ -/* - tests/test_constants_and_functions.cpp -- global constants and functions, enumerations, raw byte strings - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" - -enum MyEnum { EFirstEntry = 1, ESecondEntry }; - -std::string test_function1() { - return "test_function()"; -} - -std::string test_function2(MyEnum k) { - return "test_function(enum=" + std::to_string(k) + ")"; -} - -std::string test_function3(int i) { - return "test_function(" + std::to_string(i) + ")"; -} - -py::str test_function4() { return "test_function()"; } -py::str test_function4(char *) { return "test_function(char *)"; } -py::str test_function4(int, float) { return "test_function(int, float)"; } -py::str test_function4(float, int) { return "test_function(float, int)"; } - -py::bytes return_bytes() { - const char *data = "\x01\x00\x02\x00"; - return std::string(data, 4); -} - -std::string print_bytes(py::bytes bytes) { - std::string ret = "bytes["; - const auto value = static_cast(bytes); - for (size_t i = 0; i < value.length(); ++i) { - ret += std::to_string(static_cast(value[i])) + " "; - } - ret.back() = ']'; - return ret; -} - -// Test that we properly handle C++17 exception specifiers (which are part of the function signature -// in C++17). These should all still work before C++17, but don't affect the function signature. -namespace test_exc_sp { -int f1(int x) noexcept { return x+1; } -int f2(int x) noexcept(true) { return x+2; } -int f3(int x) noexcept(false) { return x+3; } -#if defined(__GNUG__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated" -#endif -int f4(int x) throw() { return x+4; } // Deprecated equivalent to noexcept(true) -#if defined(__GNUG__) -# pragma GCC diagnostic pop -#endif -struct C { - int m1(int x) noexcept { return x-1; } - int m2(int x) const noexcept { return x-2; } - int m3(int x) noexcept(true) { return x-3; } - int m4(int x) const noexcept(true) { return x-4; } - int m5(int x) noexcept(false) { return x-5; } - int m6(int x) const noexcept(false) { return x-6; } -#if defined(__GNUG__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated" -#endif - int m7(int x) throw() { return x-7; } - int m8(int x) const throw() { return x-8; } -#if defined(__GNUG__) -# pragma GCC diagnostic pop -#endif -}; -} - - -TEST_SUBMODULE(constants_and_functions, m) { - // test_constants - m.attr("some_constant") = py::int_(14); - - // test_function_overloading - m.def("test_function", &test_function1); - m.def("test_function", &test_function2); - m.def("test_function", &test_function3); - -#if defined(PYBIND11_OVERLOAD_CAST) - m.def("test_function", py::overload_cast<>(&test_function4)); - m.def("test_function", py::overload_cast(&test_function4)); - m.def("test_function", py::overload_cast(&test_function4)); - m.def("test_function", py::overload_cast(&test_function4)); -#else - m.def("test_function", static_cast(&test_function4)); - m.def("test_function", static_cast(&test_function4)); - m.def("test_function", static_cast(&test_function4)); - m.def("test_function", static_cast(&test_function4)); -#endif - - py::enum_(m, "MyEnum") - .value("EFirstEntry", EFirstEntry) - .value("ESecondEntry", ESecondEntry) - .export_values(); - - // test_bytes - m.def("return_bytes", &return_bytes); - m.def("print_bytes", &print_bytes); - - // test_exception_specifiers - using namespace test_exc_sp; - py::class_(m, "C") - .def(py::init<>()) - .def("m1", &C::m1) - .def("m2", &C::m2) - .def("m3", &C::m3) - .def("m4", &C::m4) - .def("m5", &C::m5) - .def("m6", &C::m6) - .def("m7", &C::m7) - .def("m8", &C::m8) - ; - m.def("f1", f1); - m.def("f2", f2); - m.def("f3", f3); - m.def("f4", f4); -} diff --git a/pybind11/tests/test_constants_and_functions.py b/pybind11/tests/test_constants_and_functions.py deleted file mode 100644 index 472682d..0000000 --- a/pybind11/tests/test_constants_and_functions.py +++ /dev/null @@ -1,39 +0,0 @@ -from pybind11_tests import constants_and_functions as m - - -def test_constants(): - assert m.some_constant == 14 - - -def test_function_overloading(): - assert m.test_function() == "test_function()" - assert m.test_function(7) == "test_function(7)" - assert m.test_function(m.MyEnum.EFirstEntry) == "test_function(enum=1)" - assert m.test_function(m.MyEnum.ESecondEntry) == "test_function(enum=2)" - - assert m.test_function() == "test_function()" - assert m.test_function("abcd") == "test_function(char *)" - assert m.test_function(1, 1.0) == "test_function(int, float)" - assert m.test_function(1, 1.0) == "test_function(int, float)" - assert m.test_function(2.0, 2) == "test_function(float, int)" - - -def test_bytes(): - assert m.print_bytes(m.return_bytes()) == "bytes[1 0 2 0]" - - -def test_exception_specifiers(): - c = m.C() - assert c.m1(2) == 1 - assert c.m2(3) == 1 - assert c.m3(5) == 2 - assert c.m4(7) == 3 - assert c.m5(10) == 5 - assert c.m6(14) == 8 - assert c.m7(20) == 13 - assert c.m8(29) == 21 - - assert m.f1(33) == 34 - assert m.f2(53) == 55 - assert m.f3(86) == 89 - assert m.f4(140) == 144 diff --git a/pybind11/tests/test_copy_move.cpp b/pybind11/tests/test_copy_move.cpp deleted file mode 100644 index 98d5e0a..0000000 --- a/pybind11/tests/test_copy_move.cpp +++ /dev/null @@ -1,213 +0,0 @@ -/* - tests/test_copy_move_policies.cpp -- 'copy' and 'move' return value policies - and related tests - - Copyright (c) 2016 Ben North - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" -#include - -template -struct empty { - static const derived& get_one() { return instance_; } - static derived instance_; -}; - -struct lacking_copy_ctor : public empty { - lacking_copy_ctor() {} - lacking_copy_ctor(const lacking_copy_ctor& other) = delete; -}; - -template <> lacking_copy_ctor empty::instance_ = {}; - -struct lacking_move_ctor : public empty { - lacking_move_ctor() {} - lacking_move_ctor(const lacking_move_ctor& other) = delete; - lacking_move_ctor(lacking_move_ctor&& other) = delete; -}; - -template <> lacking_move_ctor empty::instance_ = {}; - -/* Custom type caster move/copy test classes */ -class MoveOnlyInt { -public: - MoveOnlyInt() { print_default_created(this); } - MoveOnlyInt(int v) : value{std::move(v)} { print_created(this, value); } - MoveOnlyInt(MoveOnlyInt &&m) { print_move_created(this, m.value); std::swap(value, m.value); } - MoveOnlyInt &operator=(MoveOnlyInt &&m) { print_move_assigned(this, m.value); std::swap(value, m.value); return *this; } - MoveOnlyInt(const MoveOnlyInt &) = delete; - MoveOnlyInt &operator=(const MoveOnlyInt &) = delete; - ~MoveOnlyInt() { print_destroyed(this); } - - int value; -}; -class MoveOrCopyInt { -public: - MoveOrCopyInt() { print_default_created(this); } - MoveOrCopyInt(int v) : value{std::move(v)} { print_created(this, value); } - MoveOrCopyInt(MoveOrCopyInt &&m) { print_move_created(this, m.value); std::swap(value, m.value); } - MoveOrCopyInt &operator=(MoveOrCopyInt &&m) { print_move_assigned(this, m.value); std::swap(value, m.value); return *this; } - MoveOrCopyInt(const MoveOrCopyInt &c) { print_copy_created(this, c.value); value = c.value; } - MoveOrCopyInt &operator=(const MoveOrCopyInt &c) { print_copy_assigned(this, c.value); value = c.value; return *this; } - ~MoveOrCopyInt() { print_destroyed(this); } - - int value; -}; -class CopyOnlyInt { -public: - CopyOnlyInt() { print_default_created(this); } - CopyOnlyInt(int v) : value{std::move(v)} { print_created(this, value); } - CopyOnlyInt(const CopyOnlyInt &c) { print_copy_created(this, c.value); value = c.value; } - CopyOnlyInt &operator=(const CopyOnlyInt &c) { print_copy_assigned(this, c.value); value = c.value; return *this; } - ~CopyOnlyInt() { print_destroyed(this); } - - int value; -}; -NAMESPACE_BEGIN(pybind11) -NAMESPACE_BEGIN(detail) -template <> struct type_caster { - PYBIND11_TYPE_CASTER(MoveOnlyInt, _("MoveOnlyInt")); - bool load(handle src, bool) { value = MoveOnlyInt(src.cast()); return true; } - static handle cast(const MoveOnlyInt &m, return_value_policy r, handle p) { return pybind11::cast(m.value, r, p); } -}; - -template <> struct type_caster { - PYBIND11_TYPE_CASTER(MoveOrCopyInt, _("MoveOrCopyInt")); - bool load(handle src, bool) { value = MoveOrCopyInt(src.cast()); return true; } - static handle cast(const MoveOrCopyInt &m, return_value_policy r, handle p) { return pybind11::cast(m.value, r, p); } -}; - -template <> struct type_caster { -protected: - CopyOnlyInt value; -public: - static constexpr auto name = _("CopyOnlyInt"); - bool load(handle src, bool) { value = CopyOnlyInt(src.cast()); return true; } - static handle cast(const CopyOnlyInt &m, return_value_policy r, handle p) { return pybind11::cast(m.value, r, p); } - static handle cast(const CopyOnlyInt *src, return_value_policy policy, handle parent) { - if (!src) return none().release(); - return cast(*src, policy, parent); - } - operator CopyOnlyInt*() { return &value; } - operator CopyOnlyInt&() { return value; } - template using cast_op_type = pybind11::detail::cast_op_type; -}; -NAMESPACE_END(detail) -NAMESPACE_END(pybind11) - -TEST_SUBMODULE(copy_move_policies, m) { - // test_lacking_copy_ctor - py::class_(m, "lacking_copy_ctor") - .def_static("get_one", &lacking_copy_ctor::get_one, - py::return_value_policy::copy); - // test_lacking_move_ctor - py::class_(m, "lacking_move_ctor") - .def_static("get_one", &lacking_move_ctor::get_one, - py::return_value_policy::move); - - // test_move_and_copy_casts - m.def("move_and_copy_casts", [](py::object o) { - int r = 0; - r += py::cast(o).value; /* moves */ - r += py::cast(o).value; /* moves */ - r += py::cast(o).value; /* copies */ - MoveOrCopyInt m1(py::cast(o)); /* moves */ - MoveOnlyInt m2(py::cast(o)); /* moves */ - CopyOnlyInt m3(py::cast(o)); /* copies */ - r += m1.value + m2.value + m3.value; - - return r; - }); - - // test_move_and_copy_loads - m.def("move_only", [](MoveOnlyInt m) { return m.value; }); - m.def("move_or_copy", [](MoveOrCopyInt m) { return m.value; }); - m.def("copy_only", [](CopyOnlyInt m) { return m.value; }); - m.def("move_pair", [](std::pair p) { - return p.first.value + p.second.value; - }); - m.def("move_tuple", [](std::tuple t) { - return std::get<0>(t).value + std::get<1>(t).value + std::get<2>(t).value; - }); - m.def("copy_tuple", [](std::tuple t) { - return std::get<0>(t).value + std::get<1>(t).value; - }); - m.def("move_copy_nested", [](std::pair>, MoveOrCopyInt>> x) { - return x.first.value + std::get<0>(x.second.first).value + std::get<1>(x.second.first).value + - std::get<0>(std::get<2>(x.second.first)).value + x.second.second.value; - }); - m.def("move_and_copy_cstats", []() { - ConstructorStats::gc(); - // Reset counts to 0 so that previous tests don't affect later ones: - auto &mc = ConstructorStats::get(); - mc.move_assignments = mc.move_constructions = mc.copy_assignments = mc.copy_constructions = 0; - auto &mo = ConstructorStats::get(); - mo.move_assignments = mo.move_constructions = mo.copy_assignments = mo.copy_constructions = 0; - auto &co = ConstructorStats::get(); - co.move_assignments = co.move_constructions = co.copy_assignments = co.copy_constructions = 0; - py::dict d; - d["MoveOrCopyInt"] = py::cast(mc, py::return_value_policy::reference); - d["MoveOnlyInt"] = py::cast(mo, py::return_value_policy::reference); - d["CopyOnlyInt"] = py::cast(co, py::return_value_policy::reference); - return d; - }); -#ifdef PYBIND11_HAS_OPTIONAL - // test_move_and_copy_load_optional - m.attr("has_optional") = true; - m.def("move_optional", [](std::optional o) { - return o->value; - }); - m.def("move_or_copy_optional", [](std::optional o) { - return o->value; - }); - m.def("copy_optional", [](std::optional o) { - return o->value; - }); - m.def("move_optional_tuple", [](std::optional> x) { - return std::get<0>(*x).value + std::get<1>(*x).value + std::get<2>(*x).value; - }); -#else - m.attr("has_optional") = false; -#endif - - // #70 compilation issue if operator new is not public - struct PrivateOpNew { - int value = 1; - private: -#if defined(_MSC_VER) -# pragma warning(disable: 4822) // warning C4822: local class member function does not have a body -#endif - void *operator new(size_t bytes); - }; - py::class_(m, "PrivateOpNew").def_readonly("value", &PrivateOpNew::value); - m.def("private_op_new_value", []() { return PrivateOpNew(); }); - m.def("private_op_new_reference", []() -> const PrivateOpNew & { - static PrivateOpNew x{}; - return x; - }, py::return_value_policy::reference); - - // test_move_fallback - // #389: rvp::move should fall-through to copy on non-movable objects - struct MoveIssue1 { - int v; - MoveIssue1(int v) : v{v} {} - MoveIssue1(const MoveIssue1 &c) = default; - MoveIssue1(MoveIssue1 &&) = delete; - }; - py::class_(m, "MoveIssue1").def(py::init()).def_readwrite("value", &MoveIssue1::v); - - struct MoveIssue2 { - int v; - MoveIssue2(int v) : v{v} {} - MoveIssue2(MoveIssue2 &&) = default; - }; - py::class_(m, "MoveIssue2").def(py::init()).def_readwrite("value", &MoveIssue2::v); - - m.def("get_moveissue1", [](int i) { return new MoveIssue1(i); }, py::return_value_policy::move); - m.def("get_moveissue2", [](int i) { return MoveIssue2(i); }, py::return_value_policy::move); -} diff --git a/pybind11/tests/test_copy_move.py b/pybind11/tests/test_copy_move.py deleted file mode 100644 index aff2d99..0000000 --- a/pybind11/tests/test_copy_move.py +++ /dev/null @@ -1,112 +0,0 @@ -import pytest -from pybind11_tests import copy_move_policies as m - - -def test_lacking_copy_ctor(): - with pytest.raises(RuntimeError) as excinfo: - m.lacking_copy_ctor.get_one() - assert "the object is non-copyable!" in str(excinfo.value) - - -def test_lacking_move_ctor(): - with pytest.raises(RuntimeError) as excinfo: - m.lacking_move_ctor.get_one() - assert "the object is neither movable nor copyable!" in str(excinfo.value) - - -def test_move_and_copy_casts(): - """Cast some values in C++ via custom type casters and count the number of moves/copies.""" - - cstats = m.move_and_copy_cstats() - c_m, c_mc, c_c = cstats["MoveOnlyInt"], cstats["MoveOrCopyInt"], cstats["CopyOnlyInt"] - - # The type move constructions/assignments below each get incremented: the move assignment comes - # from the type_caster load; the move construction happens when extracting that via a cast or - # loading into an argument. - assert m.move_and_copy_casts(3) == 18 - assert c_m.copy_assignments + c_m.copy_constructions == 0 - assert c_m.move_assignments == 2 - assert c_m.move_constructions >= 2 - assert c_mc.alive() == 0 - assert c_mc.copy_assignments + c_mc.copy_constructions == 0 - assert c_mc.move_assignments == 2 - assert c_mc.move_constructions >= 2 - assert c_c.alive() == 0 - assert c_c.copy_assignments == 2 - assert c_c.copy_constructions >= 2 - assert c_m.alive() + c_mc.alive() + c_c.alive() == 0 - - -def test_move_and_copy_loads(): - """Call some functions that load arguments via custom type casters and count the number of - moves/copies.""" - - cstats = m.move_and_copy_cstats() - c_m, c_mc, c_c = cstats["MoveOnlyInt"], cstats["MoveOrCopyInt"], cstats["CopyOnlyInt"] - - assert m.move_only(10) == 10 # 1 move, c_m - assert m.move_or_copy(11) == 11 # 1 move, c_mc - assert m.copy_only(12) == 12 # 1 copy, c_c - assert m.move_pair((13, 14)) == 27 # 1 c_m move, 1 c_mc move - assert m.move_tuple((15, 16, 17)) == 48 # 2 c_m moves, 1 c_mc move - assert m.copy_tuple((18, 19)) == 37 # 2 c_c copies - # Direct constructions: 2 c_m moves, 2 c_mc moves, 1 c_c copy - # Extra moves/copies when moving pairs/tuples: 3 c_m, 3 c_mc, 2 c_c - assert m.move_copy_nested((1, ((2, 3, (4,)), 5))) == 15 - - assert c_m.copy_assignments + c_m.copy_constructions == 0 - assert c_m.move_assignments == 6 - assert c_m.move_constructions == 9 - assert c_mc.copy_assignments + c_mc.copy_constructions == 0 - assert c_mc.move_assignments == 5 - assert c_mc.move_constructions == 8 - assert c_c.copy_assignments == 4 - assert c_c.copy_constructions == 6 - assert c_m.alive() + c_mc.alive() + c_c.alive() == 0 - - -@pytest.mark.skipif(not m.has_optional, reason='no ') -def test_move_and_copy_load_optional(): - """Tests move/copy loads of std::optional arguments""" - - cstats = m.move_and_copy_cstats() - c_m, c_mc, c_c = cstats["MoveOnlyInt"], cstats["MoveOrCopyInt"], cstats["CopyOnlyInt"] - - # The extra move/copy constructions below come from the std::optional move (which has to move - # its arguments): - assert m.move_optional(10) == 10 # c_m: 1 move assign, 2 move construct - assert m.move_or_copy_optional(11) == 11 # c_mc: 1 move assign, 2 move construct - assert m.copy_optional(12) == 12 # c_c: 1 copy assign, 2 copy construct - # 1 move assign + move construct moves each of c_m, c_mc, 1 c_c copy - # +1 move/copy construct each from moving the tuple - # +1 move/copy construct each from moving the optional (which moves the tuple again) - assert m.move_optional_tuple((3, 4, 5)) == 12 - - assert c_m.copy_assignments + c_m.copy_constructions == 0 - assert c_m.move_assignments == 2 - assert c_m.move_constructions == 5 - assert c_mc.copy_assignments + c_mc.copy_constructions == 0 - assert c_mc.move_assignments == 2 - assert c_mc.move_constructions == 5 - assert c_c.copy_assignments == 2 - assert c_c.copy_constructions == 5 - assert c_m.alive() + c_mc.alive() + c_c.alive() == 0 - - -def test_private_op_new(): - """An object with a private `operator new` cannot be returned by value""" - - with pytest.raises(RuntimeError) as excinfo: - m.private_op_new_value() - assert "the object is neither movable nor copyable" in str(excinfo.value) - - assert m.private_op_new_reference().value == 1 - - -def test_move_fallback(): - """#389: rvp::move should fall-through to copy on non-movable objects""" - - m2 = m.get_moveissue2(2) - assert m2.value == 2 - m1 = m.get_moveissue1(1) - assert m1.value == 1 diff --git a/pybind11/tests/test_docstring_options.cpp b/pybind11/tests/test_docstring_options.cpp deleted file mode 100644 index 8c8f79f..0000000 --- a/pybind11/tests/test_docstring_options.cpp +++ /dev/null @@ -1,61 +0,0 @@ -/* - tests/test_docstring_options.cpp -- generation of docstrings and signatures - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" - -TEST_SUBMODULE(docstring_options, m) { - // test_docstring_options - { - py::options options; - options.disable_function_signatures(); - - m.def("test_function1", [](int, int) {}, py::arg("a"), py::arg("b")); - m.def("test_function2", [](int, int) {}, py::arg("a"), py::arg("b"), "A custom docstring"); - - m.def("test_overloaded1", [](int) {}, py::arg("i"), "Overload docstring"); - m.def("test_overloaded1", [](double) {}, py::arg("d")); - - m.def("test_overloaded2", [](int) {}, py::arg("i"), "overload docstring 1"); - m.def("test_overloaded2", [](double) {}, py::arg("d"), "overload docstring 2"); - - m.def("test_overloaded3", [](int) {}, py::arg("i")); - m.def("test_overloaded3", [](double) {}, py::arg("d"), "Overload docstr"); - - options.enable_function_signatures(); - - m.def("test_function3", [](int, int) {}, py::arg("a"), py::arg("b")); - m.def("test_function4", [](int, int) {}, py::arg("a"), py::arg("b"), "A custom docstring"); - - options.disable_function_signatures().disable_user_defined_docstrings(); - - m.def("test_function5", [](int, int) {}, py::arg("a"), py::arg("b"), "A custom docstring"); - - { - py::options nested_options; - nested_options.enable_user_defined_docstrings(); - m.def("test_function6", [](int, int) {}, py::arg("a"), py::arg("b"), "A custom docstring"); - } - } - - m.def("test_function7", [](int, int) {}, py::arg("a"), py::arg("b"), "A custom docstring"); - - { - py::options options; - options.disable_user_defined_docstrings(); - - struct DocstringTestFoo { - int value; - void setValue(int v) { value = v; } - int getValue() const { return value; } - }; - py::class_(m, "DocstringTestFoo", "This is a class docstring") - .def_property("value_prop", &DocstringTestFoo::getValue, &DocstringTestFoo::setValue, "This is a property docstring") - ; - } -} diff --git a/pybind11/tests/test_docstring_options.py b/pybind11/tests/test_docstring_options.py deleted file mode 100644 index 0dbca60..0000000 --- a/pybind11/tests/test_docstring_options.py +++ /dev/null @@ -1,38 +0,0 @@ -from pybind11_tests import docstring_options as m - - -def test_docstring_options(): - # options.disable_function_signatures() - assert not m.test_function1.__doc__ - - assert m.test_function2.__doc__ == "A custom docstring" - - # docstring specified on just the first overload definition: - assert m.test_overloaded1.__doc__ == "Overload docstring" - - # docstring on both overloads: - assert m.test_overloaded2.__doc__ == "overload docstring 1\noverload docstring 2" - - # docstring on only second overload: - assert m.test_overloaded3.__doc__ == "Overload docstr" - - # options.enable_function_signatures() - assert m.test_function3.__doc__ .startswith("test_function3(a: int, b: int) -> None") - - assert m.test_function4.__doc__ .startswith("test_function4(a: int, b: int) -> None") - assert m.test_function4.__doc__ .endswith("A custom docstring\n") - - # options.disable_function_signatures() - # options.disable_user_defined_docstrings() - assert not m.test_function5.__doc__ - - # nested options.enable_user_defined_docstrings() - assert m.test_function6.__doc__ == "A custom docstring" - - # RAII destructor - assert m.test_function7.__doc__ .startswith("test_function7(a: int, b: int) -> None") - assert m.test_function7.__doc__ .endswith("A custom docstring\n") - - # Suppression of user-defined docstrings for non-function objects - assert not m.DocstringTestFoo.__doc__ - assert not m.DocstringTestFoo.value_prop.__doc__ diff --git a/pybind11/tests/test_eigen.cpp b/pybind11/tests/test_eigen.cpp deleted file mode 100644 index aba088d..0000000 --- a/pybind11/tests/test_eigen.cpp +++ /dev/null @@ -1,329 +0,0 @@ -/* - tests/eigen.cpp -- automatic conversion of Eigen types - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" -#include -#include - -#if defined(_MSC_VER) -# pragma warning(disable: 4996) // C4996: std::unary_negation is deprecated -#endif - -#include - -using MatrixXdR = Eigen::Matrix; - - - -// Sets/resets a testing reference matrix to have values of 10*r + c, where r and c are the -// (1-based) row/column number. -template void reset_ref(M &x) { - for (int i = 0; i < x.rows(); i++) for (int j = 0; j < x.cols(); j++) - x(i, j) = 11 + 10*i + j; -} - -// Returns a static, column-major matrix -Eigen::MatrixXd &get_cm() { - static Eigen::MatrixXd *x; - if (!x) { - x = new Eigen::MatrixXd(3, 3); - reset_ref(*x); - } - return *x; -} -// Likewise, but row-major -MatrixXdR &get_rm() { - static MatrixXdR *x; - if (!x) { - x = new MatrixXdR(3, 3); - reset_ref(*x); - } - return *x; -} -// Resets the values of the static matrices returned by get_cm()/get_rm() -void reset_refs() { - reset_ref(get_cm()); - reset_ref(get_rm()); -} - -// Returns element 2,1 from a matrix (used to test copy/nocopy) -double get_elem(Eigen::Ref m) { return m(2, 1); }; - - -// Returns a matrix with 10*r + 100*c added to each matrix element (to help test that the matrix -// reference is referencing rows/columns correctly). -template Eigen::MatrixXd adjust_matrix(MatrixArgType m) { - Eigen::MatrixXd ret(m); - for (int c = 0; c < m.cols(); c++) for (int r = 0; r < m.rows(); r++) - ret(r, c) += 10*r + 100*c; - return ret; -} - -struct CustomOperatorNew { - CustomOperatorNew() = default; - - Eigen::Matrix4d a = Eigen::Matrix4d::Zero(); - Eigen::Matrix4d b = Eigen::Matrix4d::Identity(); - - EIGEN_MAKE_ALIGNED_OPERATOR_NEW; -}; - -TEST_SUBMODULE(eigen, m) { - using FixedMatrixR = Eigen::Matrix; - using FixedMatrixC = Eigen::Matrix; - using DenseMatrixR = Eigen::Matrix; - using DenseMatrixC = Eigen::Matrix; - using FourRowMatrixC = Eigen::Matrix; - using FourColMatrixC = Eigen::Matrix; - using FourRowMatrixR = Eigen::Matrix; - using FourColMatrixR = Eigen::Matrix; - using SparseMatrixR = Eigen::SparseMatrix; - using SparseMatrixC = Eigen::SparseMatrix; - - m.attr("have_eigen") = true; - - // various tests - m.def("double_col", [](const Eigen::VectorXf &x) -> Eigen::VectorXf { return 2.0f * x; }); - m.def("double_row", [](const Eigen::RowVectorXf &x) -> Eigen::RowVectorXf { return 2.0f * x; }); - m.def("double_complex", [](const Eigen::VectorXcf &x) -> Eigen::VectorXcf { return 2.0f * x; }); - m.def("double_threec", [](py::EigenDRef x) { x *= 2; }); - m.def("double_threer", [](py::EigenDRef x) { x *= 2; }); - m.def("double_mat_cm", [](Eigen::MatrixXf x) -> Eigen::MatrixXf { return 2.0f * x; }); - m.def("double_mat_rm", [](DenseMatrixR x) -> DenseMatrixR { return 2.0f * x; }); - - // test_eigen_ref_to_python - // Different ways of passing via Eigen::Ref; the first and second are the Eigen-recommended - m.def("cholesky1", [](Eigen::Ref x) -> Eigen::MatrixXd { return x.llt().matrixL(); }); - m.def("cholesky2", [](const Eigen::Ref &x) -> Eigen::MatrixXd { return x.llt().matrixL(); }); - m.def("cholesky3", [](const Eigen::Ref &x) -> Eigen::MatrixXd { return x.llt().matrixL(); }); - m.def("cholesky4", [](Eigen::Ref x) -> Eigen::MatrixXd { return x.llt().matrixL(); }); - - // test_eigen_ref_mutators - // Mutators: these add some value to the given element using Eigen, but Eigen should be mapping into - // the numpy array data and so the result should show up there. There are three versions: one that - // works on a contiguous-row matrix (numpy's default), one for a contiguous-column matrix, and one - // for any matrix. - auto add_rm = [](Eigen::Ref x, int r, int c, double v) { x(r,c) += v; }; - auto add_cm = [](Eigen::Ref x, int r, int c, double v) { x(r,c) += v; }; - - // Mutators (Eigen maps into numpy variables): - m.def("add_rm", add_rm); // Only takes row-contiguous - m.def("add_cm", add_cm); // Only takes column-contiguous - // Overloaded versions that will accept either row or column contiguous: - m.def("add1", add_rm); - m.def("add1", add_cm); - m.def("add2", add_cm); - m.def("add2", add_rm); - // This one accepts a matrix of any stride: - m.def("add_any", [](py::EigenDRef x, int r, int c, double v) { x(r,c) += v; }); - - // Return mutable references (numpy maps into eigen variables) - m.def("get_cm_ref", []() { return Eigen::Ref(get_cm()); }); - m.def("get_rm_ref", []() { return Eigen::Ref(get_rm()); }); - // The same references, but non-mutable (numpy maps into eigen variables, but is !writeable) - m.def("get_cm_const_ref", []() { return Eigen::Ref(get_cm()); }); - m.def("get_rm_const_ref", []() { return Eigen::Ref(get_rm()); }); - - m.def("reset_refs", reset_refs); // Restores get_{cm,rm}_ref to original values - - // Increments and returns ref to (same) matrix - m.def("incr_matrix", [](Eigen::Ref m, double v) { - m += Eigen::MatrixXd::Constant(m.rows(), m.cols(), v); - return m; - }, py::return_value_policy::reference); - - // Same, but accepts a matrix of any strides - m.def("incr_matrix_any", [](py::EigenDRef m, double v) { - m += Eigen::MatrixXd::Constant(m.rows(), m.cols(), v); - return m; - }, py::return_value_policy::reference); - - // Returns an eigen slice of even rows - m.def("even_rows", [](py::EigenDRef m) { - return py::EigenDMap( - m.data(), (m.rows() + 1) / 2, m.cols(), - py::EigenDStride(m.outerStride(), 2 * m.innerStride())); - }, py::return_value_policy::reference); - - // Returns an eigen slice of even columns - m.def("even_cols", [](py::EigenDRef m) { - return py::EigenDMap( - m.data(), m.rows(), (m.cols() + 1) / 2, - py::EigenDStride(2 * m.outerStride(), m.innerStride())); - }, py::return_value_policy::reference); - - // Returns diagonals: a vector-like object with an inner stride != 1 - m.def("diagonal", [](const Eigen::Ref &x) { return x.diagonal(); }); - m.def("diagonal_1", [](const Eigen::Ref &x) { return x.diagonal<1>(); }); - m.def("diagonal_n", [](const Eigen::Ref &x, int index) { return x.diagonal(index); }); - - // Return a block of a matrix (gives non-standard strides) - m.def("block", [](const Eigen::Ref &x, int start_row, int start_col, int block_rows, int block_cols) { - return x.block(start_row, start_col, block_rows, block_cols); - }); - - // test_eigen_return_references, test_eigen_keepalive - // return value referencing/copying tests: - class ReturnTester { - Eigen::MatrixXd mat = create(); - public: - ReturnTester() { print_created(this); } - ~ReturnTester() { print_destroyed(this); } - static Eigen::MatrixXd create() { return Eigen::MatrixXd::Ones(10, 10); } - static const Eigen::MatrixXd createConst() { return Eigen::MatrixXd::Ones(10, 10); } - Eigen::MatrixXd &get() { return mat; } - Eigen::MatrixXd *getPtr() { return &mat; } - const Eigen::MatrixXd &view() { return mat; } - const Eigen::MatrixXd *viewPtr() { return &mat; } - Eigen::Ref ref() { return mat; } - Eigen::Ref refConst() { return mat; } - Eigen::Block block(int r, int c, int nrow, int ncol) { return mat.block(r, c, nrow, ncol); } - Eigen::Block blockConst(int r, int c, int nrow, int ncol) const { return mat.block(r, c, nrow, ncol); } - py::EigenDMap corners() { return py::EigenDMap(mat.data(), - py::EigenDStride(mat.outerStride() * (mat.outerSize()-1), mat.innerStride() * (mat.innerSize()-1))); } - py::EigenDMap cornersConst() const { return py::EigenDMap(mat.data(), - py::EigenDStride(mat.outerStride() * (mat.outerSize()-1), mat.innerStride() * (mat.innerSize()-1))); } - }; - using rvp = py::return_value_policy; - py::class_(m, "ReturnTester") - .def(py::init<>()) - .def_static("create", &ReturnTester::create) - .def_static("create_const", &ReturnTester::createConst) - .def("get", &ReturnTester::get, rvp::reference_internal) - .def("get_ptr", &ReturnTester::getPtr, rvp::reference_internal) - .def("view", &ReturnTester::view, rvp::reference_internal) - .def("view_ptr", &ReturnTester::view, rvp::reference_internal) - .def("copy_get", &ReturnTester::get) // Default rvp: copy - .def("copy_view", &ReturnTester::view) // " - .def("ref", &ReturnTester::ref) // Default for Ref is to reference - .def("ref_const", &ReturnTester::refConst) // Likewise, but const - .def("ref_safe", &ReturnTester::ref, rvp::reference_internal) - .def("ref_const_safe", &ReturnTester::refConst, rvp::reference_internal) - .def("copy_ref", &ReturnTester::ref, rvp::copy) - .def("copy_ref_const", &ReturnTester::refConst, rvp::copy) - .def("block", &ReturnTester::block) - .def("block_safe", &ReturnTester::block, rvp::reference_internal) - .def("block_const", &ReturnTester::blockConst, rvp::reference_internal) - .def("copy_block", &ReturnTester::block, rvp::copy) - .def("corners", &ReturnTester::corners, rvp::reference_internal) - .def("corners_const", &ReturnTester::cornersConst, rvp::reference_internal) - ; - - // test_special_matrix_objects - // Returns a DiagonalMatrix with diagonal (1,2,3,...) - m.def("incr_diag", [](int k) { - Eigen::DiagonalMatrix m(k); - for (int i = 0; i < k; i++) m.diagonal()[i] = i+1; - return m; - }); - - // Returns a SelfAdjointView referencing the lower triangle of m - m.def("symmetric_lower", [](const Eigen::MatrixXi &m) { - return m.selfadjointView(); - }); - // Returns a SelfAdjointView referencing the lower triangle of m - m.def("symmetric_upper", [](const Eigen::MatrixXi &m) { - return m.selfadjointView(); - }); - - // Test matrix for various functions below. - Eigen::MatrixXf mat(5, 6); - mat << 0, 3, 0, 0, 0, 11, - 22, 0, 0, 0, 17, 11, - 7, 5, 0, 1, 0, 11, - 0, 0, 0, 0, 0, 11, - 0, 0, 14, 0, 8, 11; - - // test_fixed, and various other tests - m.def("fixed_r", [mat]() -> FixedMatrixR { return FixedMatrixR(mat); }); - m.def("fixed_r_const", [mat]() -> const FixedMatrixR { return FixedMatrixR(mat); }); - m.def("fixed_c", [mat]() -> FixedMatrixC { return FixedMatrixC(mat); }); - m.def("fixed_copy_r", [](const FixedMatrixR &m) -> FixedMatrixR { return m; }); - m.def("fixed_copy_c", [](const FixedMatrixC &m) -> FixedMatrixC { return m; }); - // test_mutator_descriptors - m.def("fixed_mutator_r", [](Eigen::Ref) {}); - m.def("fixed_mutator_c", [](Eigen::Ref) {}); - m.def("fixed_mutator_a", [](py::EigenDRef) {}); - // test_dense - m.def("dense_r", [mat]() -> DenseMatrixR { return DenseMatrixR(mat); }); - m.def("dense_c", [mat]() -> DenseMatrixC { return DenseMatrixC(mat); }); - m.def("dense_copy_r", [](const DenseMatrixR &m) -> DenseMatrixR { return m; }); - m.def("dense_copy_c", [](const DenseMatrixC &m) -> DenseMatrixC { return m; }); - // test_sparse, test_sparse_signature - m.def("sparse_r", [mat]() -> SparseMatrixR { return Eigen::SparseView(mat); }); - m.def("sparse_c", [mat]() -> SparseMatrixC { return Eigen::SparseView(mat); }); - m.def("sparse_copy_r", [](const SparseMatrixR &m) -> SparseMatrixR { return m; }); - m.def("sparse_copy_c", [](const SparseMatrixC &m) -> SparseMatrixC { return m; }); - // test_partially_fixed - m.def("partial_copy_four_rm_r", [](const FourRowMatrixR &m) -> FourRowMatrixR { return m; }); - m.def("partial_copy_four_rm_c", [](const FourColMatrixR &m) -> FourColMatrixR { return m; }); - m.def("partial_copy_four_cm_r", [](const FourRowMatrixC &m) -> FourRowMatrixC { return m; }); - m.def("partial_copy_four_cm_c", [](const FourColMatrixC &m) -> FourColMatrixC { return m; }); - - // test_cpp_casting - // Test that we can cast a numpy object to a Eigen::MatrixXd explicitly - m.def("cpp_copy", [](py::handle m) { return m.cast()(1, 0); }); - m.def("cpp_ref_c", [](py::handle m) { return m.cast>()(1, 0); }); - m.def("cpp_ref_r", [](py::handle m) { return m.cast>()(1, 0); }); - m.def("cpp_ref_any", [](py::handle m) { return m.cast>()(1, 0); }); - - - // test_nocopy_wrapper - // Test that we can prevent copying into an argument that would normally copy: First a version - // that would allow copying (if types or strides don't match) for comparison: - m.def("get_elem", &get_elem); - // Now this alternative that calls the tells pybind to fail rather than copy: - m.def("get_elem_nocopy", [](Eigen::Ref m) -> double { return get_elem(m); }, - py::arg().noconvert()); - // Also test a row-major-only no-copy const ref: - m.def("get_elem_rm_nocopy", [](Eigen::Ref> &m) -> long { return m(2, 1); }, - py::arg().noconvert()); - - // test_issue738 - // Issue #738: 1xN or Nx1 2D matrices were neither accepted nor properly copied with an - // incompatible stride value on the length-1 dimension--but that should be allowed (without - // requiring a copy!) because the stride value can be safely ignored on a size-1 dimension. - m.def("iss738_f1", &adjust_matrix &>, py::arg().noconvert()); - m.def("iss738_f2", &adjust_matrix> &>, py::arg().noconvert()); - - // test_issue1105 - // Issue #1105: when converting from a numpy two-dimensional (Nx1) or (1xN) value into a dense - // eigen Vector or RowVector, the argument would fail to load because the numpy copy would fail: - // numpy won't broadcast a Nx1 into a 1-dimensional vector. - m.def("iss1105_col", [](Eigen::VectorXd) { return true; }); - m.def("iss1105_row", [](Eigen::RowVectorXd) { return true; }); - - // test_named_arguments - // Make sure named arguments are working properly: - m.def("matrix_multiply", [](const py::EigenDRef A, const py::EigenDRef B) - -> Eigen::MatrixXd { - if (A.cols() != B.rows()) throw std::domain_error("Nonconformable matrices!"); - return A * B; - }, py::arg("A"), py::arg("B")); - - // test_custom_operator_new - py::class_(m, "CustomOperatorNew") - .def(py::init<>()) - .def_readonly("a", &CustomOperatorNew::a) - .def_readonly("b", &CustomOperatorNew::b); - - // test_eigen_ref_life_support - // In case of a failure (the caster's temp array does not live long enough), creating - // a new array (np.ones(10)) increases the chances that the temp array will be garbage - // collected and/or that its memory will be overridden with different values. - m.def("get_elem_direct", [](Eigen::Ref v) { - py::module::import("numpy").attr("ones")(10); - return v(5); - }); - m.def("get_elem_indirect", [](std::vector> v) { - py::module::import("numpy").attr("ones")(10); - return v[0](5); - }); -} diff --git a/pybind11/tests/test_eigen.py b/pybind11/tests/test_eigen.py deleted file mode 100644 index 55d9351..0000000 --- a/pybind11/tests/test_eigen.py +++ /dev/null @@ -1,694 +0,0 @@ -import pytest -from pybind11_tests import ConstructorStats - -pytestmark = pytest.requires_eigen_and_numpy - -with pytest.suppress(ImportError): - from pybind11_tests import eigen as m - import numpy as np - - ref = np.array([[ 0., 3, 0, 0, 0, 11], - [22, 0, 0, 0, 17, 11], - [ 7, 5, 0, 1, 0, 11], - [ 0, 0, 0, 0, 0, 11], - [ 0, 0, 14, 0, 8, 11]]) - - -def assert_equal_ref(mat): - np.testing.assert_array_equal(mat, ref) - - -def assert_sparse_equal_ref(sparse_mat): - assert_equal_ref(sparse_mat.toarray()) - - -def test_fixed(): - assert_equal_ref(m.fixed_c()) - assert_equal_ref(m.fixed_r()) - assert_equal_ref(m.fixed_copy_r(m.fixed_r())) - assert_equal_ref(m.fixed_copy_c(m.fixed_c())) - assert_equal_ref(m.fixed_copy_r(m.fixed_c())) - assert_equal_ref(m.fixed_copy_c(m.fixed_r())) - - -def test_dense(): - assert_equal_ref(m.dense_r()) - assert_equal_ref(m.dense_c()) - assert_equal_ref(m.dense_copy_r(m.dense_r())) - assert_equal_ref(m.dense_copy_c(m.dense_c())) - assert_equal_ref(m.dense_copy_r(m.dense_c())) - assert_equal_ref(m.dense_copy_c(m.dense_r())) - - -def test_partially_fixed(): - ref2 = np.array([[0., 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) - np.testing.assert_array_equal(m.partial_copy_four_rm_r(ref2), ref2) - np.testing.assert_array_equal(m.partial_copy_four_rm_c(ref2), ref2) - np.testing.assert_array_equal(m.partial_copy_four_rm_r(ref2[:, 1]), ref2[:, [1]]) - np.testing.assert_array_equal(m.partial_copy_four_rm_c(ref2[0, :]), ref2[[0], :]) - np.testing.assert_array_equal(m.partial_copy_four_rm_r(ref2[:, (0, 2)]), ref2[:, (0, 2)]) - np.testing.assert_array_equal( - m.partial_copy_four_rm_c(ref2[(3, 1, 2), :]), ref2[(3, 1, 2), :]) - - np.testing.assert_array_equal(m.partial_copy_four_cm_r(ref2), ref2) - np.testing.assert_array_equal(m.partial_copy_four_cm_c(ref2), ref2) - np.testing.assert_array_equal(m.partial_copy_four_cm_r(ref2[:, 1]), ref2[:, [1]]) - np.testing.assert_array_equal(m.partial_copy_four_cm_c(ref2[0, :]), ref2[[0], :]) - np.testing.assert_array_equal(m.partial_copy_four_cm_r(ref2[:, (0, 2)]), ref2[:, (0, 2)]) - np.testing.assert_array_equal( - m.partial_copy_four_cm_c(ref2[(3, 1, 2), :]), ref2[(3, 1, 2), :]) - - # TypeError should be raise for a shape mismatch - functions = [m.partial_copy_four_rm_r, m.partial_copy_four_rm_c, - m.partial_copy_four_cm_r, m.partial_copy_four_cm_c] - matrix_with_wrong_shape = [[1, 2], - [3, 4]] - for f in functions: - with pytest.raises(TypeError) as excinfo: - f(matrix_with_wrong_shape) - assert "incompatible function arguments" in str(excinfo.value) - - -def test_mutator_descriptors(): - zr = np.arange(30, dtype='float32').reshape(5, 6) # row-major - zc = zr.reshape(6, 5).transpose() # column-major - - m.fixed_mutator_r(zr) - m.fixed_mutator_c(zc) - m.fixed_mutator_a(zr) - m.fixed_mutator_a(zc) - with pytest.raises(TypeError) as excinfo: - m.fixed_mutator_r(zc) - assert ('(arg0: numpy.ndarray[float32[5, 6], flags.writeable, flags.c_contiguous]) -> None' - in str(excinfo.value)) - with pytest.raises(TypeError) as excinfo: - m.fixed_mutator_c(zr) - assert ('(arg0: numpy.ndarray[float32[5, 6], flags.writeable, flags.f_contiguous]) -> None' - in str(excinfo.value)) - with pytest.raises(TypeError) as excinfo: - m.fixed_mutator_a(np.array([[1, 2], [3, 4]], dtype='float32')) - assert ('(arg0: numpy.ndarray[float32[5, 6], flags.writeable]) -> None' - in str(excinfo.value)) - zr.flags.writeable = False - with pytest.raises(TypeError): - m.fixed_mutator_r(zr) - with pytest.raises(TypeError): - m.fixed_mutator_a(zr) - - -def test_cpp_casting(): - assert m.cpp_copy(m.fixed_r()) == 22. - assert m.cpp_copy(m.fixed_c()) == 22. - z = np.array([[5., 6], [7, 8]]) - assert m.cpp_copy(z) == 7. - assert m.cpp_copy(m.get_cm_ref()) == 21. - assert m.cpp_copy(m.get_rm_ref()) == 21. - assert m.cpp_ref_c(m.get_cm_ref()) == 21. - assert m.cpp_ref_r(m.get_rm_ref()) == 21. - with pytest.raises(RuntimeError) as excinfo: - # Can't reference m.fixed_c: it contains floats, m.cpp_ref_any wants doubles - m.cpp_ref_any(m.fixed_c()) - assert 'Unable to cast Python instance' in str(excinfo.value) - with pytest.raises(RuntimeError) as excinfo: - # Can't reference m.fixed_r: it contains floats, m.cpp_ref_any wants doubles - m.cpp_ref_any(m.fixed_r()) - assert 'Unable to cast Python instance' in str(excinfo.value) - assert m.cpp_ref_any(m.ReturnTester.create()) == 1. - - assert m.cpp_ref_any(m.get_cm_ref()) == 21. - assert m.cpp_ref_any(m.get_cm_ref()) == 21. - - -def test_pass_readonly_array(): - z = np.full((5, 6), 42.0) - z.flags.writeable = False - np.testing.assert_array_equal(z, m.fixed_copy_r(z)) - np.testing.assert_array_equal(m.fixed_r_const(), m.fixed_r()) - assert not m.fixed_r_const().flags.writeable - np.testing.assert_array_equal(m.fixed_copy_r(m.fixed_r_const()), m.fixed_r_const()) - - -def test_nonunit_stride_from_python(): - counting_mat = np.arange(9.0, dtype=np.float32).reshape((3, 3)) - second_row = counting_mat[1, :] - second_col = counting_mat[:, 1] - np.testing.assert_array_equal(m.double_row(second_row), 2.0 * second_row) - np.testing.assert_array_equal(m.double_col(second_row), 2.0 * second_row) - np.testing.assert_array_equal(m.double_complex(second_row), 2.0 * second_row) - np.testing.assert_array_equal(m.double_row(second_col), 2.0 * second_col) - np.testing.assert_array_equal(m.double_col(second_col), 2.0 * second_col) - np.testing.assert_array_equal(m.double_complex(second_col), 2.0 * second_col) - - counting_3d = np.arange(27.0, dtype=np.float32).reshape((3, 3, 3)) - slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]] - for slice_idx, ref_mat in enumerate(slices): - np.testing.assert_array_equal(m.double_mat_cm(ref_mat), 2.0 * ref_mat) - np.testing.assert_array_equal(m.double_mat_rm(ref_mat), 2.0 * ref_mat) - - # Mutator: - m.double_threer(second_row) - m.double_threec(second_col) - np.testing.assert_array_equal(counting_mat, [[0., 2, 2], [6, 16, 10], [6, 14, 8]]) - - -def test_negative_stride_from_python(msg): - """Eigen doesn't support (as of yet) negative strides. When a function takes an Eigen matrix by - copy or const reference, we can pass a numpy array that has negative strides. Otherwise, an - exception will be thrown as Eigen will not be able to map the numpy array.""" - - counting_mat = np.arange(9.0, dtype=np.float32).reshape((3, 3)) - counting_mat = counting_mat[::-1, ::-1] - second_row = counting_mat[1, :] - second_col = counting_mat[:, 1] - np.testing.assert_array_equal(m.double_row(second_row), 2.0 * second_row) - np.testing.assert_array_equal(m.double_col(second_row), 2.0 * second_row) - np.testing.assert_array_equal(m.double_complex(second_row), 2.0 * second_row) - np.testing.assert_array_equal(m.double_row(second_col), 2.0 * second_col) - np.testing.assert_array_equal(m.double_col(second_col), 2.0 * second_col) - np.testing.assert_array_equal(m.double_complex(second_col), 2.0 * second_col) - - counting_3d = np.arange(27.0, dtype=np.float32).reshape((3, 3, 3)) - counting_3d = counting_3d[::-1, ::-1, ::-1] - slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]] - for slice_idx, ref_mat in enumerate(slices): - np.testing.assert_array_equal(m.double_mat_cm(ref_mat), 2.0 * ref_mat) - np.testing.assert_array_equal(m.double_mat_rm(ref_mat), 2.0 * ref_mat) - - # Mutator: - with pytest.raises(TypeError) as excinfo: - m.double_threer(second_row) - assert msg(excinfo.value) == """ - double_threer(): incompatible function arguments. The following argument types are supported: - 1. (arg0: numpy.ndarray[float32[1, 3], flags.writeable]) -> None - - Invoked with: """ + repr(np.array([ 5., 4., 3.], dtype='float32')) # noqa: E501 line too long - - with pytest.raises(TypeError) as excinfo: - m.double_threec(second_col) - assert msg(excinfo.value) == """ - double_threec(): incompatible function arguments. The following argument types are supported: - 1. (arg0: numpy.ndarray[float32[3, 1], flags.writeable]) -> None - - Invoked with: """ + repr(np.array([ 7., 4., 1.], dtype='float32')) # noqa: E501 line too long - - -def test_nonunit_stride_to_python(): - assert np.all(m.diagonal(ref) == ref.diagonal()) - assert np.all(m.diagonal_1(ref) == ref.diagonal(1)) - for i in range(-5, 7): - assert np.all(m.diagonal_n(ref, i) == ref.diagonal(i)), "m.diagonal_n({})".format(i) - - assert np.all(m.block(ref, 2, 1, 3, 3) == ref[2:5, 1:4]) - assert np.all(m.block(ref, 1, 4, 4, 2) == ref[1:, 4:]) - assert np.all(m.block(ref, 1, 4, 3, 2) == ref[1:4, 4:]) - - -def test_eigen_ref_to_python(): - chols = [m.cholesky1, m.cholesky2, m.cholesky3, m.cholesky4] - for i, chol in enumerate(chols, start=1): - mymat = chol(np.array([[1., 2, 4], [2, 13, 23], [4, 23, 77]])) - assert np.all(mymat == np.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]])), "cholesky{}".format(i) - - -def assign_both(a1, a2, r, c, v): - a1[r, c] = v - a2[r, c] = v - - -def array_copy_but_one(a, r, c, v): - z = np.array(a, copy=True) - z[r, c] = v - return z - - -def test_eigen_return_references(): - """Tests various ways of returning references and non-referencing copies""" - - master = np.ones((10, 10)) - a = m.ReturnTester() - a_get1 = a.get() - assert not a_get1.flags.owndata and a_get1.flags.writeable - assign_both(a_get1, master, 3, 3, 5) - a_get2 = a.get_ptr() - assert not a_get2.flags.owndata and a_get2.flags.writeable - assign_both(a_get1, master, 2, 3, 6) - - a_view1 = a.view() - assert not a_view1.flags.owndata and not a_view1.flags.writeable - with pytest.raises(ValueError): - a_view1[2, 3] = 4 - a_view2 = a.view_ptr() - assert not a_view2.flags.owndata and not a_view2.flags.writeable - with pytest.raises(ValueError): - a_view2[2, 3] = 4 - - a_copy1 = a.copy_get() - assert a_copy1.flags.owndata and a_copy1.flags.writeable - np.testing.assert_array_equal(a_copy1, master) - a_copy1[7, 7] = -44 # Shouldn't affect anything else - c1want = array_copy_but_one(master, 7, 7, -44) - a_copy2 = a.copy_view() - assert a_copy2.flags.owndata and a_copy2.flags.writeable - np.testing.assert_array_equal(a_copy2, master) - a_copy2[4, 4] = -22 # Shouldn't affect anything else - c2want = array_copy_but_one(master, 4, 4, -22) - - a_ref1 = a.ref() - assert not a_ref1.flags.owndata and a_ref1.flags.writeable - assign_both(a_ref1, master, 1, 1, 15) - a_ref2 = a.ref_const() - assert not a_ref2.flags.owndata and not a_ref2.flags.writeable - with pytest.raises(ValueError): - a_ref2[5, 5] = 33 - a_ref3 = a.ref_safe() - assert not a_ref3.flags.owndata and a_ref3.flags.writeable - assign_both(a_ref3, master, 0, 7, 99) - a_ref4 = a.ref_const_safe() - assert not a_ref4.flags.owndata and not a_ref4.flags.writeable - with pytest.raises(ValueError): - a_ref4[7, 0] = 987654321 - - a_copy3 = a.copy_ref() - assert a_copy3.flags.owndata and a_copy3.flags.writeable - np.testing.assert_array_equal(a_copy3, master) - a_copy3[8, 1] = 11 - c3want = array_copy_but_one(master, 8, 1, 11) - a_copy4 = a.copy_ref_const() - assert a_copy4.flags.owndata and a_copy4.flags.writeable - np.testing.assert_array_equal(a_copy4, master) - a_copy4[8, 4] = 88 - c4want = array_copy_but_one(master, 8, 4, 88) - - a_block1 = a.block(3, 3, 2, 2) - assert not a_block1.flags.owndata and a_block1.flags.writeable - a_block1[0, 0] = 55 - master[3, 3] = 55 - a_block2 = a.block_safe(2, 2, 3, 2) - assert not a_block2.flags.owndata and a_block2.flags.writeable - a_block2[2, 1] = -123 - master[4, 3] = -123 - a_block3 = a.block_const(6, 7, 4, 3) - assert not a_block3.flags.owndata and not a_block3.flags.writeable - with pytest.raises(ValueError): - a_block3[2, 2] = -44444 - - a_copy5 = a.copy_block(2, 2, 2, 3) - assert a_copy5.flags.owndata and a_copy5.flags.writeable - np.testing.assert_array_equal(a_copy5, master[2:4, 2:5]) - a_copy5[1, 1] = 777 - c5want = array_copy_but_one(master[2:4, 2:5], 1, 1, 777) - - a_corn1 = a.corners() - assert not a_corn1.flags.owndata and a_corn1.flags.writeable - a_corn1 *= 50 - a_corn1[1, 1] = 999 - master[0, 0] = 50 - master[0, 9] = 50 - master[9, 0] = 50 - master[9, 9] = 999 - a_corn2 = a.corners_const() - assert not a_corn2.flags.owndata and not a_corn2.flags.writeable - with pytest.raises(ValueError): - a_corn2[1, 0] = 51 - - # All of the changes made all the way along should be visible everywhere - # now (except for the copies, of course) - np.testing.assert_array_equal(a_get1, master) - np.testing.assert_array_equal(a_get2, master) - np.testing.assert_array_equal(a_view1, master) - np.testing.assert_array_equal(a_view2, master) - np.testing.assert_array_equal(a_ref1, master) - np.testing.assert_array_equal(a_ref2, master) - np.testing.assert_array_equal(a_ref3, master) - np.testing.assert_array_equal(a_ref4, master) - np.testing.assert_array_equal(a_block1, master[3:5, 3:5]) - np.testing.assert_array_equal(a_block2, master[2:5, 2:4]) - np.testing.assert_array_equal(a_block3, master[6:10, 7:10]) - np.testing.assert_array_equal(a_corn1, master[0::master.shape[0] - 1, 0::master.shape[1] - 1]) - np.testing.assert_array_equal(a_corn2, master[0::master.shape[0] - 1, 0::master.shape[1] - 1]) - - np.testing.assert_array_equal(a_copy1, c1want) - np.testing.assert_array_equal(a_copy2, c2want) - np.testing.assert_array_equal(a_copy3, c3want) - np.testing.assert_array_equal(a_copy4, c4want) - np.testing.assert_array_equal(a_copy5, c5want) - - -def assert_keeps_alive(cl, method, *args): - cstats = ConstructorStats.get(cl) - start_with = cstats.alive() - a = cl() - assert cstats.alive() == start_with + 1 - z = method(a, *args) - assert cstats.alive() == start_with + 1 - del a - # Here's the keep alive in action: - assert cstats.alive() == start_with + 1 - del z - # Keep alive should have expired: - assert cstats.alive() == start_with - - -def test_eigen_keepalive(): - a = m.ReturnTester() - cstats = ConstructorStats.get(m.ReturnTester) - assert cstats.alive() == 1 - unsafe = [a.ref(), a.ref_const(), a.block(1, 2, 3, 4)] - copies = [a.copy_get(), a.copy_view(), a.copy_ref(), a.copy_ref_const(), - a.copy_block(4, 3, 2, 1)] - del a - assert cstats.alive() == 0 - del unsafe - del copies - - for meth in [m.ReturnTester.get, m.ReturnTester.get_ptr, m.ReturnTester.view, - m.ReturnTester.view_ptr, m.ReturnTester.ref_safe, m.ReturnTester.ref_const_safe, - m.ReturnTester.corners, m.ReturnTester.corners_const]: - assert_keeps_alive(m.ReturnTester, meth) - - for meth in [m.ReturnTester.block_safe, m.ReturnTester.block_const]: - assert_keeps_alive(m.ReturnTester, meth, 4, 3, 2, 1) - - -def test_eigen_ref_mutators(): - """Tests Eigen's ability to mutate numpy values""" - - orig = np.array([[1., 2, 3], [4, 5, 6], [7, 8, 9]]) - zr = np.array(orig) - zc = np.array(orig, order='F') - m.add_rm(zr, 1, 0, 100) - assert np.all(zr == np.array([[1., 2, 3], [104, 5, 6], [7, 8, 9]])) - m.add_cm(zc, 1, 0, 200) - assert np.all(zc == np.array([[1., 2, 3], [204, 5, 6], [7, 8, 9]])) - - m.add_any(zr, 1, 0, 20) - assert np.all(zr == np.array([[1., 2, 3], [124, 5, 6], [7, 8, 9]])) - m.add_any(zc, 1, 0, 10) - assert np.all(zc == np.array([[1., 2, 3], [214, 5, 6], [7, 8, 9]])) - - # Can't reference a col-major array with a row-major Ref, and vice versa: - with pytest.raises(TypeError): - m.add_rm(zc, 1, 0, 1) - with pytest.raises(TypeError): - m.add_cm(zr, 1, 0, 1) - - # Overloads: - m.add1(zr, 1, 0, -100) - m.add2(zr, 1, 0, -20) - assert np.all(zr == orig) - m.add1(zc, 1, 0, -200) - m.add2(zc, 1, 0, -10) - assert np.all(zc == orig) - - # a non-contiguous slice (this won't work on either the row- or - # column-contiguous refs, but should work for the any) - cornersr = zr[0::2, 0::2] - cornersc = zc[0::2, 0::2] - - assert np.all(cornersr == np.array([[1., 3], [7, 9]])) - assert np.all(cornersc == np.array([[1., 3], [7, 9]])) - - with pytest.raises(TypeError): - m.add_rm(cornersr, 0, 1, 25) - with pytest.raises(TypeError): - m.add_cm(cornersr, 0, 1, 25) - with pytest.raises(TypeError): - m.add_rm(cornersc, 0, 1, 25) - with pytest.raises(TypeError): - m.add_cm(cornersc, 0, 1, 25) - m.add_any(cornersr, 0, 1, 25) - m.add_any(cornersc, 0, 1, 44) - assert np.all(zr == np.array([[1., 2, 28], [4, 5, 6], [7, 8, 9]])) - assert np.all(zc == np.array([[1., 2, 47], [4, 5, 6], [7, 8, 9]])) - - # You shouldn't be allowed to pass a non-writeable array to a mutating Eigen method: - zro = zr[0:4, 0:4] - zro.flags.writeable = False - with pytest.raises(TypeError): - m.add_rm(zro, 0, 0, 0) - with pytest.raises(TypeError): - m.add_any(zro, 0, 0, 0) - with pytest.raises(TypeError): - m.add1(zro, 0, 0, 0) - with pytest.raises(TypeError): - m.add2(zro, 0, 0, 0) - - # integer array shouldn't be passable to a double-matrix-accepting mutating func: - zi = np.array([[1, 2], [3, 4]]) - with pytest.raises(TypeError): - m.add_rm(zi) - - -def test_numpy_ref_mutators(): - """Tests numpy mutating Eigen matrices (for returned Eigen::Ref<...>s)""" - - m.reset_refs() # In case another test already changed it - - zc = m.get_cm_ref() - zcro = m.get_cm_const_ref() - zr = m.get_rm_ref() - zrro = m.get_rm_const_ref() - - assert [zc[1, 2], zcro[1, 2], zr[1, 2], zrro[1, 2]] == [23] * 4 - - assert not zc.flags.owndata and zc.flags.writeable - assert not zr.flags.owndata and zr.flags.writeable - assert not zcro.flags.owndata and not zcro.flags.writeable - assert not zrro.flags.owndata and not zrro.flags.writeable - - zc[1, 2] = 99 - expect = np.array([[11., 12, 13], [21, 22, 99], [31, 32, 33]]) - # We should have just changed zc, of course, but also zcro and the original eigen matrix - assert np.all(zc == expect) - assert np.all(zcro == expect) - assert np.all(m.get_cm_ref() == expect) - - zr[1, 2] = 99 - assert np.all(zr == expect) - assert np.all(zrro == expect) - assert np.all(m.get_rm_ref() == expect) - - # Make sure the readonly ones are numpy-readonly: - with pytest.raises(ValueError): - zcro[1, 2] = 6 - with pytest.raises(ValueError): - zrro[1, 2] = 6 - - # We should be able to explicitly copy like this (and since we're copying, - # the const should drop away) - y1 = np.array(m.get_cm_const_ref()) - - assert y1.flags.owndata and y1.flags.writeable - # We should get copies of the eigen data, which was modified above: - assert y1[1, 2] == 99 - y1[1, 2] += 12 - assert y1[1, 2] == 111 - assert zc[1, 2] == 99 # Make sure we aren't referencing the original - - -def test_both_ref_mutators(): - """Tests a complex chain of nested eigen/numpy references""" - - m.reset_refs() # In case another test already changed it - - z = m.get_cm_ref() # numpy -> eigen - z[0, 2] -= 3 - z2 = m.incr_matrix(z, 1) # numpy -> eigen -> numpy -> eigen - z2[1, 1] += 6 - z3 = m.incr_matrix(z, 2) # (numpy -> eigen)^3 - z3[2, 2] += -5 - z4 = m.incr_matrix(z, 3) # (numpy -> eigen)^4 - z4[1, 1] -= 1 - z5 = m.incr_matrix(z, 4) # (numpy -> eigen)^5 - z5[0, 0] = 0 - assert np.all(z == z2) - assert np.all(z == z3) - assert np.all(z == z4) - assert np.all(z == z5) - expect = np.array([[0., 22, 20], [31, 37, 33], [41, 42, 38]]) - assert np.all(z == expect) - - y = np.array(range(100), dtype='float64').reshape(10, 10) - y2 = m.incr_matrix_any(y, 10) # np -> eigen -> np - y3 = m.incr_matrix_any(y2[0::2, 0::2], -33) # np -> eigen -> np slice -> np -> eigen -> np - y4 = m.even_rows(y3) # numpy -> eigen slice -> (... y3) - y5 = m.even_cols(y4) # numpy -> eigen slice -> (... y4) - y6 = m.incr_matrix_any(y5, 1000) # numpy -> eigen -> (... y5) - - # Apply same mutations using just numpy: - yexpect = np.array(range(100), dtype='float64').reshape(10, 10) - yexpect += 10 - yexpect[0::2, 0::2] -= 33 - yexpect[0::4, 0::4] += 1000 - assert np.all(y6 == yexpect[0::4, 0::4]) - assert np.all(y5 == yexpect[0::4, 0::4]) - assert np.all(y4 == yexpect[0::4, 0::2]) - assert np.all(y3 == yexpect[0::2, 0::2]) - assert np.all(y2 == yexpect) - assert np.all(y == yexpect) - - -def test_nocopy_wrapper(): - # get_elem requires a column-contiguous matrix reference, but should be - # callable with other types of matrix (via copying): - int_matrix_colmajor = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], order='F') - dbl_matrix_colmajor = np.array(int_matrix_colmajor, dtype='double', order='F', copy=True) - int_matrix_rowmajor = np.array(int_matrix_colmajor, order='C', copy=True) - dbl_matrix_rowmajor = np.array(int_matrix_rowmajor, dtype='double', order='C', copy=True) - - # All should be callable via get_elem: - assert m.get_elem(int_matrix_colmajor) == 8 - assert m.get_elem(dbl_matrix_colmajor) == 8 - assert m.get_elem(int_matrix_rowmajor) == 8 - assert m.get_elem(dbl_matrix_rowmajor) == 8 - - # All but the second should fail with m.get_elem_nocopy: - with pytest.raises(TypeError) as excinfo: - m.get_elem_nocopy(int_matrix_colmajor) - assert ('get_elem_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.f_contiguous' in str(excinfo.value)) - assert m.get_elem_nocopy(dbl_matrix_colmajor) == 8 - with pytest.raises(TypeError) as excinfo: - m.get_elem_nocopy(int_matrix_rowmajor) - assert ('get_elem_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.f_contiguous' in str(excinfo.value)) - with pytest.raises(TypeError) as excinfo: - m.get_elem_nocopy(dbl_matrix_rowmajor) - assert ('get_elem_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.f_contiguous' in str(excinfo.value)) - - # For the row-major test, we take a long matrix in row-major, so only the third is allowed: - with pytest.raises(TypeError) as excinfo: - m.get_elem_rm_nocopy(int_matrix_colmajor) - assert ('get_elem_rm_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.c_contiguous' in str(excinfo.value)) - with pytest.raises(TypeError) as excinfo: - m.get_elem_rm_nocopy(dbl_matrix_colmajor) - assert ('get_elem_rm_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.c_contiguous' in str(excinfo.value)) - assert m.get_elem_rm_nocopy(int_matrix_rowmajor) == 8 - with pytest.raises(TypeError) as excinfo: - m.get_elem_rm_nocopy(dbl_matrix_rowmajor) - assert ('get_elem_rm_nocopy(): incompatible function arguments.' in str(excinfo.value) and - ', flags.c_contiguous' in str(excinfo.value)) - - -def test_eigen_ref_life_support(): - """Ensure the lifetime of temporary arrays created by the `Ref` caster - - The `Ref` caster sometimes creates a copy which needs to stay alive. This needs to - happen both for directs casts (just the array) or indirectly (e.g. list of arrays). - """ - - a = np.full(shape=10, fill_value=8, dtype=np.int8) - assert m.get_elem_direct(a) == 8 - - list_of_a = [a] - assert m.get_elem_indirect(list_of_a) == 8 - - -def test_special_matrix_objects(): - assert np.all(m.incr_diag(7) == np.diag([1., 2, 3, 4, 5, 6, 7])) - - asymm = np.array([[ 1., 2, 3, 4], - [ 5, 6, 7, 8], - [ 9, 10, 11, 12], - [13, 14, 15, 16]]) - symm_lower = np.array(asymm) - symm_upper = np.array(asymm) - for i in range(4): - for j in range(i + 1, 4): - symm_lower[i, j] = symm_lower[j, i] - symm_upper[j, i] = symm_upper[i, j] - - assert np.all(m.symmetric_lower(asymm) == symm_lower) - assert np.all(m.symmetric_upper(asymm) == symm_upper) - - -def test_dense_signature(doc): - assert doc(m.double_col) == """ - double_col(arg0: numpy.ndarray[float32[m, 1]]) -> numpy.ndarray[float32[m, 1]] - """ - assert doc(m.double_row) == """ - double_row(arg0: numpy.ndarray[float32[1, n]]) -> numpy.ndarray[float32[1, n]] - """ - assert doc(m.double_complex) == """ - double_complex(arg0: numpy.ndarray[complex64[m, 1]]) -> numpy.ndarray[complex64[m, 1]] - """ - assert doc(m.double_mat_rm) == """ - double_mat_rm(arg0: numpy.ndarray[float32[m, n]]) -> numpy.ndarray[float32[m, n]] - """ - - -def test_named_arguments(): - a = np.array([[1.0, 2], [3, 4], [5, 6]]) - b = np.ones((2, 1)) - - assert np.all(m.matrix_multiply(a, b) == np.array([[3.], [7], [11]])) - assert np.all(m.matrix_multiply(A=a, B=b) == np.array([[3.], [7], [11]])) - assert np.all(m.matrix_multiply(B=b, A=a) == np.array([[3.], [7], [11]])) - - with pytest.raises(ValueError) as excinfo: - m.matrix_multiply(b, a) - assert str(excinfo.value) == 'Nonconformable matrices!' - - with pytest.raises(ValueError) as excinfo: - m.matrix_multiply(A=b, B=a) - assert str(excinfo.value) == 'Nonconformable matrices!' - - with pytest.raises(ValueError) as excinfo: - m.matrix_multiply(B=a, A=b) - assert str(excinfo.value) == 'Nonconformable matrices!' - - -@pytest.requires_eigen_and_scipy -def test_sparse(): - assert_sparse_equal_ref(m.sparse_r()) - assert_sparse_equal_ref(m.sparse_c()) - assert_sparse_equal_ref(m.sparse_copy_r(m.sparse_r())) - assert_sparse_equal_ref(m.sparse_copy_c(m.sparse_c())) - assert_sparse_equal_ref(m.sparse_copy_r(m.sparse_c())) - assert_sparse_equal_ref(m.sparse_copy_c(m.sparse_r())) - - -@pytest.requires_eigen_and_scipy -def test_sparse_signature(doc): - assert doc(m.sparse_copy_r) == """ - sparse_copy_r(arg0: scipy.sparse.csr_matrix[float32]) -> scipy.sparse.csr_matrix[float32] - """ # noqa: E501 line too long - assert doc(m.sparse_copy_c) == """ - sparse_copy_c(arg0: scipy.sparse.csc_matrix[float32]) -> scipy.sparse.csc_matrix[float32] - """ # noqa: E501 line too long - - -def test_issue738(): - """Ignore strides on a length-1 dimension (even if they would be incompatible length > 1)""" - assert np.all(m.iss738_f1(np.array([[1., 2, 3]])) == np.array([[1., 102, 203]])) - assert np.all(m.iss738_f1(np.array([[1.], [2], [3]])) == np.array([[1.], [12], [23]])) - - assert np.all(m.iss738_f2(np.array([[1., 2, 3]])) == np.array([[1., 102, 203]])) - assert np.all(m.iss738_f2(np.array([[1.], [2], [3]])) == np.array([[1.], [12], [23]])) - - -def test_issue1105(): - """Issue 1105: 1xN or Nx1 input arrays weren't accepted for eigen - compile-time row vectors or column vector""" - assert m.iss1105_row(np.ones((1, 7))) - assert m.iss1105_col(np.ones((7, 1))) - - # These should still fail (incompatible dimensions): - with pytest.raises(TypeError) as excinfo: - m.iss1105_row(np.ones((7, 1))) - assert "incompatible function arguments" in str(excinfo.value) - with pytest.raises(TypeError) as excinfo: - m.iss1105_col(np.ones((1, 7))) - assert "incompatible function arguments" in str(excinfo.value) - - -def test_custom_operator_new(): - """Using Eigen types as member variables requires a class-specific - operator new with proper alignment""" - - o = m.CustomOperatorNew() - np.testing.assert_allclose(o.a, 0.0) - np.testing.assert_allclose(o.b.diagonal(), 1.0) diff --git a/pybind11/tests/test_embed/CMakeLists.txt b/pybind11/tests/test_embed/CMakeLists.txt deleted file mode 100644 index 8b4f1f8..0000000 --- a/pybind11/tests/test_embed/CMakeLists.txt +++ /dev/null @@ -1,41 +0,0 @@ -if(${PYTHON_MODULE_EXTENSION} MATCHES "pypy") - add_custom_target(cpptest) # Dummy target on PyPy. Embedding is not supported. - set(_suppress_unused_variable_warning "${DOWNLOAD_CATCH}") - return() -endif() - -find_package(Catch 1.9.3) -if(CATCH_FOUND) - message(STATUS "Building interpreter tests using Catch v${CATCH_VERSION}") -else() - message(STATUS "Catch not detected. Interpreter tests will be skipped. Install Catch headers" - " manually or use `cmake -DDOWNLOAD_CATCH=1` to fetch them automatically.") - return() -endif() - -add_executable(test_embed - catch.cpp - test_interpreter.cpp -) -target_include_directories(test_embed PRIVATE ${CATCH_INCLUDE_DIR}) -pybind11_enable_warnings(test_embed) - -if(NOT CMAKE_VERSION VERSION_LESS 3.0) - target_link_libraries(test_embed PRIVATE pybind11::embed) -else() - target_include_directories(test_embed PRIVATE ${PYBIND11_INCLUDE_DIR} ${PYTHON_INCLUDE_DIRS}) - target_compile_options(test_embed PRIVATE ${PYBIND11_CPP_STANDARD}) - target_link_libraries(test_embed PRIVATE ${PYTHON_LIBRARIES}) -endif() - -find_package(Threads REQUIRED) -target_link_libraries(test_embed PUBLIC ${CMAKE_THREAD_LIBS_INIT}) - -add_custom_target(cpptest COMMAND $ - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) - -pybind11_add_module(external_module THIN_LTO external_module.cpp) -set_target_properties(external_module PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) -add_dependencies(cpptest external_module) - -add_dependencies(check cpptest) diff --git a/pybind11/tests/test_embed/catch.cpp b/pybind11/tests/test_embed/catch.cpp deleted file mode 100644 index dd13738..0000000 --- a/pybind11/tests/test_embed/catch.cpp +++ /dev/null @@ -1,22 +0,0 @@ -// The Catch implementation is compiled here. This is a standalone -// translation unit to avoid recompiling it for every test change. - -#include - -#ifdef _MSC_VER -// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to catch -// 2.0.1; this should be fixed in the next catch release after 2.0.1). -# pragma warning(disable: 4996) -#endif - -#define CATCH_CONFIG_RUNNER -#include - -namespace py = pybind11; - -int main(int argc, char *argv[]) { - py::scoped_interpreter guard{}; - auto result = Catch::Session().run(argc, argv); - - return result < 0xff ? result : 0xff; -} diff --git a/pybind11/tests/test_embed/external_module.cpp b/pybind11/tests/test_embed/external_module.cpp deleted file mode 100644 index e9a6058..0000000 --- a/pybind11/tests/test_embed/external_module.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include - -namespace py = pybind11; - -/* Simple test module/test class to check that the referenced internals data of external pybind11 - * modules aren't preserved over a finalize/initialize. - */ - -PYBIND11_MODULE(external_module, m) { - class A { - public: - A(int value) : v{value} {}; - int v; - }; - - py::class_(m, "A") - .def(py::init()) - .def_readwrite("value", &A::v); - - m.def("internals_at", []() { - return reinterpret_cast(&py::detail::get_internals()); - }); -} diff --git a/pybind11/tests/test_embed/test_interpreter.cpp b/pybind11/tests/test_embed/test_interpreter.cpp deleted file mode 100644 index 222bd56..0000000 --- a/pybind11/tests/test_embed/test_interpreter.cpp +++ /dev/null @@ -1,284 +0,0 @@ -#include - -#ifdef _MSC_VER -// Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to catch -// 2.0.1; this should be fixed in the next catch release after 2.0.1). -# pragma warning(disable: 4996) -#endif - -#include - -#include -#include -#include - -namespace py = pybind11; -using namespace py::literals; - -class Widget { -public: - Widget(std::string message) : message(message) { } - virtual ~Widget() = default; - - std::string the_message() const { return message; } - virtual int the_answer() const = 0; - -private: - std::string message; -}; - -class PyWidget final : public Widget { - using Widget::Widget; - - int the_answer() const override { PYBIND11_OVERLOAD_PURE(int, Widget, the_answer); } -}; - -PYBIND11_EMBEDDED_MODULE(widget_module, m) { - py::class_(m, "Widget") - .def(py::init()) - .def_property_readonly("the_message", &Widget::the_message); - - m.def("add", [](int i, int j) { return i + j; }); -} - -PYBIND11_EMBEDDED_MODULE(throw_exception, ) { - throw std::runtime_error("C++ Error"); -} - -PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) { - auto d = py::dict(); - d["missing"].cast(); -} - -TEST_CASE("Pass classes and data between modules defined in C++ and Python") { - auto module = py::module::import("test_interpreter"); - REQUIRE(py::hasattr(module, "DerivedWidget")); - - auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module.attr("__dict__")); - py::exec(R"( - widget = DerivedWidget("{} - {}".format(hello, x)) - message = widget.the_message - )", py::globals(), locals); - REQUIRE(locals["message"].cast() == "Hello, World! - 5"); - - auto py_widget = module.attr("DerivedWidget")("The question"); - auto message = py_widget.attr("the_message"); - REQUIRE(message.cast() == "The question"); - - const auto &cpp_widget = py_widget.cast(); - REQUIRE(cpp_widget.the_answer() == 42); -} - -TEST_CASE("Import error handling") { - REQUIRE_NOTHROW(py::module::import("widget_module")); - REQUIRE_THROWS_WITH(py::module::import("throw_exception"), - "ImportError: C++ Error"); - REQUIRE_THROWS_WITH(py::module::import("throw_error_already_set"), - Catch::Contains("ImportError: KeyError")); -} - -TEST_CASE("There can be only one interpreter") { - static_assert(std::is_move_constructible::value, ""); - static_assert(!std::is_move_assignable::value, ""); - static_assert(!std::is_copy_constructible::value, ""); - static_assert(!std::is_copy_assignable::value, ""); - - REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running"); - REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running"); - - py::finalize_interpreter(); - REQUIRE_NOTHROW(py::scoped_interpreter()); - { - auto pyi1 = py::scoped_interpreter(); - auto pyi2 = std::move(pyi1); - } - py::initialize_interpreter(); -} - -bool has_pybind11_internals_builtin() { - auto builtins = py::handle(PyEval_GetBuiltins()); - return builtins.contains(PYBIND11_INTERNALS_ID); -}; - -bool has_pybind11_internals_static() { - auto **&ipp = py::detail::get_internals_pp(); - return ipp && *ipp; -} - -TEST_CASE("Restart the interpreter") { - // Verify pre-restart state. - REQUIRE(py::module::import("widget_module").attr("add")(1, 2).cast() == 3); - REQUIRE(has_pybind11_internals_builtin()); - REQUIRE(has_pybind11_internals_static()); - REQUIRE(py::module::import("external_module").attr("A")(123).attr("value").cast() == 123); - - // local and foreign module internals should point to the same internals: - REQUIRE(reinterpret_cast(*py::detail::get_internals_pp()) == - py::module::import("external_module").attr("internals_at")().cast()); - - // Restart the interpreter. - py::finalize_interpreter(); - REQUIRE(Py_IsInitialized() == 0); - - py::initialize_interpreter(); - REQUIRE(Py_IsInitialized() == 1); - - // Internals are deleted after a restart. - REQUIRE_FALSE(has_pybind11_internals_builtin()); - REQUIRE_FALSE(has_pybind11_internals_static()); - pybind11::detail::get_internals(); - REQUIRE(has_pybind11_internals_builtin()); - REQUIRE(has_pybind11_internals_static()); - REQUIRE(reinterpret_cast(*py::detail::get_internals_pp()) == - py::module::import("external_module").attr("internals_at")().cast()); - - // Make sure that an interpreter with no get_internals() created until finalize still gets the - // internals destroyed - py::finalize_interpreter(); - py::initialize_interpreter(); - bool ran = false; - py::module::import("__main__").attr("internals_destroy_test") = - py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast(ran) = true; }); - REQUIRE_FALSE(has_pybind11_internals_builtin()); - REQUIRE_FALSE(has_pybind11_internals_static()); - REQUIRE_FALSE(ran); - py::finalize_interpreter(); - REQUIRE(ran); - py::initialize_interpreter(); - REQUIRE_FALSE(has_pybind11_internals_builtin()); - REQUIRE_FALSE(has_pybind11_internals_static()); - - // C++ modules can be reloaded. - auto cpp_module = py::module::import("widget_module"); - REQUIRE(cpp_module.attr("add")(1, 2).cast() == 3); - - // C++ type information is reloaded and can be used in python modules. - auto py_module = py::module::import("test_interpreter"); - auto py_widget = py_module.attr("DerivedWidget")("Hello after restart"); - REQUIRE(py_widget.attr("the_message").cast() == "Hello after restart"); -} - -TEST_CASE("Subinterpreter") { - // Add tags to the modules in the main interpreter and test the basics. - py::module::import("__main__").attr("main_tag") = "main interpreter"; - { - auto m = py::module::import("widget_module"); - m.attr("extension_module_tag") = "added to module in main interpreter"; - - REQUIRE(m.attr("add")(1, 2).cast() == 3); - } - REQUIRE(has_pybind11_internals_builtin()); - REQUIRE(has_pybind11_internals_static()); - - /// Create and switch to a subinterpreter. - auto main_tstate = PyThreadState_Get(); - auto sub_tstate = Py_NewInterpreter(); - - // Subinterpreters get their own copy of builtins. detail::get_internals() still - // works by returning from the static variable, i.e. all interpreters share a single - // global pybind11::internals; - REQUIRE_FALSE(has_pybind11_internals_builtin()); - REQUIRE(has_pybind11_internals_static()); - - // Modules tags should be gone. - REQUIRE_FALSE(py::hasattr(py::module::import("__main__"), "tag")); - { - auto m = py::module::import("widget_module"); - REQUIRE_FALSE(py::hasattr(m, "extension_module_tag")); - - // Function bindings should still work. - REQUIRE(m.attr("add")(1, 2).cast() == 3); - } - - // Restore main interpreter. - Py_EndInterpreter(sub_tstate); - PyThreadState_Swap(main_tstate); - - REQUIRE(py::hasattr(py::module::import("__main__"), "main_tag")); - REQUIRE(py::hasattr(py::module::import("widget_module"), "extension_module_tag")); -} - -TEST_CASE("Execution frame") { - // When the interpreter is embedded, there is no execution frame, but `py::exec` - // should still function by using reasonable globals: `__main__.__dict__`. - py::exec("var = dict(number=42)"); - REQUIRE(py::globals()["var"]["number"].cast() == 42); -} - -TEST_CASE("Threads") { - // Restart interpreter to ensure threads are not initialized - py::finalize_interpreter(); - py::initialize_interpreter(); - REQUIRE_FALSE(has_pybind11_internals_static()); - - constexpr auto num_threads = 10; - auto locals = py::dict("count"_a=0); - - { - py::gil_scoped_release gil_release{}; - REQUIRE(has_pybind11_internals_static()); - - auto threads = std::vector(); - for (auto i = 0; i < num_threads; ++i) { - threads.emplace_back([&]() { - py::gil_scoped_acquire gil{}; - locals["count"] = locals["count"].cast() + 1; - }); - } - - for (auto &thread : threads) { - thread.join(); - } - } - - REQUIRE(locals["count"].cast() == num_threads); -} - -// Scope exit utility https://stackoverflow.com/a/36644501/7255855 -struct scope_exit { - std::function f_; - explicit scope_exit(std::function f) noexcept : f_(std::move(f)) {} - ~scope_exit() { if (f_) f_(); } -}; - -TEST_CASE("Reload module from file") { - // Disable generation of cached bytecode (.pyc files) for this test, otherwise - // Python might pick up an old version from the cache instead of the new versions - // of the .py files generated below - auto sys = py::module::import("sys"); - bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast(); - sys.attr("dont_write_bytecode") = true; - // Reset the value at scope exit - scope_exit reset_dont_write_bytecode([&]() { - sys.attr("dont_write_bytecode") = dont_write_bytecode; - }); - - std::string module_name = "test_module_reload"; - std::string module_file = module_name + ".py"; - - // Create the module .py file - std::ofstream test_module(module_file); - test_module << "def test():\n"; - test_module << " return 1\n"; - test_module.close(); - // Delete the file at scope exit - scope_exit delete_module_file([&]() { - std::remove(module_file.c_str()); - }); - - // Import the module from file - auto module = py::module::import(module_name.c_str()); - int result = module.attr("test")().cast(); - REQUIRE(result == 1); - - // Update the module .py file with a small change - test_module.open(module_file); - test_module << "def test():\n"; - test_module << " return 2\n"; - test_module.close(); - - // Reload the module - module.reload(); - result = module.attr("test")().cast(); - REQUIRE(result == 2); -} diff --git a/pybind11/tests/test_embed/test_interpreter.py b/pybind11/tests/test_embed/test_interpreter.py deleted file mode 100644 index 26a0479..0000000 --- a/pybind11/tests/test_embed/test_interpreter.py +++ /dev/null @@ -1,9 +0,0 @@ -from widget_module import Widget - - -class DerivedWidget(Widget): - def __init__(self, message): - super(DerivedWidget, self).__init__(message) - - def the_answer(self): - return 42 diff --git a/pybind11/tests/test_enum.cpp b/pybind11/tests/test_enum.cpp deleted file mode 100644 index 3153089..0000000 --- a/pybind11/tests/test_enum.cpp +++ /dev/null @@ -1,87 +0,0 @@ -/* - tests/test_enums.cpp -- enumerations - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" - -TEST_SUBMODULE(enums, m) { - // test_unscoped_enum - enum UnscopedEnum { - EOne = 1, - ETwo, - EThree - }; - py::enum_(m, "UnscopedEnum", py::arithmetic(), "An unscoped enumeration") - .value("EOne", EOne, "Docstring for EOne") - .value("ETwo", ETwo, "Docstring for ETwo") - .value("EThree", EThree, "Docstring for EThree") - .export_values(); - - // test_scoped_enum - enum class ScopedEnum { - Two = 2, - Three - }; - py::enum_(m, "ScopedEnum", py::arithmetic()) - .value("Two", ScopedEnum::Two) - .value("Three", ScopedEnum::Three); - - m.def("test_scoped_enum", [](ScopedEnum z) { - return "ScopedEnum::" + std::string(z == ScopedEnum::Two ? "Two" : "Three"); - }); - - // test_binary_operators - enum Flags { - Read = 4, - Write = 2, - Execute = 1 - }; - py::enum_(m, "Flags", py::arithmetic()) - .value("Read", Flags::Read) - .value("Write", Flags::Write) - .value("Execute", Flags::Execute) - .export_values(); - - // test_implicit_conversion - class ClassWithUnscopedEnum { - public: - enum EMode { - EFirstMode = 1, - ESecondMode - }; - - static EMode test_function(EMode mode) { - return mode; - } - }; - py::class_ exenum_class(m, "ClassWithUnscopedEnum"); - exenum_class.def_static("test_function", &ClassWithUnscopedEnum::test_function); - py::enum_(exenum_class, "EMode") - .value("EFirstMode", ClassWithUnscopedEnum::EFirstMode) - .value("ESecondMode", ClassWithUnscopedEnum::ESecondMode) - .export_values(); - - // test_enum_to_int - m.def("test_enum_to_int", [](int) { }); - m.def("test_enum_to_uint", [](uint32_t) { }); - m.def("test_enum_to_long_long", [](long long) { }); - - // test_duplicate_enum_name - enum SimpleEnum - { - ONE, TWO, THREE - }; - - m.def("register_bad_enum", [m]() { - py::enum_(m, "SimpleEnum") - .value("ONE", SimpleEnum::ONE) //NOTE: all value function calls are called with the same first parameter value - .value("ONE", SimpleEnum::TWO) - .value("ONE", SimpleEnum::THREE) - .export_values(); - }); -} diff --git a/pybind11/tests/test_enum.py b/pybind11/tests/test_enum.py deleted file mode 100644 index 7fe9b61..0000000 --- a/pybind11/tests/test_enum.py +++ /dev/null @@ -1,206 +0,0 @@ -import pytest -from pybind11_tests import enums as m - - -def test_unscoped_enum(): - assert str(m.UnscopedEnum.EOne) == "UnscopedEnum.EOne" - assert str(m.UnscopedEnum.ETwo) == "UnscopedEnum.ETwo" - assert str(m.EOne) == "UnscopedEnum.EOne" - - # name property - assert m.UnscopedEnum.EOne.name == "EOne" - assert m.UnscopedEnum.ETwo.name == "ETwo" - assert m.EOne.name == "EOne" - # name readonly - with pytest.raises(AttributeError): - m.UnscopedEnum.EOne.name = "" - # name returns a copy - foo = m.UnscopedEnum.EOne.name - foo = "bar" - assert m.UnscopedEnum.EOne.name == "EOne" - - # __members__ property - assert m.UnscopedEnum.__members__ == \ - {"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo, "EThree": m.UnscopedEnum.EThree} - # __members__ readonly - with pytest.raises(AttributeError): - m.UnscopedEnum.__members__ = {} - # __members__ returns a copy - foo = m.UnscopedEnum.__members__ - foo["bar"] = "baz" - assert m.UnscopedEnum.__members__ == \ - {"EOne": m.UnscopedEnum.EOne, "ETwo": m.UnscopedEnum.ETwo, "EThree": m.UnscopedEnum.EThree} - - for docstring_line in '''An unscoped enumeration - -Members: - - EOne : Docstring for EOne - - ETwo : Docstring for ETwo - - EThree : Docstring for EThree'''.split('\n'): - assert docstring_line in m.UnscopedEnum.__doc__ - - # Unscoped enums will accept ==/!= int comparisons - y = m.UnscopedEnum.ETwo - assert y == 2 - assert 2 == y - assert y != 3 - assert 3 != y - # Compare with None - assert (y != None) # noqa: E711 - assert not (y == None) # noqa: E711 - # Compare with an object - assert (y != object()) - assert not (y == object()) - # Compare with string - assert y != "2" - assert "2" != y - assert not ("2" == y) - assert not (y == "2") - - with pytest.raises(TypeError): - y < object() - - with pytest.raises(TypeError): - y <= object() - - with pytest.raises(TypeError): - y > object() - - with pytest.raises(TypeError): - y >= object() - - with pytest.raises(TypeError): - y | object() - - with pytest.raises(TypeError): - y & object() - - with pytest.raises(TypeError): - y ^ object() - - assert int(m.UnscopedEnum.ETwo) == 2 - assert str(m.UnscopedEnum(2)) == "UnscopedEnum.ETwo" - - # order - assert m.UnscopedEnum.EOne < m.UnscopedEnum.ETwo - assert m.UnscopedEnum.EOne < 2 - assert m.UnscopedEnum.ETwo > m.UnscopedEnum.EOne - assert m.UnscopedEnum.ETwo > 1 - assert m.UnscopedEnum.ETwo <= 2 - assert m.UnscopedEnum.ETwo >= 2 - assert m.UnscopedEnum.EOne <= m.UnscopedEnum.ETwo - assert m.UnscopedEnum.EOne <= 2 - assert m.UnscopedEnum.ETwo >= m.UnscopedEnum.EOne - assert m.UnscopedEnum.ETwo >= 1 - assert not (m.UnscopedEnum.ETwo < m.UnscopedEnum.EOne) - assert not (2 < m.UnscopedEnum.EOne) - - # arithmetic - assert m.UnscopedEnum.EOne & m.UnscopedEnum.EThree == m.UnscopedEnum.EOne - assert m.UnscopedEnum.EOne | m.UnscopedEnum.ETwo == m.UnscopedEnum.EThree - assert m.UnscopedEnum.EOne ^ m.UnscopedEnum.EThree == m.UnscopedEnum.ETwo - - -def test_scoped_enum(): - assert m.test_scoped_enum(m.ScopedEnum.Three) == "ScopedEnum::Three" - z = m.ScopedEnum.Two - assert m.test_scoped_enum(z) == "ScopedEnum::Two" - - # Scoped enums will *NOT* accept ==/!= int comparisons (Will always return False) - assert not z == 3 - assert not 3 == z - assert z != 3 - assert 3 != z - # Compare with None - assert (z != None) # noqa: E711 - assert not (z == None) # noqa: E711 - # Compare with an object - assert (z != object()) - assert not (z == object()) - # Scoped enums will *NOT* accept >, <, >= and <= int comparisons (Will throw exceptions) - with pytest.raises(TypeError): - z > 3 - with pytest.raises(TypeError): - z < 3 - with pytest.raises(TypeError): - z >= 3 - with pytest.raises(TypeError): - z <= 3 - - # order - assert m.ScopedEnum.Two < m.ScopedEnum.Three - assert m.ScopedEnum.Three > m.ScopedEnum.Two - assert m.ScopedEnum.Two <= m.ScopedEnum.Three - assert m.ScopedEnum.Two <= m.ScopedEnum.Two - assert m.ScopedEnum.Two >= m.ScopedEnum.Two - assert m.ScopedEnum.Three >= m.ScopedEnum.Two - - -def test_implicit_conversion(): - assert str(m.ClassWithUnscopedEnum.EMode.EFirstMode) == "EMode.EFirstMode" - assert str(m.ClassWithUnscopedEnum.EFirstMode) == "EMode.EFirstMode" - - f = m.ClassWithUnscopedEnum.test_function - first = m.ClassWithUnscopedEnum.EFirstMode - second = m.ClassWithUnscopedEnum.ESecondMode - - assert f(first) == 1 - - assert f(first) == f(first) - assert not f(first) != f(first) - - assert f(first) != f(second) - assert not f(first) == f(second) - - assert f(first) == int(f(first)) - assert not f(first) != int(f(first)) - - assert f(first) != int(f(second)) - assert not f(first) == int(f(second)) - - # noinspection PyDictCreation - x = {f(first): 1, f(second): 2} - x[f(first)] = 3 - x[f(second)] = 4 - # Hashing test - assert str(x) == "{EMode.EFirstMode: 3, EMode.ESecondMode: 4}" - - -def test_binary_operators(): - assert int(m.Flags.Read) == 4 - assert int(m.Flags.Write) == 2 - assert int(m.Flags.Execute) == 1 - assert int(m.Flags.Read | m.Flags.Write | m.Flags.Execute) == 7 - assert int(m.Flags.Read | m.Flags.Write) == 6 - assert int(m.Flags.Read | m.Flags.Execute) == 5 - assert int(m.Flags.Write | m.Flags.Execute) == 3 - assert int(m.Flags.Write | 1) == 3 - assert ~m.Flags.Write == -3 - - state = m.Flags.Read | m.Flags.Write - assert (state & m.Flags.Read) != 0 - assert (state & m.Flags.Write) != 0 - assert (state & m.Flags.Execute) == 0 - assert (state & 1) == 0 - - state2 = ~state - assert state2 == -7 - assert int(state ^ state2) == -1 - - -def test_enum_to_int(): - m.test_enum_to_int(m.Flags.Read) - m.test_enum_to_int(m.ClassWithUnscopedEnum.EMode.EFirstMode) - m.test_enum_to_uint(m.Flags.Read) - m.test_enum_to_uint(m.ClassWithUnscopedEnum.EMode.EFirstMode) - m.test_enum_to_long_long(m.Flags.Read) - m.test_enum_to_long_long(m.ClassWithUnscopedEnum.EMode.EFirstMode) - - -def test_duplicate_enum_name(): - with pytest.raises(ValueError) as excinfo: - m.register_bad_enum() - assert str(excinfo.value) == 'SimpleEnum: element "ONE" already exists!' diff --git a/pybind11/tests/test_eval.cpp b/pybind11/tests/test_eval.cpp deleted file mode 100644 index e094821..0000000 --- a/pybind11/tests/test_eval.cpp +++ /dev/null @@ -1,91 +0,0 @@ -/* - tests/test_eval.cpp -- Usage of eval() and eval_file() - - Copyright (c) 2016 Klemens D. Morgenstern - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - - -#include -#include "pybind11_tests.h" - -TEST_SUBMODULE(eval_, m) { - // test_evals - - auto global = py::dict(py::module::import("__main__").attr("__dict__")); - - m.def("test_eval_statements", [global]() { - auto local = py::dict(); - local["call_test"] = py::cpp_function([&]() -> int { - return 42; - }); - - // Regular string literal - py::exec( - "message = 'Hello World!'\n" - "x = call_test()", - global, local - ); - - // Multi-line raw string literal - py::exec(R"( - if x == 42: - print(message) - else: - raise RuntimeError - )", global, local - ); - auto x = local["x"].cast(); - - return x == 42; - }); - - m.def("test_eval", [global]() { - auto local = py::dict(); - local["x"] = py::int_(42); - auto x = py::eval("x", global, local); - return x.cast() == 42; - }); - - m.def("test_eval_single_statement", []() { - auto local = py::dict(); - local["call_test"] = py::cpp_function([&]() -> int { - return 42; - }); - - auto result = py::eval("x = call_test()", py::dict(), local); - auto x = local["x"].cast(); - return result.is_none() && x == 42; - }); - - m.def("test_eval_file", [global](py::str filename) { - auto local = py::dict(); - local["y"] = py::int_(43); - - int val_out; - local["call_test2"] = py::cpp_function([&](int value) { val_out = value; }); - - auto result = py::eval_file(filename, global, local); - return val_out == 43 && result.is_none(); - }); - - m.def("test_eval_failure", []() { - try { - py::eval("nonsense code ..."); - } catch (py::error_already_set &) { - return true; - } - return false; - }); - - m.def("test_eval_file_failure", []() { - try { - py::eval_file("non-existing file"); - } catch (std::exception &) { - return true; - } - return false; - }); -} diff --git a/pybind11/tests/test_eval.py b/pybind11/tests/test_eval.py deleted file mode 100644 index bda4ef6..0000000 --- a/pybind11/tests/test_eval.py +++ /dev/null @@ -1,17 +0,0 @@ -import os -from pybind11_tests import eval_ as m - - -def test_evals(capture): - with capture: - assert m.test_eval_statements() - assert capture == "Hello World!" - - assert m.test_eval() - assert m.test_eval_single_statement() - - filename = os.path.join(os.path.dirname(__file__), "test_eval_call.py") - assert m.test_eval_file(filename) - - assert m.test_eval_failure() - assert m.test_eval_file_failure() diff --git a/pybind11/tests/test_eval_call.py b/pybind11/tests/test_eval_call.py deleted file mode 100644 index 53c7e72..0000000 --- a/pybind11/tests/test_eval_call.py +++ /dev/null @@ -1,4 +0,0 @@ -# This file is called from 'test_eval.py' - -if 'call_test2' in locals(): - call_test2(y) # noqa: F821 undefined name diff --git a/pybind11/tests/test_exceptions.cpp b/pybind11/tests/test_exceptions.cpp deleted file mode 100644 index d301390..0000000 --- a/pybind11/tests/test_exceptions.cpp +++ /dev/null @@ -1,196 +0,0 @@ -/* - tests/test_custom-exceptions.cpp -- exception translation - - Copyright (c) 2016 Pim Schellart - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" - -// A type that should be raised as an exception in Python -class MyException : public std::exception { -public: - explicit MyException(const char * m) : message{m} {} - virtual const char * what() const noexcept override {return message.c_str();} -private: - std::string message = ""; -}; - -// A type that should be translated to a standard Python exception -class MyException2 : public std::exception { -public: - explicit MyException2(const char * m) : message{m} {} - virtual const char * what() const noexcept override {return message.c_str();} -private: - std::string message = ""; -}; - -// A type that is not derived from std::exception (and is thus unknown) -class MyException3 { -public: - explicit MyException3(const char * m) : message{m} {} - virtual const char * what() const noexcept {return message.c_str();} -private: - std::string message = ""; -}; - -// A type that should be translated to MyException -// and delegated to its exception translator -class MyException4 : public std::exception { -public: - explicit MyException4(const char * m) : message{m} {} - virtual const char * what() const noexcept override {return message.c_str();} -private: - std::string message = ""; -}; - - -// Like the above, but declared via the helper function -class MyException5 : public std::logic_error { -public: - explicit MyException5(const std::string &what) : std::logic_error(what) {} -}; - -// Inherits from MyException5 -class MyException5_1 : public MyException5 { - using MyException5::MyException5; -}; - -struct PythonCallInDestructor { - PythonCallInDestructor(const py::dict &d) : d(d) {} - ~PythonCallInDestructor() { d["good"] = true; } - - py::dict d; -}; - -TEST_SUBMODULE(exceptions, m) { - m.def("throw_std_exception", []() { - throw std::runtime_error("This exception was intentionally thrown."); - }); - - // make a new custom exception and use it as a translation target - static py::exception ex(m, "MyException"); - py::register_exception_translator([](std::exception_ptr p) { - try { - if (p) std::rethrow_exception(p); - } catch (const MyException &e) { - // Set MyException as the active python error - ex(e.what()); - } - }); - - // register new translator for MyException2 - // no need to store anything here because this type will - // never by visible from Python - py::register_exception_translator([](std::exception_ptr p) { - try { - if (p) std::rethrow_exception(p); - } catch (const MyException2 &e) { - // Translate this exception to a standard RuntimeError - PyErr_SetString(PyExc_RuntimeError, e.what()); - } - }); - - // register new translator for MyException4 - // which will catch it and delegate to the previously registered - // translator for MyException by throwing a new exception - py::register_exception_translator([](std::exception_ptr p) { - try { - if (p) std::rethrow_exception(p); - } catch (const MyException4 &e) { - throw MyException(e.what()); - } - }); - - // A simple exception translation: - auto ex5 = py::register_exception(m, "MyException5"); - // A slightly more complicated one that declares MyException5_1 as a subclass of MyException5 - py::register_exception(m, "MyException5_1", ex5.ptr()); - - m.def("throws1", []() { throw MyException("this error should go to a custom type"); }); - m.def("throws2", []() { throw MyException2("this error should go to a standard Python exception"); }); - m.def("throws3", []() { throw MyException3("this error cannot be translated"); }); - m.def("throws4", []() { throw MyException4("this error is rethrown"); }); - m.def("throws5", []() { throw MyException5("this is a helper-defined translated exception"); }); - m.def("throws5_1", []() { throw MyException5_1("MyException5 subclass"); }); - m.def("throws_logic_error", []() { throw std::logic_error("this error should fall through to the standard handler"); }); - m.def("exception_matches", []() { - py::dict foo; - try { - // Assign to a py::object to force read access of nonexistent dict entry - py::object o = foo["bar"]; - } - catch (py::error_already_set& ex) { - if (!ex.matches(PyExc_KeyError)) throw; - return true; - } - return false; - }); - m.def("exception_matches_base", []() { - py::dict foo; - try { - // Assign to a py::object to force read access of nonexistent dict entry - py::object o = foo["bar"]; - } - catch (py::error_already_set &ex) { - if (!ex.matches(PyExc_Exception)) throw; - return true; - } - return false; - }); - m.def("modulenotfound_exception_matches_base", []() { - try { - // On Python >= 3.6, this raises a ModuleNotFoundError, a subclass of ImportError - py::module::import("nonexistent"); - } - catch (py::error_already_set &ex) { - if (!ex.matches(PyExc_ImportError)) throw; - return true; - } - return false; - }); - - m.def("throw_already_set", [](bool err) { - if (err) - PyErr_SetString(PyExc_ValueError, "foo"); - try { - throw py::error_already_set(); - } catch (const std::runtime_error& e) { - if ((err && e.what() != std::string("ValueError: foo")) || - (!err && e.what() != std::string("Unknown internal error occurred"))) - { - PyErr_Clear(); - throw std::runtime_error("error message mismatch"); - } - } - PyErr_Clear(); - if (err) - PyErr_SetString(PyExc_ValueError, "foo"); - throw py::error_already_set(); - }); - - m.def("python_call_in_destructor", [](py::dict d) { - try { - PythonCallInDestructor set_dict_in_destructor(d); - PyErr_SetString(PyExc_ValueError, "foo"); - throw py::error_already_set(); - } catch (const py::error_already_set&) { - return true; - } - return false; - }); - - // test_nested_throws - m.def("try_catch", [m](py::object exc_type, py::function f, py::args args) { - try { f(*args); } - catch (py::error_already_set &ex) { - if (ex.matches(exc_type)) - py::print(ex.what()); - else - throw; - } - }); - -} diff --git a/pybind11/tests/test_exceptions.py b/pybind11/tests/test_exceptions.py deleted file mode 100644 index 6edff9f..0000000 --- a/pybind11/tests/test_exceptions.py +++ /dev/null @@ -1,146 +0,0 @@ -import pytest - -from pybind11_tests import exceptions as m -import pybind11_cross_module_tests as cm - - -def test_std_exception(msg): - with pytest.raises(RuntimeError) as excinfo: - m.throw_std_exception() - assert msg(excinfo.value) == "This exception was intentionally thrown." - - -def test_error_already_set(msg): - with pytest.raises(RuntimeError) as excinfo: - m.throw_already_set(False) - assert msg(excinfo.value) == "Unknown internal error occurred" - - with pytest.raises(ValueError) as excinfo: - m.throw_already_set(True) - assert msg(excinfo.value) == "foo" - - -def test_cross_module_exceptions(): - with pytest.raises(RuntimeError) as excinfo: - cm.raise_runtime_error() - assert str(excinfo.value) == "My runtime error" - - with pytest.raises(ValueError) as excinfo: - cm.raise_value_error() - assert str(excinfo.value) == "My value error" - - with pytest.raises(ValueError) as excinfo: - cm.throw_pybind_value_error() - assert str(excinfo.value) == "pybind11 value error" - - with pytest.raises(TypeError) as excinfo: - cm.throw_pybind_type_error() - assert str(excinfo.value) == "pybind11 type error" - - with pytest.raises(StopIteration) as excinfo: - cm.throw_stop_iteration() - - -def test_python_call_in_catch(): - d = {} - assert m.python_call_in_destructor(d) is True - assert d["good"] is True - - -def test_exception_matches(): - assert m.exception_matches() - assert m.exception_matches_base() - assert m.modulenotfound_exception_matches_base() - - -def test_custom(msg): - # Can we catch a MyException? - with pytest.raises(m.MyException) as excinfo: - m.throws1() - assert msg(excinfo.value) == "this error should go to a custom type" - - # Can we translate to standard Python exceptions? - with pytest.raises(RuntimeError) as excinfo: - m.throws2() - assert msg(excinfo.value) == "this error should go to a standard Python exception" - - # Can we handle unknown exceptions? - with pytest.raises(RuntimeError) as excinfo: - m.throws3() - assert msg(excinfo.value) == "Caught an unknown exception!" - - # Can we delegate to another handler by rethrowing? - with pytest.raises(m.MyException) as excinfo: - m.throws4() - assert msg(excinfo.value) == "this error is rethrown" - - # Can we fall-through to the default handler? - with pytest.raises(RuntimeError) as excinfo: - m.throws_logic_error() - assert msg(excinfo.value) == "this error should fall through to the standard handler" - - # Can we handle a helper-declared exception? - with pytest.raises(m.MyException5) as excinfo: - m.throws5() - assert msg(excinfo.value) == "this is a helper-defined translated exception" - - # Exception subclassing: - with pytest.raises(m.MyException5) as excinfo: - m.throws5_1() - assert msg(excinfo.value) == "MyException5 subclass" - assert isinstance(excinfo.value, m.MyException5_1) - - with pytest.raises(m.MyException5_1) as excinfo: - m.throws5_1() - assert msg(excinfo.value) == "MyException5 subclass" - - with pytest.raises(m.MyException5) as excinfo: - try: - m.throws5() - except m.MyException5_1: - raise RuntimeError("Exception error: caught child from parent") - assert msg(excinfo.value) == "this is a helper-defined translated exception" - - -def test_nested_throws(capture): - """Tests nested (e.g. C++ -> Python -> C++) exception handling""" - - def throw_myex(): - raise m.MyException("nested error") - - def throw_myex5(): - raise m.MyException5("nested error 5") - - # In the comments below, the exception is caught in the first step, thrown in the last step - - # C++ -> Python - with capture: - m.try_catch(m.MyException5, throw_myex5) - assert str(capture).startswith("MyException5: nested error 5") - - # Python -> C++ -> Python - with pytest.raises(m.MyException) as excinfo: - m.try_catch(m.MyException5, throw_myex) - assert str(excinfo.value) == "nested error" - - def pycatch(exctype, f, *args): - try: - f(*args) - except m.MyException as e: - print(e) - - # C++ -> Python -> C++ -> Python - with capture: - m.try_catch( - m.MyException5, pycatch, m.MyException, m.try_catch, m.MyException, throw_myex5) - assert str(capture).startswith("MyException5: nested error 5") - - # C++ -> Python -> C++ - with capture: - m.try_catch(m.MyException, pycatch, m.MyException5, m.throws4) - assert capture == "this error is rethrown" - - # Python -> C++ -> Python -> C++ - with pytest.raises(m.MyException5) as excinfo: - m.try_catch(m.MyException, pycatch, m.MyException, m.throws5) - assert str(excinfo.value) == "this is a helper-defined translated exception" diff --git a/pybind11/tests/test_factory_constructors.cpp b/pybind11/tests/test_factory_constructors.cpp deleted file mode 100644 index 5cfbfdc..0000000 --- a/pybind11/tests/test_factory_constructors.cpp +++ /dev/null @@ -1,338 +0,0 @@ -/* - tests/test_factory_constructors.cpp -- tests construction from a factory function - via py::init_factory() - - Copyright (c) 2017 Jason Rhinelander - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" -#include - -// Classes for testing python construction via C++ factory function: -// Not publicly constructible, copyable, or movable: -class TestFactory1 { - friend class TestFactoryHelper; - TestFactory1() : value("(empty)") { print_default_created(this); } - TestFactory1(int v) : value(std::to_string(v)) { print_created(this, value); } - TestFactory1(std::string v) : value(std::move(v)) { print_created(this, value); } - TestFactory1(TestFactory1 &&) = delete; - TestFactory1(const TestFactory1 &) = delete; - TestFactory1 &operator=(TestFactory1 &&) = delete; - TestFactory1 &operator=(const TestFactory1 &) = delete; -public: - std::string value; - ~TestFactory1() { print_destroyed(this); } -}; -// Non-public construction, but moveable: -class TestFactory2 { - friend class TestFactoryHelper; - TestFactory2() : value("(empty2)") { print_default_created(this); } - TestFactory2(int v) : value(std::to_string(v)) { print_created(this, value); } - TestFactory2(std::string v) : value(std::move(v)) { print_created(this, value); } -public: - TestFactory2(TestFactory2 &&m) { value = std::move(m.value); print_move_created(this); } - TestFactory2 &operator=(TestFactory2 &&m) { value = std::move(m.value); print_move_assigned(this); return *this; } - std::string value; - ~TestFactory2() { print_destroyed(this); } -}; -// Mixed direct/factory construction: -class TestFactory3 { -protected: - friend class TestFactoryHelper; - TestFactory3() : value("(empty3)") { print_default_created(this); } - TestFactory3(int v) : value(std::to_string(v)) { print_created(this, value); } -public: - TestFactory3(std::string v) : value(std::move(v)) { print_created(this, value); } - TestFactory3(TestFactory3 &&m) { value = std::move(m.value); print_move_created(this); } - TestFactory3 &operator=(TestFactory3 &&m) { value = std::move(m.value); print_move_assigned(this); return *this; } - std::string value; - virtual ~TestFactory3() { print_destroyed(this); } -}; -// Inheritance test -class TestFactory4 : public TestFactory3 { -public: - TestFactory4() : TestFactory3() { print_default_created(this); } - TestFactory4(int v) : TestFactory3(v) { print_created(this, v); } - virtual ~TestFactory4() { print_destroyed(this); } -}; -// Another class for an invalid downcast test -class TestFactory5 : public TestFactory3 { -public: - TestFactory5(int i) : TestFactory3(i) { print_created(this, i); } - virtual ~TestFactory5() { print_destroyed(this); } -}; - -class TestFactory6 { -protected: - int value; - bool alias = false; -public: - TestFactory6(int i) : value{i} { print_created(this, i); } - TestFactory6(TestFactory6 &&f) { print_move_created(this); value = f.value; alias = f.alias; } - TestFactory6(const TestFactory6 &f) { print_copy_created(this); value = f.value; alias = f.alias; } - virtual ~TestFactory6() { print_destroyed(this); } - virtual int get() { return value; } - bool has_alias() { return alias; } -}; -class PyTF6 : public TestFactory6 { -public: - // Special constructor that allows the factory to construct a PyTF6 from a TestFactory6 only - // when an alias is needed: - PyTF6(TestFactory6 &&base) : TestFactory6(std::move(base)) { alias = true; print_created(this, "move", value); } - PyTF6(int i) : TestFactory6(i) { alias = true; print_created(this, i); } - PyTF6(PyTF6 &&f) : TestFactory6(std::move(f)) { print_move_created(this); } - PyTF6(const PyTF6 &f) : TestFactory6(f) { print_copy_created(this); } - PyTF6(std::string s) : TestFactory6((int) s.size()) { alias = true; print_created(this, s); } - virtual ~PyTF6() { print_destroyed(this); } - int get() override { PYBIND11_OVERLOAD(int, TestFactory6, get, /*no args*/); } -}; - -class TestFactory7 { -protected: - int value; - bool alias = false; -public: - TestFactory7(int i) : value{i} { print_created(this, i); } - TestFactory7(TestFactory7 &&f) { print_move_created(this); value = f.value; alias = f.alias; } - TestFactory7(const TestFactory7 &f) { print_copy_created(this); value = f.value; alias = f.alias; } - virtual ~TestFactory7() { print_destroyed(this); } - virtual int get() { return value; } - bool has_alias() { return alias; } -}; -class PyTF7 : public TestFactory7 { -public: - PyTF7(int i) : TestFactory7(i) { alias = true; print_created(this, i); } - PyTF7(PyTF7 &&f) : TestFactory7(std::move(f)) { print_move_created(this); } - PyTF7(const PyTF7 &f) : TestFactory7(f) { print_copy_created(this); } - virtual ~PyTF7() { print_destroyed(this); } - int get() override { PYBIND11_OVERLOAD(int, TestFactory7, get, /*no args*/); } -}; - - -class TestFactoryHelper { -public: - // Non-movable, non-copyable type: - // Return via pointer: - static TestFactory1 *construct1() { return new TestFactory1(); } - // Holder: - static std::unique_ptr construct1(int a) { return std::unique_ptr(new TestFactory1(a)); } - // pointer again - static TestFactory1 *construct1_string(std::string a) { return new TestFactory1(a); } - - // Moveable type: - // pointer: - static TestFactory2 *construct2() { return new TestFactory2(); } - // holder: - static std::unique_ptr construct2(int a) { return std::unique_ptr(new TestFactory2(a)); } - // by value moving: - static TestFactory2 construct2(std::string a) { return TestFactory2(a); } - - // shared_ptr holder type: - // pointer: - static TestFactory3 *construct3() { return new TestFactory3(); } - // holder: - static std::shared_ptr construct3(int a) { return std::shared_ptr(new TestFactory3(a)); } -}; - -TEST_SUBMODULE(factory_constructors, m) { - - // Define various trivial types to allow simpler overload resolution: - py::module m_tag = m.def_submodule("tag"); -#define MAKE_TAG_TYPE(Name) \ - struct Name##_tag {}; \ - py::class_(m_tag, #Name "_tag").def(py::init<>()); \ - m_tag.attr(#Name) = py::cast(Name##_tag{}) - MAKE_TAG_TYPE(pointer); - MAKE_TAG_TYPE(unique_ptr); - MAKE_TAG_TYPE(move); - MAKE_TAG_TYPE(shared_ptr); - MAKE_TAG_TYPE(derived); - MAKE_TAG_TYPE(TF4); - MAKE_TAG_TYPE(TF5); - MAKE_TAG_TYPE(null_ptr); - MAKE_TAG_TYPE(base); - MAKE_TAG_TYPE(invalid_base); - MAKE_TAG_TYPE(alias); - MAKE_TAG_TYPE(unaliasable); - MAKE_TAG_TYPE(mixed); - - // test_init_factory_basic, test_bad_type - py::class_(m, "TestFactory1") - .def(py::init([](unique_ptr_tag, int v) { return TestFactoryHelper::construct1(v); })) - .def(py::init(&TestFactoryHelper::construct1_string)) // raw function pointer - .def(py::init([](pointer_tag) { return TestFactoryHelper::construct1(); })) - .def(py::init([](py::handle, int v, py::handle) { return TestFactoryHelper::construct1(v); })) - .def_readwrite("value", &TestFactory1::value) - ; - py::class_(m, "TestFactory2") - .def(py::init([](pointer_tag, int v) { return TestFactoryHelper::construct2(v); })) - .def(py::init([](unique_ptr_tag, std::string v) { return TestFactoryHelper::construct2(v); })) - .def(py::init([](move_tag) { return TestFactoryHelper::construct2(); })) - .def_readwrite("value", &TestFactory2::value) - ; - - // Stateful & reused: - int c = 1; - auto c4a = [c](pointer_tag, TF4_tag, int a) { (void) c; return new TestFactory4(a);}; - - // test_init_factory_basic, test_init_factory_casting - py::class_>(m, "TestFactory3") - .def(py::init([](pointer_tag, int v) { return TestFactoryHelper::construct3(v); })) - .def(py::init([](shared_ptr_tag) { return TestFactoryHelper::construct3(); })) - .def("__init__", [](TestFactory3 &self, std::string v) { new (&self) TestFactory3(v); }) // placement-new ctor - - // factories returning a derived type: - .def(py::init(c4a)) // derived ptr - .def(py::init([](pointer_tag, TF5_tag, int a) { return new TestFactory5(a); })) - // derived shared ptr: - .def(py::init([](shared_ptr_tag, TF4_tag, int a) { return std::make_shared(a); })) - .def(py::init([](shared_ptr_tag, TF5_tag, int a) { return std::make_shared(a); })) - - // Returns nullptr: - .def(py::init([](null_ptr_tag) { return (TestFactory3 *) nullptr; })) - - .def_readwrite("value", &TestFactory3::value) - ; - - // test_init_factory_casting - py::class_>(m, "TestFactory4") - .def(py::init(c4a)) // pointer - ; - - // Doesn't need to be registered, but registering makes getting ConstructorStats easier: - py::class_>(m, "TestFactory5"); - - // test_init_factory_alias - // Alias testing - py::class_(m, "TestFactory6") - .def(py::init([](base_tag, int i) { return TestFactory6(i); })) - .def(py::init([](alias_tag, int i) { return PyTF6(i); })) - .def(py::init([](alias_tag, std::string s) { return PyTF6(s); })) - .def(py::init([](alias_tag, pointer_tag, int i) { return new PyTF6(i); })) - .def(py::init([](base_tag, pointer_tag, int i) { return new TestFactory6(i); })) - .def(py::init([](base_tag, alias_tag, pointer_tag, int i) { return (TestFactory6 *) new PyTF6(i); })) - - .def("get", &TestFactory6::get) - .def("has_alias", &TestFactory6::has_alias) - - .def_static("get_cstats", &ConstructorStats::get, py::return_value_policy::reference) - .def_static("get_alias_cstats", &ConstructorStats::get, py::return_value_policy::reference) - ; - - // test_init_factory_dual - // Separate alias constructor testing - py::class_>(m, "TestFactory7") - .def(py::init( - [](int i) { return TestFactory7(i); }, - [](int i) { return PyTF7(i); })) - .def(py::init( - [](pointer_tag, int i) { return new TestFactory7(i); }, - [](pointer_tag, int i) { return new PyTF7(i); })) - .def(py::init( - [](mixed_tag, int i) { return new TestFactory7(i); }, - [](mixed_tag, int i) { return PyTF7(i); })) - .def(py::init( - [](mixed_tag, std::string s) { return TestFactory7((int) s.size()); }, - [](mixed_tag, std::string s) { return new PyTF7((int) s.size()); })) - .def(py::init( - [](base_tag, pointer_tag, int i) { return new TestFactory7(i); }, - [](base_tag, pointer_tag, int i) { return (TestFactory7 *) new PyTF7(i); })) - .def(py::init( - [](alias_tag, pointer_tag, int i) { return new PyTF7(i); }, - [](alias_tag, pointer_tag, int i) { return new PyTF7(10*i); })) - .def(py::init( - [](shared_ptr_tag, base_tag, int i) { return std::make_shared(i); }, - [](shared_ptr_tag, base_tag, int i) { auto *p = new PyTF7(i); return std::shared_ptr(p); })) - .def(py::init( - [](shared_ptr_tag, invalid_base_tag, int i) { return std::make_shared(i); }, - [](shared_ptr_tag, invalid_base_tag, int i) { return std::make_shared(i); })) // <-- invalid alias factory - - .def("get", &TestFactory7::get) - .def("has_alias", &TestFactory7::has_alias) - - .def_static("get_cstats", &ConstructorStats::get, py::return_value_policy::reference) - .def_static("get_alias_cstats", &ConstructorStats::get, py::return_value_policy::reference) - ; - - // test_placement_new_alternative - // Class with a custom new operator but *without* a placement new operator (issue #948) - class NoPlacementNew { - public: - NoPlacementNew(int i) : i(i) { } - static void *operator new(std::size_t s) { - auto *p = ::operator new(s); - py::print("operator new called, returning", reinterpret_cast(p)); - return p; - } - static void operator delete(void *p) { - py::print("operator delete called on", reinterpret_cast(p)); - ::operator delete(p); - } - int i; - }; - // As of 2.2, `py::init` no longer requires placement new - py::class_(m, "NoPlacementNew") - .def(py::init()) - .def(py::init([]() { return new NoPlacementNew(100); })) - .def_readwrite("i", &NoPlacementNew::i) - ; - - - // test_reallocations - // Class that has verbose operator_new/operator_delete calls - struct NoisyAlloc { - NoisyAlloc(const NoisyAlloc &) = default; - NoisyAlloc(int i) { py::print(py::str("NoisyAlloc(int {})").format(i)); } - NoisyAlloc(double d) { py::print(py::str("NoisyAlloc(double {})").format(d)); } - ~NoisyAlloc() { py::print("~NoisyAlloc()"); } - - static void *operator new(size_t s) { py::print("noisy new"); return ::operator new(s); } - static void *operator new(size_t, void *p) { py::print("noisy placement new"); return p; } - static void operator delete(void *p, size_t) { py::print("noisy delete"); ::operator delete(p); } - static void operator delete(void *, void *) { py::print("noisy placement delete"); } -#if defined(_MSC_VER) && _MSC_VER < 1910 - // MSVC 2015 bug: the above "noisy delete" isn't invoked (fixed in MSVC 2017) - static void operator delete(void *p) { py::print("noisy delete"); ::operator delete(p); } -#endif - }; - py::class_(m, "NoisyAlloc") - // Since these overloads have the same number of arguments, the dispatcher will try each of - // them until the arguments convert. Thus we can get a pre-allocation here when passing a - // single non-integer: - .def("__init__", [](NoisyAlloc *a, int i) { new (a) NoisyAlloc(i); }) // Regular constructor, runs first, requires preallocation - .def(py::init([](double d) { return new NoisyAlloc(d); })) - - // The two-argument version: first the factory pointer overload. - .def(py::init([](int i, int) { return new NoisyAlloc(i); })) - // Return-by-value: - .def(py::init([](double d, int) { return NoisyAlloc(d); })) - // Old-style placement new init; requires preallocation - .def("__init__", [](NoisyAlloc &a, double d, double) { new (&a) NoisyAlloc(d); }) - // Requires deallocation of previous overload preallocated value: - .def(py::init([](int i, double) { return new NoisyAlloc(i); })) - // Regular again: requires yet another preallocation - .def("__init__", [](NoisyAlloc &a, int i, std::string) { new (&a) NoisyAlloc(i); }) - ; - - - - - // static_assert testing (the following def's should all fail with appropriate compilation errors): -#if 0 - struct BadF1Base {}; - struct BadF1 : BadF1Base {}; - struct PyBadF1 : BadF1 {}; - py::class_> bf1(m, "BadF1"); - // wrapped factory function must return a compatible pointer, holder, or value - bf1.def(py::init([]() { return 3; })); - // incompatible factory function pointer return type - bf1.def(py::init([]() { static int three = 3; return &three; })); - // incompatible factory function std::shared_ptr return type: cannot convert shared_ptr to holder - // (non-polymorphic base) - bf1.def(py::init([]() { return std::shared_ptr(new BadF1()); })); -#endif -} diff --git a/pybind11/tests/test_factory_constructors.py b/pybind11/tests/test_factory_constructors.py deleted file mode 100644 index 78a3910..0000000 --- a/pybind11/tests/test_factory_constructors.py +++ /dev/null @@ -1,459 +0,0 @@ -import pytest -import re - -from pybind11_tests import factory_constructors as m -from pybind11_tests.factory_constructors import tag -from pybind11_tests import ConstructorStats - - -def test_init_factory_basic(): - """Tests py::init_factory() wrapper around various ways of returning the object""" - - cstats = [ConstructorStats.get(c) for c in [m.TestFactory1, m.TestFactory2, m.TestFactory3]] - cstats[0].alive() # force gc - n_inst = ConstructorStats.detail_reg_inst() - - x1 = m.TestFactory1(tag.unique_ptr, 3) - assert x1.value == "3" - y1 = m.TestFactory1(tag.pointer) - assert y1.value == "(empty)" - z1 = m.TestFactory1("hi!") - assert z1.value == "hi!" - - assert ConstructorStats.detail_reg_inst() == n_inst + 3 - - x2 = m.TestFactory2(tag.move) - assert x2.value == "(empty2)" - y2 = m.TestFactory2(tag.pointer, 7) - assert y2.value == "7" - z2 = m.TestFactory2(tag.unique_ptr, "hi again") - assert z2.value == "hi again" - - assert ConstructorStats.detail_reg_inst() == n_inst + 6 - - x3 = m.TestFactory3(tag.shared_ptr) - assert x3.value == "(empty3)" - y3 = m.TestFactory3(tag.pointer, 42) - assert y3.value == "42" - z3 = m.TestFactory3("bye") - assert z3.value == "bye" - - with pytest.raises(TypeError) as excinfo: - m.TestFactory3(tag.null_ptr) - assert str(excinfo.value) == "pybind11::init(): factory function returned nullptr" - - assert [i.alive() for i in cstats] == [3, 3, 3] - assert ConstructorStats.detail_reg_inst() == n_inst + 9 - - del x1, y2, y3, z3 - assert [i.alive() for i in cstats] == [2, 2, 1] - assert ConstructorStats.detail_reg_inst() == n_inst + 5 - del x2, x3, y1, z1, z2 - assert [i.alive() for i in cstats] == [0, 0, 0] - assert ConstructorStats.detail_reg_inst() == n_inst - - assert [i.values() for i in cstats] == [ - ["3", "hi!"], - ["7", "hi again"], - ["42", "bye"] - ] - assert [i.default_constructions for i in cstats] == [1, 1, 1] - - -def test_init_factory_signature(msg): - with pytest.raises(TypeError) as excinfo: - m.TestFactory1("invalid", "constructor", "arguments") - assert msg(excinfo.value) == """ - __init__(): incompatible constructor arguments. The following argument types are supported: - 1. m.factory_constructors.TestFactory1(arg0: m.factory_constructors.tag.unique_ptr_tag, arg1: int) - 2. m.factory_constructors.TestFactory1(arg0: str) - 3. m.factory_constructors.TestFactory1(arg0: m.factory_constructors.tag.pointer_tag) - 4. m.factory_constructors.TestFactory1(arg0: handle, arg1: int, arg2: handle) - - Invoked with: 'invalid', 'constructor', 'arguments' - """ # noqa: E501 line too long - - assert msg(m.TestFactory1.__init__.__doc__) == """ - __init__(*args, **kwargs) - Overloaded function. - - 1. __init__(self: m.factory_constructors.TestFactory1, arg0: m.factory_constructors.tag.unique_ptr_tag, arg1: int) -> None - - 2. __init__(self: m.factory_constructors.TestFactory1, arg0: str) -> None - - 3. __init__(self: m.factory_constructors.TestFactory1, arg0: m.factory_constructors.tag.pointer_tag) -> None - - 4. __init__(self: m.factory_constructors.TestFactory1, arg0: handle, arg1: int, arg2: handle) -> None - """ # noqa: E501 line too long - - -def test_init_factory_casting(): - """Tests py::init_factory() wrapper with various upcasting and downcasting returns""" - - cstats = [ConstructorStats.get(c) for c in [m.TestFactory3, m.TestFactory4, m.TestFactory5]] - cstats[0].alive() # force gc - n_inst = ConstructorStats.detail_reg_inst() - - # Construction from derived references: - a = m.TestFactory3(tag.pointer, tag.TF4, 4) - assert a.value == "4" - b = m.TestFactory3(tag.shared_ptr, tag.TF4, 5) - assert b.value == "5" - c = m.TestFactory3(tag.pointer, tag.TF5, 6) - assert c.value == "6" - d = m.TestFactory3(tag.shared_ptr, tag.TF5, 7) - assert d.value == "7" - - assert ConstructorStats.detail_reg_inst() == n_inst + 4 - - # Shared a lambda with TF3: - e = m.TestFactory4(tag.pointer, tag.TF4, 8) - assert e.value == "8" - - assert ConstructorStats.detail_reg_inst() == n_inst + 5 - assert [i.alive() for i in cstats] == [5, 3, 2] - - del a - assert [i.alive() for i in cstats] == [4, 2, 2] - assert ConstructorStats.detail_reg_inst() == n_inst + 4 - - del b, c, e - assert [i.alive() for i in cstats] == [1, 0, 1] - assert ConstructorStats.detail_reg_inst() == n_inst + 1 - - del d - assert [i.alive() for i in cstats] == [0, 0, 0] - assert ConstructorStats.detail_reg_inst() == n_inst - - assert [i.values() for i in cstats] == [ - ["4", "5", "6", "7", "8"], - ["4", "5", "8"], - ["6", "7"] - ] - - -def test_init_factory_alias(): - """Tests py::init_factory() wrapper with value conversions and alias types""" - - cstats = [m.TestFactory6.get_cstats(), m.TestFactory6.get_alias_cstats()] - cstats[0].alive() # force gc - n_inst = ConstructorStats.detail_reg_inst() - - a = m.TestFactory6(tag.base, 1) - assert a.get() == 1 - assert not a.has_alias() - b = m.TestFactory6(tag.alias, "hi there") - assert b.get() == 8 - assert b.has_alias() - c = m.TestFactory6(tag.alias, 3) - assert c.get() == 3 - assert c.has_alias() - d = m.TestFactory6(tag.alias, tag.pointer, 4) - assert d.get() == 4 - assert d.has_alias() - e = m.TestFactory6(tag.base, tag.pointer, 5) - assert e.get() == 5 - assert not e.has_alias() - f = m.TestFactory6(tag.base, tag.alias, tag.pointer, 6) - assert f.get() == 6 - assert f.has_alias() - - assert ConstructorStats.detail_reg_inst() == n_inst + 6 - assert [i.alive() for i in cstats] == [6, 4] - - del a, b, e - assert [i.alive() for i in cstats] == [3, 3] - assert ConstructorStats.detail_reg_inst() == n_inst + 3 - del f, c, d - assert [i.alive() for i in cstats] == [0, 0] - assert ConstructorStats.detail_reg_inst() == n_inst - - class MyTest(m.TestFactory6): - def __init__(self, *args): - m.TestFactory6.__init__(self, *args) - - def get(self): - return -5 + m.TestFactory6.get(self) - - # Return Class by value, moved into new alias: - z = MyTest(tag.base, 123) - assert z.get() == 118 - assert z.has_alias() - - # Return alias by value, moved into new alias: - y = MyTest(tag.alias, "why hello!") - assert y.get() == 5 - assert y.has_alias() - - # Return Class by pointer, moved into new alias then original destroyed: - x = MyTest(tag.base, tag.pointer, 47) - assert x.get() == 42 - assert x.has_alias() - - assert ConstructorStats.detail_reg_inst() == n_inst + 3 - assert [i.alive() for i in cstats] == [3, 3] - del x, y, z - assert [i.alive() for i in cstats] == [0, 0] - assert ConstructorStats.detail_reg_inst() == n_inst - - assert [i.values() for i in cstats] == [ - ["1", "8", "3", "4", "5", "6", "123", "10", "47"], - ["hi there", "3", "4", "6", "move", "123", "why hello!", "move", "47"] - ] - - -def test_init_factory_dual(): - """Tests init factory functions with dual main/alias factory functions""" - from pybind11_tests.factory_constructors import TestFactory7 - - cstats = [TestFactory7.get_cstats(), TestFactory7.get_alias_cstats()] - cstats[0].alive() # force gc - n_inst = ConstructorStats.detail_reg_inst() - - class PythFactory7(TestFactory7): - def get(self): - return 100 + TestFactory7.get(self) - - a1 = TestFactory7(1) - a2 = PythFactory7(2) - assert a1.get() == 1 - assert a2.get() == 102 - assert not a1.has_alias() - assert a2.has_alias() - - b1 = TestFactory7(tag.pointer, 3) - b2 = PythFactory7(tag.pointer, 4) - assert b1.get() == 3 - assert b2.get() == 104 - assert not b1.has_alias() - assert b2.has_alias() - - c1 = TestFactory7(tag.mixed, 5) - c2 = PythFactory7(tag.mixed, 6) - assert c1.get() == 5 - assert c2.get() == 106 - assert not c1.has_alias() - assert c2.has_alias() - - d1 = TestFactory7(tag.base, tag.pointer, 7) - d2 = PythFactory7(tag.base, tag.pointer, 8) - assert d1.get() == 7 - assert d2.get() == 108 - assert not d1.has_alias() - assert d2.has_alias() - - # Both return an alias; the second multiplies the value by 10: - e1 = TestFactory7(tag.alias, tag.pointer, 9) - e2 = PythFactory7(tag.alias, tag.pointer, 10) - assert e1.get() == 9 - assert e2.get() == 200 - assert e1.has_alias() - assert e2.has_alias() - - f1 = TestFactory7(tag.shared_ptr, tag.base, 11) - f2 = PythFactory7(tag.shared_ptr, tag.base, 12) - assert f1.get() == 11 - assert f2.get() == 112 - assert not f1.has_alias() - assert f2.has_alias() - - g1 = TestFactory7(tag.shared_ptr, tag.invalid_base, 13) - assert g1.get() == 13 - assert not g1.has_alias() - with pytest.raises(TypeError) as excinfo: - PythFactory7(tag.shared_ptr, tag.invalid_base, 14) - assert (str(excinfo.value) == - "pybind11::init(): construction failed: returned holder-wrapped instance is not an " - "alias instance") - - assert [i.alive() for i in cstats] == [13, 7] - assert ConstructorStats.detail_reg_inst() == n_inst + 13 - - del a1, a2, b1, d1, e1, e2 - assert [i.alive() for i in cstats] == [7, 4] - assert ConstructorStats.detail_reg_inst() == n_inst + 7 - del b2, c1, c2, d2, f1, f2, g1 - assert [i.alive() for i in cstats] == [0, 0] - assert ConstructorStats.detail_reg_inst() == n_inst - - assert [i.values() for i in cstats] == [ - ["1", "2", "3", "4", "5", "6", "7", "8", "9", "100", "11", "12", "13", "14"], - ["2", "4", "6", "8", "9", "100", "12"] - ] - - -def test_no_placement_new(capture): - """Prior to 2.2, `py::init<...>` relied on the type supporting placement - new; this tests a class without placement new support.""" - with capture: - a = m.NoPlacementNew(123) - - found = re.search(r'^operator new called, returning (\d+)\n$', str(capture)) - assert found - assert a.i == 123 - with capture: - del a - pytest.gc_collect() - assert capture == "operator delete called on " + found.group(1) - - with capture: - b = m.NoPlacementNew() - - found = re.search(r'^operator new called, returning (\d+)\n$', str(capture)) - assert found - assert b.i == 100 - with capture: - del b - pytest.gc_collect() - assert capture == "operator delete called on " + found.group(1) - - -def test_multiple_inheritance(): - class MITest(m.TestFactory1, m.TestFactory2): - def __init__(self): - m.TestFactory1.__init__(self, tag.unique_ptr, 33) - m.TestFactory2.__init__(self, tag.move) - - a = MITest() - assert m.TestFactory1.value.fget(a) == "33" - assert m.TestFactory2.value.fget(a) == "(empty2)" - - -def create_and_destroy(*args): - a = m.NoisyAlloc(*args) - print("---") - del a - pytest.gc_collect() - - -def strip_comments(s): - return re.sub(r'\s+#.*', '', s) - - -def test_reallocations(capture, msg): - """When the constructor is overloaded, previous overloads can require a preallocated value. - This test makes sure that such preallocated values only happen when they might be necessary, - and that they are deallocated properly""" - - pytest.gc_collect() - - with capture: - create_and_destroy(1) - assert msg(capture) == """ - noisy new - noisy placement new - NoisyAlloc(int 1) - --- - ~NoisyAlloc() - noisy delete - """ - with capture: - create_and_destroy(1.5) - assert msg(capture) == strip_comments(""" - noisy new # allocation required to attempt first overload - noisy delete # have to dealloc before considering factory init overload - noisy new # pointer factory calling "new", part 1: allocation - NoisyAlloc(double 1.5) # ... part two, invoking constructor - --- - ~NoisyAlloc() # Destructor - noisy delete # operator delete - """) - - with capture: - create_and_destroy(2, 3) - assert msg(capture) == strip_comments(""" - noisy new # pointer factory calling "new", allocation - NoisyAlloc(int 2) # constructor - --- - ~NoisyAlloc() # Destructor - noisy delete # operator delete - """) - - with capture: - create_and_destroy(2.5, 3) - assert msg(capture) == strip_comments(""" - NoisyAlloc(double 2.5) # construction (local func variable: operator_new not called) - noisy new # return-by-value "new" part 1: allocation - ~NoisyAlloc() # moved-away local func variable destruction - --- - ~NoisyAlloc() # Destructor - noisy delete # operator delete - """) - - with capture: - create_and_destroy(3.5, 4.5) - assert msg(capture) == strip_comments(""" - noisy new # preallocation needed before invoking placement-new overload - noisy placement new # Placement new - NoisyAlloc(double 3.5) # construction - --- - ~NoisyAlloc() # Destructor - noisy delete # operator delete - """) - - with capture: - create_and_destroy(4, 0.5) - assert msg(capture) == strip_comments(""" - noisy new # preallocation needed before invoking placement-new overload - noisy delete # deallocation of preallocated storage - noisy new # Factory pointer allocation - NoisyAlloc(int 4) # factory pointer construction - --- - ~NoisyAlloc() # Destructor - noisy delete # operator delete - """) - - with capture: - create_and_destroy(5, "hi") - assert msg(capture) == strip_comments(""" - noisy new # preallocation needed before invoking first placement new - noisy delete # delete before considering new-style constructor - noisy new # preallocation for second placement new - noisy placement new # Placement new in the second placement new overload - NoisyAlloc(int 5) # construction - --- - ~NoisyAlloc() # Destructor - noisy delete # operator delete - """) - - -@pytest.unsupported_on_py2 -def test_invalid_self(): - """Tests invocation of the pybind-registered base class with an invalid `self` argument. You - can only actually do this on Python 3: Python 2 raises an exception itself if you try.""" - class NotPybindDerived(object): - pass - - # Attempts to initialize with an invalid type passed as `self`: - class BrokenTF1(m.TestFactory1): - def __init__(self, bad): - if bad == 1: - a = m.TestFactory2(tag.pointer, 1) - m.TestFactory1.__init__(a, tag.pointer) - elif bad == 2: - a = NotPybindDerived() - m.TestFactory1.__init__(a, tag.pointer) - - # Same as above, but for a class with an alias: - class BrokenTF6(m.TestFactory6): - def __init__(self, bad): - if bad == 1: - a = m.TestFactory2(tag.pointer, 1) - m.TestFactory6.__init__(a, tag.base, 1) - elif bad == 2: - a = m.TestFactory2(tag.pointer, 1) - m.TestFactory6.__init__(a, tag.alias, 1) - elif bad == 3: - m.TestFactory6.__init__(NotPybindDerived.__new__(NotPybindDerived), tag.base, 1) - elif bad == 4: - m.TestFactory6.__init__(NotPybindDerived.__new__(NotPybindDerived), tag.alias, 1) - - for arg in (1, 2): - with pytest.raises(TypeError) as excinfo: - BrokenTF1(arg) - assert str(excinfo.value) == "__init__(self, ...) called with invalid `self` argument" - - for arg in (1, 2, 3, 4): - with pytest.raises(TypeError) as excinfo: - BrokenTF6(arg) - assert str(excinfo.value) == "__init__(self, ...) called with invalid `self` argument" diff --git a/pybind11/tests/test_gil_scoped.cpp b/pybind11/tests/test_gil_scoped.cpp deleted file mode 100644 index 76c17fd..0000000 --- a/pybind11/tests/test_gil_scoped.cpp +++ /dev/null @@ -1,52 +0,0 @@ -/* - tests/test_gil_scoped.cpp -- acquire and release gil - - Copyright (c) 2017 Borja Zarco (Google LLC) - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include - - -class VirtClass { -public: - virtual ~VirtClass() {} - virtual void virtual_func() {} - virtual void pure_virtual_func() = 0; -}; - -class PyVirtClass : public VirtClass { - void virtual_func() override { - PYBIND11_OVERLOAD(void, VirtClass, virtual_func,); - } - void pure_virtual_func() override { - PYBIND11_OVERLOAD_PURE(void, VirtClass, pure_virtual_func,); - } -}; - -TEST_SUBMODULE(gil_scoped, m) { - py::class_(m, "VirtClass") - .def(py::init<>()) - .def("virtual_func", &VirtClass::virtual_func) - .def("pure_virtual_func", &VirtClass::pure_virtual_func); - - m.def("test_callback_py_obj", - [](py::object func) { func(); }); - m.def("test_callback_std_func", - [](const std::function &func) { func(); }); - m.def("test_callback_virtual_func", - [](VirtClass &virt) { virt.virtual_func(); }); - m.def("test_callback_pure_virtual_func", - [](VirtClass &virt) { virt.pure_virtual_func(); }); - m.def("test_cross_module_gil", - []() { - auto cm = py::module::import("cross_module_gil_utils"); - auto gil_acquire = reinterpret_cast( - PyLong_AsVoidPtr(cm.attr("gil_acquire_funcaddr").ptr())); - py::gil_scoped_release gil_release; - gil_acquire(); - }); -} diff --git a/pybind11/tests/test_gil_scoped.py b/pybind11/tests/test_gil_scoped.py deleted file mode 100644 index 1548337..0000000 --- a/pybind11/tests/test_gil_scoped.py +++ /dev/null @@ -1,85 +0,0 @@ -import multiprocessing -import threading -from pybind11_tests import gil_scoped as m - - -def _run_in_process(target, *args, **kwargs): - """Runs target in process and returns its exitcode after 10s (None if still alive).""" - process = multiprocessing.Process(target=target, args=args, kwargs=kwargs) - process.daemon = True - try: - process.start() - # Do not need to wait much, 10s should be more than enough. - process.join(timeout=10) - return process.exitcode - finally: - if process.is_alive(): - process.terminate() - - -def _python_to_cpp_to_python(): - """Calls different C++ functions that come back to Python.""" - class ExtendedVirtClass(m.VirtClass): - def virtual_func(self): - pass - - def pure_virtual_func(self): - pass - - extended = ExtendedVirtClass() - m.test_callback_py_obj(lambda: None) - m.test_callback_std_func(lambda: None) - m.test_callback_virtual_func(extended) - m.test_callback_pure_virtual_func(extended) - - -def _python_to_cpp_to_python_from_threads(num_threads, parallel=False): - """Calls different C++ functions that come back to Python, from Python threads.""" - threads = [] - for _ in range(num_threads): - thread = threading.Thread(target=_python_to_cpp_to_python) - thread.daemon = True - thread.start() - if parallel: - threads.append(thread) - else: - thread.join() - for thread in threads: - thread.join() - - -def test_python_to_cpp_to_python_from_thread(): - """Makes sure there is no GIL deadlock when running in a thread. - - It runs in a separate process to be able to stop and assert if it deadlocks. - """ - assert _run_in_process(_python_to_cpp_to_python_from_threads, 1) == 0 - - -def test_python_to_cpp_to_python_from_thread_multiple_parallel(): - """Makes sure there is no GIL deadlock when running in a thread multiple times in parallel. - - It runs in a separate process to be able to stop and assert if it deadlocks. - """ - assert _run_in_process(_python_to_cpp_to_python_from_threads, 8, parallel=True) == 0 - - -def test_python_to_cpp_to_python_from_thread_multiple_sequential(): - """Makes sure there is no GIL deadlock when running in a thread multiple times sequentially. - - It runs in a separate process to be able to stop and assert if it deadlocks. - """ - assert _run_in_process(_python_to_cpp_to_python_from_threads, 8, parallel=False) == 0 - - -def test_python_to_cpp_to_python_from_process(): - """Makes sure there is no GIL deadlock when using processes. - - This test is for completion, but it was never an issue. - """ - assert _run_in_process(_python_to_cpp_to_python) == 0 - - -def test_cross_module_gil(): - """Makes sure that the GIL can be acquired by another module from a GIL-released state.""" - m.test_cross_module_gil() # Should not raise a SIGSEGV diff --git a/pybind11/tests/test_iostream.cpp b/pybind11/tests/test_iostream.cpp deleted file mode 100644 index e67f88a..0000000 --- a/pybind11/tests/test_iostream.cpp +++ /dev/null @@ -1,73 +0,0 @@ -/* - tests/test_iostream.cpp -- Usage of scoped_output_redirect - - Copyright (c) 2017 Henry F. Schreiner - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - - -#include -#include "pybind11_tests.h" -#include - - -void noisy_function(std::string msg, bool flush) { - - std::cout << msg; - if (flush) - std::cout << std::flush; -} - -void noisy_funct_dual(std::string msg, std::string emsg) { - std::cout << msg; - std::cerr << emsg; -} - -TEST_SUBMODULE(iostream, m) { - - add_ostream_redirect(m); - - // test_evals - - m.def("captured_output_default", [](std::string msg) { - py::scoped_ostream_redirect redir; - std::cout << msg << std::flush; - }); - - m.def("captured_output", [](std::string msg) { - py::scoped_ostream_redirect redir(std::cout, py::module::import("sys").attr("stdout")); - std::cout << msg << std::flush; - }); - - m.def("guard_output", &noisy_function, - py::call_guard(), - py::arg("msg"), py::arg("flush")=true); - - m.def("captured_err", [](std::string msg) { - py::scoped_ostream_redirect redir(std::cerr, py::module::import("sys").attr("stderr")); - std::cerr << msg << std::flush; - }); - - m.def("noisy_function", &noisy_function, py::arg("msg"), py::arg("flush") = true); - - m.def("dual_guard", &noisy_funct_dual, - py::call_guard(), - py::arg("msg"), py::arg("emsg")); - - m.def("raw_output", [](std::string msg) { - std::cout << msg << std::flush; - }); - - m.def("raw_err", [](std::string msg) { - std::cerr << msg << std::flush; - }); - - m.def("captured_dual", [](std::string msg, std::string emsg) { - py::scoped_ostream_redirect redirout(std::cout, py::module::import("sys").attr("stdout")); - py::scoped_ostream_redirect redirerr(std::cerr, py::module::import("sys").attr("stderr")); - std::cout << msg << std::flush; - std::cerr << emsg << std::flush; - }); -} diff --git a/pybind11/tests/test_iostream.py b/pybind11/tests/test_iostream.py deleted file mode 100644 index 27095b2..0000000 --- a/pybind11/tests/test_iostream.py +++ /dev/null @@ -1,214 +0,0 @@ -from pybind11_tests import iostream as m -import sys - -from contextlib import contextmanager - -try: - # Python 3 - from io import StringIO -except ImportError: - # Python 2 - try: - from cStringIO import StringIO - except ImportError: - from StringIO import StringIO - -try: - # Python 3.4 - from contextlib import redirect_stdout -except ImportError: - @contextmanager - def redirect_stdout(target): - original = sys.stdout - sys.stdout = target - yield - sys.stdout = original - -try: - # Python 3.5 - from contextlib import redirect_stderr -except ImportError: - @contextmanager - def redirect_stderr(target): - original = sys.stderr - sys.stderr = target - yield - sys.stderr = original - - -def test_captured(capsys): - msg = "I've been redirected to Python, I hope!" - m.captured_output(msg) - stdout, stderr = capsys.readouterr() - assert stdout == msg - assert stderr == '' - - m.captured_output_default(msg) - stdout, stderr = capsys.readouterr() - assert stdout == msg - assert stderr == '' - - m.captured_err(msg) - stdout, stderr = capsys.readouterr() - assert stdout == '' - assert stderr == msg - - -def test_captured_large_string(capsys): - # Make this bigger than the buffer used on the C++ side: 1024 chars - msg = "I've been redirected to Python, I hope!" - msg = msg * (1024 // len(msg) + 1) - - m.captured_output_default(msg) - stdout, stderr = capsys.readouterr() - assert stdout == msg - assert stderr == '' - - -def test_guard_capture(capsys): - msg = "I've been redirected to Python, I hope!" - m.guard_output(msg) - stdout, stderr = capsys.readouterr() - assert stdout == msg - assert stderr == '' - - -def test_series_captured(capture): - with capture: - m.captured_output("a") - m.captured_output("b") - assert capture == "ab" - - -def test_flush(capfd): - msg = "(not flushed)" - msg2 = "(flushed)" - - with m.ostream_redirect(): - m.noisy_function(msg, flush=False) - stdout, stderr = capfd.readouterr() - assert stdout == '' - - m.noisy_function(msg2, flush=True) - stdout, stderr = capfd.readouterr() - assert stdout == msg + msg2 - - m.noisy_function(msg, flush=False) - - stdout, stderr = capfd.readouterr() - assert stdout == msg - - -def test_not_captured(capfd): - msg = "Something that should not show up in log" - stream = StringIO() - with redirect_stdout(stream): - m.raw_output(msg) - stdout, stderr = capfd.readouterr() - assert stdout == msg - assert stderr == '' - assert stream.getvalue() == '' - - stream = StringIO() - with redirect_stdout(stream): - m.captured_output(msg) - stdout, stderr = capfd.readouterr() - assert stdout == '' - assert stderr == '' - assert stream.getvalue() == msg - - -def test_err(capfd): - msg = "Something that should not show up in log" - stream = StringIO() - with redirect_stderr(stream): - m.raw_err(msg) - stdout, stderr = capfd.readouterr() - assert stdout == '' - assert stderr == msg - assert stream.getvalue() == '' - - stream = StringIO() - with redirect_stderr(stream): - m.captured_err(msg) - stdout, stderr = capfd.readouterr() - assert stdout == '' - assert stderr == '' - assert stream.getvalue() == msg - - -def test_multi_captured(capfd): - stream = StringIO() - with redirect_stdout(stream): - m.captured_output("a") - m.raw_output("b") - m.captured_output("c") - m.raw_output("d") - stdout, stderr = capfd.readouterr() - assert stdout == 'bd' - assert stream.getvalue() == 'ac' - - -def test_dual(capsys): - m.captured_dual("a", "b") - stdout, stderr = capsys.readouterr() - assert stdout == "a" - assert stderr == "b" - - -def test_redirect(capfd): - msg = "Should not be in log!" - stream = StringIO() - with redirect_stdout(stream): - m.raw_output(msg) - stdout, stderr = capfd.readouterr() - assert stdout == msg - assert stream.getvalue() == '' - - stream = StringIO() - with redirect_stdout(stream): - with m.ostream_redirect(): - m.raw_output(msg) - stdout, stderr = capfd.readouterr() - assert stdout == '' - assert stream.getvalue() == msg - - stream = StringIO() - with redirect_stdout(stream): - m.raw_output(msg) - stdout, stderr = capfd.readouterr() - assert stdout == msg - assert stream.getvalue() == '' - - -def test_redirect_err(capfd): - msg = "StdOut" - msg2 = "StdErr" - - stream = StringIO() - with redirect_stderr(stream): - with m.ostream_redirect(stdout=False): - m.raw_output(msg) - m.raw_err(msg2) - stdout, stderr = capfd.readouterr() - assert stdout == msg - assert stderr == '' - assert stream.getvalue() == msg2 - - -def test_redirect_both(capfd): - msg = "StdOut" - msg2 = "StdErr" - - stream = StringIO() - stream2 = StringIO() - with redirect_stdout(stream): - with redirect_stderr(stream2): - with m.ostream_redirect(): - m.raw_output(msg) - m.raw_err(msg2) - stdout, stderr = capfd.readouterr() - assert stdout == '' - assert stderr == '' - assert stream.getvalue() == msg - assert stream2.getvalue() == msg2 diff --git a/pybind11/tests/test_kwargs_and_defaults.cpp b/pybind11/tests/test_kwargs_and_defaults.cpp deleted file mode 100644 index 6563fb9..0000000 --- a/pybind11/tests/test_kwargs_and_defaults.cpp +++ /dev/null @@ -1,102 +0,0 @@ -/* - tests/test_kwargs_and_defaults.cpp -- keyword arguments and default values - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" -#include - -TEST_SUBMODULE(kwargs_and_defaults, m) { - auto kw_func = [](int x, int y) { return "x=" + std::to_string(x) + ", y=" + std::to_string(y); }; - - // test_named_arguments - m.def("kw_func0", kw_func); - m.def("kw_func1", kw_func, py::arg("x"), py::arg("y")); - m.def("kw_func2", kw_func, py::arg("x") = 100, py::arg("y") = 200); - m.def("kw_func3", [](const char *) { }, py::arg("data") = std::string("Hello world!")); - - /* A fancier default argument */ - std::vector list{{13, 17}}; - m.def("kw_func4", [](const std::vector &entries) { - std::string ret = "{"; - for (int i : entries) - ret += std::to_string(i) + " "; - ret.back() = '}'; - return ret; - }, py::arg("myList") = list); - - m.def("kw_func_udl", kw_func, "x"_a, "y"_a=300); - m.def("kw_func_udl_z", kw_func, "x"_a, "y"_a=0); - - // test_args_and_kwargs - m.def("args_function", [](py::args args) -> py::tuple { - return std::move(args); - }); - m.def("args_kwargs_function", [](py::args args, py::kwargs kwargs) { - return py::make_tuple(args, kwargs); - }); - - // test_mixed_args_and_kwargs - m.def("mixed_plus_args", [](int i, double j, py::args args) { - return py::make_tuple(i, j, args); - }); - m.def("mixed_plus_kwargs", [](int i, double j, py::kwargs kwargs) { - return py::make_tuple(i, j, kwargs); - }); - auto mixed_plus_both = [](int i, double j, py::args args, py::kwargs kwargs) { - return py::make_tuple(i, j, args, kwargs); - }; - m.def("mixed_plus_args_kwargs", mixed_plus_both); - - m.def("mixed_plus_args_kwargs_defaults", mixed_plus_both, - py::arg("i") = 1, py::arg("j") = 3.14159); - - // test_args_refcount - // PyPy needs a garbage collection to get the reference count values to match CPython's behaviour - #ifdef PYPY_VERSION - #define GC_IF_NEEDED ConstructorStats::gc() - #else - #define GC_IF_NEEDED - #endif - m.def("arg_refcount_h", [](py::handle h) { GC_IF_NEEDED; return h.ref_count(); }); - m.def("arg_refcount_h", [](py::handle h, py::handle, py::handle) { GC_IF_NEEDED; return h.ref_count(); }); - m.def("arg_refcount_o", [](py::object o) { GC_IF_NEEDED; return o.ref_count(); }); - m.def("args_refcount", [](py::args a) { - GC_IF_NEEDED; - py::tuple t(a.size()); - for (size_t i = 0; i < a.size(); i++) - // Use raw Python API here to avoid an extra, intermediate incref on the tuple item: - t[i] = (int) Py_REFCNT(PyTuple_GET_ITEM(a.ptr(), static_cast(i))); - return t; - }); - m.def("mixed_args_refcount", [](py::object o, py::args a) { - GC_IF_NEEDED; - py::tuple t(a.size() + 1); - t[0] = o.ref_count(); - for (size_t i = 0; i < a.size(); i++) - // Use raw Python API here to avoid an extra, intermediate incref on the tuple item: - t[i + 1] = (int) Py_REFCNT(PyTuple_GET_ITEM(a.ptr(), static_cast(i))); - return t; - }); - - // pybind11 won't allow these to be bound: args and kwargs, if present, must be at the end. - // Uncomment these to test that the static_assert is indeed working: -// m.def("bad_args1", [](py::args, int) {}); -// m.def("bad_args2", [](py::kwargs, int) {}); -// m.def("bad_args3", [](py::kwargs, py::args) {}); -// m.def("bad_args4", [](py::args, int, py::kwargs) {}); -// m.def("bad_args5", [](py::args, py::kwargs, int) {}); -// m.def("bad_args6", [](py::args, py::args) {}); -// m.def("bad_args7", [](py::kwargs, py::kwargs) {}); - - // test_function_signatures (along with most of the above) - struct KWClass { void foo(int, float) {} }; - py::class_(m, "KWClass") - .def("foo0", &KWClass::foo) - .def("foo1", &KWClass::foo, "x"_a, "y"_a); -} diff --git a/pybind11/tests/test_kwargs_and_defaults.py b/pybind11/tests/test_kwargs_and_defaults.py deleted file mode 100644 index 27a05a0..0000000 --- a/pybind11/tests/test_kwargs_and_defaults.py +++ /dev/null @@ -1,147 +0,0 @@ -import pytest -from pybind11_tests import kwargs_and_defaults as m - - -def test_function_signatures(doc): - assert doc(m.kw_func0) == "kw_func0(arg0: int, arg1: int) -> str" - assert doc(m.kw_func1) == "kw_func1(x: int, y: int) -> str" - assert doc(m.kw_func2) == "kw_func2(x: int = 100, y: int = 200) -> str" - assert doc(m.kw_func3) == "kw_func3(data: str = 'Hello world!') -> None" - assert doc(m.kw_func4) == "kw_func4(myList: List[int] = [13, 17]) -> str" - assert doc(m.kw_func_udl) == "kw_func_udl(x: int, y: int = 300) -> str" - assert doc(m.kw_func_udl_z) == "kw_func_udl_z(x: int, y: int = 0) -> str" - assert doc(m.args_function) == "args_function(*args) -> tuple" - assert doc(m.args_kwargs_function) == "args_kwargs_function(*args, **kwargs) -> tuple" - assert doc(m.KWClass.foo0) == \ - "foo0(self: m.kwargs_and_defaults.KWClass, arg0: int, arg1: float) -> None" - assert doc(m.KWClass.foo1) == \ - "foo1(self: m.kwargs_and_defaults.KWClass, x: int, y: float) -> None" - - -def test_named_arguments(msg): - assert m.kw_func0(5, 10) == "x=5, y=10" - - assert m.kw_func1(5, 10) == "x=5, y=10" - assert m.kw_func1(5, y=10) == "x=5, y=10" - assert m.kw_func1(y=10, x=5) == "x=5, y=10" - - assert m.kw_func2() == "x=100, y=200" - assert m.kw_func2(5) == "x=5, y=200" - assert m.kw_func2(x=5) == "x=5, y=200" - assert m.kw_func2(y=10) == "x=100, y=10" - assert m.kw_func2(5, 10) == "x=5, y=10" - assert m.kw_func2(x=5, y=10) == "x=5, y=10" - - with pytest.raises(TypeError) as excinfo: - # noinspection PyArgumentList - m.kw_func2(x=5, y=10, z=12) - assert excinfo.match( - r'(?s)^kw_func2\(\): incompatible.*Invoked with: kwargs: ((x=5|y=10|z=12)(, |$))' + '{3}$') - - assert m.kw_func4() == "{13 17}" - assert m.kw_func4(myList=[1, 2, 3]) == "{1 2 3}" - - assert m.kw_func_udl(x=5, y=10) == "x=5, y=10" - assert m.kw_func_udl_z(x=5) == "x=5, y=0" - - -def test_arg_and_kwargs(): - args = 'arg1_value', 'arg2_value', 3 - assert m.args_function(*args) == args - - args = 'a1', 'a2' - kwargs = dict(arg3='a3', arg4=4) - assert m.args_kwargs_function(*args, **kwargs) == (args, kwargs) - - -def test_mixed_args_and_kwargs(msg): - mpa = m.mixed_plus_args - mpk = m.mixed_plus_kwargs - mpak = m.mixed_plus_args_kwargs - mpakd = m.mixed_plus_args_kwargs_defaults - - assert mpa(1, 2.5, 4, 99.5, None) == (1, 2.5, (4, 99.5, None)) - assert mpa(1, 2.5) == (1, 2.5, ()) - with pytest.raises(TypeError) as excinfo: - assert mpa(1) - assert msg(excinfo.value) == """ - mixed_plus_args(): incompatible function arguments. The following argument types are supported: - 1. (arg0: int, arg1: float, *args) -> tuple - - Invoked with: 1 - """ # noqa: E501 line too long - with pytest.raises(TypeError) as excinfo: - assert mpa() - assert msg(excinfo.value) == """ - mixed_plus_args(): incompatible function arguments. The following argument types are supported: - 1. (arg0: int, arg1: float, *args) -> tuple - - Invoked with: - """ # noqa: E501 line too long - - assert mpk(-2, 3.5, pi=3.14159, e=2.71828) == (-2, 3.5, {'e': 2.71828, 'pi': 3.14159}) - assert mpak(7, 7.7, 7.77, 7.777, 7.7777, minusseven=-7) == ( - 7, 7.7, (7.77, 7.777, 7.7777), {'minusseven': -7}) - assert mpakd() == (1, 3.14159, (), {}) - assert mpakd(3) == (3, 3.14159, (), {}) - assert mpakd(j=2.71828) == (1, 2.71828, (), {}) - assert mpakd(k=42) == (1, 3.14159, (), {'k': 42}) - assert mpakd(1, 1, 2, 3, 5, 8, then=13, followedby=21) == ( - 1, 1, (2, 3, 5, 8), {'then': 13, 'followedby': 21}) - # Arguments specified both positionally and via kwargs should fail: - with pytest.raises(TypeError) as excinfo: - assert mpakd(1, i=1) - assert msg(excinfo.value) == """ - mixed_plus_args_kwargs_defaults(): incompatible function arguments. The following argument types are supported: - 1. (i: int = 1, j: float = 3.14159, *args, **kwargs) -> tuple - - Invoked with: 1; kwargs: i=1 - """ # noqa: E501 line too long - with pytest.raises(TypeError) as excinfo: - assert mpakd(1, 2, j=1) - assert msg(excinfo.value) == """ - mixed_plus_args_kwargs_defaults(): incompatible function arguments. The following argument types are supported: - 1. (i: int = 1, j: float = 3.14159, *args, **kwargs) -> tuple - - Invoked with: 1, 2; kwargs: j=1 - """ # noqa: E501 line too long - - -def test_args_refcount(): - """Issue/PR #1216 - py::args elements get double-inc_ref()ed when combined with regular - arguments""" - refcount = m.arg_refcount_h - - myval = 54321 - expected = refcount(myval) - assert m.arg_refcount_h(myval) == expected - assert m.arg_refcount_o(myval) == expected + 1 - assert m.arg_refcount_h(myval) == expected - assert refcount(myval) == expected - - assert m.mixed_plus_args(1, 2.0, "a", myval) == (1, 2.0, ("a", myval)) - assert refcount(myval) == expected - - assert m.mixed_plus_kwargs(3, 4.0, a=1, b=myval) == (3, 4.0, {"a": 1, "b": myval}) - assert refcount(myval) == expected - - assert m.args_function(-1, myval) == (-1, myval) - assert refcount(myval) == expected - - assert m.mixed_plus_args_kwargs(5, 6.0, myval, a=myval) == (5, 6.0, (myval,), {"a": myval}) - assert refcount(myval) == expected - - assert m.args_kwargs_function(7, 8, myval, a=1, b=myval) == \ - ((7, 8, myval), {"a": 1, "b": myval}) - assert refcount(myval) == expected - - exp3 = refcount(myval, myval, myval) - assert m.args_refcount(myval, myval, myval) == (exp3, exp3, exp3) - assert refcount(myval) == expected - - # This function takes the first arg as a `py::object` and the rest as a `py::args`. Unlike the - # previous case, when we have both positional and `py::args` we need to construct a new tuple - # for the `py::args`; in the previous case, we could simply inc_ref and pass on Python's input - # tuple without having to inc_ref the individual elements, but here we can't, hence the extra - # refs. - assert m.mixed_args_refcount(myval, myval, myval) == (exp3 + 3, exp3 + 3, exp3 + 3) diff --git a/pybind11/tests/test_local_bindings.cpp b/pybind11/tests/test_local_bindings.cpp deleted file mode 100644 index 97c02db..0000000 --- a/pybind11/tests/test_local_bindings.cpp +++ /dev/null @@ -1,101 +0,0 @@ -/* - tests/test_local_bindings.cpp -- tests the py::module_local class feature which makes a class - binding local to the module in which it is defined. - - Copyright (c) 2017 Jason Rhinelander - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "local_bindings.h" -#include -#include -#include - -TEST_SUBMODULE(local_bindings, m) { - // test_load_external - m.def("load_external1", [](ExternalType1 &e) { return e.i; }); - m.def("load_external2", [](ExternalType2 &e) { return e.i; }); - - // test_local_bindings - // Register a class with py::module_local: - bind_local(m, "LocalType", py::module_local()) - .def("get3", [](LocalType &t) { return t.i + 3; }) - ; - - m.def("local_value", [](LocalType &l) { return l.i; }); - - // test_nonlocal_failure - // The main pybind11 test module is loaded first, so this registration will succeed (the second - // one, in pybind11_cross_module_tests.cpp, is designed to fail): - bind_local(m, "NonLocalType") - .def(py::init()) - .def("get", [](LocalType &i) { return i.i; }) - ; - - // test_duplicate_local - // py::module_local declarations should be visible across compilation units that get linked together; - // this tries to register a duplicate local. It depends on a definition in test_class.cpp and - // should raise a runtime error from the duplicate definition attempt. If test_class isn't - // available it *also* throws a runtime error (with "test_class not enabled" as value). - m.def("register_local_external", [m]() { - auto main = py::module::import("pybind11_tests"); - if (py::hasattr(main, "class_")) { - bind_local(m, "LocalExternal", py::module_local()); - } - else throw std::runtime_error("test_class not enabled"); - }); - - // test_stl_bind_local - // stl_bind.h binders defaults to py::module_local if the types are local or converting: - py::bind_vector(m, "LocalVec"); - py::bind_map(m, "LocalMap"); - // and global if the type (or one of the types, for the map) is global: - py::bind_vector(m, "NonLocalVec"); - py::bind_map(m, "NonLocalMap"); - - // test_stl_bind_global - // They can, however, be overridden to global using `py::module_local(false)`: - bind_local(m, "NonLocal2"); - py::bind_vector(m, "LocalVec2", py::module_local()); - py::bind_map(m, "NonLocalMap2", py::module_local(false)); - - // test_mixed_local_global - // We try this both with the global type registered first and vice versa (the order shouldn't - // matter). - m.def("register_mixed_global", [m]() { - bind_local(m, "MixedGlobalLocal", py::module_local(false)); - }); - m.def("register_mixed_local", [m]() { - bind_local(m, "MixedLocalGlobal", py::module_local()); - }); - m.def("get_mixed_gl", [](int i) { return MixedGlobalLocal(i); }); - m.def("get_mixed_lg", [](int i) { return MixedLocalGlobal(i); }); - - // test_internal_locals_differ - m.def("local_cpp_types_addr", []() { return (uintptr_t) &py::detail::registered_local_types_cpp(); }); - - // test_stl_caster_vs_stl_bind - m.def("load_vector_via_caster", [](std::vector v) { - return std::accumulate(v.begin(), v.end(), 0); - }); - - // test_cross_module_calls - m.def("return_self", [](LocalVec *v) { return v; }); - m.def("return_copy", [](const LocalVec &v) { return LocalVec(v); }); - - class Cat : public pets::Pet { public: Cat(std::string name) : Pet(name) {}; }; - py::class_(m, "Pet", py::module_local()) - .def("get_name", &pets::Pet::name); - // Binding for local extending class: - py::class_(m, "Cat") - .def(py::init()); - m.def("pet_name", [](pets::Pet &p) { return p.name(); }); - - py::class_(m, "MixGL").def(py::init()); - m.def("get_gl_value", [](MixGL &o) { return o.i + 10; }); - - py::class_(m, "MixGL2").def(py::init()); -} diff --git a/pybind11/tests/test_local_bindings.py b/pybind11/tests/test_local_bindings.py deleted file mode 100644 index b380376..0000000 --- a/pybind11/tests/test_local_bindings.py +++ /dev/null @@ -1,226 +0,0 @@ -import pytest - -from pybind11_tests import local_bindings as m - - -def test_load_external(): - """Load a `py::module_local` type that's only registered in an external module""" - import pybind11_cross_module_tests as cm - - assert m.load_external1(cm.ExternalType1(11)) == 11 - assert m.load_external2(cm.ExternalType2(22)) == 22 - - with pytest.raises(TypeError) as excinfo: - assert m.load_external2(cm.ExternalType1(21)) == 21 - assert "incompatible function arguments" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - assert m.load_external1(cm.ExternalType2(12)) == 12 - assert "incompatible function arguments" in str(excinfo.value) - - -def test_local_bindings(): - """Tests that duplicate `py::module_local` class bindings work across modules""" - - # Make sure we can load the second module with the conflicting (but local) definition: - import pybind11_cross_module_tests as cm - - i1 = m.LocalType(5) - assert i1.get() == 4 - assert i1.get3() == 8 - - i2 = cm.LocalType(10) - assert i2.get() == 11 - assert i2.get2() == 12 - - assert not hasattr(i1, 'get2') - assert not hasattr(i2, 'get3') - - # Loading within the local module - assert m.local_value(i1) == 5 - assert cm.local_value(i2) == 10 - - # Cross-module loading works as well (on failure, the type loader looks for - # external module-local converters): - assert m.local_value(i2) == 10 - assert cm.local_value(i1) == 5 - - -def test_nonlocal_failure(): - """Tests that attempting to register a non-local type in multiple modules fails""" - import pybind11_cross_module_tests as cm - - with pytest.raises(RuntimeError) as excinfo: - cm.register_nonlocal() - assert str(excinfo.value) == 'generic_type: type "NonLocalType" is already registered!' - - -def test_duplicate_local(): - """Tests expected failure when registering a class twice with py::local in the same module""" - with pytest.raises(RuntimeError) as excinfo: - m.register_local_external() - import pybind11_tests - assert str(excinfo.value) == ( - 'generic_type: type "LocalExternal" is already registered!' - if hasattr(pybind11_tests, 'class_') else 'test_class not enabled') - - -def test_stl_bind_local(): - import pybind11_cross_module_tests as cm - - v1, v2 = m.LocalVec(), cm.LocalVec() - v1.append(m.LocalType(1)) - v1.append(m.LocalType(2)) - v2.append(cm.LocalType(1)) - v2.append(cm.LocalType(2)) - - # Cross module value loading: - v1.append(cm.LocalType(3)) - v2.append(m.LocalType(3)) - - assert [i.get() for i in v1] == [0, 1, 2] - assert [i.get() for i in v2] == [2, 3, 4] - - v3, v4 = m.NonLocalVec(), cm.NonLocalVec2() - v3.append(m.NonLocalType(1)) - v3.append(m.NonLocalType(2)) - v4.append(m.NonLocal2(3)) - v4.append(m.NonLocal2(4)) - - assert [i.get() for i in v3] == [1, 2] - assert [i.get() for i in v4] == [13, 14] - - d1, d2 = m.LocalMap(), cm.LocalMap() - d1["a"] = v1[0] - d1["b"] = v1[1] - d2["c"] = v2[0] - d2["d"] = v2[1] - assert {i: d1[i].get() for i in d1} == {'a': 0, 'b': 1} - assert {i: d2[i].get() for i in d2} == {'c': 2, 'd': 3} - - -def test_stl_bind_global(): - import pybind11_cross_module_tests as cm - - with pytest.raises(RuntimeError) as excinfo: - cm.register_nonlocal_map() - assert str(excinfo.value) == 'generic_type: type "NonLocalMap" is already registered!' - - with pytest.raises(RuntimeError) as excinfo: - cm.register_nonlocal_vec() - assert str(excinfo.value) == 'generic_type: type "NonLocalVec" is already registered!' - - with pytest.raises(RuntimeError) as excinfo: - cm.register_nonlocal_map2() - assert str(excinfo.value) == 'generic_type: type "NonLocalMap2" is already registered!' - - -def test_mixed_local_global(): - """Local types take precedence over globally registered types: a module with a `module_local` - type can be registered even if the type is already registered globally. With the module, - casting will go to the local type; outside the module casting goes to the global type.""" - import pybind11_cross_module_tests as cm - m.register_mixed_global() - m.register_mixed_local() - - a = [] - a.append(m.MixedGlobalLocal(1)) - a.append(m.MixedLocalGlobal(2)) - a.append(m.get_mixed_gl(3)) - a.append(m.get_mixed_lg(4)) - - assert [x.get() for x in a] == [101, 1002, 103, 1004] - - cm.register_mixed_global_local() - cm.register_mixed_local_global() - a.append(m.MixedGlobalLocal(5)) - a.append(m.MixedLocalGlobal(6)) - a.append(cm.MixedGlobalLocal(7)) - a.append(cm.MixedLocalGlobal(8)) - a.append(m.get_mixed_gl(9)) - a.append(m.get_mixed_lg(10)) - a.append(cm.get_mixed_gl(11)) - a.append(cm.get_mixed_lg(12)) - - assert [x.get() for x in a] == \ - [101, 1002, 103, 1004, 105, 1006, 207, 2008, 109, 1010, 211, 2012] - - -def test_internal_locals_differ(): - """Makes sure the internal local type map differs across the two modules""" - import pybind11_cross_module_tests as cm - assert m.local_cpp_types_addr() != cm.local_cpp_types_addr() - - -def test_stl_caster_vs_stl_bind(msg): - """One module uses a generic vector caster from `` while the other - exports `std::vector` via `py:bind_vector` and `py::module_local`""" - import pybind11_cross_module_tests as cm - - v1 = cm.VectorInt([1, 2, 3]) - assert m.load_vector_via_caster(v1) == 6 - assert cm.load_vector_via_binding(v1) == 6 - - v2 = [1, 2, 3] - assert m.load_vector_via_caster(v2) == 6 - with pytest.raises(TypeError) as excinfo: - cm.load_vector_via_binding(v2) == 6 - assert msg(excinfo.value) == """ - load_vector_via_binding(): incompatible function arguments. The following argument types are supported: - 1. (arg0: pybind11_cross_module_tests.VectorInt) -> int - - Invoked with: [1, 2, 3] - """ # noqa: E501 line too long - - -def test_cross_module_calls(): - import pybind11_cross_module_tests as cm - - v1 = m.LocalVec() - v1.append(m.LocalType(1)) - v2 = cm.LocalVec() - v2.append(cm.LocalType(2)) - - # Returning the self pointer should get picked up as returning an existing - # instance (even when that instance is of a foreign, non-local type). - assert m.return_self(v1) is v1 - assert cm.return_self(v2) is v2 - assert m.return_self(v2) is v2 - assert cm.return_self(v1) is v1 - - assert m.LocalVec is not cm.LocalVec - # Returning a copy, on the other hand, always goes to the local type, - # regardless of where the source type came from. - assert type(m.return_copy(v1)) is m.LocalVec - assert type(m.return_copy(v2)) is m.LocalVec - assert type(cm.return_copy(v1)) is cm.LocalVec - assert type(cm.return_copy(v2)) is cm.LocalVec - - # Test the example given in the documentation (which also tests inheritance casting): - mycat = m.Cat("Fluffy") - mydog = cm.Dog("Rover") - assert mycat.get_name() == "Fluffy" - assert mydog.name() == "Rover" - assert m.Cat.__base__.__name__ == "Pet" - assert cm.Dog.__base__.__name__ == "Pet" - assert m.Cat.__base__ is not cm.Dog.__base__ - assert m.pet_name(mycat) == "Fluffy" - assert m.pet_name(mydog) == "Rover" - assert cm.pet_name(mycat) == "Fluffy" - assert cm.pet_name(mydog) == "Rover" - - assert m.MixGL is not cm.MixGL - a = m.MixGL(1) - b = cm.MixGL(2) - assert m.get_gl_value(a) == 11 - assert m.get_gl_value(b) == 12 - assert cm.get_gl_value(a) == 101 - assert cm.get_gl_value(b) == 102 - - c, d = m.MixGL2(3), cm.MixGL2(4) - with pytest.raises(TypeError) as excinfo: - m.get_gl_value(c) - assert "incompatible function arguments" in str(excinfo.value) - with pytest.raises(TypeError) as excinfo: - m.get_gl_value(d) - assert "incompatible function arguments" in str(excinfo.value) diff --git a/pybind11/tests/test_methods_and_attributes.cpp b/pybind11/tests/test_methods_and_attributes.cpp deleted file mode 100644 index c7b82f1..0000000 --- a/pybind11/tests/test_methods_and_attributes.cpp +++ /dev/null @@ -1,460 +0,0 @@ -/* - tests/test_methods_and_attributes.cpp -- constructors, deconstructors, attribute access, - __str__, argument and return value conventions - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" - -#if !defined(PYBIND11_OVERLOAD_CAST) -template -using overload_cast_ = pybind11::detail::overload_cast_impl; -#endif - -class ExampleMandA { -public: - ExampleMandA() { print_default_created(this); } - ExampleMandA(int value) : value(value) { print_created(this, value); } - ExampleMandA(const ExampleMandA &e) : value(e.value) { print_copy_created(this); } - ExampleMandA(ExampleMandA &&e) : value(e.value) { print_move_created(this); } - ~ExampleMandA() { print_destroyed(this); } - - std::string toString() { - return "ExampleMandA[value=" + std::to_string(value) + "]"; - } - - void operator=(const ExampleMandA &e) { print_copy_assigned(this); value = e.value; } - void operator=(ExampleMandA &&e) { print_move_assigned(this); value = e.value; } - - void add1(ExampleMandA other) { value += other.value; } // passing by value - void add2(ExampleMandA &other) { value += other.value; } // passing by reference - void add3(const ExampleMandA &other) { value += other.value; } // passing by const reference - void add4(ExampleMandA *other) { value += other->value; } // passing by pointer - void add5(const ExampleMandA *other) { value += other->value; } // passing by const pointer - - void add6(int other) { value += other; } // passing by value - void add7(int &other) { value += other; } // passing by reference - void add8(const int &other) { value += other; } // passing by const reference - void add9(int *other) { value += *other; } // passing by pointer - void add10(const int *other) { value += *other; } // passing by const pointer - - ExampleMandA self1() { return *this; } // return by value - ExampleMandA &self2() { return *this; } // return by reference - const ExampleMandA &self3() { return *this; } // return by const reference - ExampleMandA *self4() { return this; } // return by pointer - const ExampleMandA *self5() { return this; } // return by const pointer - - int internal1() { return value; } // return by value - int &internal2() { return value; } // return by reference - const int &internal3() { return value; } // return by const reference - int *internal4() { return &value; } // return by pointer - const int *internal5() { return &value; } // return by const pointer - - py::str overloaded() { return "()"; } - py::str overloaded(int) { return "(int)"; } - py::str overloaded(int, float) { return "(int, float)"; } - py::str overloaded(float, int) { return "(float, int)"; } - py::str overloaded(int, int) { return "(int, int)"; } - py::str overloaded(float, float) { return "(float, float)"; } - py::str overloaded(int) const { return "(int) const"; } - py::str overloaded(int, float) const { return "(int, float) const"; } - py::str overloaded(float, int) const { return "(float, int) const"; } - py::str overloaded(int, int) const { return "(int, int) const"; } - py::str overloaded(float, float) const { return "(float, float) const"; } - - static py::str overloaded(float) { return "static float"; } - - int value = 0; -}; - -struct TestProperties { - int value = 1; - static int static_value; - - int get() const { return value; } - void set(int v) { value = v; } - - static int static_get() { return static_value; } - static void static_set(int v) { static_value = v; } -}; -int TestProperties::static_value = 1; - -struct TestPropertiesOverride : TestProperties { - int value = 99; - static int static_value; -}; -int TestPropertiesOverride::static_value = 99; - -struct TestPropRVP { - UserType v1{1}; - UserType v2{1}; - static UserType sv1; - static UserType sv2; - - const UserType &get1() const { return v1; } - const UserType &get2() const { return v2; } - UserType get_rvalue() const { return v2; } - void set1(int v) { v1.set(v); } - void set2(int v) { v2.set(v); } -}; -UserType TestPropRVP::sv1(1); -UserType TestPropRVP::sv2(1); - -// py::arg/py::arg_v testing: these arguments just record their argument when invoked -class ArgInspector1 { public: std::string arg = "(default arg inspector 1)"; }; -class ArgInspector2 { public: std::string arg = "(default arg inspector 2)"; }; -class ArgAlwaysConverts { }; -namespace pybind11 { namespace detail { -template <> struct type_caster { -public: - PYBIND11_TYPE_CASTER(ArgInspector1, _("ArgInspector1")); - - bool load(handle src, bool convert) { - value.arg = "loading ArgInspector1 argument " + - std::string(convert ? "WITH" : "WITHOUT") + " conversion allowed. " - "Argument value = " + (std::string) str(src); - return true; - } - - static handle cast(const ArgInspector1 &src, return_value_policy, handle) { - return str(src.arg).release(); - } -}; -template <> struct type_caster { -public: - PYBIND11_TYPE_CASTER(ArgInspector2, _("ArgInspector2")); - - bool load(handle src, bool convert) { - value.arg = "loading ArgInspector2 argument " + - std::string(convert ? "WITH" : "WITHOUT") + " conversion allowed. " - "Argument value = " + (std::string) str(src); - return true; - } - - static handle cast(const ArgInspector2 &src, return_value_policy, handle) { - return str(src.arg).release(); - } -}; -template <> struct type_caster { -public: - PYBIND11_TYPE_CASTER(ArgAlwaysConverts, _("ArgAlwaysConverts")); - - bool load(handle, bool convert) { - return convert; - } - - static handle cast(const ArgAlwaysConverts &, return_value_policy, handle) { - return py::none().release(); - } -}; -}} - -// test_custom_caster_destruction -class DestructionTester { -public: - DestructionTester() { print_default_created(this); } - ~DestructionTester() { print_destroyed(this); } - DestructionTester(const DestructionTester &) { print_copy_created(this); } - DestructionTester(DestructionTester &&) { print_move_created(this); } - DestructionTester &operator=(const DestructionTester &) { print_copy_assigned(this); return *this; } - DestructionTester &operator=(DestructionTester &&) { print_move_assigned(this); return *this; } -}; -namespace pybind11 { namespace detail { -template <> struct type_caster { - PYBIND11_TYPE_CASTER(DestructionTester, _("DestructionTester")); - bool load(handle, bool) { return true; } - - static handle cast(const DestructionTester &, return_value_policy, handle) { - return py::bool_(true).release(); - } -}; -}} - -// Test None-allowed py::arg argument policy -class NoneTester { public: int answer = 42; }; -int none1(const NoneTester &obj) { return obj.answer; } -int none2(NoneTester *obj) { return obj ? obj->answer : -1; } -int none3(std::shared_ptr &obj) { return obj ? obj->answer : -1; } -int none4(std::shared_ptr *obj) { return obj && *obj ? (*obj)->answer : -1; } -int none5(std::shared_ptr obj) { return obj ? obj->answer : -1; } - -struct StrIssue { - int val = -1; - - StrIssue() = default; - StrIssue(int i) : val{i} {} -}; - -// Issues #854, #910: incompatible function args when member function/pointer is in unregistered base class -class UnregisteredBase { -public: - void do_nothing() const {} - void increase_value() { rw_value++; ro_value += 0.25; } - void set_int(int v) { rw_value = v; } - int get_int() const { return rw_value; } - double get_double() const { return ro_value; } - int rw_value = 42; - double ro_value = 1.25; -}; -class RegisteredDerived : public UnregisteredBase { -public: - using UnregisteredBase::UnregisteredBase; - double sum() const { return rw_value + ro_value; } -}; - -TEST_SUBMODULE(methods_and_attributes, m) { - // test_methods_and_attributes - py::class_ emna(m, "ExampleMandA"); - emna.def(py::init<>()) - .def(py::init()) - .def(py::init()) - .def("add1", &ExampleMandA::add1) - .def("add2", &ExampleMandA::add2) - .def("add3", &ExampleMandA::add3) - .def("add4", &ExampleMandA::add4) - .def("add5", &ExampleMandA::add5) - .def("add6", &ExampleMandA::add6) - .def("add7", &ExampleMandA::add7) - .def("add8", &ExampleMandA::add8) - .def("add9", &ExampleMandA::add9) - .def("add10", &ExampleMandA::add10) - .def("self1", &ExampleMandA::self1) - .def("self2", &ExampleMandA::self2) - .def("self3", &ExampleMandA::self3) - .def("self4", &ExampleMandA::self4) - .def("self5", &ExampleMandA::self5) - .def("internal1", &ExampleMandA::internal1) - .def("internal2", &ExampleMandA::internal2) - .def("internal3", &ExampleMandA::internal3) - .def("internal4", &ExampleMandA::internal4) - .def("internal5", &ExampleMandA::internal5) -#if defined(PYBIND11_OVERLOAD_CAST) - .def("overloaded", py::overload_cast<>(&ExampleMandA::overloaded)) - .def("overloaded", py::overload_cast(&ExampleMandA::overloaded)) - .def("overloaded", py::overload_cast(&ExampleMandA::overloaded)) - .def("overloaded", py::overload_cast(&ExampleMandA::overloaded)) - .def("overloaded", py::overload_cast(&ExampleMandA::overloaded)) - .def("overloaded", py::overload_cast(&ExampleMandA::overloaded)) - .def("overloaded_float", py::overload_cast(&ExampleMandA::overloaded)) - .def("overloaded_const", py::overload_cast(&ExampleMandA::overloaded, py::const_)) - .def("overloaded_const", py::overload_cast(&ExampleMandA::overloaded, py::const_)) - .def("overloaded_const", py::overload_cast(&ExampleMandA::overloaded, py::const_)) - .def("overloaded_const", py::overload_cast(&ExampleMandA::overloaded, py::const_)) - .def("overloaded_const", py::overload_cast(&ExampleMandA::overloaded, py::const_)) -#else - // Use both the traditional static_cast method and the C++11 compatible overload_cast_ - .def("overloaded", overload_cast_<>()(&ExampleMandA::overloaded)) - .def("overloaded", overload_cast_()(&ExampleMandA::overloaded)) - .def("overloaded", overload_cast_()(&ExampleMandA::overloaded)) - .def("overloaded", static_cast(&ExampleMandA::overloaded)) - .def("overloaded", static_cast(&ExampleMandA::overloaded)) - .def("overloaded", static_cast(&ExampleMandA::overloaded)) - .def("overloaded_float", overload_cast_()(&ExampleMandA::overloaded)) - .def("overloaded_const", overload_cast_()(&ExampleMandA::overloaded, py::const_)) - .def("overloaded_const", overload_cast_()(&ExampleMandA::overloaded, py::const_)) - .def("overloaded_const", static_cast(&ExampleMandA::overloaded)) - .def("overloaded_const", static_cast(&ExampleMandA::overloaded)) - .def("overloaded_const", static_cast(&ExampleMandA::overloaded)) -#endif - // test_no_mixed_overloads - // Raise error if trying to mix static/non-static overloads on the same name: - .def_static("add_mixed_overloads1", []() { - auto emna = py::reinterpret_borrow>(py::module::import("pybind11_tests.methods_and_attributes").attr("ExampleMandA")); - emna.def ("overload_mixed1", static_cast(&ExampleMandA::overloaded)) - .def_static("overload_mixed1", static_cast(&ExampleMandA::overloaded)); - }) - .def_static("add_mixed_overloads2", []() { - auto emna = py::reinterpret_borrow>(py::module::import("pybind11_tests.methods_and_attributes").attr("ExampleMandA")); - emna.def_static("overload_mixed2", static_cast(&ExampleMandA::overloaded)) - .def ("overload_mixed2", static_cast(&ExampleMandA::overloaded)); - }) - .def("__str__", &ExampleMandA::toString) - .def_readwrite("value", &ExampleMandA::value); - - // test_copy_method - // Issue #443: can't call copied methods in Python 3 - emna.attr("add2b") = emna.attr("add2"); - - // test_properties, test_static_properties, test_static_cls - py::class_(m, "TestProperties") - .def(py::init<>()) - .def_readonly("def_readonly", &TestProperties::value) - .def_readwrite("def_readwrite", &TestProperties::value) - .def_property("def_writeonly", nullptr, - [](TestProperties& s,int v) { s.value = v; } ) - .def_property("def_property_writeonly", nullptr, &TestProperties::set) - .def_property_readonly("def_property_readonly", &TestProperties::get) - .def_property("def_property", &TestProperties::get, &TestProperties::set) - .def_property("def_property_impossible", nullptr, nullptr) - .def_readonly_static("def_readonly_static", &TestProperties::static_value) - .def_readwrite_static("def_readwrite_static", &TestProperties::static_value) - .def_property_static("def_writeonly_static", nullptr, - [](py::object, int v) { TestProperties::static_value = v; }) - .def_property_readonly_static("def_property_readonly_static", - [](py::object) { return TestProperties::static_get(); }) - .def_property_static("def_property_writeonly_static", nullptr, - [](py::object, int v) { return TestProperties::static_set(v); }) - .def_property_static("def_property_static", - [](py::object) { return TestProperties::static_get(); }, - [](py::object, int v) { TestProperties::static_set(v); }) - .def_property_static("static_cls", - [](py::object cls) { return cls; }, - [](py::object cls, py::function f) { f(cls); }); - - py::class_(m, "TestPropertiesOverride") - .def(py::init<>()) - .def_readonly("def_readonly", &TestPropertiesOverride::value) - .def_readonly_static("def_readonly_static", &TestPropertiesOverride::static_value); - - auto static_get1 = [](py::object) -> const UserType & { return TestPropRVP::sv1; }; - auto static_get2 = [](py::object) -> const UserType & { return TestPropRVP::sv2; }; - auto static_set1 = [](py::object, int v) { TestPropRVP::sv1.set(v); }; - auto static_set2 = [](py::object, int v) { TestPropRVP::sv2.set(v); }; - auto rvp_copy = py::return_value_policy::copy; - - // test_property_return_value_policies - py::class_(m, "TestPropRVP") - .def(py::init<>()) - .def_property_readonly("ro_ref", &TestPropRVP::get1) - .def_property_readonly("ro_copy", &TestPropRVP::get2, rvp_copy) - .def_property_readonly("ro_func", py::cpp_function(&TestPropRVP::get2, rvp_copy)) - .def_property("rw_ref", &TestPropRVP::get1, &TestPropRVP::set1) - .def_property("rw_copy", &TestPropRVP::get2, &TestPropRVP::set2, rvp_copy) - .def_property("rw_func", py::cpp_function(&TestPropRVP::get2, rvp_copy), &TestPropRVP::set2) - .def_property_readonly_static("static_ro_ref", static_get1) - .def_property_readonly_static("static_ro_copy", static_get2, rvp_copy) - .def_property_readonly_static("static_ro_func", py::cpp_function(static_get2, rvp_copy)) - .def_property_static("static_rw_ref", static_get1, static_set1) - .def_property_static("static_rw_copy", static_get2, static_set2, rvp_copy) - .def_property_static("static_rw_func", py::cpp_function(static_get2, rvp_copy), static_set2) - // test_property_rvalue_policy - .def_property_readonly("rvalue", &TestPropRVP::get_rvalue) - .def_property_readonly_static("static_rvalue", [](py::object) { return UserType(1); }); - - // test_metaclass_override - struct MetaclassOverride { }; - py::class_(m, "MetaclassOverride", py::metaclass((PyObject *) &PyType_Type)) - .def_property_readonly_static("readonly", [](py::object) { return 1; }); - -#if !defined(PYPY_VERSION) - // test_dynamic_attributes - class DynamicClass { - public: - DynamicClass() { print_default_created(this); } - ~DynamicClass() { print_destroyed(this); } - }; - py::class_(m, "DynamicClass", py::dynamic_attr()) - .def(py::init()); - - class CppDerivedDynamicClass : public DynamicClass { }; - py::class_(m, "CppDerivedDynamicClass") - .def(py::init()); -#endif - - // test_noconvert_args - // - // Test converting. The ArgAlwaysConverts is just there to make the first no-conversion pass - // fail so that our call always ends up happening via the second dispatch (the one that allows - // some conversion). - class ArgInspector { - public: - ArgInspector1 f(ArgInspector1 a, ArgAlwaysConverts) { return a; } - std::string g(ArgInspector1 a, const ArgInspector1 &b, int c, ArgInspector2 *d, ArgAlwaysConverts) { - return a.arg + "\n" + b.arg + "\n" + std::to_string(c) + "\n" + d->arg; - } - static ArgInspector2 h(ArgInspector2 a, ArgAlwaysConverts) { return a; } - }; - py::class_(m, "ArgInspector") - .def(py::init<>()) - .def("f", &ArgInspector::f, py::arg(), py::arg() = ArgAlwaysConverts()) - .def("g", &ArgInspector::g, "a"_a.noconvert(), "b"_a, "c"_a.noconvert()=13, "d"_a=ArgInspector2(), py::arg() = ArgAlwaysConverts()) - .def_static("h", &ArgInspector::h, py::arg().noconvert(), py::arg() = ArgAlwaysConverts()) - ; - m.def("arg_inspect_func", [](ArgInspector2 a, ArgInspector1 b, ArgAlwaysConverts) { return a.arg + "\n" + b.arg; }, - py::arg().noconvert(false), py::arg_v(nullptr, ArgInspector1()).noconvert(true), py::arg() = ArgAlwaysConverts()); - - m.def("floats_preferred", [](double f) { return 0.5 * f; }, py::arg("f")); - m.def("floats_only", [](double f) { return 0.5 * f; }, py::arg("f").noconvert()); - m.def("ints_preferred", [](int i) { return i / 2; }, py::arg("i")); - m.def("ints_only", [](int i) { return i / 2; }, py::arg("i").noconvert()); - - // test_bad_arg_default - // Issue/PR #648: bad arg default debugging output -#if !defined(NDEBUG) - m.attr("debug_enabled") = true; -#else - m.attr("debug_enabled") = false; -#endif - m.def("bad_arg_def_named", []{ - auto m = py::module::import("pybind11_tests"); - m.def("should_fail", [](int, UnregisteredType) {}, py::arg(), py::arg("a") = UnregisteredType()); - }); - m.def("bad_arg_def_unnamed", []{ - auto m = py::module::import("pybind11_tests"); - m.def("should_fail", [](int, UnregisteredType) {}, py::arg(), py::arg() = UnregisteredType()); - }); - - // test_accepts_none - py::class_>(m, "NoneTester") - .def(py::init<>()); - m.def("no_none1", &none1, py::arg().none(false)); - m.def("no_none2", &none2, py::arg().none(false)); - m.def("no_none3", &none3, py::arg().none(false)); - m.def("no_none4", &none4, py::arg().none(false)); - m.def("no_none5", &none5, py::arg().none(false)); - m.def("ok_none1", &none1); - m.def("ok_none2", &none2, py::arg().none(true)); - m.def("ok_none3", &none3); - m.def("ok_none4", &none4, py::arg().none(true)); - m.def("ok_none5", &none5); - - // test_str_issue - // Issue #283: __str__ called on uninitialized instance when constructor arguments invalid - py::class_(m, "StrIssue") - .def(py::init()) - .def(py::init<>()) - .def("__str__", [](const StrIssue &si) { - return "StrIssue[" + std::to_string(si.val) + "]"; } - ); - - // test_unregistered_base_implementations - // - // Issues #854/910: incompatible function args when member function/pointer is in unregistered - // base class The methods and member pointers below actually resolve to members/pointers in - // UnregisteredBase; before this test/fix they would be registered via lambda with a first - // argument of an unregistered type, and thus uncallable. - py::class_(m, "RegisteredDerived") - .def(py::init<>()) - .def("do_nothing", &RegisteredDerived::do_nothing) - .def("increase_value", &RegisteredDerived::increase_value) - .def_readwrite("rw_value", &RegisteredDerived::rw_value) - .def_readonly("ro_value", &RegisteredDerived::ro_value) - // These should trigger a static_assert if uncommented - //.def_readwrite("fails", &UserType::value) // should trigger a static_assert if uncommented - //.def_readonly("fails", &UserType::value) // should trigger a static_assert if uncommented - .def_property("rw_value_prop", &RegisteredDerived::get_int, &RegisteredDerived::set_int) - .def_property_readonly("ro_value_prop", &RegisteredDerived::get_double) - // This one is in the registered class: - .def("sum", &RegisteredDerived::sum) - ; - - using Adapted = decltype(py::method_adaptor(&RegisteredDerived::do_nothing)); - static_assert(std::is_same::value, ""); - - // test_custom_caster_destruction - // Test that `take_ownership` works on types with a custom type caster when given a pointer - - // default policy: don't take ownership: - m.def("custom_caster_no_destroy", []() { static auto *dt = new DestructionTester(); return dt; }); - - m.def("custom_caster_destroy", []() { return new DestructionTester(); }, - py::return_value_policy::take_ownership); // Takes ownership: destroy when finished - m.def("custom_caster_destroy_const", []() -> const DestructionTester * { return new DestructionTester(); }, - py::return_value_policy::take_ownership); // Likewise (const doesn't inhibit destruction) - m.def("destruction_tester_cstats", &ConstructorStats::get, py::return_value_policy::reference); -} diff --git a/pybind11/tests/test_methods_and_attributes.py b/pybind11/tests/test_methods_and_attributes.py deleted file mode 100644 index f1c862b..0000000 --- a/pybind11/tests/test_methods_and_attributes.py +++ /dev/null @@ -1,512 +0,0 @@ -import pytest -from pybind11_tests import methods_and_attributes as m -from pybind11_tests import ConstructorStats - - -def test_methods_and_attributes(): - instance1 = m.ExampleMandA() - instance2 = m.ExampleMandA(32) - - instance1.add1(instance2) - instance1.add2(instance2) - instance1.add3(instance2) - instance1.add4(instance2) - instance1.add5(instance2) - instance1.add6(32) - instance1.add7(32) - instance1.add8(32) - instance1.add9(32) - instance1.add10(32) - - assert str(instance1) == "ExampleMandA[value=320]" - assert str(instance2) == "ExampleMandA[value=32]" - assert str(instance1.self1()) == "ExampleMandA[value=320]" - assert str(instance1.self2()) == "ExampleMandA[value=320]" - assert str(instance1.self3()) == "ExampleMandA[value=320]" - assert str(instance1.self4()) == "ExampleMandA[value=320]" - assert str(instance1.self5()) == "ExampleMandA[value=320]" - - assert instance1.internal1() == 320 - assert instance1.internal2() == 320 - assert instance1.internal3() == 320 - assert instance1.internal4() == 320 - assert instance1.internal5() == 320 - - assert instance1.overloaded() == "()" - assert instance1.overloaded(0) == "(int)" - assert instance1.overloaded(1, 1.0) == "(int, float)" - assert instance1.overloaded(2.0, 2) == "(float, int)" - assert instance1.overloaded(3, 3) == "(int, int)" - assert instance1.overloaded(4., 4.) == "(float, float)" - assert instance1.overloaded_const(-3) == "(int) const" - assert instance1.overloaded_const(5, 5.0) == "(int, float) const" - assert instance1.overloaded_const(6.0, 6) == "(float, int) const" - assert instance1.overloaded_const(7, 7) == "(int, int) const" - assert instance1.overloaded_const(8., 8.) == "(float, float) const" - assert instance1.overloaded_float(1, 1) == "(float, float)" - assert instance1.overloaded_float(1, 1.) == "(float, float)" - assert instance1.overloaded_float(1., 1) == "(float, float)" - assert instance1.overloaded_float(1., 1.) == "(float, float)" - - assert instance1.value == 320 - instance1.value = 100 - assert str(instance1) == "ExampleMandA[value=100]" - - cstats = ConstructorStats.get(m.ExampleMandA) - assert cstats.alive() == 2 - del instance1, instance2 - assert cstats.alive() == 0 - assert cstats.values() == ["32"] - assert cstats.default_constructions == 1 - assert cstats.copy_constructions == 3 - assert cstats.move_constructions >= 1 - assert cstats.copy_assignments == 0 - assert cstats.move_assignments == 0 - - -def test_copy_method(): - """Issue #443: calling copied methods fails in Python 3""" - - m.ExampleMandA.add2c = m.ExampleMandA.add2 - m.ExampleMandA.add2d = m.ExampleMandA.add2b - a = m.ExampleMandA(123) - assert a.value == 123 - a.add2(m.ExampleMandA(-100)) - assert a.value == 23 - a.add2b(m.ExampleMandA(20)) - assert a.value == 43 - a.add2c(m.ExampleMandA(6)) - assert a.value == 49 - a.add2d(m.ExampleMandA(-7)) - assert a.value == 42 - - -def test_properties(): - instance = m.TestProperties() - - assert instance.def_readonly == 1 - with pytest.raises(AttributeError): - instance.def_readonly = 2 - - instance.def_readwrite = 2 - assert instance.def_readwrite == 2 - - assert instance.def_property_readonly == 2 - with pytest.raises(AttributeError): - instance.def_property_readonly = 3 - - instance.def_property = 3 - assert instance.def_property == 3 - - with pytest.raises(AttributeError) as excinfo: - dummy = instance.def_property_writeonly # noqa: F841 unused var - assert "unreadable attribute" in str(excinfo.value) - - instance.def_property_writeonly = 4 - assert instance.def_property_readonly == 4 - - with pytest.raises(AttributeError) as excinfo: - dummy = instance.def_property_impossible # noqa: F841 unused var - assert "unreadable attribute" in str(excinfo.value) - - with pytest.raises(AttributeError) as excinfo: - instance.def_property_impossible = 5 - assert "can't set attribute" in str(excinfo.value) - - -def test_static_properties(): - assert m.TestProperties.def_readonly_static == 1 - with pytest.raises(AttributeError) as excinfo: - m.TestProperties.def_readonly_static = 2 - assert "can't set attribute" in str(excinfo.value) - - m.TestProperties.def_readwrite_static = 2 - assert m.TestProperties.def_readwrite_static == 2 - - with pytest.raises(AttributeError) as excinfo: - dummy = m.TestProperties.def_writeonly_static # noqa: F841 unused var - assert "unreadable attribute" in str(excinfo.value) - - m.TestProperties.def_writeonly_static = 3 - assert m.TestProperties.def_readonly_static == 3 - - assert m.TestProperties.def_property_readonly_static == 3 - with pytest.raises(AttributeError) as excinfo: - m.TestProperties.def_property_readonly_static = 99 - assert "can't set attribute" in str(excinfo.value) - - m.TestProperties.def_property_static = 4 - assert m.TestProperties.def_property_static == 4 - - with pytest.raises(AttributeError) as excinfo: - dummy = m.TestProperties.def_property_writeonly_static - assert "unreadable attribute" in str(excinfo.value) - - m.TestProperties.def_property_writeonly_static = 5 - assert m.TestProperties.def_property_static == 5 - - # Static property read and write via instance - instance = m.TestProperties() - - m.TestProperties.def_readwrite_static = 0 - assert m.TestProperties.def_readwrite_static == 0 - assert instance.def_readwrite_static == 0 - - instance.def_readwrite_static = 2 - assert m.TestProperties.def_readwrite_static == 2 - assert instance.def_readwrite_static == 2 - - with pytest.raises(AttributeError) as excinfo: - dummy = instance.def_property_writeonly_static # noqa: F841 unused var - assert "unreadable attribute" in str(excinfo.value) - - instance.def_property_writeonly_static = 4 - assert instance.def_property_static == 4 - - # It should be possible to override properties in derived classes - assert m.TestPropertiesOverride().def_readonly == 99 - assert m.TestPropertiesOverride.def_readonly_static == 99 - - -def test_static_cls(): - """Static property getter and setters expect the type object as the their only argument""" - - instance = m.TestProperties() - assert m.TestProperties.static_cls is m.TestProperties - assert instance.static_cls is m.TestProperties - - def check_self(self): - assert self is m.TestProperties - - m.TestProperties.static_cls = check_self - instance.static_cls = check_self - - -def test_metaclass_override(): - """Overriding pybind11's default metaclass changes the behavior of `static_property`""" - - assert type(m.ExampleMandA).__name__ == "pybind11_type" - assert type(m.MetaclassOverride).__name__ == "type" - - assert m.MetaclassOverride.readonly == 1 - assert type(m.MetaclassOverride.__dict__["readonly"]).__name__ == "pybind11_static_property" - - # Regular `type` replaces the property instead of calling `__set__()` - m.MetaclassOverride.readonly = 2 - assert m.MetaclassOverride.readonly == 2 - assert isinstance(m.MetaclassOverride.__dict__["readonly"], int) - - -def test_no_mixed_overloads(): - from pybind11_tests import debug_enabled - - with pytest.raises(RuntimeError) as excinfo: - m.ExampleMandA.add_mixed_overloads1() - assert (str(excinfo.value) == - "overloading a method with both static and instance methods is not supported; " + - ("compile in debug mode for more details" if not debug_enabled else - "error while attempting to bind static method ExampleMandA.overload_mixed1" - "(arg0: float) -> str") - ) - - with pytest.raises(RuntimeError) as excinfo: - m.ExampleMandA.add_mixed_overloads2() - assert (str(excinfo.value) == - "overloading a method with both static and instance methods is not supported; " + - ("compile in debug mode for more details" if not debug_enabled else - "error while attempting to bind instance method ExampleMandA.overload_mixed2" - "(self: pybind11_tests.methods_and_attributes.ExampleMandA, arg0: int, arg1: int)" - " -> str") - ) - - -@pytest.mark.parametrize("access", ["ro", "rw", "static_ro", "static_rw"]) -def test_property_return_value_policies(access): - if not access.startswith("static"): - obj = m.TestPropRVP() - else: - obj = m.TestPropRVP - - ref = getattr(obj, access + "_ref") - assert ref.value == 1 - ref.value = 2 - assert getattr(obj, access + "_ref").value == 2 - ref.value = 1 # restore original value for static properties - - copy = getattr(obj, access + "_copy") - assert copy.value == 1 - copy.value = 2 - assert getattr(obj, access + "_copy").value == 1 - - copy = getattr(obj, access + "_func") - assert copy.value == 1 - copy.value = 2 - assert getattr(obj, access + "_func").value == 1 - - -def test_property_rvalue_policy(): - """When returning an rvalue, the return value policy is automatically changed from - `reference(_internal)` to `move`. The following would not work otherwise.""" - - instance = m.TestPropRVP() - o = instance.rvalue - assert o.value == 1 - - os = m.TestPropRVP.static_rvalue - assert os.value == 1 - - -# https://bitbucket.org/pypy/pypy/issues/2447 -@pytest.unsupported_on_pypy -def test_dynamic_attributes(): - instance = m.DynamicClass() - assert not hasattr(instance, "foo") - assert "foo" not in dir(instance) - - # Dynamically add attribute - instance.foo = 42 - assert hasattr(instance, "foo") - assert instance.foo == 42 - assert "foo" in dir(instance) - - # __dict__ should be accessible and replaceable - assert "foo" in instance.__dict__ - instance.__dict__ = {"bar": True} - assert not hasattr(instance, "foo") - assert hasattr(instance, "bar") - - with pytest.raises(TypeError) as excinfo: - instance.__dict__ = [] - assert str(excinfo.value) == "__dict__ must be set to a dictionary, not a 'list'" - - cstats = ConstructorStats.get(m.DynamicClass) - assert cstats.alive() == 1 - del instance - assert cstats.alive() == 0 - - # Derived classes should work as well - class PythonDerivedDynamicClass(m.DynamicClass): - pass - - for cls in m.CppDerivedDynamicClass, PythonDerivedDynamicClass: - derived = cls() - derived.foobar = 100 - assert derived.foobar == 100 - - assert cstats.alive() == 1 - del derived - assert cstats.alive() == 0 - - -# https://bitbucket.org/pypy/pypy/issues/2447 -@pytest.unsupported_on_pypy -def test_cyclic_gc(): - # One object references itself - instance = m.DynamicClass() - instance.circular_reference = instance - - cstats = ConstructorStats.get(m.DynamicClass) - assert cstats.alive() == 1 - del instance - assert cstats.alive() == 0 - - # Two object reference each other - i1 = m.DynamicClass() - i2 = m.DynamicClass() - i1.cycle = i2 - i2.cycle = i1 - - assert cstats.alive() == 2 - del i1, i2 - assert cstats.alive() == 0 - - -def test_noconvert_args(msg): - a = m.ArgInspector() - assert msg(a.f("hi")) == """ - loading ArgInspector1 argument WITH conversion allowed. Argument value = hi - """ - assert msg(a.g("this is a", "this is b")) == """ - loading ArgInspector1 argument WITHOUT conversion allowed. Argument value = this is a - loading ArgInspector1 argument WITH conversion allowed. Argument value = this is b - 13 - loading ArgInspector2 argument WITH conversion allowed. Argument value = (default arg inspector 2) - """ # noqa: E501 line too long - assert msg(a.g("this is a", "this is b", 42)) == """ - loading ArgInspector1 argument WITHOUT conversion allowed. Argument value = this is a - loading ArgInspector1 argument WITH conversion allowed. Argument value = this is b - 42 - loading ArgInspector2 argument WITH conversion allowed. Argument value = (default arg inspector 2) - """ # noqa: E501 line too long - assert msg(a.g("this is a", "this is b", 42, "this is d")) == """ - loading ArgInspector1 argument WITHOUT conversion allowed. Argument value = this is a - loading ArgInspector1 argument WITH conversion allowed. Argument value = this is b - 42 - loading ArgInspector2 argument WITH conversion allowed. Argument value = this is d - """ - assert (a.h("arg 1") == - "loading ArgInspector2 argument WITHOUT conversion allowed. Argument value = arg 1") - assert msg(m.arg_inspect_func("A1", "A2")) == """ - loading ArgInspector2 argument WITH conversion allowed. Argument value = A1 - loading ArgInspector1 argument WITHOUT conversion allowed. Argument value = A2 - """ - - assert m.floats_preferred(4) == 2.0 - assert m.floats_only(4.0) == 2.0 - with pytest.raises(TypeError) as excinfo: - m.floats_only(4) - assert msg(excinfo.value) == """ - floats_only(): incompatible function arguments. The following argument types are supported: - 1. (f: float) -> float - - Invoked with: 4 - """ - - assert m.ints_preferred(4) == 2 - assert m.ints_preferred(True) == 0 - with pytest.raises(TypeError) as excinfo: - m.ints_preferred(4.0) - assert msg(excinfo.value) == """ - ints_preferred(): incompatible function arguments. The following argument types are supported: - 1. (i: int) -> int - - Invoked with: 4.0 - """ # noqa: E501 line too long - - assert m.ints_only(4) == 2 - with pytest.raises(TypeError) as excinfo: - m.ints_only(4.0) - assert msg(excinfo.value) == """ - ints_only(): incompatible function arguments. The following argument types are supported: - 1. (i: int) -> int - - Invoked with: 4.0 - """ - - -def test_bad_arg_default(msg): - from pybind11_tests import debug_enabled - - with pytest.raises(RuntimeError) as excinfo: - m.bad_arg_def_named() - assert msg(excinfo.value) == ( - "arg(): could not convert default argument 'a: UnregisteredType' in function " - "'should_fail' into a Python object (type not registered yet?)" - if debug_enabled else - "arg(): could not convert default argument into a Python object (type not registered " - "yet?). Compile in debug mode for more information." - ) - - with pytest.raises(RuntimeError) as excinfo: - m.bad_arg_def_unnamed() - assert msg(excinfo.value) == ( - "arg(): could not convert default argument 'UnregisteredType' in function " - "'should_fail' into a Python object (type not registered yet?)" - if debug_enabled else - "arg(): could not convert default argument into a Python object (type not registered " - "yet?). Compile in debug mode for more information." - ) - - -def test_accepts_none(msg): - a = m.NoneTester() - assert m.no_none1(a) == 42 - assert m.no_none2(a) == 42 - assert m.no_none3(a) == 42 - assert m.no_none4(a) == 42 - assert m.no_none5(a) == 42 - assert m.ok_none1(a) == 42 - assert m.ok_none2(a) == 42 - assert m.ok_none3(a) == 42 - assert m.ok_none4(a) == 42 - assert m.ok_none5(a) == 42 - - with pytest.raises(TypeError) as excinfo: - m.no_none1(None) - assert "incompatible function arguments" in str(excinfo.value) - with pytest.raises(TypeError) as excinfo: - m.no_none2(None) - assert "incompatible function arguments" in str(excinfo.value) - with pytest.raises(TypeError) as excinfo: - m.no_none3(None) - assert "incompatible function arguments" in str(excinfo.value) - with pytest.raises(TypeError) as excinfo: - m.no_none4(None) - assert "incompatible function arguments" in str(excinfo.value) - with pytest.raises(TypeError) as excinfo: - m.no_none5(None) - assert "incompatible function arguments" in str(excinfo.value) - - # The first one still raises because you can't pass None as a lvalue reference arg: - with pytest.raises(TypeError) as excinfo: - assert m.ok_none1(None) == -1 - assert msg(excinfo.value) == """ - ok_none1(): incompatible function arguments. The following argument types are supported: - 1. (arg0: m.methods_and_attributes.NoneTester) -> int - - Invoked with: None - """ - - # The rest take the argument as pointer or holder, and accept None: - assert m.ok_none2(None) == -1 - assert m.ok_none3(None) == -1 - assert m.ok_none4(None) == -1 - assert m.ok_none5(None) == -1 - - -def test_str_issue(msg): - """#283: __str__ called on uninitialized instance when constructor arguments invalid""" - - assert str(m.StrIssue(3)) == "StrIssue[3]" - - with pytest.raises(TypeError) as excinfo: - str(m.StrIssue("no", "such", "constructor")) - assert msg(excinfo.value) == """ - __init__(): incompatible constructor arguments. The following argument types are supported: - 1. m.methods_and_attributes.StrIssue(arg0: int) - 2. m.methods_and_attributes.StrIssue() - - Invoked with: 'no', 'such', 'constructor' - """ - - -def test_unregistered_base_implementations(): - a = m.RegisteredDerived() - a.do_nothing() - assert a.rw_value == 42 - assert a.ro_value == 1.25 - a.rw_value += 5 - assert a.sum() == 48.25 - a.increase_value() - assert a.rw_value == 48 - assert a.ro_value == 1.5 - assert a.sum() == 49.5 - assert a.rw_value_prop == 48 - a.rw_value_prop += 1 - assert a.rw_value_prop == 49 - a.increase_value() - assert a.ro_value_prop == 1.75 - - -def test_custom_caster_destruction(): - """Tests that returning a pointer to a type that gets converted with a custom type caster gets - destroyed when the function has py::return_value_policy::take_ownership policy applied.""" - - cstats = m.destruction_tester_cstats() - # This one *doesn't* have take_ownership: the pointer should be used but not destroyed: - z = m.custom_caster_no_destroy() - assert cstats.alive() == 1 and cstats.default_constructions == 1 - assert z - - # take_ownership applied: this constructs a new object, casts it, then destroys it: - z = m.custom_caster_destroy() - assert z - assert cstats.default_constructions == 2 - - # Same, but with a const pointer return (which should *not* inhibit destruction): - z = m.custom_caster_destroy_const() - assert z - assert cstats.default_constructions == 3 - - # Make sure we still only have the original object (from ..._no_destroy()) alive: - assert cstats.alive() == 1 diff --git a/pybind11/tests/test_modules.cpp b/pybind11/tests/test_modules.cpp deleted file mode 100644 index c1475fa..0000000 --- a/pybind11/tests/test_modules.cpp +++ /dev/null @@ -1,98 +0,0 @@ -/* - tests/test_modules.cpp -- nested modules, importing modules, and - internal references - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" - -TEST_SUBMODULE(modules, m) { - // test_nested_modules - py::module m_sub = m.def_submodule("subsubmodule"); - m_sub.def("submodule_func", []() { return "submodule_func()"; }); - - // test_reference_internal - class A { - public: - A(int v) : v(v) { print_created(this, v); } - ~A() { print_destroyed(this); } - A(const A&) { print_copy_created(this); } - A& operator=(const A ©) { print_copy_assigned(this); v = copy.v; return *this; } - std::string toString() { return "A[" + std::to_string(v) + "]"; } - private: - int v; - }; - py::class_(m_sub, "A") - .def(py::init()) - .def("__repr__", &A::toString); - - class B { - public: - B() { print_default_created(this); } - ~B() { print_destroyed(this); } - B(const B&) { print_copy_created(this); } - B& operator=(const B ©) { print_copy_assigned(this); a1 = copy.a1; a2 = copy.a2; return *this; } - A &get_a1() { return a1; } - A &get_a2() { return a2; } - - A a1{1}; - A a2{2}; - }; - py::class_(m_sub, "B") - .def(py::init<>()) - .def("get_a1", &B::get_a1, "Return the internal A 1", py::return_value_policy::reference_internal) - .def("get_a2", &B::get_a2, "Return the internal A 2", py::return_value_policy::reference_internal) - .def_readwrite("a1", &B::a1) // def_readonly uses an internal reference return policy by default - .def_readwrite("a2", &B::a2); - - m.attr("OD") = py::module::import("collections").attr("OrderedDict"); - - // test_duplicate_registration - // Registering two things with the same name - m.def("duplicate_registration", []() { - class Dupe1 { }; - class Dupe2 { }; - class Dupe3 { }; - class DupeException { }; - - auto dm = py::module("dummy"); - auto failures = py::list(); - - py::class_(dm, "Dupe1"); - py::class_(dm, "Dupe2"); - dm.def("dupe1_factory", []() { return Dupe1(); }); - py::exception(dm, "DupeException"); - - try { - py::class_(dm, "Dupe1"); - failures.append("Dupe1 class"); - } catch (std::runtime_error &) {} - try { - dm.def("Dupe1", []() { return Dupe1(); }); - failures.append("Dupe1 function"); - } catch (std::runtime_error &) {} - try { - py::class_(dm, "dupe1_factory"); - failures.append("dupe1_factory"); - } catch (std::runtime_error &) {} - try { - py::exception(dm, "Dupe2"); - failures.append("Dupe2"); - } catch (std::runtime_error &) {} - try { - dm.def("DupeException", []() { return 30; }); - failures.append("DupeException1"); - } catch (std::runtime_error &) {} - try { - py::class_(dm, "DupeException"); - failures.append("DupeException2"); - } catch (std::runtime_error &) {} - - return failures; - }); -} diff --git a/pybind11/tests/test_modules.py b/pybind11/tests/test_modules.py deleted file mode 100644 index 2552838..0000000 --- a/pybind11/tests/test_modules.py +++ /dev/null @@ -1,72 +0,0 @@ -from pybind11_tests import modules as m -from pybind11_tests.modules import subsubmodule as ms -from pybind11_tests import ConstructorStats - - -def test_nested_modules(): - import pybind11_tests - assert pybind11_tests.__name__ == "pybind11_tests" - assert pybind11_tests.modules.__name__ == "pybind11_tests.modules" - assert pybind11_tests.modules.subsubmodule.__name__ == "pybind11_tests.modules.subsubmodule" - assert m.__name__ == "pybind11_tests.modules" - assert ms.__name__ == "pybind11_tests.modules.subsubmodule" - - assert ms.submodule_func() == "submodule_func()" - - -def test_reference_internal(): - b = ms.B() - assert str(b.get_a1()) == "A[1]" - assert str(b.a1) == "A[1]" - assert str(b.get_a2()) == "A[2]" - assert str(b.a2) == "A[2]" - - b.a1 = ms.A(42) - b.a2 = ms.A(43) - assert str(b.get_a1()) == "A[42]" - assert str(b.a1) == "A[42]" - assert str(b.get_a2()) == "A[43]" - assert str(b.a2) == "A[43]" - - astats, bstats = ConstructorStats.get(ms.A), ConstructorStats.get(ms.B) - assert astats.alive() == 2 - assert bstats.alive() == 1 - del b - assert astats.alive() == 0 - assert bstats.alive() == 0 - assert astats.values() == ['1', '2', '42', '43'] - assert bstats.values() == [] - assert astats.default_constructions == 0 - assert bstats.default_constructions == 1 - assert astats.copy_constructions == 0 - assert bstats.copy_constructions == 0 - # assert astats.move_constructions >= 0 # Don't invoke any - # assert bstats.move_constructions >= 0 # Don't invoke any - assert astats.copy_assignments == 2 - assert bstats.copy_assignments == 0 - assert astats.move_assignments == 0 - assert bstats.move_assignments == 0 - - -def test_importing(): - from pybind11_tests.modules import OD - from collections import OrderedDict - - assert OD is OrderedDict - assert str(OD([(1, 'a'), (2, 'b')])) == "OrderedDict([(1, 'a'), (2, 'b')])" - - -def test_pydoc(): - """Pydoc needs to be able to provide help() for everything inside a pybind11 module""" - import pybind11_tests - import pydoc - - assert pybind11_tests.__name__ == "pybind11_tests" - assert pybind11_tests.__doc__ == "pybind11 test module" - assert pydoc.text.docmodule(pybind11_tests) - - -def test_duplicate_registration(): - """Registering two things with the same name""" - - assert m.duplicate_registration() == [] diff --git a/pybind11/tests/test_multiple_inheritance.cpp b/pybind11/tests/test_multiple_inheritance.cpp deleted file mode 100644 index ba1674f..0000000 --- a/pybind11/tests/test_multiple_inheritance.cpp +++ /dev/null @@ -1,220 +0,0 @@ -/* - tests/test_multiple_inheritance.cpp -- multiple inheritance, - implicit MI casts - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" - -// Many bases for testing that multiple inheritance from many classes (i.e. requiring extra -// space for holder constructed flags) works. -template struct BaseN { - BaseN(int i) : i(i) { } - int i; -}; - -// test_mi_static_properties -struct Vanilla { - std::string vanilla() { return "Vanilla"; }; -}; -struct WithStatic1 { - static std::string static_func1() { return "WithStatic1"; }; - static int static_value1; -}; -struct WithStatic2 { - static std::string static_func2() { return "WithStatic2"; }; - static int static_value2; -}; -struct VanillaStaticMix1 : Vanilla, WithStatic1, WithStatic2 { - static std::string static_func() { return "VanillaStaticMix1"; } - static int static_value; -}; -struct VanillaStaticMix2 : WithStatic1, Vanilla, WithStatic2 { - static std::string static_func() { return "VanillaStaticMix2"; } - static int static_value; -}; -int WithStatic1::static_value1 = 1; -int WithStatic2::static_value2 = 2; -int VanillaStaticMix1::static_value = 12; -int VanillaStaticMix2::static_value = 12; - -TEST_SUBMODULE(multiple_inheritance, m) { - - // test_multiple_inheritance_mix1 - // test_multiple_inheritance_mix2 - struct Base1 { - Base1(int i) : i(i) { } - int foo() { return i; } - int i; - }; - py::class_ b1(m, "Base1"); - b1.def(py::init()) - .def("foo", &Base1::foo); - - struct Base2 { - Base2(int i) : i(i) { } - int bar() { return i; } - int i; - }; - py::class_ b2(m, "Base2"); - b2.def(py::init()) - .def("bar", &Base2::bar); - - - // test_multiple_inheritance_cpp - struct Base12 : Base1, Base2 { - Base12(int i, int j) : Base1(i), Base2(j) { } - }; - struct MIType : Base12 { - MIType(int i, int j) : Base12(i, j) { } - }; - py::class_(m, "Base12"); - py::class_(m, "MIType") - .def(py::init()); - - - // test_multiple_inheritance_python_many_bases - #define PYBIND11_BASEN(N) py::class_>(m, "BaseN" #N).def(py::init()).def("f" #N, [](BaseN &b) { return b.i + N; }) - PYBIND11_BASEN( 1); PYBIND11_BASEN( 2); PYBIND11_BASEN( 3); PYBIND11_BASEN( 4); - PYBIND11_BASEN( 5); PYBIND11_BASEN( 6); PYBIND11_BASEN( 7); PYBIND11_BASEN( 8); - PYBIND11_BASEN( 9); PYBIND11_BASEN(10); PYBIND11_BASEN(11); PYBIND11_BASEN(12); - PYBIND11_BASEN(13); PYBIND11_BASEN(14); PYBIND11_BASEN(15); PYBIND11_BASEN(16); - PYBIND11_BASEN(17); - - // Uncommenting this should result in a compile time failure (MI can only be specified via - // template parameters because pybind has to know the types involved; see discussion in #742 for - // details). -// struct Base12v2 : Base1, Base2 { -// Base12v2(int i, int j) : Base1(i), Base2(j) { } -// }; -// py::class_(m, "Base12v2", b1, b2) -// .def(py::init()); - - - // test_multiple_inheritance_virtbase - // Test the case where not all base classes are specified, and where pybind11 requires the - // py::multiple_inheritance flag to perform proper casting between types. - struct Base1a { - Base1a(int i) : i(i) { } - int foo() { return i; } - int i; - }; - py::class_>(m, "Base1a") - .def(py::init()) - .def("foo", &Base1a::foo); - - struct Base2a { - Base2a(int i) : i(i) { } - int bar() { return i; } - int i; - }; - py::class_>(m, "Base2a") - .def(py::init()) - .def("bar", &Base2a::bar); - - struct Base12a : Base1a, Base2a { - Base12a(int i, int j) : Base1a(i), Base2a(j) { } - }; - py::class_>(m, "Base12a", py::multiple_inheritance()) - .def(py::init()); - - m.def("bar_base2a", [](Base2a *b) { return b->bar(); }); - m.def("bar_base2a_sharedptr", [](std::shared_ptr b) { return b->bar(); }); - - // test_mi_unaligned_base - // test_mi_base_return - // Issue #801: invalid casting to derived type with MI bases - struct I801B1 { int a = 1; I801B1() = default; I801B1(const I801B1 &) = default; virtual ~I801B1() = default; }; - struct I801B2 { int b = 2; I801B2() = default; I801B2(const I801B2 &) = default; virtual ~I801B2() = default; }; - struct I801C : I801B1, I801B2 {}; - struct I801D : I801C {}; // Indirect MI - // Unregistered classes: - struct I801B3 { int c = 3; virtual ~I801B3() = default; }; - struct I801E : I801B3, I801D {}; - - py::class_>(m, "I801B1").def(py::init<>()).def_readonly("a", &I801B1::a); - py::class_>(m, "I801B2").def(py::init<>()).def_readonly("b", &I801B2::b); - py::class_>(m, "I801C").def(py::init<>()); - py::class_>(m, "I801D").def(py::init<>()); - - // Two separate issues here: first, we want to recognize a pointer to a base type as being a - // known instance even when the pointer value is unequal (i.e. due to a non-first - // multiple-inheritance base class): - m.def("i801b1_c", [](I801C *c) { return static_cast(c); }); - m.def("i801b2_c", [](I801C *c) { return static_cast(c); }); - m.def("i801b1_d", [](I801D *d) { return static_cast(d); }); - m.def("i801b2_d", [](I801D *d) { return static_cast(d); }); - - // Second, when returned a base class pointer to a derived instance, we cannot assume that the - // pointer is `reinterpret_cast`able to the derived pointer because, like above, the base class - // pointer could be offset. - m.def("i801c_b1", []() -> I801B1 * { return new I801C(); }); - m.def("i801c_b2", []() -> I801B2 * { return new I801C(); }); - m.def("i801d_b1", []() -> I801B1 * { return new I801D(); }); - m.def("i801d_b2", []() -> I801B2 * { return new I801D(); }); - - // Return a base class pointer to a pybind-registered type when the actual derived type - // isn't pybind-registered (and uses multiple-inheritance to offset the pybind base) - m.def("i801e_c", []() -> I801C * { return new I801E(); }); - m.def("i801e_b2", []() -> I801B2 * { return new I801E(); }); - - - // test_mi_static_properties - py::class_(m, "Vanilla") - .def(py::init<>()) - .def("vanilla", &Vanilla::vanilla); - - py::class_(m, "WithStatic1") - .def(py::init<>()) - .def_static("static_func1", &WithStatic1::static_func1) - .def_readwrite_static("static_value1", &WithStatic1::static_value1); - - py::class_(m, "WithStatic2") - .def(py::init<>()) - .def_static("static_func2", &WithStatic2::static_func2) - .def_readwrite_static("static_value2", &WithStatic2::static_value2); - - py::class_( - m, "VanillaStaticMix1") - .def(py::init<>()) - .def_static("static_func", &VanillaStaticMix1::static_func) - .def_readwrite_static("static_value", &VanillaStaticMix1::static_value); - - py::class_( - m, "VanillaStaticMix2") - .def(py::init<>()) - .def_static("static_func", &VanillaStaticMix2::static_func) - .def_readwrite_static("static_value", &VanillaStaticMix2::static_value); - - -#if !defined(PYPY_VERSION) - struct WithDict { }; - struct VanillaDictMix1 : Vanilla, WithDict { }; - struct VanillaDictMix2 : WithDict, Vanilla { }; - py::class_(m, "WithDict", py::dynamic_attr()).def(py::init<>()); - py::class_(m, "VanillaDictMix1").def(py::init<>()); - py::class_(m, "VanillaDictMix2").def(py::init<>()); -#endif - - // test_diamond_inheritance - // Issue #959: segfault when constructing diamond inheritance instance - // All of these have int members so that there will be various unequal pointers involved. - struct B { int b; B() = default; B(const B&) = default; virtual ~B() = default; }; - struct C0 : public virtual B { int c0; }; - struct C1 : public virtual B { int c1; }; - struct D : public C0, public C1 { int d; }; - py::class_(m, "B") - .def("b", [](B *self) { return self; }); - py::class_(m, "C0") - .def("c0", [](C0 *self) { return self; }); - py::class_(m, "C1") - .def("c1", [](C1 *self) { return self; }); - py::class_(m, "D") - .def(py::init<>()); -} diff --git a/pybind11/tests/test_multiple_inheritance.py b/pybind11/tests/test_multiple_inheritance.py deleted file mode 100644 index 475dd3b..0000000 --- a/pybind11/tests/test_multiple_inheritance.py +++ /dev/null @@ -1,349 +0,0 @@ -import pytest -from pybind11_tests import ConstructorStats -from pybind11_tests import multiple_inheritance as m - - -def test_multiple_inheritance_cpp(): - mt = m.MIType(3, 4) - - assert mt.foo() == 3 - assert mt.bar() == 4 - - -def test_multiple_inheritance_mix1(): - class Base1: - def __init__(self, i): - self.i = i - - def foo(self): - return self.i - - class MITypePy(Base1, m.Base2): - def __init__(self, i, j): - Base1.__init__(self, i) - m.Base2.__init__(self, j) - - mt = MITypePy(3, 4) - - assert mt.foo() == 3 - assert mt.bar() == 4 - - -def test_multiple_inheritance_mix2(): - - class Base2: - def __init__(self, i): - self.i = i - - def bar(self): - return self.i - - class MITypePy(m.Base1, Base2): - def __init__(self, i, j): - m.Base1.__init__(self, i) - Base2.__init__(self, j) - - mt = MITypePy(3, 4) - - assert mt.foo() == 3 - assert mt.bar() == 4 - - -def test_multiple_inheritance_python(): - - class MI1(m.Base1, m.Base2): - def __init__(self, i, j): - m.Base1.__init__(self, i) - m.Base2.__init__(self, j) - - class B1(object): - def v(self): - return 1 - - class MI2(B1, m.Base1, m.Base2): - def __init__(self, i, j): - B1.__init__(self) - m.Base1.__init__(self, i) - m.Base2.__init__(self, j) - - class MI3(MI2): - def __init__(self, i, j): - MI2.__init__(self, i, j) - - class MI4(MI3, m.Base2): - def __init__(self, i, j): - MI3.__init__(self, i, j) - # This should be ignored (Base2 is already initialized via MI2): - m.Base2.__init__(self, i + 100) - - class MI5(m.Base2, B1, m.Base1): - def __init__(self, i, j): - B1.__init__(self) - m.Base1.__init__(self, i) - m.Base2.__init__(self, j) - - class MI6(m.Base2, B1): - def __init__(self, i): - m.Base2.__init__(self, i) - B1.__init__(self) - - class B2(B1): - def v(self): - return 2 - - class B3(object): - def v(self): - return 3 - - class B4(B3, B2): - def v(self): - return 4 - - class MI7(B4, MI6): - def __init__(self, i): - B4.__init__(self) - MI6.__init__(self, i) - - class MI8(MI6, B3): - def __init__(self, i): - MI6.__init__(self, i) - B3.__init__(self) - - class MI8b(B3, MI6): - def __init__(self, i): - B3.__init__(self) - MI6.__init__(self, i) - - mi1 = MI1(1, 2) - assert mi1.foo() == 1 - assert mi1.bar() == 2 - - mi2 = MI2(3, 4) - assert mi2.v() == 1 - assert mi2.foo() == 3 - assert mi2.bar() == 4 - - mi3 = MI3(5, 6) - assert mi3.v() == 1 - assert mi3.foo() == 5 - assert mi3.bar() == 6 - - mi4 = MI4(7, 8) - assert mi4.v() == 1 - assert mi4.foo() == 7 - assert mi4.bar() == 8 - - mi5 = MI5(10, 11) - assert mi5.v() == 1 - assert mi5.foo() == 10 - assert mi5.bar() == 11 - - mi6 = MI6(12) - assert mi6.v() == 1 - assert mi6.bar() == 12 - - mi7 = MI7(13) - assert mi7.v() == 4 - assert mi7.bar() == 13 - - mi8 = MI8(14) - assert mi8.v() == 1 - assert mi8.bar() == 14 - - mi8b = MI8b(15) - assert mi8b.v() == 3 - assert mi8b.bar() == 15 - - -def test_multiple_inheritance_python_many_bases(): - - class MIMany14(m.BaseN1, m.BaseN2, m.BaseN3, m.BaseN4): - def __init__(self): - m.BaseN1.__init__(self, 1) - m.BaseN2.__init__(self, 2) - m.BaseN3.__init__(self, 3) - m.BaseN4.__init__(self, 4) - - class MIMany58(m.BaseN5, m.BaseN6, m.BaseN7, m.BaseN8): - def __init__(self): - m.BaseN5.__init__(self, 5) - m.BaseN6.__init__(self, 6) - m.BaseN7.__init__(self, 7) - m.BaseN8.__init__(self, 8) - - class MIMany916(m.BaseN9, m.BaseN10, m.BaseN11, m.BaseN12, m.BaseN13, m.BaseN14, m.BaseN15, - m.BaseN16): - def __init__(self): - m.BaseN9.__init__(self, 9) - m.BaseN10.__init__(self, 10) - m.BaseN11.__init__(self, 11) - m.BaseN12.__init__(self, 12) - m.BaseN13.__init__(self, 13) - m.BaseN14.__init__(self, 14) - m.BaseN15.__init__(self, 15) - m.BaseN16.__init__(self, 16) - - class MIMany19(MIMany14, MIMany58, m.BaseN9): - def __init__(self): - MIMany14.__init__(self) - MIMany58.__init__(self) - m.BaseN9.__init__(self, 9) - - class MIMany117(MIMany14, MIMany58, MIMany916, m.BaseN17): - def __init__(self): - MIMany14.__init__(self) - MIMany58.__init__(self) - MIMany916.__init__(self) - m.BaseN17.__init__(self, 17) - - # Inherits from 4 registered C++ classes: can fit in one pointer on any modern arch: - a = MIMany14() - for i in range(1, 4): - assert getattr(a, "f" + str(i))() == 2 * i - - # Inherits from 8: requires 1/2 pointers worth of holder flags on 32/64-bit arch: - b = MIMany916() - for i in range(9, 16): - assert getattr(b, "f" + str(i))() == 2 * i - - # Inherits from 9: requires >= 2 pointers worth of holder flags - c = MIMany19() - for i in range(1, 9): - assert getattr(c, "f" + str(i))() == 2 * i - - # Inherits from 17: requires >= 3 pointers worth of holder flags - d = MIMany117() - for i in range(1, 17): - assert getattr(d, "f" + str(i))() == 2 * i - - -def test_multiple_inheritance_virtbase(): - - class MITypePy(m.Base12a): - def __init__(self, i, j): - m.Base12a.__init__(self, i, j) - - mt = MITypePy(3, 4) - assert mt.bar() == 4 - assert m.bar_base2a(mt) == 4 - assert m.bar_base2a_sharedptr(mt) == 4 - - -def test_mi_static_properties(): - """Mixing bases with and without static properties should be possible - and the result should be independent of base definition order""" - - for d in (m.VanillaStaticMix1(), m.VanillaStaticMix2()): - assert d.vanilla() == "Vanilla" - assert d.static_func1() == "WithStatic1" - assert d.static_func2() == "WithStatic2" - assert d.static_func() == d.__class__.__name__ - - m.WithStatic1.static_value1 = 1 - m.WithStatic2.static_value2 = 2 - assert d.static_value1 == 1 - assert d.static_value2 == 2 - assert d.static_value == 12 - - d.static_value1 = 0 - assert d.static_value1 == 0 - d.static_value2 = 0 - assert d.static_value2 == 0 - d.static_value = 0 - assert d.static_value == 0 - - -@pytest.unsupported_on_pypy -def test_mi_dynamic_attributes(): - """Mixing bases with and without dynamic attribute support""" - - for d in (m.VanillaDictMix1(), m.VanillaDictMix2()): - d.dynamic = 1 - assert d.dynamic == 1 - - -def test_mi_unaligned_base(): - """Returning an offset (non-first MI) base class pointer should recognize the instance""" - - n_inst = ConstructorStats.detail_reg_inst() - - c = m.I801C() - d = m.I801D() - # + 4 below because we have the two instances, and each instance has offset base I801B2 - assert ConstructorStats.detail_reg_inst() == n_inst + 4 - b1c = m.i801b1_c(c) - assert b1c is c - b2c = m.i801b2_c(c) - assert b2c is c - b1d = m.i801b1_d(d) - assert b1d is d - b2d = m.i801b2_d(d) - assert b2d is d - - assert ConstructorStats.detail_reg_inst() == n_inst + 4 # no extra instances - del c, b1c, b2c - assert ConstructorStats.detail_reg_inst() == n_inst + 2 - del d, b1d, b2d - assert ConstructorStats.detail_reg_inst() == n_inst - - -def test_mi_base_return(): - """Tests returning an offset (non-first MI) base class pointer to a derived instance""" - - n_inst = ConstructorStats.detail_reg_inst() - - c1 = m.i801c_b1() - assert type(c1) is m.I801C - assert c1.a == 1 - assert c1.b == 2 - - d1 = m.i801d_b1() - assert type(d1) is m.I801D - assert d1.a == 1 - assert d1.b == 2 - - assert ConstructorStats.detail_reg_inst() == n_inst + 4 - - c2 = m.i801c_b2() - assert type(c2) is m.I801C - assert c2.a == 1 - assert c2.b == 2 - - d2 = m.i801d_b2() - assert type(d2) is m.I801D - assert d2.a == 1 - assert d2.b == 2 - - assert ConstructorStats.detail_reg_inst() == n_inst + 8 - - del c2 - assert ConstructorStats.detail_reg_inst() == n_inst + 6 - del c1, d1, d2 - assert ConstructorStats.detail_reg_inst() == n_inst - - # Returning an unregistered derived type with a registered base; we won't - # pick up the derived type, obviously, but should still work (as an object - # of whatever type was returned). - e1 = m.i801e_c() - assert type(e1) is m.I801C - assert e1.a == 1 - assert e1.b == 2 - - e2 = m.i801e_b2() - assert type(e2) is m.I801B2 - assert e2.b == 2 - - -def test_diamond_inheritance(): - """Tests that diamond inheritance works as expected (issue #959)""" - - # Issue #959: this shouldn't segfault: - d = m.D() - - # Make sure all the various distinct pointers are all recognized as registered instances: - assert d is d.c0() - assert d is d.c1() - assert d is d.b() - assert d is d.c0().b() - assert d is d.c1().b() - assert d is d.c0().c1().b().c0().b() diff --git a/pybind11/tests/test_numpy_array.cpp b/pybind11/tests/test_numpy_array.cpp deleted file mode 100644 index 156a3bf..0000000 --- a/pybind11/tests/test_numpy_array.cpp +++ /dev/null @@ -1,390 +0,0 @@ -/* - tests/test_numpy_array.cpp -- test core array functionality - - Copyright (c) 2016 Ivan Smirnov - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" - -#include -#include - -#include - -// Size / dtype checks. -struct DtypeCheck { - py::dtype numpy{}; - py::dtype pybind11{}; -}; - -template -DtypeCheck get_dtype_check(const char* name) { - py::module np = py::module::import("numpy"); - DtypeCheck check{}; - check.numpy = np.attr("dtype")(np.attr(name)); - check.pybind11 = py::dtype::of(); - return check; -} - -std::vector get_concrete_dtype_checks() { - return { - // Normalization - get_dtype_check("int8"), - get_dtype_check("uint8"), - get_dtype_check("int16"), - get_dtype_check("uint16"), - get_dtype_check("int32"), - get_dtype_check("uint32"), - get_dtype_check("int64"), - get_dtype_check("uint64") - }; -} - -struct DtypeSizeCheck { - std::string name{}; - int size_cpp{}; - int size_numpy{}; - // For debugging. - py::dtype dtype{}; -}; - -template -DtypeSizeCheck get_dtype_size_check() { - DtypeSizeCheck check{}; - check.name = py::type_id(); - check.size_cpp = sizeof(T); - check.dtype = py::dtype::of(); - check.size_numpy = check.dtype.attr("itemsize").template cast(); - return check; -} - -std::vector get_platform_dtype_size_checks() { - return { - get_dtype_size_check(), - get_dtype_size_check(), - get_dtype_size_check(), - get_dtype_size_check(), - get_dtype_size_check(), - get_dtype_size_check(), - get_dtype_size_check(), - get_dtype_size_check(), - }; -} - -// Arrays. -using arr = py::array; -using arr_t = py::array_t; -static_assert(std::is_same::value, ""); - -template arr data(const arr& a, Ix... index) { - return arr(a.nbytes() - a.offset_at(index...), (const uint8_t *) a.data(index...)); -} - -template arr data_t(const arr_t& a, Ix... index) { - return arr(a.size() - a.index_at(index...), a.data(index...)); -} - -template arr& mutate_data(arr& a, Ix... index) { - auto ptr = (uint8_t *) a.mutable_data(index...); - for (ssize_t i = 0; i < a.nbytes() - a.offset_at(index...); i++) - ptr[i] = (uint8_t) (ptr[i] * 2); - return a; -} - -template arr_t& mutate_data_t(arr_t& a, Ix... index) { - auto ptr = a.mutable_data(index...); - for (ssize_t i = 0; i < a.size() - a.index_at(index...); i++) - ptr[i]++; - return a; -} - -template ssize_t index_at(const arr& a, Ix... idx) { return a.index_at(idx...); } -template ssize_t index_at_t(const arr_t& a, Ix... idx) { return a.index_at(idx...); } -template ssize_t offset_at(const arr& a, Ix... idx) { return a.offset_at(idx...); } -template ssize_t offset_at_t(const arr_t& a, Ix... idx) { return a.offset_at(idx...); } -template ssize_t at_t(const arr_t& a, Ix... idx) { return a.at(idx...); } -template arr_t& mutate_at_t(arr_t& a, Ix... idx) { a.mutable_at(idx...)++; return a; } - -#define def_index_fn(name, type) \ - sm.def(#name, [](type a) { return name(a); }); \ - sm.def(#name, [](type a, int i) { return name(a, i); }); \ - sm.def(#name, [](type a, int i, int j) { return name(a, i, j); }); \ - sm.def(#name, [](type a, int i, int j, int k) { return name(a, i, j, k); }); - -template py::handle auxiliaries(T &&r, T2 &&r2) { - if (r.ndim() != 2) throw std::domain_error("error: ndim != 2"); - py::list l; - l.append(*r.data(0, 0)); - l.append(*r2.mutable_data(0, 0)); - l.append(r.data(0, 1) == r2.mutable_data(0, 1)); - l.append(r.ndim()); - l.append(r.itemsize()); - l.append(r.shape(0)); - l.append(r.shape(1)); - l.append(r.size()); - l.append(r.nbytes()); - return l.release(); -} - -// note: declaration at local scope would create a dangling reference! -static int data_i = 42; - -TEST_SUBMODULE(numpy_array, sm) { - try { py::module::import("numpy"); } - catch (...) { return; } - - // test_dtypes - py::class_(sm, "DtypeCheck") - .def_readonly("numpy", &DtypeCheck::numpy) - .def_readonly("pybind11", &DtypeCheck::pybind11) - .def("__repr__", [](const DtypeCheck& self) { - return py::str("").format( - self.numpy, self.pybind11); - }); - sm.def("get_concrete_dtype_checks", &get_concrete_dtype_checks); - - py::class_(sm, "DtypeSizeCheck") - .def_readonly("name", &DtypeSizeCheck::name) - .def_readonly("size_cpp", &DtypeSizeCheck::size_cpp) - .def_readonly("size_numpy", &DtypeSizeCheck::size_numpy) - .def("__repr__", [](const DtypeSizeCheck& self) { - return py::str("").format( - self.name, self.size_cpp, self.size_numpy, self.dtype); - }); - sm.def("get_platform_dtype_size_checks", &get_platform_dtype_size_checks); - - // test_array_attributes - sm.def("ndim", [](const arr& a) { return a.ndim(); }); - sm.def("shape", [](const arr& a) { return arr(a.ndim(), a.shape()); }); - sm.def("shape", [](const arr& a, ssize_t dim) { return a.shape(dim); }); - sm.def("strides", [](const arr& a) { return arr(a.ndim(), a.strides()); }); - sm.def("strides", [](const arr& a, ssize_t dim) { return a.strides(dim); }); - sm.def("writeable", [](const arr& a) { return a.writeable(); }); - sm.def("size", [](const arr& a) { return a.size(); }); - sm.def("itemsize", [](const arr& a) { return a.itemsize(); }); - sm.def("nbytes", [](const arr& a) { return a.nbytes(); }); - sm.def("owndata", [](const arr& a) { return a.owndata(); }); - - // test_index_offset - def_index_fn(index_at, const arr&); - def_index_fn(index_at_t, const arr_t&); - def_index_fn(offset_at, const arr&); - def_index_fn(offset_at_t, const arr_t&); - // test_data - def_index_fn(data, const arr&); - def_index_fn(data_t, const arr_t&); - // test_mutate_data, test_mutate_readonly - def_index_fn(mutate_data, arr&); - def_index_fn(mutate_data_t, arr_t&); - def_index_fn(at_t, const arr_t&); - def_index_fn(mutate_at_t, arr_t&); - - // test_make_c_f_array - sm.def("make_f_array", [] { return py::array_t({ 2, 2 }, { 4, 8 }); }); - sm.def("make_c_array", [] { return py::array_t({ 2, 2 }, { 8, 4 }); }); - - // test_empty_shaped_array - sm.def("make_empty_shaped_array", [] { return py::array(py::dtype("f"), {}, {}); }); - // test numpy scalars (empty shape, ndim==0) - sm.def("scalar_int", []() { return py::array(py::dtype("i"), {}, {}, &data_i); }); - - // test_wrap - sm.def("wrap", [](py::array a) { - return py::array( - a.dtype(), - {a.shape(), a.shape() + a.ndim()}, - {a.strides(), a.strides() + a.ndim()}, - a.data(), - a - ); - }); - - // test_numpy_view - struct ArrayClass { - int data[2] = { 1, 2 }; - ArrayClass() { py::print("ArrayClass()"); } - ~ArrayClass() { py::print("~ArrayClass()"); } - }; - py::class_(sm, "ArrayClass") - .def(py::init<>()) - .def("numpy_view", [](py::object &obj) { - py::print("ArrayClass::numpy_view()"); - ArrayClass &a = obj.cast(); - return py::array_t({2}, {4}, a.data, obj); - } - ); - - // test_cast_numpy_int64_to_uint64 - sm.def("function_taking_uint64", [](uint64_t) { }); - - // test_isinstance - sm.def("isinstance_untyped", [](py::object yes, py::object no) { - return py::isinstance(yes) && !py::isinstance(no); - }); - sm.def("isinstance_typed", [](py::object o) { - return py::isinstance>(o) && !py::isinstance>(o); - }); - - // test_constructors - sm.def("default_constructors", []() { - return py::dict( - "array"_a=py::array(), - "array_t"_a=py::array_t(), - "array_t"_a=py::array_t() - ); - }); - sm.def("converting_constructors", [](py::object o) { - return py::dict( - "array"_a=py::array(o), - "array_t"_a=py::array_t(o), - "array_t"_a=py::array_t(o) - ); - }); - - // test_overload_resolution - sm.def("overloaded", [](py::array_t) { return "double"; }); - sm.def("overloaded", [](py::array_t) { return "float"; }); - sm.def("overloaded", [](py::array_t) { return "int"; }); - sm.def("overloaded", [](py::array_t) { return "unsigned short"; }); - sm.def("overloaded", [](py::array_t) { return "long long"; }); - sm.def("overloaded", [](py::array_t>) { return "double complex"; }); - sm.def("overloaded", [](py::array_t>) { return "float complex"; }); - - sm.def("overloaded2", [](py::array_t>) { return "double complex"; }); - sm.def("overloaded2", [](py::array_t) { return "double"; }); - sm.def("overloaded2", [](py::array_t>) { return "float complex"; }); - sm.def("overloaded2", [](py::array_t) { return "float"; }); - - // Only accept the exact types: - sm.def("overloaded3", [](py::array_t) { return "int"; }, py::arg().noconvert()); - sm.def("overloaded3", [](py::array_t) { return "double"; }, py::arg().noconvert()); - - // Make sure we don't do unsafe coercion (e.g. float to int) when not using forcecast, but - // rather that float gets converted via the safe (conversion to double) overload: - sm.def("overloaded4", [](py::array_t) { return "long long"; }); - sm.def("overloaded4", [](py::array_t) { return "double"; }); - - // But we do allow conversion to int if forcecast is enabled (but only if no overload matches - // without conversion) - sm.def("overloaded5", [](py::array_t) { return "unsigned int"; }); - sm.def("overloaded5", [](py::array_t) { return "double"; }); - - // test_greedy_string_overload - // Issue 685: ndarray shouldn't go to std::string overload - sm.def("issue685", [](std::string) { return "string"; }); - sm.def("issue685", [](py::array) { return "array"; }); - sm.def("issue685", [](py::object) { return "other"; }); - - // test_array_unchecked_fixed_dims - sm.def("proxy_add2", [](py::array_t a, double v) { - auto r = a.mutable_unchecked<2>(); - for (ssize_t i = 0; i < r.shape(0); i++) - for (ssize_t j = 0; j < r.shape(1); j++) - r(i, j) += v; - }, py::arg().noconvert(), py::arg()); - - sm.def("proxy_init3", [](double start) { - py::array_t a({ 3, 3, 3 }); - auto r = a.mutable_unchecked<3>(); - for (ssize_t i = 0; i < r.shape(0); i++) - for (ssize_t j = 0; j < r.shape(1); j++) - for (ssize_t k = 0; k < r.shape(2); k++) - r(i, j, k) = start++; - return a; - }); - sm.def("proxy_init3F", [](double start) { - py::array_t a({ 3, 3, 3 }); - auto r = a.mutable_unchecked<3>(); - for (ssize_t k = 0; k < r.shape(2); k++) - for (ssize_t j = 0; j < r.shape(1); j++) - for (ssize_t i = 0; i < r.shape(0); i++) - r(i, j, k) = start++; - return a; - }); - sm.def("proxy_squared_L2_norm", [](py::array_t a) { - auto r = a.unchecked<1>(); - double sumsq = 0; - for (ssize_t i = 0; i < r.shape(0); i++) - sumsq += r[i] * r(i); // Either notation works for a 1D array - return sumsq; - }); - - sm.def("proxy_auxiliaries2", [](py::array_t a) { - auto r = a.unchecked<2>(); - auto r2 = a.mutable_unchecked<2>(); - return auxiliaries(r, r2); - }); - - // test_array_unchecked_dyn_dims - // Same as the above, but without a compile-time dimensions specification: - sm.def("proxy_add2_dyn", [](py::array_t a, double v) { - auto r = a.mutable_unchecked(); - if (r.ndim() != 2) throw std::domain_error("error: ndim != 2"); - for (ssize_t i = 0; i < r.shape(0); i++) - for (ssize_t j = 0; j < r.shape(1); j++) - r(i, j) += v; - }, py::arg().noconvert(), py::arg()); - sm.def("proxy_init3_dyn", [](double start) { - py::array_t a({ 3, 3, 3 }); - auto r = a.mutable_unchecked(); - if (r.ndim() != 3) throw std::domain_error("error: ndim != 3"); - for (ssize_t i = 0; i < r.shape(0); i++) - for (ssize_t j = 0; j < r.shape(1); j++) - for (ssize_t k = 0; k < r.shape(2); k++) - r(i, j, k) = start++; - return a; - }); - sm.def("proxy_auxiliaries2_dyn", [](py::array_t a) { - return auxiliaries(a.unchecked(), a.mutable_unchecked()); - }); - - sm.def("array_auxiliaries2", [](py::array_t a) { - return auxiliaries(a, a); - }); - - // test_array_failures - // Issue #785: Uninformative "Unknown internal error" exception when constructing array from empty object: - sm.def("array_fail_test", []() { return py::array(py::object()); }); - sm.def("array_t_fail_test", []() { return py::array_t(py::object()); }); - // Make sure the error from numpy is being passed through: - sm.def("array_fail_test_negative_size", []() { int c = 0; return py::array(-1, &c); }); - - // test_initializer_list - // Issue (unnumbered; reported in #788): regression: initializer lists can be ambiguous - sm.def("array_initializer_list1", []() { return py::array_t(1); }); // { 1 } also works, but clang warns about it - sm.def("array_initializer_list2", []() { return py::array_t({ 1, 2 }); }); - sm.def("array_initializer_list3", []() { return py::array_t({ 1, 2, 3 }); }); - sm.def("array_initializer_list4", []() { return py::array_t({ 1, 2, 3, 4 }); }); - - // test_array_resize - // reshape array to 2D without changing size - sm.def("array_reshape2", [](py::array_t a) { - const ssize_t dim_sz = (ssize_t)std::sqrt(a.size()); - if (dim_sz * dim_sz != a.size()) - throw std::domain_error("array_reshape2: input array total size is not a squared integer"); - a.resize({dim_sz, dim_sz}); - }); - - // resize to 3D array with each dimension = N - sm.def("array_resize3", [](py::array_t a, size_t N, bool refcheck) { - a.resize({N, N, N}, refcheck); - }); - - // test_array_create_and_resize - // return 2D array with Nrows = Ncols = N - sm.def("create_and_resize", [](size_t N) { - py::array_t a; - a.resize({N, N}); - std::fill(a.mutable_data(), a.mutable_data() + a.size(), 42.); - return a; - }); - -#if PY_MAJOR_VERSION >= 3 - sm.def("index_using_ellipsis", [](py::array a) { - return a[py::make_tuple(0, py::ellipsis(), 0)]; - }); -#endif -} diff --git a/pybind11/tests/test_numpy_array.py b/pybind11/tests/test_numpy_array.py deleted file mode 100644 index d0a6324..0000000 --- a/pybind11/tests/test_numpy_array.py +++ /dev/null @@ -1,447 +0,0 @@ -import pytest -from pybind11_tests import numpy_array as m - -pytestmark = pytest.requires_numpy - -with pytest.suppress(ImportError): - import numpy as np - - -def test_dtypes(): - # See issue #1328. - # - Platform-dependent sizes. - for size_check in m.get_platform_dtype_size_checks(): - print(size_check) - assert size_check.size_cpp == size_check.size_numpy, size_check - # - Concrete sizes. - for check in m.get_concrete_dtype_checks(): - print(check) - assert check.numpy == check.pybind11, check - if check.numpy.num != check.pybind11.num: - print("NOTE: typenum mismatch for {}: {} != {}".format( - check, check.numpy.num, check.pybind11.num)) - - -@pytest.fixture(scope='function') -def arr(): - return np.array([[1, 2, 3], [4, 5, 6]], '=u2') - - -def test_array_attributes(): - a = np.array(0, 'f8') - assert m.ndim(a) == 0 - assert all(m.shape(a) == []) - assert all(m.strides(a) == []) - with pytest.raises(IndexError) as excinfo: - m.shape(a, 0) - assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)' - with pytest.raises(IndexError) as excinfo: - m.strides(a, 0) - assert str(excinfo.value) == 'invalid axis: 0 (ndim = 0)' - assert m.writeable(a) - assert m.size(a) == 1 - assert m.itemsize(a) == 8 - assert m.nbytes(a) == 8 - assert m.owndata(a) - - a = np.array([[1, 2, 3], [4, 5, 6]], 'u2').view() - a.flags.writeable = False - assert m.ndim(a) == 2 - assert all(m.shape(a) == [2, 3]) - assert m.shape(a, 0) == 2 - assert m.shape(a, 1) == 3 - assert all(m.strides(a) == [6, 2]) - assert m.strides(a, 0) == 6 - assert m.strides(a, 1) == 2 - with pytest.raises(IndexError) as excinfo: - m.shape(a, 2) - assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)' - with pytest.raises(IndexError) as excinfo: - m.strides(a, 2) - assert str(excinfo.value) == 'invalid axis: 2 (ndim = 2)' - assert not m.writeable(a) - assert m.size(a) == 6 - assert m.itemsize(a) == 2 - assert m.nbytes(a) == 12 - assert not m.owndata(a) - - -@pytest.mark.parametrize('args, ret', [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]) -def test_index_offset(arr, args, ret): - assert m.index_at(arr, *args) == ret - assert m.index_at_t(arr, *args) == ret - assert m.offset_at(arr, *args) == ret * arr.dtype.itemsize - assert m.offset_at_t(arr, *args) == ret * arr.dtype.itemsize - - -def test_dim_check_fail(arr): - for func in (m.index_at, m.index_at_t, m.offset_at, m.offset_at_t, m.data, m.data_t, - m.mutate_data, m.mutate_data_t): - with pytest.raises(IndexError) as excinfo: - func(arr, 1, 2, 3) - assert str(excinfo.value) == 'too many indices for an array: 3 (ndim = 2)' - - -@pytest.mark.parametrize('args, ret', - [([], [1, 2, 3, 4, 5, 6]), - ([1], [4, 5, 6]), - ([0, 1], [2, 3, 4, 5, 6]), - ([1, 2], [6])]) -def test_data(arr, args, ret): - from sys import byteorder - assert all(m.data_t(arr, *args) == ret) - assert all(m.data(arr, *args)[(0 if byteorder == 'little' else 1)::2] == ret) - assert all(m.data(arr, *args)[(1 if byteorder == 'little' else 0)::2] == 0) - - -@pytest.mark.parametrize('dim', [0, 1, 3]) -def test_at_fail(arr, dim): - for func in m.at_t, m.mutate_at_t: - with pytest.raises(IndexError) as excinfo: - func(arr, *([0] * dim)) - assert str(excinfo.value) == 'index dimension mismatch: {} (ndim = 2)'.format(dim) - - -def test_at(arr): - assert m.at_t(arr, 0, 2) == 3 - assert m.at_t(arr, 1, 0) == 4 - - assert all(m.mutate_at_t(arr, 0, 2).ravel() == [1, 2, 4, 4, 5, 6]) - assert all(m.mutate_at_t(arr, 1, 0).ravel() == [1, 2, 4, 5, 5, 6]) - - -def test_mutate_readonly(arr): - arr.flags.writeable = False - for func, args in (m.mutate_data, ()), (m.mutate_data_t, ()), (m.mutate_at_t, (0, 0)): - with pytest.raises(ValueError) as excinfo: - func(arr, *args) - assert str(excinfo.value) == 'array is not writeable' - - -def test_mutate_data(arr): - assert all(m.mutate_data(arr).ravel() == [2, 4, 6, 8, 10, 12]) - assert all(m.mutate_data(arr).ravel() == [4, 8, 12, 16, 20, 24]) - assert all(m.mutate_data(arr, 1).ravel() == [4, 8, 12, 32, 40, 48]) - assert all(m.mutate_data(arr, 0, 1).ravel() == [4, 16, 24, 64, 80, 96]) - assert all(m.mutate_data(arr, 1, 2).ravel() == [4, 16, 24, 64, 80, 192]) - - assert all(m.mutate_data_t(arr).ravel() == [5, 17, 25, 65, 81, 193]) - assert all(m.mutate_data_t(arr).ravel() == [6, 18, 26, 66, 82, 194]) - assert all(m.mutate_data_t(arr, 1).ravel() == [6, 18, 26, 67, 83, 195]) - assert all(m.mutate_data_t(arr, 0, 1).ravel() == [6, 19, 27, 68, 84, 196]) - assert all(m.mutate_data_t(arr, 1, 2).ravel() == [6, 19, 27, 68, 84, 197]) - - -def test_bounds_check(arr): - for func in (m.index_at, m.index_at_t, m.data, m.data_t, - m.mutate_data, m.mutate_data_t, m.at_t, m.mutate_at_t): - with pytest.raises(IndexError) as excinfo: - func(arr, 2, 0) - assert str(excinfo.value) == 'index 2 is out of bounds for axis 0 with size 2' - with pytest.raises(IndexError) as excinfo: - func(arr, 0, 4) - assert str(excinfo.value) == 'index 4 is out of bounds for axis 1 with size 3' - - -def test_make_c_f_array(): - assert m.make_c_array().flags.c_contiguous - assert not m.make_c_array().flags.f_contiguous - assert m.make_f_array().flags.f_contiguous - assert not m.make_f_array().flags.c_contiguous - - -def test_make_empty_shaped_array(): - m.make_empty_shaped_array() - - # empty shape means numpy scalar, PEP 3118 - assert m.scalar_int().ndim == 0 - assert m.scalar_int().shape == () - assert m.scalar_int() == 42 - - -def test_wrap(): - def assert_references(a, b, base=None): - from distutils.version import LooseVersion - if base is None: - base = a - assert a is not b - assert a.__array_interface__['data'][0] == b.__array_interface__['data'][0] - assert a.shape == b.shape - assert a.strides == b.strides - assert a.flags.c_contiguous == b.flags.c_contiguous - assert a.flags.f_contiguous == b.flags.f_contiguous - assert a.flags.writeable == b.flags.writeable - assert a.flags.aligned == b.flags.aligned - if LooseVersion(np.__version__) >= LooseVersion("1.14.0"): - assert a.flags.writebackifcopy == b.flags.writebackifcopy - else: - assert a.flags.updateifcopy == b.flags.updateifcopy - assert np.all(a == b) - assert not b.flags.owndata - assert b.base is base - if a.flags.writeable and a.ndim == 2: - a[0, 0] = 1234 - assert b[0, 0] == 1234 - - a1 = np.array([1, 2], dtype=np.int16) - assert a1.flags.owndata and a1.base is None - a2 = m.wrap(a1) - assert_references(a1, a2) - - a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='F') - assert a1.flags.owndata and a1.base is None - a2 = m.wrap(a1) - assert_references(a1, a2) - - a1 = np.array([[1, 2], [3, 4]], dtype=np.float32, order='C') - a1.flags.writeable = False - a2 = m.wrap(a1) - assert_references(a1, a2) - - a1 = np.random.random((4, 4, 4)) - a2 = m.wrap(a1) - assert_references(a1, a2) - - a1t = a1.transpose() - a2 = m.wrap(a1t) - assert_references(a1t, a2, a1) - - a1d = a1.diagonal() - a2 = m.wrap(a1d) - assert_references(a1d, a2, a1) - - a1m = a1[::-1, ::-1, ::-1] - a2 = m.wrap(a1m) - assert_references(a1m, a2, a1) - - -def test_numpy_view(capture): - with capture: - ac = m.ArrayClass() - ac_view_1 = ac.numpy_view() - ac_view_2 = ac.numpy_view() - assert np.all(ac_view_1 == np.array([1, 2], dtype=np.int32)) - del ac - pytest.gc_collect() - assert capture == """ - ArrayClass() - ArrayClass::numpy_view() - ArrayClass::numpy_view() - """ - ac_view_1[0] = 4 - ac_view_1[1] = 3 - assert ac_view_2[0] == 4 - assert ac_view_2[1] == 3 - with capture: - del ac_view_1 - del ac_view_2 - pytest.gc_collect() - pytest.gc_collect() - assert capture == """ - ~ArrayClass() - """ - - -@pytest.unsupported_on_pypy -def test_cast_numpy_int64_to_uint64(): - m.function_taking_uint64(123) - m.function_taking_uint64(np.uint64(123)) - - -def test_isinstance(): - assert m.isinstance_untyped(np.array([1, 2, 3]), "not an array") - assert m.isinstance_typed(np.array([1.0, 2.0, 3.0])) - - -def test_constructors(): - defaults = m.default_constructors() - for a in defaults.values(): - assert a.size == 0 - assert defaults["array"].dtype == np.array([]).dtype - assert defaults["array_t"].dtype == np.int32 - assert defaults["array_t"].dtype == np.float64 - - results = m.converting_constructors([1, 2, 3]) - for a in results.values(): - np.testing.assert_array_equal(a, [1, 2, 3]) - assert results["array"].dtype == np.int_ - assert results["array_t"].dtype == np.int32 - assert results["array_t"].dtype == np.float64 - - -def test_overload_resolution(msg): - # Exact overload matches: - assert m.overloaded(np.array([1], dtype='float64')) == 'double' - assert m.overloaded(np.array([1], dtype='float32')) == 'float' - assert m.overloaded(np.array([1], dtype='ushort')) == 'unsigned short' - assert m.overloaded(np.array([1], dtype='intc')) == 'int' - assert m.overloaded(np.array([1], dtype='longlong')) == 'long long' - assert m.overloaded(np.array([1], dtype='complex')) == 'double complex' - assert m.overloaded(np.array([1], dtype='csingle')) == 'float complex' - - # No exact match, should call first convertible version: - assert m.overloaded(np.array([1], dtype='uint8')) == 'double' - - with pytest.raises(TypeError) as excinfo: - m.overloaded("not an array") - assert msg(excinfo.value) == """ - overloaded(): incompatible function arguments. The following argument types are supported: - 1. (arg0: numpy.ndarray[float64]) -> str - 2. (arg0: numpy.ndarray[float32]) -> str - 3. (arg0: numpy.ndarray[int32]) -> str - 4. (arg0: numpy.ndarray[uint16]) -> str - 5. (arg0: numpy.ndarray[int64]) -> str - 6. (arg0: numpy.ndarray[complex128]) -> str - 7. (arg0: numpy.ndarray[complex64]) -> str - - Invoked with: 'not an array' - """ - - assert m.overloaded2(np.array([1], dtype='float64')) == 'double' - assert m.overloaded2(np.array([1], dtype='float32')) == 'float' - assert m.overloaded2(np.array([1], dtype='complex64')) == 'float complex' - assert m.overloaded2(np.array([1], dtype='complex128')) == 'double complex' - assert m.overloaded2(np.array([1], dtype='float32')) == 'float' - - assert m.overloaded3(np.array([1], dtype='float64')) == 'double' - assert m.overloaded3(np.array([1], dtype='intc')) == 'int' - expected_exc = """ - overloaded3(): incompatible function arguments. The following argument types are supported: - 1. (arg0: numpy.ndarray[int32]) -> str - 2. (arg0: numpy.ndarray[float64]) -> str - - Invoked with: """ - - with pytest.raises(TypeError) as excinfo: - m.overloaded3(np.array([1], dtype='uintc')) - assert msg(excinfo.value) == expected_exc + repr(np.array([1], dtype='uint32')) - with pytest.raises(TypeError) as excinfo: - m.overloaded3(np.array([1], dtype='float32')) - assert msg(excinfo.value) == expected_exc + repr(np.array([1.], dtype='float32')) - with pytest.raises(TypeError) as excinfo: - m.overloaded3(np.array([1], dtype='complex')) - assert msg(excinfo.value) == expected_exc + repr(np.array([1. + 0.j])) - - # Exact matches: - assert m.overloaded4(np.array([1], dtype='double')) == 'double' - assert m.overloaded4(np.array([1], dtype='longlong')) == 'long long' - # Non-exact matches requiring conversion. Since float to integer isn't a - # save conversion, it should go to the double overload, but short can go to - # either (and so should end up on the first-registered, the long long). - assert m.overloaded4(np.array([1], dtype='float32')) == 'double' - assert m.overloaded4(np.array([1], dtype='short')) == 'long long' - - assert m.overloaded5(np.array([1], dtype='double')) == 'double' - assert m.overloaded5(np.array([1], dtype='uintc')) == 'unsigned int' - assert m.overloaded5(np.array([1], dtype='float32')) == 'unsigned int' - - -def test_greedy_string_overload(): - """Tests fix for #685 - ndarray shouldn't go to std::string overload""" - - assert m.issue685("abc") == "string" - assert m.issue685(np.array([97, 98, 99], dtype='b')) == "array" - assert m.issue685(123) == "other" - - -def test_array_unchecked_fixed_dims(msg): - z1 = np.array([[1, 2], [3, 4]], dtype='float64') - m.proxy_add2(z1, 10) - assert np.all(z1 == [[11, 12], [13, 14]]) - - with pytest.raises(ValueError) as excinfo: - m.proxy_add2(np.array([1., 2, 3]), 5.0) - assert msg(excinfo.value) == "array has incorrect number of dimensions: 1; expected 2" - - expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int') - assert np.all(m.proxy_init3(3.0) == expect_c) - expect_f = np.transpose(expect_c) - assert np.all(m.proxy_init3F(3.0) == expect_f) - - assert m.proxy_squared_L2_norm(np.array(range(6))) == 55 - assert m.proxy_squared_L2_norm(np.array(range(6), dtype="float64")) == 55 - - assert m.proxy_auxiliaries2(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32] - assert m.proxy_auxiliaries2(z1) == m.array_auxiliaries2(z1) - - -def test_array_unchecked_dyn_dims(msg): - z1 = np.array([[1, 2], [3, 4]], dtype='float64') - m.proxy_add2_dyn(z1, 10) - assert np.all(z1 == [[11, 12], [13, 14]]) - - expect_c = np.ndarray(shape=(3, 3, 3), buffer=np.array(range(3, 30)), dtype='int') - assert np.all(m.proxy_init3_dyn(3.0) == expect_c) - - assert m.proxy_auxiliaries2_dyn(z1) == [11, 11, True, 2, 8, 2, 2, 4, 32] - assert m.proxy_auxiliaries2_dyn(z1) == m.array_auxiliaries2(z1) - - -def test_array_failure(): - with pytest.raises(ValueError) as excinfo: - m.array_fail_test() - assert str(excinfo.value) == 'cannot create a pybind11::array from a nullptr' - - with pytest.raises(ValueError) as excinfo: - m.array_t_fail_test() - assert str(excinfo.value) == 'cannot create a pybind11::array_t from a nullptr' - - with pytest.raises(ValueError) as excinfo: - m.array_fail_test_negative_size() - assert str(excinfo.value) == 'negative dimensions are not allowed' - - -def test_initializer_list(): - assert m.array_initializer_list1().shape == (1,) - assert m.array_initializer_list2().shape == (1, 2) - assert m.array_initializer_list3().shape == (1, 2, 3) - assert m.array_initializer_list4().shape == (1, 2, 3, 4) - - -def test_array_resize(msg): - a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float64') - m.array_reshape2(a) - assert(a.size == 9) - assert(np.all(a == [[1, 2, 3], [4, 5, 6], [7, 8, 9]])) - - # total size change should succced with refcheck off - m.array_resize3(a, 4, False) - assert(a.size == 64) - # ... and fail with refcheck on - try: - m.array_resize3(a, 3, True) - except ValueError as e: - assert(str(e).startswith("cannot resize an array")) - # transposed array doesn't own data - b = a.transpose() - try: - m.array_resize3(b, 3, False) - except ValueError as e: - assert(str(e).startswith("cannot resize this array: it does not own its data")) - # ... but reshape should be fine - m.array_reshape2(b) - assert(b.shape == (8, 8)) - - -@pytest.unsupported_on_pypy -def test_array_create_and_resize(msg): - a = m.create_and_resize(2) - assert(a.size == 4) - assert(np.all(a == 42.)) - - -@pytest.unsupported_on_py2 -def test_index_using_ellipsis(): - a = m.index_using_ellipsis(np.zeros((5, 6, 7))) - assert a.shape == (6,) - - -@pytest.unsupported_on_pypy -def test_dtype_refcount_leak(): - from sys import getrefcount - dtype = np.dtype(np.float_) - a = np.array([1], dtype=dtype) - before = getrefcount(dtype) - m.ndim(a) - after = getrefcount(dtype) - assert after == before diff --git a/pybind11/tests/test_numpy_dtypes.cpp b/pybind11/tests/test_numpy_dtypes.cpp deleted file mode 100644 index 467e025..0000000 --- a/pybind11/tests/test_numpy_dtypes.cpp +++ /dev/null @@ -1,474 +0,0 @@ -/* - tests/test_numpy_dtypes.cpp -- Structured and compound NumPy dtypes - - Copyright (c) 2016 Ivan Smirnov - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include - -#ifdef __GNUC__ -#define PYBIND11_PACKED(cls) cls __attribute__((__packed__)) -#else -#define PYBIND11_PACKED(cls) __pragma(pack(push, 1)) cls __pragma(pack(pop)) -#endif - -namespace py = pybind11; - -struct SimpleStruct { - bool bool_; - uint32_t uint_; - float float_; - long double ldbl_; -}; - -std::ostream& operator<<(std::ostream& os, const SimpleStruct& v) { - return os << "s:" << v.bool_ << "," << v.uint_ << "," << v.float_ << "," << v.ldbl_; -} - -struct SimpleStructReordered { - bool bool_; - float float_; - uint32_t uint_; - long double ldbl_; -}; - -PYBIND11_PACKED(struct PackedStruct { - bool bool_; - uint32_t uint_; - float float_; - long double ldbl_; -}); - -std::ostream& operator<<(std::ostream& os, const PackedStruct& v) { - return os << "p:" << v.bool_ << "," << v.uint_ << "," << v.float_ << "," << v.ldbl_; -} - -PYBIND11_PACKED(struct NestedStruct { - SimpleStruct a; - PackedStruct b; -}); - -std::ostream& operator<<(std::ostream& os, const NestedStruct& v) { - return os << "n:a=" << v.a << ";b=" << v.b; -} - -struct PartialStruct { - bool bool_; - uint32_t uint_; - float float_; - uint64_t dummy2; - long double ldbl_; -}; - -struct PartialNestedStruct { - uint64_t dummy1; - PartialStruct a; - uint64_t dummy2; -}; - -struct UnboundStruct { }; - -struct StringStruct { - char a[3]; - std::array b; -}; - -struct ComplexStruct { - std::complex cflt; - std::complex cdbl; -}; - -std::ostream& operator<<(std::ostream& os, const ComplexStruct& v) { - return os << "c:" << v.cflt << "," << v.cdbl; -} - -struct ArrayStruct { - char a[3][4]; - int32_t b[2]; - std::array c; - std::array d[4]; -}; - -PYBIND11_PACKED(struct StructWithUglyNames { - int8_t __x__; - uint64_t __y__; -}); - -enum class E1 : int64_t { A = -1, B = 1 }; -enum E2 : uint8_t { X = 1, Y = 2 }; - -PYBIND11_PACKED(struct EnumStruct { - E1 e1; - E2 e2; -}); - -std::ostream& operator<<(std::ostream& os, const StringStruct& v) { - os << "a='"; - for (size_t i = 0; i < 3 && v.a[i]; i++) os << v.a[i]; - os << "',b='"; - for (size_t i = 0; i < 3 && v.b[i]; i++) os << v.b[i]; - return os << "'"; -} - -std::ostream& operator<<(std::ostream& os, const ArrayStruct& v) { - os << "a={"; - for (int i = 0; i < 3; i++) { - if (i > 0) - os << ','; - os << '{'; - for (int j = 0; j < 3; j++) - os << v.a[i][j] << ','; - os << v.a[i][3] << '}'; - } - os << "},b={" << v.b[0] << ',' << v.b[1]; - os << "},c={" << int(v.c[0]) << ',' << int(v.c[1]) << ',' << int(v.c[2]); - os << "},d={"; - for (int i = 0; i < 4; i++) { - if (i > 0) - os << ','; - os << '{' << v.d[i][0] << ',' << v.d[i][1] << '}'; - } - return os << '}'; -} - -std::ostream& operator<<(std::ostream& os, const EnumStruct& v) { - return os << "e1=" << (v.e1 == E1::A ? "A" : "B") << ",e2=" << (v.e2 == E2::X ? "X" : "Y"); -} - -template -py::array mkarray_via_buffer(size_t n) { - return py::array(py::buffer_info(nullptr, sizeof(T), - py::format_descriptor::format(), - 1, { n }, { sizeof(T) })); -} - -#define SET_TEST_VALS(s, i) do { \ - s.bool_ = (i) % 2 != 0; \ - s.uint_ = (uint32_t) (i); \ - s.float_ = (float) (i) * 1.5f; \ - s.ldbl_ = (long double) (i) * -2.5L; } while (0) - -template -py::array_t create_recarray(size_t n) { - auto arr = mkarray_via_buffer(n); - auto req = arr.request(); - auto ptr = static_cast(req.ptr); - for (size_t i = 0; i < n; i++) { - SET_TEST_VALS(ptr[i], i); - } - return arr; -} - -template -py::list print_recarray(py::array_t arr) { - const auto req = arr.request(); - const auto ptr = static_cast(req.ptr); - auto l = py::list(); - for (ssize_t i = 0; i < req.size; i++) { - std::stringstream ss; - ss << ptr[i]; - l.append(py::str(ss.str())); - } - return l; -} - -py::array_t test_array_ctors(int i) { - using arr_t = py::array_t; - - std::vector data { 1, 2, 3, 4, 5, 6 }; - std::vector shape { 3, 2 }; - std::vector strides { 8, 4 }; - - auto ptr = data.data(); - auto vptr = (void *) ptr; - auto dtype = py::dtype("int32"); - - py::buffer_info buf_ndim1(vptr, 4, "i", 6); - py::buffer_info buf_ndim1_null(nullptr, 4, "i", 6); - py::buffer_info buf_ndim2(vptr, 4, "i", 2, shape, strides); - py::buffer_info buf_ndim2_null(nullptr, 4, "i", 2, shape, strides); - - auto fill = [](py::array arr) { - auto req = arr.request(); - for (int i = 0; i < 6; i++) ((int32_t *) req.ptr)[i] = i + 1; - return arr; - }; - - switch (i) { - // shape: (3, 2) - case 10: return arr_t(shape, strides, ptr); - case 11: return py::array(shape, strides, ptr); - case 12: return py::array(dtype, shape, strides, vptr); - case 13: return arr_t(shape, ptr); - case 14: return py::array(shape, ptr); - case 15: return py::array(dtype, shape, vptr); - case 16: return arr_t(buf_ndim2); - case 17: return py::array(buf_ndim2); - // shape: (3, 2) - post-fill - case 20: return fill(arr_t(shape, strides)); - case 21: return py::array(shape, strides, ptr); // can't have nullptr due to templated ctor - case 22: return fill(py::array(dtype, shape, strides)); - case 23: return fill(arr_t(shape)); - case 24: return py::array(shape, ptr); // can't have nullptr due to templated ctor - case 25: return fill(py::array(dtype, shape)); - case 26: return fill(arr_t(buf_ndim2_null)); - case 27: return fill(py::array(buf_ndim2_null)); - // shape: (6, ) - case 30: return arr_t(6, ptr); - case 31: return py::array(6, ptr); - case 32: return py::array(dtype, 6, vptr); - case 33: return arr_t(buf_ndim1); - case 34: return py::array(buf_ndim1); - // shape: (6, ) - case 40: return fill(arr_t(6)); - case 41: return py::array(6, ptr); // can't have nullptr due to templated ctor - case 42: return fill(py::array(dtype, 6)); - case 43: return fill(arr_t(buf_ndim1_null)); - case 44: return fill(py::array(buf_ndim1_null)); - } - return arr_t(); -} - -py::list test_dtype_ctors() { - py::list list; - list.append(py::dtype("int32")); - list.append(py::dtype(std::string("float64"))); - list.append(py::dtype::from_args(py::str("bool"))); - py::list names, offsets, formats; - py::dict dict; - names.append(py::str("a")); names.append(py::str("b")); dict["names"] = names; - offsets.append(py::int_(1)); offsets.append(py::int_(10)); dict["offsets"] = offsets; - formats.append(py::dtype("int32")); formats.append(py::dtype("float64")); dict["formats"] = formats; - dict["itemsize"] = py::int_(20); - list.append(py::dtype::from_args(dict)); - list.append(py::dtype(names, formats, offsets, 20)); - list.append(py::dtype(py::buffer_info((void *) 0, sizeof(unsigned int), "I", 1))); - list.append(py::dtype(py::buffer_info((void *) 0, 0, "T{i:a:f:b:}", 1))); - return list; -} - -struct A {}; -struct B {}; - -TEST_SUBMODULE(numpy_dtypes, m) { - try { py::module::import("numpy"); } - catch (...) { return; } - - // typeinfo may be registered before the dtype descriptor for scalar casts to work... - py::class_(m, "SimpleStruct"); - - PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); - PYBIND11_NUMPY_DTYPE(SimpleStructReordered, bool_, uint_, float_, ldbl_); - PYBIND11_NUMPY_DTYPE(PackedStruct, bool_, uint_, float_, ldbl_); - PYBIND11_NUMPY_DTYPE(NestedStruct, a, b); - PYBIND11_NUMPY_DTYPE(PartialStruct, bool_, uint_, float_, ldbl_); - PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a); - PYBIND11_NUMPY_DTYPE(StringStruct, a, b); - PYBIND11_NUMPY_DTYPE(ArrayStruct, a, b, c, d); - PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2); - PYBIND11_NUMPY_DTYPE(ComplexStruct, cflt, cdbl); - - // ... or after - py::class_(m, "PackedStruct"); - - PYBIND11_NUMPY_DTYPE_EX(StructWithUglyNames, __x__, "x", __y__, "y"); - - // If uncommented, this should produce a static_assert failure telling the user that the struct - // is not a POD type -// struct NotPOD { std::string v; NotPOD() : v("hi") {}; }; -// PYBIND11_NUMPY_DTYPE(NotPOD, v); - - // Check that dtypes can be registered programmatically, both from - // initializer lists of field descriptors and from other containers. - py::detail::npy_format_descriptor::register_dtype( - {} - ); - py::detail::npy_format_descriptor::register_dtype( - std::vector{} - ); - - // test_recarray, test_scalar_conversion - m.def("create_rec_simple", &create_recarray); - m.def("create_rec_packed", &create_recarray); - m.def("create_rec_nested", [](size_t n) { // test_signature - py::array_t arr = mkarray_via_buffer(n); - auto req = arr.request(); - auto ptr = static_cast(req.ptr); - for (size_t i = 0; i < n; i++) { - SET_TEST_VALS(ptr[i].a, i); - SET_TEST_VALS(ptr[i].b, i + 1); - } - return arr; - }); - m.def("create_rec_partial", &create_recarray); - m.def("create_rec_partial_nested", [](size_t n) { - py::array_t arr = mkarray_via_buffer(n); - auto req = arr.request(); - auto ptr = static_cast(req.ptr); - for (size_t i = 0; i < n; i++) { - SET_TEST_VALS(ptr[i].a, i); - } - return arr; - }); - m.def("print_rec_simple", &print_recarray); - m.def("print_rec_packed", &print_recarray); - m.def("print_rec_nested", &print_recarray); - - // test_format_descriptors - m.def("get_format_unbound", []() { return py::format_descriptor::format(); }); - m.def("print_format_descriptors", []() { - py::list l; - for (const auto &fmt : { - py::format_descriptor::format(), - py::format_descriptor::format(), - py::format_descriptor::format(), - py::format_descriptor::format(), - py::format_descriptor::format(), - py::format_descriptor::format(), - py::format_descriptor::format(), - py::format_descriptor::format(), - py::format_descriptor::format() - }) { - l.append(py::cast(fmt)); - } - return l; - }); - - // test_dtype - m.def("print_dtypes", []() { - py::list l; - for (const py::handle &d : { - py::dtype::of(), - py::dtype::of(), - py::dtype::of(), - py::dtype::of(), - py::dtype::of(), - py::dtype::of(), - py::dtype::of(), - py::dtype::of(), - py::dtype::of(), - py::dtype::of() - }) - l.append(py::str(d)); - return l; - }); - m.def("test_dtype_ctors", &test_dtype_ctors); - m.def("test_dtype_methods", []() { - py::list list; - auto dt1 = py::dtype::of(); - auto dt2 = py::dtype::of(); - list.append(dt1); list.append(dt2); - list.append(py::bool_(dt1.has_fields())); list.append(py::bool_(dt2.has_fields())); - list.append(py::int_(dt1.itemsize())); list.append(py::int_(dt2.itemsize())); - return list; - }); - struct TrailingPaddingStruct { - int32_t a; - char b; - }; - PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b); - m.def("trailing_padding_dtype", []() { return py::dtype::of(); }); - - // test_string_array - m.def("create_string_array", [](bool non_empty) { - py::array_t arr = mkarray_via_buffer(non_empty ? 4 : 0); - if (non_empty) { - auto req = arr.request(); - auto ptr = static_cast(req.ptr); - for (ssize_t i = 0; i < req.size * req.itemsize; i++) - static_cast(req.ptr)[i] = 0; - ptr[1].a[0] = 'a'; ptr[1].b[0] = 'a'; - ptr[2].a[0] = 'a'; ptr[2].b[0] = 'a'; - ptr[3].a[0] = 'a'; ptr[3].b[0] = 'a'; - - ptr[2].a[1] = 'b'; ptr[2].b[1] = 'b'; - ptr[3].a[1] = 'b'; ptr[3].b[1] = 'b'; - - ptr[3].a[2] = 'c'; ptr[3].b[2] = 'c'; - } - return arr; - }); - m.def("print_string_array", &print_recarray); - - // test_array_array - m.def("create_array_array", [](size_t n) { - py::array_t arr = mkarray_via_buffer(n); - auto ptr = (ArrayStruct *) arr.mutable_data(); - for (size_t i = 0; i < n; i++) { - for (size_t j = 0; j < 3; j++) - for (size_t k = 0; k < 4; k++) - ptr[i].a[j][k] = char('A' + (i * 100 + j * 10 + k) % 26); - for (size_t j = 0; j < 2; j++) - ptr[i].b[j] = int32_t(i * 1000 + j); - for (size_t j = 0; j < 3; j++) - ptr[i].c[j] = uint8_t(i * 10 + j); - for (size_t j = 0; j < 4; j++) - for (size_t k = 0; k < 2; k++) - ptr[i].d[j][k] = float(i) * 100.0f + float(j) * 10.0f + float(k); - } - return arr; - }); - m.def("print_array_array", &print_recarray); - - // test_enum_array - m.def("create_enum_array", [](size_t n) { - py::array_t arr = mkarray_via_buffer(n); - auto ptr = (EnumStruct *) arr.mutable_data(); - for (size_t i = 0; i < n; i++) { - ptr[i].e1 = static_cast(-1 + ((int) i % 2) * 2); - ptr[i].e2 = static_cast(1 + (i % 2)); - } - return arr; - }); - m.def("print_enum_array", &print_recarray); - - // test_complex_array - m.def("create_complex_array", [](size_t n) { - py::array_t arr = mkarray_via_buffer(n); - auto ptr = (ComplexStruct *) arr.mutable_data(); - for (size_t i = 0; i < n; i++) { - ptr[i].cflt.real(float(i)); - ptr[i].cflt.imag(float(i) + 0.25f); - ptr[i].cdbl.real(double(i) + 0.5); - ptr[i].cdbl.imag(double(i) + 0.75); - } - return arr; - }); - m.def("print_complex_array", &print_recarray); - - // test_array_constructors - m.def("test_array_ctors", &test_array_ctors); - - // test_compare_buffer_info - struct CompareStruct { - bool x; - uint32_t y; - float z; - }; - PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z); - m.def("compare_buffer_info", []() { - py::list list; - list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(float), "f", 1)))); - list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(int), "I", 1)))); - list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(long), "l", 1)))); - list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(long), sizeof(long) == sizeof(int) ? "i" : "q", 1)))); - list.append(py::bool_(py::detail::compare_buffer_info::compare(py::buffer_info(nullptr, sizeof(CompareStruct), "T{?:x:3xI:y:f:z:}", 1)))); - return list; - }); - m.def("buffer_to_dtype", [](py::buffer& buf) { return py::dtype(buf.request()); }); - - // test_scalar_conversion - m.def("f_simple", [](SimpleStruct s) { return s.uint_ * 10; }); - m.def("f_packed", [](PackedStruct s) { return s.uint_ * 10; }); - m.def("f_nested", [](NestedStruct s) { return s.a.uint_ * 10; }); - - // test_register_dtype - m.def("register_dtype", []() { PYBIND11_NUMPY_DTYPE(SimpleStruct, bool_, uint_, float_, ldbl_); }); - - // test_str_leak - m.def("dtype_wrapper", [](py::object d) { return py::dtype::from_args(std::move(d)); }); -} diff --git a/pybind11/tests/test_numpy_dtypes.py b/pybind11/tests/test_numpy_dtypes.py deleted file mode 100644 index 2e63885..0000000 --- a/pybind11/tests/test_numpy_dtypes.py +++ /dev/null @@ -1,310 +0,0 @@ -import re -import pytest -from pybind11_tests import numpy_dtypes as m - -pytestmark = pytest.requires_numpy - -with pytest.suppress(ImportError): - import numpy as np - - -@pytest.fixture(scope='module') -def simple_dtype(): - ld = np.dtype('longdouble') - return np.dtype({'names': ['bool_', 'uint_', 'float_', 'ldbl_'], - 'formats': ['?', 'u4', 'f4', 'f{}'.format(ld.itemsize)], - 'offsets': [0, 4, 8, (16 if ld.alignment > 4 else 12)]}) - - -@pytest.fixture(scope='module') -def packed_dtype(): - return np.dtype([('bool_', '?'), ('uint_', 'u4'), ('float_', 'f4'), ('ldbl_', 'g')]) - - -def dt_fmt(): - from sys import byteorder - e = '<' if byteorder == 'little' else '>' - return ("{{'names':['bool_','uint_','float_','ldbl_']," - " 'formats':['?','" + e + "u4','" + e + "f4','" + e + "f{}']," - " 'offsets':[0,4,8,{}], 'itemsize':{}}}") - - -def simple_dtype_fmt(): - ld = np.dtype('longdouble') - simple_ld_off = 12 + 4 * (ld.alignment > 4) - return dt_fmt().format(ld.itemsize, simple_ld_off, simple_ld_off + ld.itemsize) - - -def packed_dtype_fmt(): - from sys import byteorder - return "[('bool_', '?'), ('uint_', '{e}u4'), ('float_', '{e}f4'), ('ldbl_', '{e}f{}')]".format( - np.dtype('longdouble').itemsize, e='<' if byteorder == 'little' else '>') - - -def partial_ld_offset(): - return 12 + 4 * (np.dtype('uint64').alignment > 4) + 8 + 8 * ( - np.dtype('longdouble').alignment > 8) - - -def partial_dtype_fmt(): - ld = np.dtype('longdouble') - partial_ld_off = partial_ld_offset() - return dt_fmt().format(ld.itemsize, partial_ld_off, partial_ld_off + ld.itemsize) - - -def partial_nested_fmt(): - ld = np.dtype('longdouble') - partial_nested_off = 8 + 8 * (ld.alignment > 8) - partial_ld_off = partial_ld_offset() - partial_nested_size = partial_nested_off * 2 + partial_ld_off + ld.itemsize - return "{{'names':['a'], 'formats':[{}], 'offsets':[{}], 'itemsize':{}}}".format( - partial_dtype_fmt(), partial_nested_off, partial_nested_size) - - -def assert_equal(actual, expected_data, expected_dtype): - np.testing.assert_equal(actual, np.array(expected_data, dtype=expected_dtype)) - - -def test_format_descriptors(): - with pytest.raises(RuntimeError) as excinfo: - m.get_format_unbound() - assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value)) - - ld = np.dtype('longdouble') - ldbl_fmt = ('4x' if ld.alignment > 4 else '') + ld.char - ss_fmt = "^T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}" - dbl = np.dtype('double') - partial_fmt = ("^T{?:bool_:3xI:uint_:f:float_:" + - str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) + - "xg:ldbl_:}") - nested_extra = str(max(8, ld.alignment)) - assert m.print_format_descriptors() == [ - ss_fmt, - "^T{?:bool_:I:uint_:f:float_:g:ldbl_:}", - "^T{" + ss_fmt + ":a:^T{?:bool_:I:uint_:f:float_:g:ldbl_:}:b:}", - partial_fmt, - "^T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}", - "^T{3s:a:3s:b:}", - "^T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}", - '^T{q:e1:B:e2:}', - '^T{Zf:cflt:Zd:cdbl:}' - ] - - -def test_dtype(simple_dtype): - from sys import byteorder - e = '<' if byteorder == 'little' else '>' - - assert m.print_dtypes() == [ - simple_dtype_fmt(), - packed_dtype_fmt(), - "[('a', {}), ('b', {})]".format(simple_dtype_fmt(), packed_dtype_fmt()), - partial_dtype_fmt(), - partial_nested_fmt(), - "[('a', 'S3'), ('b', 'S3')]", - ("{{'names':['a','b','c','d'], " + - "'formats':[('S4', (3,)),('" + e + "i4', (2,)),('u1', (3,)),('" + e + "f4', (4, 2))], " + - "'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e), - "[('e1', '" + e + "i8'), ('e2', 'u1')]", - "[('x', 'i1'), ('y', '" + e + "u8')]", - "[('cflt', '" + e + "c8'), ('cdbl', '" + e + "c16')]" - ] - - d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'], - 'offsets': [1, 10], 'itemsize': 20}) - d2 = np.dtype([('a', 'i4'), ('b', 'f4')]) - assert m.test_dtype_ctors() == [np.dtype('int32'), np.dtype('float64'), - np.dtype('bool'), d1, d1, np.dtype('uint32'), d2] - - assert m.test_dtype_methods() == [np.dtype('int32'), simple_dtype, False, True, - np.dtype('int32').itemsize, simple_dtype.itemsize] - - assert m.trailing_padding_dtype() == m.buffer_to_dtype(np.zeros(1, m.trailing_padding_dtype())) - - -def test_recarray(simple_dtype, packed_dtype): - elements = [(False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)] - - for func, dtype in [(m.create_rec_simple, simple_dtype), (m.create_rec_packed, packed_dtype)]: - arr = func(0) - assert arr.dtype == dtype - assert_equal(arr, [], simple_dtype) - assert_equal(arr, [], packed_dtype) - - arr = func(3) - assert arr.dtype == dtype - assert_equal(arr, elements, simple_dtype) - assert_equal(arr, elements, packed_dtype) - - if dtype == simple_dtype: - assert m.print_rec_simple(arr) == [ - "s:0,0,0,-0", - "s:1,1,1.5,-2.5", - "s:0,2,3,-5" - ] - else: - assert m.print_rec_packed(arr) == [ - "p:0,0,0,-0", - "p:1,1,1.5,-2.5", - "p:0,2,3,-5" - ] - - nested_dtype = np.dtype([('a', simple_dtype), ('b', packed_dtype)]) - - arr = m.create_rec_nested(0) - assert arr.dtype == nested_dtype - assert_equal(arr, [], nested_dtype) - - arr = m.create_rec_nested(3) - assert arr.dtype == nested_dtype - assert_equal(arr, [((False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5)), - ((True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)), - ((False, 2, 3.0, -5.0), (True, 3, 4.5, -7.5))], nested_dtype) - assert m.print_rec_nested(arr) == [ - "n:a=s:0,0,0,-0;b=p:1,1,1.5,-2.5", - "n:a=s:1,1,1.5,-2.5;b=p:0,2,3,-5", - "n:a=s:0,2,3,-5;b=p:1,3,4.5,-7.5" - ] - - arr = m.create_rec_partial(3) - assert str(arr.dtype) == partial_dtype_fmt() - partial_dtype = arr.dtype - assert '' not in arr.dtype.fields - assert partial_dtype.itemsize > simple_dtype.itemsize - assert_equal(arr, elements, simple_dtype) - assert_equal(arr, elements, packed_dtype) - - arr = m.create_rec_partial_nested(3) - assert str(arr.dtype) == partial_nested_fmt() - assert '' not in arr.dtype.fields - assert '' not in arr.dtype.fields['a'][0].fields - assert arr.dtype.itemsize > partial_dtype.itemsize - np.testing.assert_equal(arr['a'], m.create_rec_partial(3)) - - -def test_array_constructors(): - data = np.arange(1, 7, dtype='int32') - for i in range(8): - np.testing.assert_array_equal(m.test_array_ctors(10 + i), data.reshape((3, 2))) - np.testing.assert_array_equal(m.test_array_ctors(20 + i), data.reshape((3, 2))) - for i in range(5): - np.testing.assert_array_equal(m.test_array_ctors(30 + i), data) - np.testing.assert_array_equal(m.test_array_ctors(40 + i), data) - - -def test_string_array(): - arr = m.create_string_array(True) - assert str(arr.dtype) == "[('a', 'S3'), ('b', 'S3')]" - assert m.print_string_array(arr) == [ - "a='',b=''", - "a='a',b='a'", - "a='ab',b='ab'", - "a='abc',b='abc'" - ] - dtype = arr.dtype - assert arr['a'].tolist() == [b'', b'a', b'ab', b'abc'] - assert arr['b'].tolist() == [b'', b'a', b'ab', b'abc'] - arr = m.create_string_array(False) - assert dtype == arr.dtype - - -def test_array_array(): - from sys import byteorder - e = '<' if byteorder == 'little' else '>' - - arr = m.create_array_array(3) - assert str(arr.dtype) == ( - "{{'names':['a','b','c','d'], " + - "'formats':[('S4', (3,)),('" + e + "i4', (2,)),('u1', (3,)),('{e}f4', (4, 2))], " + - "'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e) - assert m.print_array_array(arr) == [ - "a={{A,B,C,D},{K,L,M,N},{U,V,W,X}},b={0,1}," + - "c={0,1,2},d={{0,1},{10,11},{20,21},{30,31}}", - "a={{W,X,Y,Z},{G,H,I,J},{Q,R,S,T}},b={1000,1001}," + - "c={10,11,12},d={{100,101},{110,111},{120,121},{130,131}}", - "a={{S,T,U,V},{C,D,E,F},{M,N,O,P}},b={2000,2001}," + - "c={20,21,22},d={{200,201},{210,211},{220,221},{230,231}}", - ] - assert arr['a'].tolist() == [[b'ABCD', b'KLMN', b'UVWX'], - [b'WXYZ', b'GHIJ', b'QRST'], - [b'STUV', b'CDEF', b'MNOP']] - assert arr['b'].tolist() == [[0, 1], [1000, 1001], [2000, 2001]] - assert m.create_array_array(0).dtype == arr.dtype - - -def test_enum_array(): - from sys import byteorder - e = '<' if byteorder == 'little' else '>' - - arr = m.create_enum_array(3) - dtype = arr.dtype - assert dtype == np.dtype([('e1', e + 'i8'), ('e2', 'u1')]) - assert m.print_enum_array(arr) == [ - "e1=A,e2=X", - "e1=B,e2=Y", - "e1=A,e2=X" - ] - assert arr['e1'].tolist() == [-1, 1, -1] - assert arr['e2'].tolist() == [1, 2, 1] - assert m.create_enum_array(0).dtype == dtype - - -def test_complex_array(): - from sys import byteorder - e = '<' if byteorder == 'little' else '>' - - arr = m.create_complex_array(3) - dtype = arr.dtype - assert dtype == np.dtype([('cflt', e + 'c8'), ('cdbl', e + 'c16')]) - assert m.print_complex_array(arr) == [ - "c:(0,0.25),(0.5,0.75)", - "c:(1,1.25),(1.5,1.75)", - "c:(2,2.25),(2.5,2.75)" - ] - assert arr['cflt'].tolist() == [0.0 + 0.25j, 1.0 + 1.25j, 2.0 + 2.25j] - assert arr['cdbl'].tolist() == [0.5 + 0.75j, 1.5 + 1.75j, 2.5 + 2.75j] - assert m.create_complex_array(0).dtype == dtype - - -def test_signature(doc): - assert doc(m.create_rec_nested) == \ - "create_rec_nested(arg0: int) -> numpy.ndarray[NestedStruct]" - - -def test_scalar_conversion(): - n = 3 - arrays = [m.create_rec_simple(n), m.create_rec_packed(n), - m.create_rec_nested(n), m.create_enum_array(n)] - funcs = [m.f_simple, m.f_packed, m.f_nested] - - for i, func in enumerate(funcs): - for j, arr in enumerate(arrays): - if i == j and i < 2: - assert [func(arr[k]) for k in range(n)] == [k * 10 for k in range(n)] - else: - with pytest.raises(TypeError) as excinfo: - func(arr[0]) - assert 'incompatible function arguments' in str(excinfo.value) - - -def test_register_dtype(): - with pytest.raises(RuntimeError) as excinfo: - m.register_dtype() - assert 'dtype is already registered' in str(excinfo.value) - - -@pytest.unsupported_on_pypy -def test_str_leak(): - from sys import getrefcount - fmt = "f4" - pytest.gc_collect() - start = getrefcount(fmt) - d = m.dtype_wrapper(fmt) - assert d is np.dtype("f4") - del d - pytest.gc_collect() - assert getrefcount(fmt) == start - - -def test_compare_buffer_info(): - assert all(m.compare_buffer_info()) diff --git a/pybind11/tests/test_numpy_vectorize.cpp b/pybind11/tests/test_numpy_vectorize.cpp deleted file mode 100644 index a875a74..0000000 --- a/pybind11/tests/test_numpy_vectorize.cpp +++ /dev/null @@ -1,89 +0,0 @@ -/* - tests/test_numpy_vectorize.cpp -- auto-vectorize functions over NumPy array - arguments - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include - -double my_func(int x, float y, double z) { - py::print("my_func(x:int={}, y:float={:.0f}, z:float={:.0f})"_s.format(x, y, z)); - return (float) x*y*z; -} - -TEST_SUBMODULE(numpy_vectorize, m) { - try { py::module::import("numpy"); } - catch (...) { return; } - - // test_vectorize, test_docs, test_array_collapse - // Vectorize all arguments of a function (though non-vector arguments are also allowed) - m.def("vectorized_func", py::vectorize(my_func)); - - // Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization) - m.def("vectorized_func2", - [](py::array_t x, py::array_t y, float z) { - return py::vectorize([z](int x, float y) { return my_func(x, y, z); })(x, y); - } - ); - - // Vectorize a complex-valued function - m.def("vectorized_func3", py::vectorize( - [](std::complex c) { return c * std::complex(2.f); } - )); - - // test_type_selection - // Numpy function which only accepts specific data types - m.def("selective_func", [](py::array_t) { return "Int branch taken."; }); - m.def("selective_func", [](py::array_t) { return "Float branch taken."; }); - m.def("selective_func", [](py::array_t, py::array::c_style>) { return "Complex float branch taken."; }); - - - // test_passthrough_arguments - // Passthrough test: references and non-pod types should be automatically passed through (in the - // function definition below, only `b`, `d`, and `g` are vectorized): - struct NonPODClass { - NonPODClass(int v) : value{v} {} - int value; - }; - py::class_(m, "NonPODClass").def(py::init()); - m.def("vec_passthrough", py::vectorize( - [](double *a, double b, py::array_t c, const int &d, int &e, NonPODClass f, const double g) { - return *a + b + c.at(0) + d + e + f.value + g; - } - )); - - // test_method_vectorization - struct VectorizeTestClass { - VectorizeTestClass(int v) : value{v} {}; - float method(int x, float y) { return y + (float) (x + value); } - int value = 0; - }; - py::class_ vtc(m, "VectorizeTestClass"); - vtc .def(py::init()) - .def_readwrite("value", &VectorizeTestClass::value); - - // Automatic vectorizing of methods - vtc.def("method", py::vectorize(&VectorizeTestClass::method)); - - // test_trivial_broadcasting - // Internal optimization test for whether the input is trivially broadcastable: - py::enum_(m, "trivial") - .value("f_trivial", py::detail::broadcast_trivial::f_trivial) - .value("c_trivial", py::detail::broadcast_trivial::c_trivial) - .value("non_trivial", py::detail::broadcast_trivial::non_trivial); - m.def("vectorized_is_trivial", []( - py::array_t arg1, - py::array_t arg2, - py::array_t arg3 - ) { - ssize_t ndim; - std::vector shape; - std::array buffers {{ arg1.request(), arg2.request(), arg3.request() }}; - return py::detail::broadcast(buffers, ndim, shape); - }); -} diff --git a/pybind11/tests/test_numpy_vectorize.py b/pybind11/tests/test_numpy_vectorize.py deleted file mode 100644 index 0e9c883..0000000 --- a/pybind11/tests/test_numpy_vectorize.py +++ /dev/null @@ -1,196 +0,0 @@ -import pytest -from pybind11_tests import numpy_vectorize as m - -pytestmark = pytest.requires_numpy - -with pytest.suppress(ImportError): - import numpy as np - - -def test_vectorize(capture): - assert np.isclose(m.vectorized_func3(np.array(3 + 7j)), [6 + 14j]) - - for f in [m.vectorized_func, m.vectorized_func2]: - with capture: - assert np.isclose(f(1, 2, 3), 6) - assert capture == "my_func(x:int=1, y:float=2, z:float=3)" - with capture: - assert np.isclose(f(np.array(1), np.array(2), 3), 6) - assert capture == "my_func(x:int=1, y:float=2, z:float=3)" - with capture: - assert np.allclose(f(np.array([1, 3]), np.array([2, 4]), 3), [6, 36]) - assert capture == """ - my_func(x:int=1, y:float=2, z:float=3) - my_func(x:int=3, y:float=4, z:float=3) - """ - with capture: - a = np.array([[1, 2], [3, 4]], order='F') - b = np.array([[10, 20], [30, 40]], order='F') - c = 3 - result = f(a, b, c) - assert np.allclose(result, a * b * c) - assert result.flags.f_contiguous - # All inputs are F order and full or singletons, so we the result is in col-major order: - assert capture == """ - my_func(x:int=1, y:float=10, z:float=3) - my_func(x:int=3, y:float=30, z:float=3) - my_func(x:int=2, y:float=20, z:float=3) - my_func(x:int=4, y:float=40, z:float=3) - """ - with capture: - a, b, c = np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3 - assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ - my_func(x:int=1, y:float=2, z:float=3) - my_func(x:int=3, y:float=4, z:float=3) - my_func(x:int=5, y:float=6, z:float=3) - my_func(x:int=7, y:float=8, z:float=3) - my_func(x:int=9, y:float=10, z:float=3) - my_func(x:int=11, y:float=12, z:float=3) - """ - with capture: - a, b, c = np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2 - assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ - my_func(x:int=1, y:float=2, z:float=2) - my_func(x:int=2, y:float=3, z:float=2) - my_func(x:int=3, y:float=4, z:float=2) - my_func(x:int=4, y:float=2, z:float=2) - my_func(x:int=5, y:float=3, z:float=2) - my_func(x:int=6, y:float=4, z:float=2) - """ - with capture: - a, b, c = np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2 - assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ - my_func(x:int=1, y:float=2, z:float=2) - my_func(x:int=2, y:float=2, z:float=2) - my_func(x:int=3, y:float=2, z:float=2) - my_func(x:int=4, y:float=3, z:float=2) - my_func(x:int=5, y:float=3, z:float=2) - my_func(x:int=6, y:float=3, z:float=2) - """ - with capture: - a, b, c = np.array([[1, 2, 3], [4, 5, 6]], order='F'), np.array([[2], [3]]), 2 - assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ - my_func(x:int=1, y:float=2, z:float=2) - my_func(x:int=2, y:float=2, z:float=2) - my_func(x:int=3, y:float=2, z:float=2) - my_func(x:int=4, y:float=3, z:float=2) - my_func(x:int=5, y:float=3, z:float=2) - my_func(x:int=6, y:float=3, z:float=2) - """ - with capture: - a, b, c = np.array([[1, 2, 3], [4, 5, 6]])[::, ::2], np.array([[2], [3]]), 2 - assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ - my_func(x:int=1, y:float=2, z:float=2) - my_func(x:int=3, y:float=2, z:float=2) - my_func(x:int=4, y:float=3, z:float=2) - my_func(x:int=6, y:float=3, z:float=2) - """ - with capture: - a, b, c = np.array([[1, 2, 3], [4, 5, 6]], order='F')[::, ::2], np.array([[2], [3]]), 2 - assert np.allclose(f(a, b, c), a * b * c) - assert capture == """ - my_func(x:int=1, y:float=2, z:float=2) - my_func(x:int=3, y:float=2, z:float=2) - my_func(x:int=4, y:float=3, z:float=2) - my_func(x:int=6, y:float=3, z:float=2) - """ - - -def test_type_selection(): - assert m.selective_func(np.array([1], dtype=np.int32)) == "Int branch taken." - assert m.selective_func(np.array([1.0], dtype=np.float32)) == "Float branch taken." - assert m.selective_func(np.array([1.0j], dtype=np.complex64)) == "Complex float branch taken." - - -def test_docs(doc): - assert doc(m.vectorized_func) == """ - vectorized_func(arg0: numpy.ndarray[int32], arg1: numpy.ndarray[float32], arg2: numpy.ndarray[float64]) -> object - """ # noqa: E501 line too long - - -def test_trivial_broadcasting(): - trivial, vectorized_is_trivial = m.trivial, m.vectorized_is_trivial - - assert vectorized_is_trivial(1, 2, 3) == trivial.c_trivial - assert vectorized_is_trivial(np.array(1), np.array(2), 3) == trivial.c_trivial - assert vectorized_is_trivial(np.array([1, 3]), np.array([2, 4]), 3) == trivial.c_trivial - assert trivial.c_trivial == vectorized_is_trivial( - np.array([[1, 3, 5], [7, 9, 11]]), np.array([[2, 4, 6], [8, 10, 12]]), 3) - assert vectorized_is_trivial( - np.array([[1, 2, 3], [4, 5, 6]]), np.array([2, 3, 4]), 2) == trivial.non_trivial - assert vectorized_is_trivial( - np.array([[1, 2, 3], [4, 5, 6]]), np.array([[2], [3]]), 2) == trivial.non_trivial - z1 = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype='int32') - z2 = np.array(z1, dtype='float32') - z3 = np.array(z1, dtype='float64') - assert vectorized_is_trivial(z1, z2, z3) == trivial.c_trivial - assert vectorized_is_trivial(1, z2, z3) == trivial.c_trivial - assert vectorized_is_trivial(z1, 1, z3) == trivial.c_trivial - assert vectorized_is_trivial(z1, z2, 1) == trivial.c_trivial - assert vectorized_is_trivial(z1[::2, ::2], 1, 1) == trivial.non_trivial - assert vectorized_is_trivial(1, 1, z1[::2, ::2]) == trivial.c_trivial - assert vectorized_is_trivial(1, 1, z3[::2, ::2]) == trivial.non_trivial - assert vectorized_is_trivial(z1, 1, z3[1::4, 1::4]) == trivial.c_trivial - - y1 = np.array(z1, order='F') - y2 = np.array(y1) - y3 = np.array(y1) - assert vectorized_is_trivial(y1, y2, y3) == trivial.f_trivial - assert vectorized_is_trivial(y1, 1, 1) == trivial.f_trivial - assert vectorized_is_trivial(1, y2, 1) == trivial.f_trivial - assert vectorized_is_trivial(1, 1, y3) == trivial.f_trivial - assert vectorized_is_trivial(y1, z2, 1) == trivial.non_trivial - assert vectorized_is_trivial(z1[1::4, 1::4], y2, 1) == trivial.f_trivial - assert vectorized_is_trivial(y1[1::4, 1::4], z2, 1) == trivial.c_trivial - - assert m.vectorized_func(z1, z2, z3).flags.c_contiguous - assert m.vectorized_func(y1, y2, y3).flags.f_contiguous - assert m.vectorized_func(z1, 1, 1).flags.c_contiguous - assert m.vectorized_func(1, y2, 1).flags.f_contiguous - assert m.vectorized_func(z1[1::4, 1::4], y2, 1).flags.f_contiguous - assert m.vectorized_func(y1[1::4, 1::4], z2, 1).flags.c_contiguous - - -def test_passthrough_arguments(doc): - assert doc(m.vec_passthrough) == ( - "vec_passthrough(" + ", ".join([ - "arg0: float", - "arg1: numpy.ndarray[float64]", - "arg2: numpy.ndarray[float64]", - "arg3: numpy.ndarray[int32]", - "arg4: int", - "arg5: m.numpy_vectorize.NonPODClass", - "arg6: numpy.ndarray[float64]"]) + ") -> object") - - b = np.array([[10, 20, 30]], dtype='float64') - c = np.array([100, 200]) # NOT a vectorized argument - d = np.array([[1000], [2000], [3000]], dtype='int') - g = np.array([[1000000, 2000000, 3000000]], dtype='int') # requires casting - assert np.all( - m.vec_passthrough(1, b, c, d, 10000, m.NonPODClass(100000), g) == - np.array([[1111111, 2111121, 3111131], - [1112111, 2112121, 3112131], - [1113111, 2113121, 3113131]])) - - -def test_method_vectorization(): - o = m.VectorizeTestClass(3) - x = np.array([1, 2], dtype='int') - y = np.array([[10], [20]], dtype='float32') - assert np.all(o.method(x, y) == [[14, 15], [24, 25]]) - - -def test_array_collapse(): - assert not isinstance(m.vectorized_func(1, 2, 3), np.ndarray) - assert not isinstance(m.vectorized_func(np.array(1), 2, 3), np.ndarray) - z = m.vectorized_func([1], 2, 3) - assert isinstance(z, np.ndarray) - assert z.shape == (1, ) - z = m.vectorized_func(1, [[[2]]], 3) - assert isinstance(z, np.ndarray) - assert z.shape == (1, 1, 1) diff --git a/pybind11/tests/test_opaque_types.cpp b/pybind11/tests/test_opaque_types.cpp deleted file mode 100644 index 0d20d9a..0000000 --- a/pybind11/tests/test_opaque_types.cpp +++ /dev/null @@ -1,67 +0,0 @@ -/* - tests/test_opaque_types.cpp -- opaque types, passing void pointers - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include -#include - -// IMPORTANT: Disable internal pybind11 translation mechanisms for STL data structures -// -// This also deliberately doesn't use the below StringList type alias to test -// that MAKE_OPAQUE can handle a type containing a `,`. (The `std::allocator` -// bit is just the default `std::vector` allocator). -PYBIND11_MAKE_OPAQUE(std::vector>); - -using StringList = std::vector>; - -TEST_SUBMODULE(opaque_types, m) { - // test_string_list - py::class_(m, "StringList") - .def(py::init<>()) - .def("pop_back", &StringList::pop_back) - /* There are multiple versions of push_back(), etc. Select the right ones. */ - .def("push_back", (void (StringList::*)(const std::string &)) &StringList::push_back) - .def("back", (std::string &(StringList::*)()) &StringList::back) - .def("__len__", [](const StringList &v) { return v.size(); }) - .def("__iter__", [](StringList &v) { - return py::make_iterator(v.begin(), v.end()); - }, py::keep_alive<0, 1>()); - - class ClassWithSTLVecProperty { - public: - StringList stringList; - }; - py::class_(m, "ClassWithSTLVecProperty") - .def(py::init<>()) - .def_readwrite("stringList", &ClassWithSTLVecProperty::stringList); - - m.def("print_opaque_list", [](const StringList &l) { - std::string ret = "Opaque list: ["; - bool first = true; - for (auto entry : l) { - if (!first) - ret += ", "; - ret += entry; - first = false; - } - return ret + "]"; - }); - - // test_pointers - m.def("return_void_ptr", []() { return (void *) 0x1234; }); - m.def("get_void_ptr_value", [](void *ptr) { return reinterpret_cast(ptr); }); - m.def("return_null_str", []() { return (char *) nullptr; }); - m.def("get_null_str_value", [](char *ptr) { return reinterpret_cast(ptr); }); - - m.def("return_unique_ptr", []() -> std::unique_ptr { - StringList *result = new StringList(); - result->push_back("some value"); - return std::unique_ptr(result); - }); -} diff --git a/pybind11/tests/test_opaque_types.py b/pybind11/tests/test_opaque_types.py deleted file mode 100644 index 6b3802f..0000000 --- a/pybind11/tests/test_opaque_types.py +++ /dev/null @@ -1,46 +0,0 @@ -import pytest -from pybind11_tests import opaque_types as m -from pybind11_tests import ConstructorStats, UserType - - -def test_string_list(): - lst = m.StringList() - lst.push_back("Element 1") - lst.push_back("Element 2") - assert m.print_opaque_list(lst) == "Opaque list: [Element 1, Element 2]" - assert lst.back() == "Element 2" - - for i, k in enumerate(lst, start=1): - assert k == "Element {}".format(i) - lst.pop_back() - assert m.print_opaque_list(lst) == "Opaque list: [Element 1]" - - cvp = m.ClassWithSTLVecProperty() - assert m.print_opaque_list(cvp.stringList) == "Opaque list: []" - - cvp.stringList = lst - cvp.stringList.push_back("Element 3") - assert m.print_opaque_list(cvp.stringList) == "Opaque list: [Element 1, Element 3]" - - -def test_pointers(msg): - living_before = ConstructorStats.get(UserType).alive() - assert m.get_void_ptr_value(m.return_void_ptr()) == 0x1234 - assert m.get_void_ptr_value(UserType()) # Should also work for other C++ types - assert ConstructorStats.get(UserType).alive() == living_before - - with pytest.raises(TypeError) as excinfo: - m.get_void_ptr_value([1, 2, 3]) # This should not work - assert msg(excinfo.value) == """ - get_void_ptr_value(): incompatible function arguments. The following argument types are supported: - 1. (arg0: capsule) -> int - - Invoked with: [1, 2, 3] - """ # noqa: E501 line too long - - assert m.return_null_str() is None - assert m.get_null_str_value(m.return_null_str()) is not None - - ptr = m.return_unique_ptr() - assert "StringList" in repr(ptr) - assert m.print_opaque_list(ptr) == "Opaque list: [some value]" diff --git a/pybind11/tests/test_operator_overloading.cpp b/pybind11/tests/test_operator_overloading.cpp deleted file mode 100644 index 7b11170..0000000 --- a/pybind11/tests/test_operator_overloading.cpp +++ /dev/null @@ -1,171 +0,0 @@ -/* - tests/test_operator_overloading.cpp -- operator overloading - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" -#include -#include - -class Vector2 { -public: - Vector2(float x, float y) : x(x), y(y) { print_created(this, toString()); } - Vector2(const Vector2 &v) : x(v.x), y(v.y) { print_copy_created(this); } - Vector2(Vector2 &&v) : x(v.x), y(v.y) { print_move_created(this); v.x = v.y = 0; } - Vector2 &operator=(const Vector2 &v) { x = v.x; y = v.y; print_copy_assigned(this); return *this; } - Vector2 &operator=(Vector2 &&v) { x = v.x; y = v.y; v.x = v.y = 0; print_move_assigned(this); return *this; } - ~Vector2() { print_destroyed(this); } - - std::string toString() const { return "[" + std::to_string(x) + ", " + std::to_string(y) + "]"; } - - Vector2 operator-() const { return Vector2(-x, -y); } - Vector2 operator+(const Vector2 &v) const { return Vector2(x + v.x, y + v.y); } - Vector2 operator-(const Vector2 &v) const { return Vector2(x - v.x, y - v.y); } - Vector2 operator-(float value) const { return Vector2(x - value, y - value); } - Vector2 operator+(float value) const { return Vector2(x + value, y + value); } - Vector2 operator*(float value) const { return Vector2(x * value, y * value); } - Vector2 operator/(float value) const { return Vector2(x / value, y / value); } - Vector2 operator*(const Vector2 &v) const { return Vector2(x * v.x, y * v.y); } - Vector2 operator/(const Vector2 &v) const { return Vector2(x / v.x, y / v.y); } - Vector2& operator+=(const Vector2 &v) { x += v.x; y += v.y; return *this; } - Vector2& operator-=(const Vector2 &v) { x -= v.x; y -= v.y; return *this; } - Vector2& operator*=(float v) { x *= v; y *= v; return *this; } - Vector2& operator/=(float v) { x /= v; y /= v; return *this; } - Vector2& operator*=(const Vector2 &v) { x *= v.x; y *= v.y; return *this; } - Vector2& operator/=(const Vector2 &v) { x /= v.x; y /= v.y; return *this; } - - friend Vector2 operator+(float f, const Vector2 &v) { return Vector2(f + v.x, f + v.y); } - friend Vector2 operator-(float f, const Vector2 &v) { return Vector2(f - v.x, f - v.y); } - friend Vector2 operator*(float f, const Vector2 &v) { return Vector2(f * v.x, f * v.y); } - friend Vector2 operator/(float f, const Vector2 &v) { return Vector2(f / v.x, f / v.y); } -private: - float x, y; -}; - -class C1 { }; -class C2 { }; - -int operator+(const C1 &, const C1 &) { return 11; } -int operator+(const C2 &, const C2 &) { return 22; } -int operator+(const C2 &, const C1 &) { return 21; } -int operator+(const C1 &, const C2 &) { return 12; } - -namespace std { - template<> - struct hash { - // Not a good hash function, but easy to test - size_t operator()(const Vector2 &) { return 4; } - }; -} - -// MSVC warns about unknown pragmas, and warnings are errors. -#ifndef _MSC_VER - #pragma GCC diagnostic push - // clang 7.0.0 and Apple LLVM 10.0.1 introduce `-Wself-assign-overloaded` to - // `-Wall`, which is used here for overloading (e.g. `py::self += py::self `). - // Here, we suppress the warning using `#pragma diagnostic`. - // Taken from: https://github.com/RobotLocomotion/drake/commit/aaf84b46 - // TODO(eric): This could be resolved using a function / functor (e.g. `py::self()`). - #if (__APPLE__) && (__clang__) - #if (__clang_major__ >= 10) && (__clang_minor__ >= 0) && (__clang_patchlevel__ >= 1) - #pragma GCC diagnostic ignored "-Wself-assign-overloaded" - #endif - #elif (__clang__) - #if (__clang_major__ >= 7) - #pragma GCC diagnostic ignored "-Wself-assign-overloaded" - #endif - #endif -#endif - -TEST_SUBMODULE(operators, m) { - - // test_operator_overloading - py::class_(m, "Vector2") - .def(py::init()) - .def(py::self + py::self) - .def(py::self + float()) - .def(py::self - py::self) - .def(py::self - float()) - .def(py::self * float()) - .def(py::self / float()) - .def(py::self * py::self) - .def(py::self / py::self) - .def(py::self += py::self) - .def(py::self -= py::self) - .def(py::self *= float()) - .def(py::self /= float()) - .def(py::self *= py::self) - .def(py::self /= py::self) - .def(float() + py::self) - .def(float() - py::self) - .def(float() * py::self) - .def(float() / py::self) - .def(-py::self) - .def("__str__", &Vector2::toString) - .def(hash(py::self)) - ; - - m.attr("Vector") = m.attr("Vector2"); - - // test_operators_notimplemented - // #393: need to return NotSupported to ensure correct arithmetic operator behavior - py::class_(m, "C1") - .def(py::init<>()) - .def(py::self + py::self); - - py::class_(m, "C2") - .def(py::init<>()) - .def(py::self + py::self) - .def("__add__", [](const C2& c2, const C1& c1) { return c2 + c1; }) - .def("__radd__", [](const C2& c2, const C1& c1) { return c1 + c2; }); - - // test_nested - // #328: first member in a class can't be used in operators - struct NestABase { int value = -2; }; - py::class_(m, "NestABase") - .def(py::init<>()) - .def_readwrite("value", &NestABase::value); - - struct NestA : NestABase { - int value = 3; - NestA& operator+=(int i) { value += i; return *this; } - }; - py::class_(m, "NestA") - .def(py::init<>()) - .def(py::self += int()) - .def("as_base", [](NestA &a) -> NestABase& { - return (NestABase&) a; - }, py::return_value_policy::reference_internal); - m.def("get_NestA", [](const NestA &a) { return a.value; }); - - struct NestB { - NestA a; - int value = 4; - NestB& operator-=(int i) { value -= i; return *this; } - }; - py::class_(m, "NestB") - .def(py::init<>()) - .def(py::self -= int()) - .def_readwrite("a", &NestB::a); - m.def("get_NestB", [](const NestB &b) { return b.value; }); - - struct NestC { - NestB b; - int value = 5; - NestC& operator*=(int i) { value *= i; return *this; } - }; - py::class_(m, "NestC") - .def(py::init<>()) - .def(py::self *= int()) - .def_readwrite("b", &NestC::b); - m.def("get_NestC", [](const NestC &c) { return c.value; }); -} - -#ifndef _MSC_VER - #pragma GCC diagnostic pop -#endif diff --git a/pybind11/tests/test_operator_overloading.py b/pybind11/tests/test_operator_overloading.py deleted file mode 100644 index bd36ac2..0000000 --- a/pybind11/tests/test_operator_overloading.py +++ /dev/null @@ -1,108 +0,0 @@ -import pytest -from pybind11_tests import operators as m -from pybind11_tests import ConstructorStats - - -def test_operator_overloading(): - v1 = m.Vector2(1, 2) - v2 = m.Vector(3, -1) - assert str(v1) == "[1.000000, 2.000000]" - assert str(v2) == "[3.000000, -1.000000]" - - assert str(-v2) == "[-3.000000, 1.000000]" - - assert str(v1 + v2) == "[4.000000, 1.000000]" - assert str(v1 - v2) == "[-2.000000, 3.000000]" - assert str(v1 - 8) == "[-7.000000, -6.000000]" - assert str(v1 + 8) == "[9.000000, 10.000000]" - assert str(v1 * 8) == "[8.000000, 16.000000]" - assert str(v1 / 8) == "[0.125000, 0.250000]" - assert str(8 - v1) == "[7.000000, 6.000000]" - assert str(8 + v1) == "[9.000000, 10.000000]" - assert str(8 * v1) == "[8.000000, 16.000000]" - assert str(8 / v1) == "[8.000000, 4.000000]" - assert str(v1 * v2) == "[3.000000, -2.000000]" - assert str(v2 / v1) == "[3.000000, -0.500000]" - - v1 += 2 * v2 - assert str(v1) == "[7.000000, 0.000000]" - v1 -= v2 - assert str(v1) == "[4.000000, 1.000000]" - v1 *= 2 - assert str(v1) == "[8.000000, 2.000000]" - v1 /= 16 - assert str(v1) == "[0.500000, 0.125000]" - v1 *= v2 - assert str(v1) == "[1.500000, -0.125000]" - v2 /= v1 - assert str(v2) == "[2.000000, 8.000000]" - - assert hash(v1) == 4 - - cstats = ConstructorStats.get(m.Vector2) - assert cstats.alive() == 2 - del v1 - assert cstats.alive() == 1 - del v2 - assert cstats.alive() == 0 - assert cstats.values() == ['[1.000000, 2.000000]', '[3.000000, -1.000000]', - '[-3.000000, 1.000000]', '[4.000000, 1.000000]', - '[-2.000000, 3.000000]', '[-7.000000, -6.000000]', - '[9.000000, 10.000000]', '[8.000000, 16.000000]', - '[0.125000, 0.250000]', '[7.000000, 6.000000]', - '[9.000000, 10.000000]', '[8.000000, 16.000000]', - '[8.000000, 4.000000]', '[3.000000, -2.000000]', - '[3.000000, -0.500000]', '[6.000000, -2.000000]'] - assert cstats.default_constructions == 0 - assert cstats.copy_constructions == 0 - assert cstats.move_constructions >= 10 - assert cstats.copy_assignments == 0 - assert cstats.move_assignments == 0 - - -def test_operators_notimplemented(): - """#393: need to return NotSupported to ensure correct arithmetic operator behavior""" - - c1, c2 = m.C1(), m.C2() - assert c1 + c1 == 11 - assert c2 + c2 == 22 - assert c2 + c1 == 21 - assert c1 + c2 == 12 - - -def test_nested(): - """#328: first member in a class can't be used in operators""" - - a = m.NestA() - b = m.NestB() - c = m.NestC() - - a += 10 - assert m.get_NestA(a) == 13 - b.a += 100 - assert m.get_NestA(b.a) == 103 - c.b.a += 1000 - assert m.get_NestA(c.b.a) == 1003 - b -= 1 - assert m.get_NestB(b) == 3 - c.b -= 3 - assert m.get_NestB(c.b) == 1 - c *= 7 - assert m.get_NestC(c) == 35 - - abase = a.as_base() - assert abase.value == -2 - a.as_base().value += 44 - assert abase.value == 42 - assert c.b.a.as_base().value == -2 - c.b.a.as_base().value += 44 - assert c.b.a.as_base().value == 42 - - del c - pytest.gc_collect() - del a # Shouldn't delete while abase is still alive - pytest.gc_collect() - - assert abase.value == 42 - del abase, b - pytest.gc_collect() diff --git a/pybind11/tests/test_pickling.cpp b/pybind11/tests/test_pickling.cpp deleted file mode 100644 index 9dc63bd..0000000 --- a/pybind11/tests/test_pickling.cpp +++ /dev/null @@ -1,130 +0,0 @@ -/* - tests/test_pickling.cpp -- pickle support - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" - -TEST_SUBMODULE(pickling, m) { - // test_roundtrip - class Pickleable { - public: - Pickleable(const std::string &value) : m_value(value) { } - const std::string &value() const { return m_value; } - - void setExtra1(int extra1) { m_extra1 = extra1; } - void setExtra2(int extra2) { m_extra2 = extra2; } - int extra1() const { return m_extra1; } - int extra2() const { return m_extra2; } - private: - std::string m_value; - int m_extra1 = 0; - int m_extra2 = 0; - }; - - class PickleableNew : public Pickleable { - public: - using Pickleable::Pickleable; - }; - - py::class_(m, "Pickleable") - .def(py::init()) - .def("value", &Pickleable::value) - .def("extra1", &Pickleable::extra1) - .def("extra2", &Pickleable::extra2) - .def("setExtra1", &Pickleable::setExtra1) - .def("setExtra2", &Pickleable::setExtra2) - // For details on the methods below, refer to - // http://docs.python.org/3/library/pickle.html#pickling-class-instances - .def("__getstate__", [](const Pickleable &p) { - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(p.value(), p.extra1(), p.extra2()); - }) - .def("__setstate__", [](Pickleable &p, py::tuple t) { - if (t.size() != 3) - throw std::runtime_error("Invalid state!"); - /* Invoke the constructor (need to use in-place version) */ - new (&p) Pickleable(t[0].cast()); - - /* Assign any additional state */ - p.setExtra1(t[1].cast()); - p.setExtra2(t[2].cast()); - }); - - py::class_(m, "PickleableNew") - .def(py::init()) - .def(py::pickle( - [](const PickleableNew &p) { - return py::make_tuple(p.value(), p.extra1(), p.extra2()); - }, - [](py::tuple t) { - if (t.size() != 3) - throw std::runtime_error("Invalid state!"); - auto p = PickleableNew(t[0].cast()); - - p.setExtra1(t[1].cast()); - p.setExtra2(t[2].cast()); - return p; - } - )); - -#if !defined(PYPY_VERSION) - // test_roundtrip_with_dict - class PickleableWithDict { - public: - PickleableWithDict(const std::string &value) : value(value) { } - - std::string value; - int extra; - }; - - class PickleableWithDictNew : public PickleableWithDict { - public: - using PickleableWithDict::PickleableWithDict; - }; - - py::class_(m, "PickleableWithDict", py::dynamic_attr()) - .def(py::init()) - .def_readwrite("value", &PickleableWithDict::value) - .def_readwrite("extra", &PickleableWithDict::extra) - .def("__getstate__", [](py::object self) { - /* Also include __dict__ in state */ - return py::make_tuple(self.attr("value"), self.attr("extra"), self.attr("__dict__")); - }) - .def("__setstate__", [](py::object self, py::tuple t) { - if (t.size() != 3) - throw std::runtime_error("Invalid state!"); - /* Cast and construct */ - auto& p = self.cast(); - new (&p) PickleableWithDict(t[0].cast()); - - /* Assign C++ state */ - p.extra = t[1].cast(); - - /* Assign Python state */ - self.attr("__dict__") = t[2]; - }); - - py::class_(m, "PickleableWithDictNew") - .def(py::init()) - .def(py::pickle( - [](py::object self) { - return py::make_tuple(self.attr("value"), self.attr("extra"), self.attr("__dict__")); - }, - [](const py::tuple &t) { - if (t.size() != 3) - throw std::runtime_error("Invalid state!"); - - auto cpp_state = PickleableWithDictNew(t[0].cast()); - cpp_state.extra = t[1].cast(); - - auto py_state = t[2].cast(); - return std::make_pair(cpp_state, py_state); - } - )); -#endif -} diff --git a/pybind11/tests/test_pickling.py b/pybind11/tests/test_pickling.py deleted file mode 100644 index 5ae05aa..0000000 --- a/pybind11/tests/test_pickling.py +++ /dev/null @@ -1,42 +0,0 @@ -import pytest -from pybind11_tests import pickling as m - -try: - import cPickle as pickle # Use cPickle on Python 2.7 -except ImportError: - import pickle - - -@pytest.mark.parametrize("cls_name", ["Pickleable", "PickleableNew"]) -def test_roundtrip(cls_name): - cls = getattr(m, cls_name) - p = cls("test_value") - p.setExtra1(15) - p.setExtra2(48) - - data = pickle.dumps(p, 2) # Must use pickle protocol >= 2 - p2 = pickle.loads(data) - assert p2.value() == p.value() - assert p2.extra1() == p.extra1() - assert p2.extra2() == p.extra2() - - -@pytest.unsupported_on_pypy -@pytest.mark.parametrize("cls_name", ["PickleableWithDict", "PickleableWithDictNew"]) -def test_roundtrip_with_dict(cls_name): - cls = getattr(m, cls_name) - p = cls("test_value") - p.extra = 15 - p.dynamic = "Attribute" - - data = pickle.dumps(p, pickle.HIGHEST_PROTOCOL) - p2 = pickle.loads(data) - assert p2.value == p.value - assert p2.extra == p.extra - assert p2.dynamic == p.dynamic - - -def test_enum_pickle(): - from pybind11_tests import enums as e - data = pickle.dumps(e.EOne, 2) - assert e.EOne == pickle.loads(data) diff --git a/pybind11/tests/test_pytypes.cpp b/pybind11/tests/test_pytypes.cpp deleted file mode 100644 index 244e1db..0000000 --- a/pybind11/tests/test_pytypes.cpp +++ /dev/null @@ -1,310 +0,0 @@ -/* - tests/test_pytypes.cpp -- Python type casters - - Copyright (c) 2017 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" - - -TEST_SUBMODULE(pytypes, m) { - // test_list - m.def("get_list", []() { - py::list list; - list.append("value"); - py::print("Entry at position 0:", list[0]); - list[0] = py::str("overwritten"); - list.insert(0, "inserted-0"); - list.insert(2, "inserted-2"); - return list; - }); - m.def("print_list", [](py::list list) { - int index = 0; - for (auto item : list) - py::print("list item {}: {}"_s.format(index++, item)); - }); - - // test_set - m.def("get_set", []() { - py::set set; - set.add(py::str("key1")); - set.add("key2"); - set.add(std::string("key3")); - return set; - }); - m.def("print_set", [](py::set set) { - for (auto item : set) - py::print("key:", item); - }); - m.def("set_contains", [](py::set set, py::object key) { - return set.contains(key); - }); - m.def("set_contains", [](py::set set, const char* key) { - return set.contains(key); - }); - - // test_dict - m.def("get_dict", []() { return py::dict("key"_a="value"); }); - m.def("print_dict", [](py::dict dict) { - for (auto item : dict) - py::print("key: {}, value={}"_s.format(item.first, item.second)); - }); - m.def("dict_keyword_constructor", []() { - auto d1 = py::dict("x"_a=1, "y"_a=2); - auto d2 = py::dict("z"_a=3, **d1); - return d2; - }); - m.def("dict_contains", [](py::dict dict, py::object val) { - return dict.contains(val); - }); - m.def("dict_contains", [](py::dict dict, const char* val) { - return dict.contains(val); - }); - - // test_str - m.def("str_from_string", []() { return py::str(std::string("baz")); }); - m.def("str_from_bytes", []() { return py::str(py::bytes("boo", 3)); }); - m.def("str_from_object", [](const py::object& obj) { return py::str(obj); }); - m.def("repr_from_object", [](const py::object& obj) { return py::repr(obj); }); - - m.def("str_format", []() { - auto s1 = "{} + {} = {}"_s.format(1, 2, 3); - auto s2 = "{a} + {b} = {c}"_s.format("a"_a=1, "b"_a=2, "c"_a=3); - return py::make_tuple(s1, s2); - }); - - // test_bytes - m.def("bytes_from_string", []() { return py::bytes(std::string("foo")); }); - m.def("bytes_from_str", []() { return py::bytes(py::str("bar", 3)); }); - - // test_capsule - m.def("return_capsule_with_destructor", []() { - py::print("creating capsule"); - return py::capsule([]() { - py::print("destructing capsule"); - }); - }); - - m.def("return_capsule_with_destructor_2", []() { - py::print("creating capsule"); - return py::capsule((void *) 1234, [](void *ptr) { - py::print("destructing capsule: {}"_s.format((size_t) ptr)); - }); - }); - - m.def("return_capsule_with_name_and_destructor", []() { - auto capsule = py::capsule((void *) 1234, "pointer type description", [](PyObject *ptr) { - if (ptr) { - auto name = PyCapsule_GetName(ptr); - py::print("destructing capsule ({}, '{}')"_s.format( - (size_t) PyCapsule_GetPointer(ptr, name), name - )); - } - }); - void *contents = capsule; - py::print("created capsule ({}, '{}')"_s.format((size_t) contents, capsule.name())); - return capsule; - }); - - // test_accessors - m.def("accessor_api", [](py::object o) { - auto d = py::dict(); - - d["basic_attr"] = o.attr("basic_attr"); - - auto l = py::list(); - for (const auto &item : o.attr("begin_end")) { - l.append(item); - } - d["begin_end"] = l; - - d["operator[object]"] = o.attr("d")["operator[object]"_s]; - d["operator[char *]"] = o.attr("d")["operator[char *]"]; - - d["attr(object)"] = o.attr("sub").attr("attr_obj"); - d["attr(char *)"] = o.attr("sub").attr("attr_char"); - try { - o.attr("sub").attr("missing").ptr(); - } catch (const py::error_already_set &) { - d["missing_attr_ptr"] = "raised"_s; - } - try { - o.attr("missing").attr("doesn't matter"); - } catch (const py::error_already_set &) { - d["missing_attr_chain"] = "raised"_s; - } - - d["is_none"] = o.attr("basic_attr").is_none(); - - d["operator()"] = o.attr("func")(1); - d["operator*"] = o.attr("func")(*o.attr("begin_end")); - - // Test implicit conversion - py::list implicit_list = o.attr("begin_end"); - d["implicit_list"] = implicit_list; - py::dict implicit_dict = o.attr("__dict__"); - d["implicit_dict"] = implicit_dict; - - return d; - }); - - m.def("tuple_accessor", [](py::tuple existing_t) { - try { - existing_t[0] = 1; - } catch (const py::error_already_set &) { - // --> Python system error - // Only new tuples (refcount == 1) are mutable - auto new_t = py::tuple(3); - for (size_t i = 0; i < new_t.size(); ++i) { - new_t[i] = i; - } - return new_t; - } - return py::tuple(); - }); - - m.def("accessor_assignment", []() { - auto l = py::list(1); - l[0] = 0; - - auto d = py::dict(); - d["get"] = l[0]; - auto var = l[0]; - d["deferred_get"] = var; - l[0] = 1; - d["set"] = l[0]; - var = 99; // this assignment should not overwrite l[0] - d["deferred_set"] = l[0]; - d["var"] = var; - - return d; - }); - - // test_constructors - m.def("default_constructors", []() { - return py::dict( - "str"_a=py::str(), - "bool"_a=py::bool_(), - "int"_a=py::int_(), - "float"_a=py::float_(), - "tuple"_a=py::tuple(), - "list"_a=py::list(), - "dict"_a=py::dict(), - "set"_a=py::set() - ); - }); - - m.def("converting_constructors", [](py::dict d) { - return py::dict( - "str"_a=py::str(d["str"]), - "bool"_a=py::bool_(d["bool"]), - "int"_a=py::int_(d["int"]), - "float"_a=py::float_(d["float"]), - "tuple"_a=py::tuple(d["tuple"]), - "list"_a=py::list(d["list"]), - "dict"_a=py::dict(d["dict"]), - "set"_a=py::set(d["set"]), - "memoryview"_a=py::memoryview(d["memoryview"]) - ); - }); - - m.def("cast_functions", [](py::dict d) { - // When converting between Python types, obj.cast() should be the same as T(obj) - return py::dict( - "str"_a=d["str"].cast(), - "bool"_a=d["bool"].cast(), - "int"_a=d["int"].cast(), - "float"_a=d["float"].cast(), - "tuple"_a=d["tuple"].cast(), - "list"_a=d["list"].cast(), - "dict"_a=d["dict"].cast(), - "set"_a=d["set"].cast(), - "memoryview"_a=d["memoryview"].cast() - ); - }); - - m.def("get_implicit_casting", []() { - py::dict d; - d["char*_i1"] = "abc"; - const char *c2 = "abc"; - d["char*_i2"] = c2; - d["char*_e"] = py::cast(c2); - d["char*_p"] = py::str(c2); - - d["int_i1"] = 42; - int i = 42; - d["int_i2"] = i; - i++; - d["int_e"] = py::cast(i); - i++; - d["int_p"] = py::int_(i); - - d["str_i1"] = std::string("str"); - std::string s2("str1"); - d["str_i2"] = s2; - s2[3] = '2'; - d["str_e"] = py::cast(s2); - s2[3] = '3'; - d["str_p"] = py::str(s2); - - py::list l(2); - l[0] = 3; - l[1] = py::cast(6); - l.append(9); - l.append(py::cast(12)); - l.append(py::int_(15)); - - return py::dict( - "d"_a=d, - "l"_a=l - ); - }); - - // test_print - m.def("print_function", []() { - py::print("Hello, World!"); - py::print(1, 2.0, "three", true, std::string("-- multiple args")); - auto args = py::make_tuple("and", "a", "custom", "separator"); - py::print("*args", *args, "sep"_a="-"); - py::print("no new line here", "end"_a=" -- "); - py::print("next print"); - - auto py_stderr = py::module::import("sys").attr("stderr"); - py::print("this goes to stderr", "file"_a=py_stderr); - - py::print("flush", "flush"_a=true); - - py::print("{a} + {b} = {c}"_s.format("a"_a="py::print", "b"_a="str.format", "c"_a="this")); - }); - - m.def("print_failure", []() { py::print(42, UnregisteredType()); }); - - m.def("hash_function", [](py::object obj) { return py::hash(obj); }); - - m.def("test_number_protocol", [](py::object a, py::object b) { - py::list l; - l.append(a.equal(b)); - l.append(a.not_equal(b)); - l.append(a < b); - l.append(a <= b); - l.append(a > b); - l.append(a >= b); - l.append(a + b); - l.append(a - b); - l.append(a * b); - l.append(a / b); - l.append(a | b); - l.append(a & b); - l.append(a ^ b); - l.append(a >> b); - l.append(a << b); - return l; - }); - - m.def("test_list_slicing", [](py::list a) { - return a[py::slice(0, -1, 2)]; - }); -} diff --git a/pybind11/tests/test_pytypes.py b/pybind11/tests/test_pytypes.py deleted file mode 100644 index 0e8d6c3..0000000 --- a/pybind11/tests/test_pytypes.py +++ /dev/null @@ -1,263 +0,0 @@ -from __future__ import division -import pytest -import sys - -from pybind11_tests import pytypes as m -from pybind11_tests import debug_enabled - - -def test_list(capture, doc): - with capture: - lst = m.get_list() - assert lst == ["inserted-0", "overwritten", "inserted-2"] - - lst.append("value2") - m.print_list(lst) - assert capture.unordered == """ - Entry at position 0: value - list item 0: inserted-0 - list item 1: overwritten - list item 2: inserted-2 - list item 3: value2 - """ - - assert doc(m.get_list) == "get_list() -> list" - assert doc(m.print_list) == "print_list(arg0: list) -> None" - - -def test_set(capture, doc): - s = m.get_set() - assert s == {"key1", "key2", "key3"} - - with capture: - s.add("key4") - m.print_set(s) - assert capture.unordered == """ - key: key1 - key: key2 - key: key3 - key: key4 - """ - - assert not m.set_contains(set([]), 42) - assert m.set_contains({42}, 42) - assert m.set_contains({"foo"}, "foo") - - assert doc(m.get_list) == "get_list() -> list" - assert doc(m.print_list) == "print_list(arg0: list) -> None" - - -def test_dict(capture, doc): - d = m.get_dict() - assert d == {"key": "value"} - - with capture: - d["key2"] = "value2" - m.print_dict(d) - assert capture.unordered == """ - key: key, value=value - key: key2, value=value2 - """ - - assert not m.dict_contains({}, 42) - assert m.dict_contains({42: None}, 42) - assert m.dict_contains({"foo": None}, "foo") - - assert doc(m.get_dict) == "get_dict() -> dict" - assert doc(m.print_dict) == "print_dict(arg0: dict) -> None" - - assert m.dict_keyword_constructor() == {"x": 1, "y": 2, "z": 3} - - -def test_str(doc): - assert m.str_from_string().encode().decode() == "baz" - assert m.str_from_bytes().encode().decode() == "boo" - - assert doc(m.str_from_bytes) == "str_from_bytes() -> str" - - class A(object): - def __str__(self): - return "this is a str" - - def __repr__(self): - return "this is a repr" - - assert m.str_from_object(A()) == "this is a str" - assert m.repr_from_object(A()) == "this is a repr" - - s1, s2 = m.str_format() - assert s1 == "1 + 2 = 3" - assert s1 == s2 - - -def test_bytes(doc): - assert m.bytes_from_string().decode() == "foo" - assert m.bytes_from_str().decode() == "bar" - - assert doc(m.bytes_from_str) == "bytes_from_str() -> {}".format( - "bytes" if sys.version_info[0] == 3 else "str" - ) - - -def test_capsule(capture): - pytest.gc_collect() - with capture: - a = m.return_capsule_with_destructor() - del a - pytest.gc_collect() - assert capture.unordered == """ - creating capsule - destructing capsule - """ - - with capture: - a = m.return_capsule_with_destructor_2() - del a - pytest.gc_collect() - assert capture.unordered == """ - creating capsule - destructing capsule: 1234 - """ - - with capture: - a = m.return_capsule_with_name_and_destructor() - del a - pytest.gc_collect() - assert capture.unordered == """ - created capsule (1234, 'pointer type description') - destructing capsule (1234, 'pointer type description') - """ - - -def test_accessors(): - class SubTestObject: - attr_obj = 1 - attr_char = 2 - - class TestObject: - basic_attr = 1 - begin_end = [1, 2, 3] - d = {"operator[object]": 1, "operator[char *]": 2} - sub = SubTestObject() - - def func(self, x, *args): - return self.basic_attr + x + sum(args) - - d = m.accessor_api(TestObject()) - assert d["basic_attr"] == 1 - assert d["begin_end"] == [1, 2, 3] - assert d["operator[object]"] == 1 - assert d["operator[char *]"] == 2 - assert d["attr(object)"] == 1 - assert d["attr(char *)"] == 2 - assert d["missing_attr_ptr"] == "raised" - assert d["missing_attr_chain"] == "raised" - assert d["is_none"] is False - assert d["operator()"] == 2 - assert d["operator*"] == 7 - assert d["implicit_list"] == [1, 2, 3] - assert all(x in TestObject.__dict__ for x in d["implicit_dict"]) - - assert m.tuple_accessor(tuple()) == (0, 1, 2) - - d = m.accessor_assignment() - assert d["get"] == 0 - assert d["deferred_get"] == 0 - assert d["set"] == 1 - assert d["deferred_set"] == 1 - assert d["var"] == 99 - - -def test_constructors(): - """C++ default and converting constructors are equivalent to type calls in Python""" - types = [str, bool, int, float, tuple, list, dict, set] - expected = {t.__name__: t() for t in types} - assert m.default_constructors() == expected - - data = { - str: 42, - bool: "Not empty", - int: "42", - float: "+1e3", - tuple: range(3), - list: range(3), - dict: [("two", 2), ("one", 1), ("three", 3)], - set: [4, 4, 5, 6, 6, 6], - memoryview: b'abc' - } - inputs = {k.__name__: v for k, v in data.items()} - expected = {k.__name__: k(v) for k, v in data.items()} - - assert m.converting_constructors(inputs) == expected - assert m.cast_functions(inputs) == expected - - # Converting constructors and cast functions should just reference rather - # than copy when no conversion is needed: - noconv1 = m.converting_constructors(expected) - for k in noconv1: - assert noconv1[k] is expected[k] - - noconv2 = m.cast_functions(expected) - for k in noconv2: - assert noconv2[k] is expected[k] - - -def test_implicit_casting(): - """Tests implicit casting when assigning or appending to dicts and lists.""" - z = m.get_implicit_casting() - assert z['d'] == { - 'char*_i1': 'abc', 'char*_i2': 'abc', 'char*_e': 'abc', 'char*_p': 'abc', - 'str_i1': 'str', 'str_i2': 'str1', 'str_e': 'str2', 'str_p': 'str3', - 'int_i1': 42, 'int_i2': 42, 'int_e': 43, 'int_p': 44 - } - assert z['l'] == [3, 6, 9, 12, 15] - - -def test_print(capture): - with capture: - m.print_function() - assert capture == """ - Hello, World! - 1 2.0 three True -- multiple args - *args-and-a-custom-separator - no new line here -- next print - flush - py::print + str.format = this - """ - assert capture.stderr == "this goes to stderr" - - with pytest.raises(RuntimeError) as excinfo: - m.print_failure() - assert str(excinfo.value) == "make_tuple(): unable to convert " + ( - "argument of type 'UnregisteredType' to Python object" - if debug_enabled else - "arguments to Python object (compile in debug mode for details)" - ) - - -def test_hash(): - class Hashable(object): - def __init__(self, value): - self.value = value - - def __hash__(self): - return self.value - - class Unhashable(object): - __hash__ = None - - assert m.hash_function(Hashable(42)) == 42 - with pytest.raises(TypeError): - m.hash_function(Unhashable()) - - -def test_number_protocol(): - for a, b in [(1, 1), (3, 5)]: - li = [a == b, a != b, a < b, a <= b, a > b, a >= b, a + b, - a - b, a * b, a / b, a | b, a & b, a ^ b, a >> b, a << b] - assert m.test_number_protocol(a, b) == li - - -def test_list_slicing(): - li = list(range(100)) - assert li[::2] == m.test_list_slicing(li) diff --git a/pybind11/tests/test_sequences_and_iterators.cpp b/pybind11/tests/test_sequences_and_iterators.cpp deleted file mode 100644 index 87ccf99..0000000 --- a/pybind11/tests/test_sequences_and_iterators.cpp +++ /dev/null @@ -1,353 +0,0 @@ -/* - tests/test_sequences_and_iterators.cpp -- supporting Pythons' sequence protocol, iterators, - etc. - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" -#include -#include - -template -class NonZeroIterator { - const T* ptr_; -public: - NonZeroIterator(const T* ptr) : ptr_(ptr) {} - const T& operator*() const { return *ptr_; } - NonZeroIterator& operator++() { ++ptr_; return *this; } -}; - -class NonZeroSentinel {}; - -template -bool operator==(const NonZeroIterator>& it, const NonZeroSentinel&) { - return !(*it).first || !(*it).second; -} - -template -py::list test_random_access_iterator(PythonType x) { - if (x.size() < 5) - throw py::value_error("Please provide at least 5 elements for testing."); - - auto checks = py::list(); - auto assert_equal = [&checks](py::handle a, py::handle b) { - auto result = PyObject_RichCompareBool(a.ptr(), b.ptr(), Py_EQ); - if (result == -1) { throw py::error_already_set(); } - checks.append(result != 0); - }; - - auto it = x.begin(); - assert_equal(x[0], *it); - assert_equal(x[0], it[0]); - assert_equal(x[1], it[1]); - - assert_equal(x[1], *(++it)); - assert_equal(x[1], *(it++)); - assert_equal(x[2], *it); - assert_equal(x[3], *(it += 1)); - assert_equal(x[2], *(--it)); - assert_equal(x[2], *(it--)); - assert_equal(x[1], *it); - assert_equal(x[0], *(it -= 1)); - - assert_equal(it->attr("real"), x[0].attr("real")); - assert_equal((it + 1)->attr("real"), x[1].attr("real")); - - assert_equal(x[1], *(it + 1)); - assert_equal(x[1], *(1 + it)); - it += 3; - assert_equal(x[1], *(it - 2)); - - checks.append(static_cast(x.end() - x.begin()) == x.size()); - checks.append((x.begin() + static_cast(x.size())) == x.end()); - checks.append(x.begin() < x.end()); - - return checks; -} - -TEST_SUBMODULE(sequences_and_iterators, m) { - // test_sliceable - class Sliceable{ - public: - Sliceable(int n): size(n) {} - int start,stop,step; - int size; - }; - py::class_(m,"Sliceable") - .def(py::init()) - .def("__getitem__",[](const Sliceable &s, py::slice slice) { - ssize_t start, stop, step, slicelength; - if (!slice.compute(s.size, &start, &stop, &step, &slicelength)) - throw py::error_already_set(); - int istart = static_cast(start); - int istop = static_cast(stop); - int istep = static_cast(step); - return std::make_tuple(istart,istop,istep); - }) - ; - - // test_sequence - class Sequence { - public: - Sequence(size_t size) : m_size(size) { - print_created(this, "of size", m_size); - m_data = new float[size]; - memset(m_data, 0, sizeof(float) * size); - } - Sequence(const std::vector &value) : m_size(value.size()) { - print_created(this, "of size", m_size, "from std::vector"); - m_data = new float[m_size]; - memcpy(m_data, &value[0], sizeof(float) * m_size); - } - Sequence(const Sequence &s) : m_size(s.m_size) { - print_copy_created(this); - m_data = new float[m_size]; - memcpy(m_data, s.m_data, sizeof(float)*m_size); - } - Sequence(Sequence &&s) : m_size(s.m_size), m_data(s.m_data) { - print_move_created(this); - s.m_size = 0; - s.m_data = nullptr; - } - - ~Sequence() { print_destroyed(this); delete[] m_data; } - - Sequence &operator=(const Sequence &s) { - if (&s != this) { - delete[] m_data; - m_size = s.m_size; - m_data = new float[m_size]; - memcpy(m_data, s.m_data, sizeof(float)*m_size); - } - print_copy_assigned(this); - return *this; - } - - Sequence &operator=(Sequence &&s) { - if (&s != this) { - delete[] m_data; - m_size = s.m_size; - m_data = s.m_data; - s.m_size = 0; - s.m_data = nullptr; - } - print_move_assigned(this); - return *this; - } - - bool operator==(const Sequence &s) const { - if (m_size != s.size()) return false; - for (size_t i = 0; i < m_size; ++i) - if (m_data[i] != s[i]) - return false; - return true; - } - bool operator!=(const Sequence &s) const { return !operator==(s); } - - float operator[](size_t index) const { return m_data[index]; } - float &operator[](size_t index) { return m_data[index]; } - - bool contains(float v) const { - for (size_t i = 0; i < m_size; ++i) - if (v == m_data[i]) - return true; - return false; - } - - Sequence reversed() const { - Sequence result(m_size); - for (size_t i = 0; i < m_size; ++i) - result[m_size - i - 1] = m_data[i]; - return result; - } - - size_t size() const { return m_size; } - - const float *begin() const { return m_data; } - const float *end() const { return m_data+m_size; } - - private: - size_t m_size; - float *m_data; - }; - py::class_(m, "Sequence") - .def(py::init()) - .def(py::init&>()) - /// Bare bones interface - .def("__getitem__", [](const Sequence &s, size_t i) { - if (i >= s.size()) throw py::index_error(); - return s[i]; - }) - .def("__setitem__", [](Sequence &s, size_t i, float v) { - if (i >= s.size()) throw py::index_error(); - s[i] = v; - }) - .def("__len__", &Sequence::size) - /// Optional sequence protocol operations - .def("__iter__", [](const Sequence &s) { return py::make_iterator(s.begin(), s.end()); }, - py::keep_alive<0, 1>() /* Essential: keep object alive while iterator exists */) - .def("__contains__", [](const Sequence &s, float v) { return s.contains(v); }) - .def("__reversed__", [](const Sequence &s) -> Sequence { return s.reversed(); }) - /// Slicing protocol (optional) - .def("__getitem__", [](const Sequence &s, py::slice slice) -> Sequence* { - size_t start, stop, step, slicelength; - if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) - throw py::error_already_set(); - Sequence *seq = new Sequence(slicelength); - for (size_t i = 0; i < slicelength; ++i) { - (*seq)[i] = s[start]; start += step; - } - return seq; - }) - .def("__setitem__", [](Sequence &s, py::slice slice, const Sequence &value) { - size_t start, stop, step, slicelength; - if (!slice.compute(s.size(), &start, &stop, &step, &slicelength)) - throw py::error_already_set(); - if (slicelength != value.size()) - throw std::runtime_error("Left and right hand size of slice assignment have different sizes!"); - for (size_t i = 0; i < slicelength; ++i) { - s[start] = value[i]; start += step; - } - }) - /// Comparisons - .def(py::self == py::self) - .def(py::self != py::self) - // Could also define py::self + py::self for concatenation, etc. - ; - - // test_map_iterator - // Interface of a map-like object that isn't (directly) an unordered_map, but provides some basic - // map-like functionality. - class StringMap { - public: - StringMap() = default; - StringMap(std::unordered_map init) - : map(std::move(init)) {} - - void set(std::string key, std::string val) { map[key] = val; } - std::string get(std::string key) const { return map.at(key); } - size_t size() const { return map.size(); } - private: - std::unordered_map map; - public: - decltype(map.cbegin()) begin() const { return map.cbegin(); } - decltype(map.cend()) end() const { return map.cend(); } - }; - py::class_(m, "StringMap") - .def(py::init<>()) - .def(py::init>()) - .def("__getitem__", [](const StringMap &map, std::string key) { - try { return map.get(key); } - catch (const std::out_of_range&) { - throw py::key_error("key '" + key + "' does not exist"); - } - }) - .def("__setitem__", &StringMap::set) - .def("__len__", &StringMap::size) - .def("__iter__", [](const StringMap &map) { return py::make_key_iterator(map.begin(), map.end()); }, - py::keep_alive<0, 1>()) - .def("items", [](const StringMap &map) { return py::make_iterator(map.begin(), map.end()); }, - py::keep_alive<0, 1>()) - ; - - // test_generalized_iterators - class IntPairs { - public: - IntPairs(std::vector> data) : data_(std::move(data)) {} - const std::pair* begin() const { return data_.data(); } - private: - std::vector> data_; - }; - py::class_(m, "IntPairs") - .def(py::init>>()) - .def("nonzero", [](const IntPairs& s) { - return py::make_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); - }, py::keep_alive<0, 1>()) - .def("nonzero_keys", [](const IntPairs& s) { - return py::make_key_iterator(NonZeroIterator>(s.begin()), NonZeroSentinel()); - }, py::keep_alive<0, 1>()) - ; - - -#if 0 - // Obsolete: special data structure for exposing custom iterator types to python - // kept here for illustrative purposes because there might be some use cases which - // are not covered by the much simpler py::make_iterator - - struct PySequenceIterator { - PySequenceIterator(const Sequence &seq, py::object ref) : seq(seq), ref(ref) { } - - float next() { - if (index == seq.size()) - throw py::stop_iteration(); - return seq[index++]; - } - - const Sequence &seq; - py::object ref; // keep a reference - size_t index = 0; - }; - - py::class_(seq, "Iterator") - .def("__iter__", [](PySequenceIterator &it) -> PySequenceIterator& { return it; }) - .def("__next__", &PySequenceIterator::next); - - On the actual Sequence object, the iterator would be constructed as follows: - .def("__iter__", [](py::object s) { return PySequenceIterator(s.cast(), s); }) -#endif - - // test_python_iterator_in_cpp - m.def("object_to_list", [](py::object o) { - auto l = py::list(); - for (auto item : o) { - l.append(item); - } - return l; - }); - - m.def("iterator_to_list", [](py::iterator it) { - auto l = py::list(); - while (it != py::iterator::sentinel()) { - l.append(*it); - ++it; - } - return l; - }); - - // Make sure that py::iterator works with std algorithms - m.def("count_none", [](py::object o) { - return std::count_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); - }); - - m.def("find_none", [](py::object o) { - auto it = std::find_if(o.begin(), o.end(), [](py::handle h) { return h.is_none(); }); - return it->is_none(); - }); - - m.def("count_nonzeros", [](py::dict d) { - return std::count_if(d.begin(), d.end(), [](std::pair p) { - return p.second.cast() != 0; - }); - }); - - m.def("tuple_iterator", &test_random_access_iterator); - m.def("list_iterator", &test_random_access_iterator); - m.def("sequence_iterator", &test_random_access_iterator); - - // test_iterator_passthrough - // #181: iterator passthrough did not compile - m.def("iterator_passthrough", [](py::iterator s) -> py::iterator { - return py::make_iterator(std::begin(s), std::end(s)); - }); - - // test_iterator_rvp - // #388: Can't make iterators via make_iterator() with different r/v policies - static std::vector list = { 1, 2, 3 }; - m.def("make_iterator_1", []() { return py::make_iterator(list); }); - m.def("make_iterator_2", []() { return py::make_iterator(list); }); -} diff --git a/pybind11/tests/test_sequences_and_iterators.py b/pybind11/tests/test_sequences_and_iterators.py deleted file mode 100644 index 6bd1606..0000000 --- a/pybind11/tests/test_sequences_and_iterators.py +++ /dev/null @@ -1,171 +0,0 @@ -import pytest -from pybind11_tests import sequences_and_iterators as m -from pybind11_tests import ConstructorStats - - -def isclose(a, b, rel_tol=1e-05, abs_tol=0.0): - """Like math.isclose() from Python 3.5""" - return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) - - -def allclose(a_list, b_list, rel_tol=1e-05, abs_tol=0.0): - return all(isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) for a, b in zip(a_list, b_list)) - - -def test_generalized_iterators(): - assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)] - assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)] - assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero()) == [] - - assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero_keys()) == [1, 3] - assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1] - assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == [] - - # __next__ must continue to raise StopIteration - it = m.IntPairs([(0, 0)]).nonzero() - for _ in range(3): - with pytest.raises(StopIteration): - next(it) - - it = m.IntPairs([(0, 0)]).nonzero_keys() - for _ in range(3): - with pytest.raises(StopIteration): - next(it) - - -def test_sliceable(): - sliceable = m.Sliceable(100) - assert sliceable[::] == (0, 100, 1) - assert sliceable[10::] == (10, 100, 1) - assert sliceable[:10:] == (0, 10, 1) - assert sliceable[::10] == (0, 100, 10) - assert sliceable[-10::] == (90, 100, 1) - assert sliceable[:-10:] == (0, 90, 1) - assert sliceable[::-10] == (99, -1, -10) - assert sliceable[50:60:1] == (50, 60, 1) - assert sliceable[50:60:-1] == (50, 60, -1) - - -def test_sequence(): - cstats = ConstructorStats.get(m.Sequence) - - s = m.Sequence(5) - assert cstats.values() == ['of size', '5'] - - assert "Sequence" in repr(s) - assert len(s) == 5 - assert s[0] == 0 and s[3] == 0 - assert 12.34 not in s - s[0], s[3] = 12.34, 56.78 - assert 12.34 in s - assert isclose(s[0], 12.34) and isclose(s[3], 56.78) - - rev = reversed(s) - assert cstats.values() == ['of size', '5'] - - rev2 = s[::-1] - assert cstats.values() == ['of size', '5'] - - it = iter(m.Sequence(0)) - for _ in range(3): # __next__ must continue to raise StopIteration - with pytest.raises(StopIteration): - next(it) - assert cstats.values() == ['of size', '0'] - - expected = [0, 56.78, 0, 0, 12.34] - assert allclose(rev, expected) - assert allclose(rev2, expected) - assert rev == rev2 - - rev[0::2] = m.Sequence([2.0, 2.0, 2.0]) - assert cstats.values() == ['of size', '3', 'from std::vector'] - - assert allclose(rev, [2, 56.78, 2, 0, 2]) - - assert cstats.alive() == 4 - del it - assert cstats.alive() == 3 - del s - assert cstats.alive() == 2 - del rev - assert cstats.alive() == 1 - del rev2 - assert cstats.alive() == 0 - - assert cstats.values() == [] - assert cstats.default_constructions == 0 - assert cstats.copy_constructions == 0 - assert cstats.move_constructions >= 1 - assert cstats.copy_assignments == 0 - assert cstats.move_assignments == 0 - - -def test_map_iterator(): - sm = m.StringMap({'hi': 'bye', 'black': 'white'}) - assert sm['hi'] == 'bye' - assert len(sm) == 2 - assert sm['black'] == 'white' - - with pytest.raises(KeyError): - assert sm['orange'] - sm['orange'] = 'banana' - assert sm['orange'] == 'banana' - - expected = {'hi': 'bye', 'black': 'white', 'orange': 'banana'} - for k in sm: - assert sm[k] == expected[k] - for k, v in sm.items(): - assert v == expected[k] - - it = iter(m.StringMap({})) - for _ in range(3): # __next__ must continue to raise StopIteration - with pytest.raises(StopIteration): - next(it) - - -def test_python_iterator_in_cpp(): - t = (1, 2, 3) - assert m.object_to_list(t) == [1, 2, 3] - assert m.object_to_list(iter(t)) == [1, 2, 3] - assert m.iterator_to_list(iter(t)) == [1, 2, 3] - - with pytest.raises(TypeError) as excinfo: - m.object_to_list(1) - assert "object is not iterable" in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - m.iterator_to_list(1) - assert "incompatible function arguments" in str(excinfo.value) - - def bad_next_call(): - raise RuntimeError("py::iterator::advance() should propagate errors") - - with pytest.raises(RuntimeError) as excinfo: - m.iterator_to_list(iter(bad_next_call, None)) - assert str(excinfo.value) == "py::iterator::advance() should propagate errors" - - lst = [1, None, 0, None] - assert m.count_none(lst) == 2 - assert m.find_none(lst) is True - assert m.count_nonzeros({"a": 0, "b": 1, "c": 2}) == 2 - - r = range(5) - assert all(m.tuple_iterator(tuple(r))) - assert all(m.list_iterator(list(r))) - assert all(m.sequence_iterator(r)) - - -def test_iterator_passthrough(): - """#181: iterator passthrough did not compile""" - from pybind11_tests.sequences_and_iterators import iterator_passthrough - - assert list(iterator_passthrough(iter([3, 5, 7, 9, 11, 13, 15]))) == [3, 5, 7, 9, 11, 13, 15] - - -def test_iterator_rvp(): - """#388: Can't make iterators via make_iterator() with different r/v policies """ - import pybind11_tests.sequences_and_iterators as m - - assert list(m.make_iterator_1()) == [1, 2, 3] - assert list(m.make_iterator_2()) == [1, 2, 3] - assert not isinstance(m.make_iterator_1(), type(m.make_iterator_2())) diff --git a/pybind11/tests/test_smart_ptr.cpp b/pybind11/tests/test_smart_ptr.cpp deleted file mode 100644 index 87c9be8..0000000 --- a/pybind11/tests/test_smart_ptr.cpp +++ /dev/null @@ -1,366 +0,0 @@ -/* - tests/test_smart_ptr.cpp -- binding classes with custom reference counting, - implicit conversions between types - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#if defined(_MSC_VER) && _MSC_VER < 1910 -# pragma warning(disable: 4702) // unreachable code in system header -#endif - -#include "pybind11_tests.h" -#include "object.h" - -// Make pybind aware of the ref-counted wrapper type (s): - -// ref is a wrapper for 'Object' which uses intrusive reference counting -// It is always possible to construct a ref from an Object* pointer without -// possible inconsistencies, hence the 'true' argument at the end. -PYBIND11_DECLARE_HOLDER_TYPE(T, ref, true); -// Make pybind11 aware of the non-standard getter member function -namespace pybind11 { namespace detail { - template - struct holder_helper> { - static const T *get(const ref &p) { return p.get_ptr(); } - }; -}} - -// The following is not required anymore for std::shared_ptr, but it should compile without error: -PYBIND11_DECLARE_HOLDER_TYPE(T, std::shared_ptr); - -// This is just a wrapper around unique_ptr, but with extra fields to deliberately bloat up the -// holder size to trigger the non-simple-layout internal instance layout for single inheritance with -// large holder type: -template class huge_unique_ptr { - std::unique_ptr ptr; - uint64_t padding[10]; -public: - huge_unique_ptr(T *p) : ptr(p) {}; - T *get() { return ptr.get(); } -}; -PYBIND11_DECLARE_HOLDER_TYPE(T, huge_unique_ptr); - -// Simple custom holder that works like unique_ptr -template -class custom_unique_ptr { - std::unique_ptr impl; -public: - custom_unique_ptr(T* p) : impl(p) { } - T* get() const { return impl.get(); } - T* release_ptr() { return impl.release(); } -}; -PYBIND11_DECLARE_HOLDER_TYPE(T, custom_unique_ptr); - -// Simple custom holder that works like shared_ptr and has operator& overload -// To obtain address of an instance of this holder pybind should use std::addressof -// Attempt to get address via operator& may leads to segmentation fault -template -class shared_ptr_with_addressof_operator { - std::shared_ptr impl; -public: - shared_ptr_with_addressof_operator( ) = default; - shared_ptr_with_addressof_operator(T* p) : impl(p) { } - T* get() const { return impl.get(); } - T** operator&() { throw std::logic_error("Call of overloaded operator& is not expected"); } -}; -PYBIND11_DECLARE_HOLDER_TYPE(T, shared_ptr_with_addressof_operator); - -// Simple custom holder that works like unique_ptr and has operator& overload -// To obtain address of an instance of this holder pybind should use std::addressof -// Attempt to get address via operator& may leads to segmentation fault -template -class unique_ptr_with_addressof_operator { - std::unique_ptr impl; -public: - unique_ptr_with_addressof_operator() = default; - unique_ptr_with_addressof_operator(T* p) : impl(p) { } - T* get() const { return impl.get(); } - T* release_ptr() { return impl.release(); } - T** operator&() { throw std::logic_error("Call of overloaded operator& is not expected"); } -}; -PYBIND11_DECLARE_HOLDER_TYPE(T, unique_ptr_with_addressof_operator); - - -TEST_SUBMODULE(smart_ptr, m) { - - // test_smart_ptr - - // Object implementation in `object.h` - py::class_> obj(m, "Object"); - obj.def("getRefCount", &Object::getRefCount); - - // Custom object with builtin reference counting (see 'object.h' for the implementation) - class MyObject1 : public Object { - public: - MyObject1(int value) : value(value) { print_created(this, toString()); } - std::string toString() const { return "MyObject1[" + std::to_string(value) + "]"; } - protected: - virtual ~MyObject1() { print_destroyed(this); } - private: - int value; - }; - py::class_>(m, "MyObject1", obj) - .def(py::init()); - py::implicitly_convertible(); - - m.def("make_object_1", []() -> Object * { return new MyObject1(1); }); - m.def("make_object_2", []() -> ref { return new MyObject1(2); }); - m.def("make_myobject1_1", []() -> MyObject1 * { return new MyObject1(4); }); - m.def("make_myobject1_2", []() -> ref { return new MyObject1(5); }); - m.def("print_object_1", [](const Object *obj) { py::print(obj->toString()); }); - m.def("print_object_2", [](ref obj) { py::print(obj->toString()); }); - m.def("print_object_3", [](const ref &obj) { py::print(obj->toString()); }); - m.def("print_object_4", [](const ref *obj) { py::print((*obj)->toString()); }); - m.def("print_myobject1_1", [](const MyObject1 *obj) { py::print(obj->toString()); }); - m.def("print_myobject1_2", [](ref obj) { py::print(obj->toString()); }); - m.def("print_myobject1_3", [](const ref &obj) { py::print(obj->toString()); }); - m.def("print_myobject1_4", [](const ref *obj) { py::print((*obj)->toString()); }); - - // Expose constructor stats for the ref type - m.def("cstats_ref", &ConstructorStats::get); - - - // Object managed by a std::shared_ptr<> - class MyObject2 { - public: - MyObject2(const MyObject2 &) = default; - MyObject2(int value) : value(value) { print_created(this, toString()); } - std::string toString() const { return "MyObject2[" + std::to_string(value) + "]"; } - virtual ~MyObject2() { print_destroyed(this); } - private: - int value; - }; - py::class_>(m, "MyObject2") - .def(py::init()); - m.def("make_myobject2_1", []() { return new MyObject2(6); }); - m.def("make_myobject2_2", []() { return std::make_shared(7); }); - m.def("print_myobject2_1", [](const MyObject2 *obj) { py::print(obj->toString()); }); - m.def("print_myobject2_2", [](std::shared_ptr obj) { py::print(obj->toString()); }); - m.def("print_myobject2_3", [](const std::shared_ptr &obj) { py::print(obj->toString()); }); - m.def("print_myobject2_4", [](const std::shared_ptr *obj) { py::print((*obj)->toString()); }); - - // Object managed by a std::shared_ptr<>, additionally derives from std::enable_shared_from_this<> - class MyObject3 : public std::enable_shared_from_this { - public: - MyObject3(const MyObject3 &) = default; - MyObject3(int value) : value(value) { print_created(this, toString()); } - std::string toString() const { return "MyObject3[" + std::to_string(value) + "]"; } - virtual ~MyObject3() { print_destroyed(this); } - private: - int value; - }; - py::class_>(m, "MyObject3") - .def(py::init()); - m.def("make_myobject3_1", []() { return new MyObject3(8); }); - m.def("make_myobject3_2", []() { return std::make_shared(9); }); - m.def("print_myobject3_1", [](const MyObject3 *obj) { py::print(obj->toString()); }); - m.def("print_myobject3_2", [](std::shared_ptr obj) { py::print(obj->toString()); }); - m.def("print_myobject3_3", [](const std::shared_ptr &obj) { py::print(obj->toString()); }); - m.def("print_myobject3_4", [](const std::shared_ptr *obj) { py::print((*obj)->toString()); }); - - // test_smart_ptr_refcounting - m.def("test_object1_refcounting", []() { - ref o = new MyObject1(0); - bool good = o->getRefCount() == 1; - py::object o2 = py::cast(o, py::return_value_policy::reference); - // always request (partial) ownership for objects with intrusive - // reference counting even when using the 'reference' RVP - good &= o->getRefCount() == 2; - return good; - }); - - // test_unique_nodelete - // Object with a private destructor - class MyObject4 { - public: - MyObject4(int value) : value{value} { print_created(this); } - int value; - private: - ~MyObject4() { print_destroyed(this); } - }; - py::class_>(m, "MyObject4") - .def(py::init()) - .def_readwrite("value", &MyObject4::value); - - // test_unique_deleter - // Object with std::unique_ptr where D is not matching the base class - // Object with a protected destructor - class MyObject4a { - public: - MyObject4a(int i) { - value = i; - print_created(this); - }; - int value; - protected: - virtual ~MyObject4a() { print_destroyed(this); } - }; - py::class_>(m, "MyObject4a") - .def(py::init()) - .def_readwrite("value", &MyObject4a::value); - - // Object derived but with public destructor and no Deleter in default holder - class MyObject4b : public MyObject4a { - public: - MyObject4b(int i) : MyObject4a(i) { print_created(this); } - ~MyObject4b() { print_destroyed(this); } - }; - py::class_(m, "MyObject4b") - .def(py::init()); - - // test_large_holder - class MyObject5 { // managed by huge_unique_ptr - public: - MyObject5(int value) : value{value} { print_created(this); } - ~MyObject5() { print_destroyed(this); } - int value; - }; - py::class_>(m, "MyObject5") - .def(py::init()) - .def_readwrite("value", &MyObject5::value); - - // test_shared_ptr_and_references - struct SharedPtrRef { - struct A { - A() { print_created(this); } - A(const A &) { print_copy_created(this); } - A(A &&) { print_move_created(this); } - ~A() { print_destroyed(this); } - }; - - A value = {}; - std::shared_ptr shared = std::make_shared(); - }; - using A = SharedPtrRef::A; - py::class_>(m, "A"); - py::class_(m, "SharedPtrRef") - .def(py::init<>()) - .def_readonly("ref", &SharedPtrRef::value) - .def_property_readonly("copy", [](const SharedPtrRef &s) { return s.value; }, - py::return_value_policy::copy) - .def_readonly("holder_ref", &SharedPtrRef::shared) - .def_property_readonly("holder_copy", [](const SharedPtrRef &s) { return s.shared; }, - py::return_value_policy::copy) - .def("set_ref", [](SharedPtrRef &, const A &) { return true; }) - .def("set_holder", [](SharedPtrRef &, std::shared_ptr) { return true; }); - - // test_shared_ptr_from_this_and_references - struct SharedFromThisRef { - struct B : std::enable_shared_from_this { - B() { print_created(this); } - B(const B &) : std::enable_shared_from_this() { print_copy_created(this); } - B(B &&) : std::enable_shared_from_this() { print_move_created(this); } - ~B() { print_destroyed(this); } - }; - - B value = {}; - std::shared_ptr shared = std::make_shared(); - }; - using B = SharedFromThisRef::B; - py::class_>(m, "B"); - py::class_(m, "SharedFromThisRef") - .def(py::init<>()) - .def_readonly("bad_wp", &SharedFromThisRef::value) - .def_property_readonly("ref", [](const SharedFromThisRef &s) -> const B & { return *s.shared; }) - .def_property_readonly("copy", [](const SharedFromThisRef &s) { return s.value; }, - py::return_value_policy::copy) - .def_readonly("holder_ref", &SharedFromThisRef::shared) - .def_property_readonly("holder_copy", [](const SharedFromThisRef &s) { return s.shared; }, - py::return_value_policy::copy) - .def("set_ref", [](SharedFromThisRef &, const B &) { return true; }) - .def("set_holder", [](SharedFromThisRef &, std::shared_ptr) { return true; }); - - // Issue #865: shared_from_this doesn't work with virtual inheritance - struct SharedFromThisVBase : std::enable_shared_from_this { - SharedFromThisVBase() = default; - SharedFromThisVBase(const SharedFromThisVBase &) = default; - virtual ~SharedFromThisVBase() = default; - }; - struct SharedFromThisVirt : virtual SharedFromThisVBase {}; - static std::shared_ptr sft(new SharedFromThisVirt()); - py::class_>(m, "SharedFromThisVirt") - .def_static("get", []() { return sft.get(); }); - - // test_move_only_holder - struct C { - C() { print_created(this); } - ~C() { print_destroyed(this); } - }; - py::class_>(m, "TypeWithMoveOnlyHolder") - .def_static("make", []() { return custom_unique_ptr(new C); }); - - // test_holder_with_addressof_operator - struct TypeForHolderWithAddressOf { - TypeForHolderWithAddressOf() { print_created(this); } - TypeForHolderWithAddressOf(const TypeForHolderWithAddressOf &) { print_copy_created(this); } - TypeForHolderWithAddressOf(TypeForHolderWithAddressOf &&) { print_move_created(this); } - ~TypeForHolderWithAddressOf() { print_destroyed(this); } - std::string toString() const { - return "TypeForHolderWithAddressOf[" + std::to_string(value) + "]"; - } - int value = 42; - }; - using HolderWithAddressOf = shared_ptr_with_addressof_operator; - py::class_(m, "TypeForHolderWithAddressOf") - .def_static("make", []() { return HolderWithAddressOf(new TypeForHolderWithAddressOf); }) - .def("get", [](const HolderWithAddressOf &self) { return self.get(); }) - .def("print_object_1", [](const TypeForHolderWithAddressOf *obj) { py::print(obj->toString()); }) - .def("print_object_2", [](HolderWithAddressOf obj) { py::print(obj.get()->toString()); }) - .def("print_object_3", [](const HolderWithAddressOf &obj) { py::print(obj.get()->toString()); }) - .def("print_object_4", [](const HolderWithAddressOf *obj) { py::print((*obj).get()->toString()); }); - - // test_move_only_holder_with_addressof_operator - struct TypeForMoveOnlyHolderWithAddressOf { - TypeForMoveOnlyHolderWithAddressOf(int value) : value{value} { print_created(this); } - ~TypeForMoveOnlyHolderWithAddressOf() { print_destroyed(this); } - std::string toString() const { - return "MoveOnlyHolderWithAddressOf[" + std::to_string(value) + "]"; - } - int value; - }; - using MoveOnlyHolderWithAddressOf = unique_ptr_with_addressof_operator; - py::class_(m, "TypeForMoveOnlyHolderWithAddressOf") - .def_static("make", []() { return MoveOnlyHolderWithAddressOf(new TypeForMoveOnlyHolderWithAddressOf(0)); }) - .def_readwrite("value", &TypeForMoveOnlyHolderWithAddressOf::value) - .def("print_object", [](const TypeForMoveOnlyHolderWithAddressOf *obj) { py::print(obj->toString()); }); - - // test_smart_ptr_from_default - struct HeldByDefaultHolder { }; - py::class_(m, "HeldByDefaultHolder") - .def(py::init<>()) - .def_static("load_shared_ptr", [](std::shared_ptr) {}); - - // test_shared_ptr_gc - // #187: issue involving std::shared_ptr<> return value policy & garbage collection - struct ElementBase { - virtual ~ElementBase() { } /* Force creation of virtual table */ - }; - py::class_>(m, "ElementBase"); - - struct ElementA : ElementBase { - ElementA(int v) : v(v) { } - int value() { return v; } - int v; - }; - py::class_>(m, "ElementA") - .def(py::init()) - .def("value", &ElementA::value); - - struct ElementList { - void add(std::shared_ptr e) { l.push_back(e); } - std::vector> l; - }; - py::class_>(m, "ElementList") - .def(py::init<>()) - .def("add", &ElementList::add) - .def("get", [](ElementList &el) { - py::list list; - for (auto &e : el.l) - list.append(py::cast(e)); - return list; - }); -} diff --git a/pybind11/tests/test_smart_ptr.py b/pybind11/tests/test_smart_ptr.py deleted file mode 100644 index c662704..0000000 --- a/pybind11/tests/test_smart_ptr.py +++ /dev/null @@ -1,286 +0,0 @@ -import pytest -from pybind11_tests import smart_ptr as m -from pybind11_tests import ConstructorStats - - -def test_smart_ptr(capture): - # Object1 - for i, o in enumerate([m.make_object_1(), m.make_object_2(), m.MyObject1(3)], start=1): - assert o.getRefCount() == 1 - with capture: - m.print_object_1(o) - m.print_object_2(o) - m.print_object_3(o) - m.print_object_4(o) - assert capture == "MyObject1[{i}]\n".format(i=i) * 4 - - for i, o in enumerate([m.make_myobject1_1(), m.make_myobject1_2(), m.MyObject1(6), 7], - start=4): - print(o) - with capture: - if not isinstance(o, int): - m.print_object_1(o) - m.print_object_2(o) - m.print_object_3(o) - m.print_object_4(o) - m.print_myobject1_1(o) - m.print_myobject1_2(o) - m.print_myobject1_3(o) - m.print_myobject1_4(o) - assert capture == "MyObject1[{i}]\n".format(i=i) * (4 if isinstance(o, int) else 8) - - cstats = ConstructorStats.get(m.MyObject1) - assert cstats.alive() == 0 - expected_values = ['MyObject1[{}]'.format(i) for i in range(1, 7)] + ['MyObject1[7]'] * 4 - assert cstats.values() == expected_values - assert cstats.default_constructions == 0 - assert cstats.copy_constructions == 0 - # assert cstats.move_constructions >= 0 # Doesn't invoke any - assert cstats.copy_assignments == 0 - assert cstats.move_assignments == 0 - - # Object2 - for i, o in zip([8, 6, 7], [m.MyObject2(8), m.make_myobject2_1(), m.make_myobject2_2()]): - print(o) - with capture: - m.print_myobject2_1(o) - m.print_myobject2_2(o) - m.print_myobject2_3(o) - m.print_myobject2_4(o) - assert capture == "MyObject2[{i}]\n".format(i=i) * 4 - - cstats = ConstructorStats.get(m.MyObject2) - assert cstats.alive() == 1 - o = None - assert cstats.alive() == 0 - assert cstats.values() == ['MyObject2[8]', 'MyObject2[6]', 'MyObject2[7]'] - assert cstats.default_constructions == 0 - assert cstats.copy_constructions == 0 - # assert cstats.move_constructions >= 0 # Doesn't invoke any - assert cstats.copy_assignments == 0 - assert cstats.move_assignments == 0 - - # Object3 - for i, o in zip([9, 8, 9], [m.MyObject3(9), m.make_myobject3_1(), m.make_myobject3_2()]): - print(o) - with capture: - m.print_myobject3_1(o) - m.print_myobject3_2(o) - m.print_myobject3_3(o) - m.print_myobject3_4(o) - assert capture == "MyObject3[{i}]\n".format(i=i) * 4 - - cstats = ConstructorStats.get(m.MyObject3) - assert cstats.alive() == 1 - o = None - assert cstats.alive() == 0 - assert cstats.values() == ['MyObject3[9]', 'MyObject3[8]', 'MyObject3[9]'] - assert cstats.default_constructions == 0 - assert cstats.copy_constructions == 0 - # assert cstats.move_constructions >= 0 # Doesn't invoke any - assert cstats.copy_assignments == 0 - assert cstats.move_assignments == 0 - - # Object - cstats = ConstructorStats.get(m.Object) - assert cstats.alive() == 0 - assert cstats.values() == [] - assert cstats.default_constructions == 10 - assert cstats.copy_constructions == 0 - # assert cstats.move_constructions >= 0 # Doesn't invoke any - assert cstats.copy_assignments == 0 - assert cstats.move_assignments == 0 - - # ref<> - cstats = m.cstats_ref() - assert cstats.alive() == 0 - assert cstats.values() == ['from pointer'] * 10 - assert cstats.default_constructions == 30 - assert cstats.copy_constructions == 12 - # assert cstats.move_constructions >= 0 # Doesn't invoke any - assert cstats.copy_assignments == 30 - assert cstats.move_assignments == 0 - - -def test_smart_ptr_refcounting(): - assert m.test_object1_refcounting() - - -def test_unique_nodelete(): - o = m.MyObject4(23) - assert o.value == 23 - cstats = ConstructorStats.get(m.MyObject4) - assert cstats.alive() == 1 - del o - assert cstats.alive() == 1 # Leak, but that's intentional - - -def test_unique_nodelete4a(): - o = m.MyObject4a(23) - assert o.value == 23 - cstats = ConstructorStats.get(m.MyObject4a) - assert cstats.alive() == 1 - del o - assert cstats.alive() == 1 # Leak, but that's intentional - - -def test_unique_deleter(): - o = m.MyObject4b(23) - assert o.value == 23 - cstats4a = ConstructorStats.get(m.MyObject4a) - assert cstats4a.alive() == 2 # Two because of previous test - cstats4b = ConstructorStats.get(m.MyObject4b) - assert cstats4b.alive() == 1 - del o - assert cstats4a.alive() == 1 # Should now only be one leftover from previous test - assert cstats4b.alive() == 0 # Should be deleted - - -def test_large_holder(): - o = m.MyObject5(5) - assert o.value == 5 - cstats = ConstructorStats.get(m.MyObject5) - assert cstats.alive() == 1 - del o - assert cstats.alive() == 0 - - -def test_shared_ptr_and_references(): - s = m.SharedPtrRef() - stats = ConstructorStats.get(m.A) - assert stats.alive() == 2 - - ref = s.ref # init_holder_helper(holder_ptr=false, owned=false) - assert stats.alive() == 2 - assert s.set_ref(ref) - with pytest.raises(RuntimeError) as excinfo: - assert s.set_holder(ref) - assert "Unable to cast from non-held to held instance" in str(excinfo.value) - - copy = s.copy # init_holder_helper(holder_ptr=false, owned=true) - assert stats.alive() == 3 - assert s.set_ref(copy) - assert s.set_holder(copy) - - holder_ref = s.holder_ref # init_holder_helper(holder_ptr=true, owned=false) - assert stats.alive() == 3 - assert s.set_ref(holder_ref) - assert s.set_holder(holder_ref) - - holder_copy = s.holder_copy # init_holder_helper(holder_ptr=true, owned=true) - assert stats.alive() == 3 - assert s.set_ref(holder_copy) - assert s.set_holder(holder_copy) - - del ref, copy, holder_ref, holder_copy, s - assert stats.alive() == 0 - - -def test_shared_ptr_from_this_and_references(): - s = m.SharedFromThisRef() - stats = ConstructorStats.get(m.B) - assert stats.alive() == 2 - - ref = s.ref # init_holder_helper(holder_ptr=false, owned=false, bad_wp=false) - assert stats.alive() == 2 - assert s.set_ref(ref) - assert s.set_holder(ref) # std::enable_shared_from_this can create a holder from a reference - - bad_wp = s.bad_wp # init_holder_helper(holder_ptr=false, owned=false, bad_wp=true) - assert stats.alive() == 2 - assert s.set_ref(bad_wp) - with pytest.raises(RuntimeError) as excinfo: - assert s.set_holder(bad_wp) - assert "Unable to cast from non-held to held instance" in str(excinfo.value) - - copy = s.copy # init_holder_helper(holder_ptr=false, owned=true, bad_wp=false) - assert stats.alive() == 3 - assert s.set_ref(copy) - assert s.set_holder(copy) - - holder_ref = s.holder_ref # init_holder_helper(holder_ptr=true, owned=false, bad_wp=false) - assert stats.alive() == 3 - assert s.set_ref(holder_ref) - assert s.set_holder(holder_ref) - - holder_copy = s.holder_copy # init_holder_helper(holder_ptr=true, owned=true, bad_wp=false) - assert stats.alive() == 3 - assert s.set_ref(holder_copy) - assert s.set_holder(holder_copy) - - del ref, bad_wp, copy, holder_ref, holder_copy, s - assert stats.alive() == 0 - - z = m.SharedFromThisVirt.get() - y = m.SharedFromThisVirt.get() - assert y is z - - -def test_move_only_holder(): - a = m.TypeWithMoveOnlyHolder.make() - stats = ConstructorStats.get(m.TypeWithMoveOnlyHolder) - assert stats.alive() == 1 - del a - assert stats.alive() == 0 - - -def test_holder_with_addressof_operator(): - # this test must not throw exception from c++ - a = m.TypeForHolderWithAddressOf.make() - a.print_object_1() - a.print_object_2() - a.print_object_3() - a.print_object_4() - - stats = ConstructorStats.get(m.TypeForHolderWithAddressOf) - assert stats.alive() == 1 - - np = m.TypeForHolderWithAddressOf.make() - assert stats.alive() == 2 - del a - assert stats.alive() == 1 - del np - assert stats.alive() == 0 - - b = m.TypeForHolderWithAddressOf.make() - c = b - assert b.get() is c.get() - assert stats.alive() == 1 - - del b - assert stats.alive() == 1 - - del c - assert stats.alive() == 0 - - -def test_move_only_holder_with_addressof_operator(): - a = m.TypeForMoveOnlyHolderWithAddressOf.make() - a.print_object() - - stats = ConstructorStats.get(m.TypeForMoveOnlyHolderWithAddressOf) - assert stats.alive() == 1 - - a.value = 42 - assert a.value == 42 - - del a - assert stats.alive() == 0 - - -def test_smart_ptr_from_default(): - instance = m.HeldByDefaultHolder() - with pytest.raises(RuntimeError) as excinfo: - m.HeldByDefaultHolder.load_shared_ptr(instance) - assert "Unable to load a custom holder type from a " \ - "default-holder instance" in str(excinfo.value) - - -def test_shared_ptr_gc(): - """#187: issue involving std::shared_ptr<> return value policy & garbage collection""" - el = m.ElementList() - for i in range(10): - el.add(m.ElementA(i)) - pytest.gc_collect() - for i, v in enumerate(el.get()): - assert i == v.value() diff --git a/pybind11/tests/test_stl.cpp b/pybind11/tests/test_stl.cpp deleted file mode 100644 index 207c9fb..0000000 --- a/pybind11/tests/test_stl.cpp +++ /dev/null @@ -1,284 +0,0 @@ -/* - tests/test_stl.cpp -- STL type casters - - Copyright (c) 2017 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" -#include - -#include -#include - -// Test with `std::variant` in C++17 mode, or with `boost::variant` in C++11/14 -#if PYBIND11_HAS_VARIANT -using std::variant; -#elif defined(PYBIND11_TEST_BOOST) && (!defined(_MSC_VER) || _MSC_VER >= 1910) -# include -# define PYBIND11_HAS_VARIANT 1 -using boost::variant; - -namespace pybind11 { namespace detail { -template -struct type_caster> : variant_caster> {}; - -template <> -struct visit_helper { - template - static auto call(Args &&...args) -> decltype(boost::apply_visitor(args...)) { - return boost::apply_visitor(args...); - } -}; -}} // namespace pybind11::detail -#endif - -PYBIND11_MAKE_OPAQUE(std::vector>); - -/// Issue #528: templated constructor -struct TplCtorClass { - template TplCtorClass(const T &) { } - bool operator==(const TplCtorClass &) const { return true; } -}; - -namespace std { - template <> - struct hash { size_t operator()(const TplCtorClass &) const { return 0; } }; -} - - -TEST_SUBMODULE(stl, m) { - // test_vector - m.def("cast_vector", []() { return std::vector{1}; }); - m.def("load_vector", [](const std::vector &v) { return v.at(0) == 1 && v.at(1) == 2; }); - // `std::vector` is special because it returns proxy objects instead of references - m.def("cast_bool_vector", []() { return std::vector{true, false}; }); - m.def("load_bool_vector", [](const std::vector &v) { - return v.at(0) == true && v.at(1) == false; - }); - // Unnumbered regression (caused by #936): pointers to stl containers aren't castable - static std::vector lvv{2}; - m.def("cast_ptr_vector", []() { return &lvv; }); - - // test_deque - m.def("cast_deque", []() { return std::deque{1}; }); - m.def("load_deque", [](const std::deque &v) { return v.at(0) == 1 && v.at(1) == 2; }); - - // test_array - m.def("cast_array", []() { return std::array {{1 , 2}}; }); - m.def("load_array", [](const std::array &a) { return a[0] == 1 && a[1] == 2; }); - - // test_valarray - m.def("cast_valarray", []() { return std::valarray{1, 4, 9}; }); - m.def("load_valarray", [](const std::valarray& v) { - return v.size() == 3 && v[0] == 1 && v[1] == 4 && v[2] == 9; - }); - - // test_map - m.def("cast_map", []() { return std::map{{"key", "value"}}; }); - m.def("load_map", [](const std::map &map) { - return map.at("key") == "value" && map.at("key2") == "value2"; - }); - - // test_set - m.def("cast_set", []() { return std::set{"key1", "key2"}; }); - m.def("load_set", [](const std::set &set) { - return set.count("key1") && set.count("key2") && set.count("key3"); - }); - - // test_recursive_casting - m.def("cast_rv_vector", []() { return std::vector{2}; }); - m.def("cast_rv_array", []() { return std::array(); }); - // NB: map and set keys are `const`, so while we technically do move them (as `const Type &&`), - // casters don't typically do anything with that, which means they fall to the `const Type &` - // caster. - m.def("cast_rv_map", []() { return std::unordered_map{{"a", RValueCaster{}}}; }); - m.def("cast_rv_nested", []() { - std::vector>, 2>> v; - v.emplace_back(); // add an array - v.back()[0].emplace_back(); // add a map to the array - v.back()[0].back().emplace("b", RValueCaster{}); - v.back()[0].back().emplace("c", RValueCaster{}); - v.back()[1].emplace_back(); // add a map to the array - v.back()[1].back().emplace("a", RValueCaster{}); - return v; - }); - static std::array lva; - static std::unordered_map lvm{{"a", RValueCaster{}}, {"b", RValueCaster{}}}; - static std::unordered_map>>> lvn; - lvn["a"].emplace_back(); // add a list - lvn["a"].back().emplace_back(); // add an array - lvn["a"].emplace_back(); // another list - lvn["a"].back().emplace_back(); // add an array - lvn["b"].emplace_back(); // add a list - lvn["b"].back().emplace_back(); // add an array - lvn["b"].back().emplace_back(); // add another array - m.def("cast_lv_vector", []() -> const decltype(lvv) & { return lvv; }); - m.def("cast_lv_array", []() -> const decltype(lva) & { return lva; }); - m.def("cast_lv_map", []() -> const decltype(lvm) & { return lvm; }); - m.def("cast_lv_nested", []() -> const decltype(lvn) & { return lvn; }); - // #853: - m.def("cast_unique_ptr_vector", []() { - std::vector> v; - v.emplace_back(new UserType{7}); - v.emplace_back(new UserType{42}); - return v; - }); - - // test_move_out_container - struct MoveOutContainer { - struct Value { int value; }; - std::list move_list() const { return {{0}, {1}, {2}}; } - }; - py::class_(m, "MoveOutContainerValue") - .def_readonly("value", &MoveOutContainer::Value::value); - py::class_(m, "MoveOutContainer") - .def(py::init<>()) - .def_property_readonly("move_list", &MoveOutContainer::move_list); - - // Class that can be move- and copy-constructed, but not assigned - struct NoAssign { - int value; - - explicit NoAssign(int value = 0) : value(value) { } - NoAssign(const NoAssign &) = default; - NoAssign(NoAssign &&) = default; - - NoAssign &operator=(const NoAssign &) = delete; - NoAssign &operator=(NoAssign &&) = delete; - }; - py::class_(m, "NoAssign", "Class with no C++ assignment operators") - .def(py::init<>()) - .def(py::init()); - -#ifdef PYBIND11_HAS_OPTIONAL - // test_optional - m.attr("has_optional") = true; - - using opt_int = std::optional; - using opt_no_assign = std::optional; - m.def("double_or_zero", [](const opt_int& x) -> int { - return x.value_or(0) * 2; - }); - m.def("half_or_none", [](int x) -> opt_int { - return x ? opt_int(x / 2) : opt_int(); - }); - m.def("test_nullopt", [](opt_int x) { - return x.value_or(42); - }, py::arg_v("x", std::nullopt, "None")); - m.def("test_no_assign", [](const opt_no_assign &x) { - return x ? x->value : 42; - }, py::arg_v("x", std::nullopt, "None")); - - m.def("nodefer_none_optional", [](std::optional) { return true; }); - m.def("nodefer_none_optional", [](py::none) { return false; }); -#endif - -#ifdef PYBIND11_HAS_EXP_OPTIONAL - // test_exp_optional - m.attr("has_exp_optional") = true; - - using exp_opt_int = std::experimental::optional; - using exp_opt_no_assign = std::experimental::optional; - m.def("double_or_zero_exp", [](const exp_opt_int& x) -> int { - return x.value_or(0) * 2; - }); - m.def("half_or_none_exp", [](int x) -> exp_opt_int { - return x ? exp_opt_int(x / 2) : exp_opt_int(); - }); - m.def("test_nullopt_exp", [](exp_opt_int x) { - return x.value_or(42); - }, py::arg_v("x", std::experimental::nullopt, "None")); - m.def("test_no_assign_exp", [](const exp_opt_no_assign &x) { - return x ? x->value : 42; - }, py::arg_v("x", std::experimental::nullopt, "None")); -#endif - -#ifdef PYBIND11_HAS_VARIANT - static_assert(std::is_same::value, - "visitor::result_type is required by boost::variant in C++11 mode"); - - struct visitor { - using result_type = const char *; - - result_type operator()(int) { return "int"; } - result_type operator()(std::string) { return "std::string"; } - result_type operator()(double) { return "double"; } - result_type operator()(std::nullptr_t) { return "std::nullptr_t"; } - }; - - // test_variant - m.def("load_variant", [](variant v) { - return py::detail::visit_helper::call(visitor(), v); - }); - m.def("load_variant_2pass", [](variant v) { - return py::detail::visit_helper::call(visitor(), v); - }); - m.def("cast_variant", []() { - using V = variant; - return py::make_tuple(V(5), V("Hello")); - }); -#endif - - // #528: templated constructor - // (no python tests: the test here is that this compiles) - m.def("tpl_ctor_vector", [](std::vector &) {}); - m.def("tpl_ctor_map", [](std::unordered_map &) {}); - m.def("tpl_ctor_set", [](std::unordered_set &) {}); -#if defined(PYBIND11_HAS_OPTIONAL) - m.def("tpl_constr_optional", [](std::optional &) {}); -#elif defined(PYBIND11_HAS_EXP_OPTIONAL) - m.def("tpl_constr_optional", [](std::experimental::optional &) {}); -#endif - - // test_vec_of_reference_wrapper - // #171: Can't return STL structures containing reference wrapper - m.def("return_vec_of_reference_wrapper", [](std::reference_wrapper p4) { - static UserType p1{1}, p2{2}, p3{3}; - return std::vector> { - std::ref(p1), std::ref(p2), std::ref(p3), p4 - }; - }); - - // test_stl_pass_by_pointer - m.def("stl_pass_by_pointer", [](std::vector* v) { return *v; }, "v"_a=nullptr); - - // #1258: pybind11/stl.h converts string to vector - m.def("func_with_string_or_vector_string_arg_overload", [](std::vector) { return 1; }); - m.def("func_with_string_or_vector_string_arg_overload", [](std::list) { return 2; }); - m.def("func_with_string_or_vector_string_arg_overload", [](std::string) { return 3; }); - - class Placeholder { - public: - Placeholder() { print_created(this); } - Placeholder(const Placeholder &) = delete; - ~Placeholder() { print_destroyed(this); } - }; - py::class_(m, "Placeholder"); - - /// test_stl_vector_ownership - m.def("test_stl_ownership", - []() { - std::vector result; - result.push_back(new Placeholder()); - return result; - }, - py::return_value_policy::take_ownership); - - m.def("array_cast_sequence", [](std::array x) { return x; }); - - /// test_issue_1561 - struct Issue1561Inner { std::string data; }; - struct Issue1561Outer { std::vector list; }; - - py::class_(m, "Issue1561Inner") - .def(py::init()) - .def_readwrite("data", &Issue1561Inner::data); - - py::class_(m, "Issue1561Outer") - .def(py::init<>()) - .def_readwrite("list", &Issue1561Outer::list); -} diff --git a/pybind11/tests/test_stl.py b/pybind11/tests/test_stl.py deleted file mode 100644 index 2335cb9..0000000 --- a/pybind11/tests/test_stl.py +++ /dev/null @@ -1,241 +0,0 @@ -import pytest - -from pybind11_tests import stl as m -from pybind11_tests import UserType -from pybind11_tests import ConstructorStats - - -def test_vector(doc): - """std::vector <-> list""" - lst = m.cast_vector() - assert lst == [1] - lst.append(2) - assert m.load_vector(lst) - assert m.load_vector(tuple(lst)) - - assert m.cast_bool_vector() == [True, False] - assert m.load_bool_vector([True, False]) - - assert doc(m.cast_vector) == "cast_vector() -> List[int]" - assert doc(m.load_vector) == "load_vector(arg0: List[int]) -> bool" - - # Test regression caused by 936: pointers to stl containers weren't castable - assert m.cast_ptr_vector() == ["lvalue", "lvalue"] - - -def test_deque(doc): - """std::deque <-> list""" - lst = m.cast_deque() - assert lst == [1] - lst.append(2) - assert m.load_deque(lst) - assert m.load_deque(tuple(lst)) - - -def test_array(doc): - """std::array <-> list""" - lst = m.cast_array() - assert lst == [1, 2] - assert m.load_array(lst) - - assert doc(m.cast_array) == "cast_array() -> List[int[2]]" - assert doc(m.load_array) == "load_array(arg0: List[int[2]]) -> bool" - - -def test_valarray(doc): - """std::valarray <-> list""" - lst = m.cast_valarray() - assert lst == [1, 4, 9] - assert m.load_valarray(lst) - - assert doc(m.cast_valarray) == "cast_valarray() -> List[int]" - assert doc(m.load_valarray) == "load_valarray(arg0: List[int]) -> bool" - - -def test_map(doc): - """std::map <-> dict""" - d = m.cast_map() - assert d == {"key": "value"} - assert "key" in d - d["key2"] = "value2" - assert "key2" in d - assert m.load_map(d) - - assert doc(m.cast_map) == "cast_map() -> Dict[str, str]" - assert doc(m.load_map) == "load_map(arg0: Dict[str, str]) -> bool" - - -def test_set(doc): - """std::set <-> set""" - s = m.cast_set() - assert s == {"key1", "key2"} - s.add("key3") - assert m.load_set(s) - - assert doc(m.cast_set) == "cast_set() -> Set[str]" - assert doc(m.load_set) == "load_set(arg0: Set[str]) -> bool" - - -def test_recursive_casting(): - """Tests that stl casters preserve lvalue/rvalue context for container values""" - assert m.cast_rv_vector() == ["rvalue", "rvalue"] - assert m.cast_lv_vector() == ["lvalue", "lvalue"] - assert m.cast_rv_array() == ["rvalue", "rvalue", "rvalue"] - assert m.cast_lv_array() == ["lvalue", "lvalue"] - assert m.cast_rv_map() == {"a": "rvalue"} - assert m.cast_lv_map() == {"a": "lvalue", "b": "lvalue"} - assert m.cast_rv_nested() == [[[{"b": "rvalue", "c": "rvalue"}], [{"a": "rvalue"}]]] - assert m.cast_lv_nested() == { - "a": [[["lvalue", "lvalue"]], [["lvalue", "lvalue"]]], - "b": [[["lvalue", "lvalue"], ["lvalue", "lvalue"]]] - } - - # Issue #853 test case: - z = m.cast_unique_ptr_vector() - assert z[0].value == 7 and z[1].value == 42 - - -def test_move_out_container(): - """Properties use the `reference_internal` policy by default. If the underlying function - returns an rvalue, the policy is automatically changed to `move` to avoid referencing - a temporary. In case the return value is a container of user-defined types, the policy - also needs to be applied to the elements, not just the container.""" - c = m.MoveOutContainer() - moved_out_list = c.move_list - assert [x.value for x in moved_out_list] == [0, 1, 2] - - -@pytest.mark.skipif(not hasattr(m, "has_optional"), reason='no ') -def test_optional(): - assert m.double_or_zero(None) == 0 - assert m.double_or_zero(42) == 84 - pytest.raises(TypeError, m.double_or_zero, 'foo') - - assert m.half_or_none(0) is None - assert m.half_or_none(42) == 21 - pytest.raises(TypeError, m.half_or_none, 'foo') - - assert m.test_nullopt() == 42 - assert m.test_nullopt(None) == 42 - assert m.test_nullopt(42) == 42 - assert m.test_nullopt(43) == 43 - - assert m.test_no_assign() == 42 - assert m.test_no_assign(None) == 42 - assert m.test_no_assign(m.NoAssign(43)) == 43 - pytest.raises(TypeError, m.test_no_assign, 43) - - assert m.nodefer_none_optional(None) - - -@pytest.mark.skipif(not hasattr(m, "has_exp_optional"), reason='no ') -def test_exp_optional(): - assert m.double_or_zero_exp(None) == 0 - assert m.double_or_zero_exp(42) == 84 - pytest.raises(TypeError, m.double_or_zero_exp, 'foo') - - assert m.half_or_none_exp(0) is None - assert m.half_or_none_exp(42) == 21 - pytest.raises(TypeError, m.half_or_none_exp, 'foo') - - assert m.test_nullopt_exp() == 42 - assert m.test_nullopt_exp(None) == 42 - assert m.test_nullopt_exp(42) == 42 - assert m.test_nullopt_exp(43) == 43 - - assert m.test_no_assign_exp() == 42 - assert m.test_no_assign_exp(None) == 42 - assert m.test_no_assign_exp(m.NoAssign(43)) == 43 - pytest.raises(TypeError, m.test_no_assign_exp, 43) - - -@pytest.mark.skipif(not hasattr(m, "load_variant"), reason='no ') -def test_variant(doc): - assert m.load_variant(1) == "int" - assert m.load_variant("1") == "std::string" - assert m.load_variant(1.0) == "double" - assert m.load_variant(None) == "std::nullptr_t" - - assert m.load_variant_2pass(1) == "int" - assert m.load_variant_2pass(1.0) == "double" - - assert m.cast_variant() == (5, "Hello") - - assert doc(m.load_variant) == "load_variant(arg0: Union[int, str, float, None]) -> str" - - -def test_vec_of_reference_wrapper(): - """#171: Can't return reference wrappers (or STL structures containing them)""" - assert str(m.return_vec_of_reference_wrapper(UserType(4))) == \ - "[UserType(1), UserType(2), UserType(3), UserType(4)]" - - -def test_stl_pass_by_pointer(msg): - """Passing nullptr or None to an STL container pointer is not expected to work""" - with pytest.raises(TypeError) as excinfo: - m.stl_pass_by_pointer() # default value is `nullptr` - assert msg(excinfo.value) == """ - stl_pass_by_pointer(): incompatible function arguments. The following argument types are supported: - 1. (v: List[int] = None) -> List[int] - - Invoked with: - """ # noqa: E501 line too long - - with pytest.raises(TypeError) as excinfo: - m.stl_pass_by_pointer(None) - assert msg(excinfo.value) == """ - stl_pass_by_pointer(): incompatible function arguments. The following argument types are supported: - 1. (v: List[int] = None) -> List[int] - - Invoked with: None - """ # noqa: E501 line too long - - assert m.stl_pass_by_pointer([1, 2, 3]) == [1, 2, 3] - - -def test_missing_header_message(): - """Trying convert `list` to a `std::vector`, or vice versa, without including - should result in a helpful suggestion in the error message""" - import pybind11_cross_module_tests as cm - - expected_message = ("Did you forget to `#include `? Or ,\n" - ", , etc. Some automatic\n" - "conversions are optional and require extra headers to be included\n" - "when compiling your pybind11 module.") - - with pytest.raises(TypeError) as excinfo: - cm.missing_header_arg([1.0, 2.0, 3.0]) - assert expected_message in str(excinfo.value) - - with pytest.raises(TypeError) as excinfo: - cm.missing_header_return() - assert expected_message in str(excinfo.value) - - -def test_function_with_string_and_vector_string_arg(): - """Check if a string is NOT implicitly converted to a list, which was the - behavior before fix of issue #1258""" - assert m.func_with_string_or_vector_string_arg_overload(('A', 'B', )) == 2 - assert m.func_with_string_or_vector_string_arg_overload(['A', 'B']) == 2 - assert m.func_with_string_or_vector_string_arg_overload('A') == 3 - - -def test_stl_ownership(): - cstats = ConstructorStats.get(m.Placeholder) - assert cstats.alive() == 0 - r = m.test_stl_ownership() - assert len(r) == 1 - del r - assert cstats.alive() == 0 - - -def test_array_cast_sequence(): - assert m.array_cast_sequence((1, 2, 3)) == [1, 2, 3] - - -def test_issue_1561(): - """ check fix for issue #1561 """ - bar = m.Issue1561Outer() - bar.list = [m.Issue1561Inner('bar')] - bar.list - assert bar.list[0].data == 'bar' diff --git a/pybind11/tests/test_stl_binders.cpp b/pybind11/tests/test_stl_binders.cpp deleted file mode 100644 index a88b589..0000000 --- a/pybind11/tests/test_stl_binders.cpp +++ /dev/null @@ -1,107 +0,0 @@ -/* - tests/test_stl_binders.cpp -- Usage of stl_binders functions - - Copyright (c) 2016 Sergey Lyskov - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" - -#include -#include -#include -#include -#include - -class El { -public: - El() = delete; - El(int v) : a(v) { } - - int a; -}; - -std::ostream & operator<<(std::ostream &s, El const&v) { - s << "El{" << v.a << '}'; - return s; -} - -/// Issue #487: binding std::vector with E non-copyable -class E_nc { -public: - explicit E_nc(int i) : value{i} {} - E_nc(const E_nc &) = delete; - E_nc &operator=(const E_nc &) = delete; - E_nc(E_nc &&) = default; - E_nc &operator=(E_nc &&) = default; - - int value; -}; - -template Container *one_to_n(int n) { - auto v = new Container(); - for (int i = 1; i <= n; i++) - v->emplace_back(i); - return v; -} - -template Map *times_ten(int n) { - auto m = new Map(); - for (int i = 1; i <= n; i++) - m->emplace(int(i), E_nc(10*i)); - return m; -} - -TEST_SUBMODULE(stl_binders, m) { - // test_vector_int - py::bind_vector>(m, "VectorInt", py::buffer_protocol()); - - // test_vector_custom - py::class_(m, "El") - .def(py::init()); - py::bind_vector>(m, "VectorEl"); - py::bind_vector>>(m, "VectorVectorEl"); - - // test_map_string_double - py::bind_map>(m, "MapStringDouble"); - py::bind_map>(m, "UnorderedMapStringDouble"); - - // test_map_string_double_const - py::bind_map>(m, "MapStringDoubleConst"); - py::bind_map>(m, "UnorderedMapStringDoubleConst"); - - py::class_(m, "ENC") - .def(py::init()) - .def_readwrite("value", &E_nc::value); - - // test_noncopyable_containers - py::bind_vector>(m, "VectorENC"); - m.def("get_vnc", &one_to_n>, py::return_value_policy::reference); - py::bind_vector>(m, "DequeENC"); - m.def("get_dnc", &one_to_n>, py::return_value_policy::reference); - py::bind_map>(m, "MapENC"); - m.def("get_mnc", ×_ten>, py::return_value_policy::reference); - py::bind_map>(m, "UmapENC"); - m.def("get_umnc", ×_ten>, py::return_value_policy::reference); - - // test_vector_buffer - py::bind_vector>(m, "VectorUChar", py::buffer_protocol()); - // no dtype declared for this version: - struct VUndeclStruct { bool w; uint32_t x; double y; bool z; }; - m.def("create_undeclstruct", [m] () mutable { - py::bind_vector>(m, "VectorUndeclStruct", py::buffer_protocol()); - }); - - // The rest depends on numpy: - try { py::module::import("numpy"); } - catch (...) { return; } - - // test_vector_buffer_numpy - struct VStruct { bool w; uint32_t x; double y; bool z; }; - PYBIND11_NUMPY_DTYPE(VStruct, w, x, y, z); - py::class_(m, "VStruct").def_readwrite("x", &VStruct::x); - py::bind_vector>(m, "VectorStruct", py::buffer_protocol()); - m.def("get_vectorstruct", [] {return std::vector {{0, 5, 3.0, 1}, {1, 30, -1e4, 0}};}); -} diff --git a/pybind11/tests/test_stl_binders.py b/pybind11/tests/test_stl_binders.py deleted file mode 100644 index 6d5a159..0000000 --- a/pybind11/tests/test_stl_binders.py +++ /dev/null @@ -1,235 +0,0 @@ -import pytest -import sys -from pybind11_tests import stl_binders as m - -with pytest.suppress(ImportError): - import numpy as np - - -def test_vector_int(): - v_int = m.VectorInt([0, 0]) - assert len(v_int) == 2 - assert bool(v_int) is True - - # test construction from a generator - v_int1 = m.VectorInt(x for x in range(5)) - assert v_int1 == m.VectorInt([0, 1, 2, 3, 4]) - - v_int2 = m.VectorInt([0, 0]) - assert v_int == v_int2 - v_int2[1] = 1 - assert v_int != v_int2 - - v_int2.append(2) - v_int2.insert(0, 1) - v_int2.insert(0, 2) - v_int2.insert(0, 3) - v_int2.insert(6, 3) - assert str(v_int2) == "VectorInt[3, 2, 1, 0, 1, 2, 3]" - with pytest.raises(IndexError): - v_int2.insert(8, 4) - - v_int.append(99) - v_int2[2:-2] = v_int - assert v_int2 == m.VectorInt([3, 2, 0, 0, 99, 2, 3]) - del v_int2[1:3] - assert v_int2 == m.VectorInt([3, 0, 99, 2, 3]) - del v_int2[0] - assert v_int2 == m.VectorInt([0, 99, 2, 3]) - - v_int2.extend(m.VectorInt([4, 5])) - assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5]) - - v_int2.extend([6, 7]) - assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7]) - - # test error handling, and that the vector is unchanged - with pytest.raises(RuntimeError): - v_int2.extend([8, 'a']) - - assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7]) - - # test extending from a generator - v_int2.extend(x for x in range(5)) - assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4]) - - # test negative indexing - assert v_int2[-1] == 4 - - # insert with negative index - v_int2.insert(-1, 88) - assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88, 4]) - - # delete negative index - del v_int2[-1] - assert v_int2 == m.VectorInt([0, 99, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 88]) - -# related to the PyPy's buffer protocol. -@pytest.unsupported_on_pypy -def test_vector_buffer(): - b = bytearray([1, 2, 3, 4]) - v = m.VectorUChar(b) - assert v[1] == 2 - v[2] = 5 - mv = memoryview(v) # We expose the buffer interface - if sys.version_info.major > 2: - assert mv[2] == 5 - mv[2] = 6 - else: - assert mv[2] == '\x05' - mv[2] = '\x06' - assert v[2] == 6 - - with pytest.raises(RuntimeError) as excinfo: - m.create_undeclstruct() # Undeclared struct contents, no buffer interface - assert "NumPy type info missing for " in str(excinfo.value) - - -@pytest.unsupported_on_pypy -@pytest.requires_numpy -def test_vector_buffer_numpy(): - a = np.array([1, 2, 3, 4], dtype=np.int32) - with pytest.raises(TypeError): - m.VectorInt(a) - - a = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=np.uintc) - v = m.VectorInt(a[0, :]) - assert len(v) == 4 - assert v[2] == 3 - ma = np.asarray(v) - ma[2] = 5 - assert v[2] == 5 - - v = m.VectorInt(a[:, 1]) - assert len(v) == 3 - assert v[2] == 10 - - v = m.get_vectorstruct() - assert v[0].x == 5 - ma = np.asarray(v) - ma[1]['x'] = 99 - assert v[1].x == 99 - - v = m.VectorStruct(np.zeros(3, dtype=np.dtype([('w', 'bool'), ('x', 'I'), - ('y', 'float64'), ('z', 'bool')], align=True))) - assert len(v) == 3 - - -def test_vector_bool(): - import pybind11_cross_module_tests as cm - - vv_c = cm.VectorBool() - for i in range(10): - vv_c.append(i % 2 == 0) - for i in range(10): - assert vv_c[i] == (i % 2 == 0) - assert str(vv_c) == "VectorBool[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]" - - -def test_vector_custom(): - v_a = m.VectorEl() - v_a.append(m.El(1)) - v_a.append(m.El(2)) - assert str(v_a) == "VectorEl[El{1}, El{2}]" - - vv_a = m.VectorVectorEl() - vv_a.append(v_a) - vv_b = vv_a[0] - assert str(vv_b) == "VectorEl[El{1}, El{2}]" - - -def test_map_string_double(): - mm = m.MapStringDouble() - mm['a'] = 1 - mm['b'] = 2.5 - - assert list(mm) == ['a', 'b'] - assert list(mm.items()) == [('a', 1), ('b', 2.5)] - assert str(mm) == "MapStringDouble{a: 1, b: 2.5}" - - um = m.UnorderedMapStringDouble() - um['ua'] = 1.1 - um['ub'] = 2.6 - - assert sorted(list(um)) == ['ua', 'ub'] - assert sorted(list(um.items())) == [('ua', 1.1), ('ub', 2.6)] - assert "UnorderedMapStringDouble" in str(um) - - -def test_map_string_double_const(): - mc = m.MapStringDoubleConst() - mc['a'] = 10 - mc['b'] = 20.5 - assert str(mc) == "MapStringDoubleConst{a: 10, b: 20.5}" - - umc = m.UnorderedMapStringDoubleConst() - umc['a'] = 11 - umc['b'] = 21.5 - - str(umc) - - -def test_noncopyable_containers(): - # std::vector - vnc = m.get_vnc(5) - for i in range(0, 5): - assert vnc[i].value == i + 1 - - for i, j in enumerate(vnc, start=1): - assert j.value == i - - # std::deque - dnc = m.get_dnc(5) - for i in range(0, 5): - assert dnc[i].value == i + 1 - - i = 1 - for j in dnc: - assert(j.value == i) - i += 1 - - # std::map - mnc = m.get_mnc(5) - for i in range(1, 6): - assert mnc[i].value == 10 * i - - vsum = 0 - for k, v in mnc.items(): - assert v.value == 10 * k - vsum += v.value - - assert vsum == 150 - - # std::unordered_map - mnc = m.get_umnc(5) - for i in range(1, 6): - assert mnc[i].value == 10 * i - - vsum = 0 - for k, v in mnc.items(): - assert v.value == 10 * k - vsum += v.value - - assert vsum == 150 - - -def test_map_delitem(): - mm = m.MapStringDouble() - mm['a'] = 1 - mm['b'] = 2.5 - - assert list(mm) == ['a', 'b'] - assert list(mm.items()) == [('a', 1), ('b', 2.5)] - del mm['a'] - assert list(mm) == ['b'] - assert list(mm.items()) == [('b', 2.5)] - - um = m.UnorderedMapStringDouble() - um['ua'] = 1.1 - um['ub'] = 2.6 - - assert sorted(list(um)) == ['ua', 'ub'] - assert sorted(list(um.items())) == [('ua', 1.1), ('ub', 2.6)] - del um['ua'] - assert sorted(list(um)) == ['ub'] - assert sorted(list(um.items())) == [('ub', 2.6)] diff --git a/pybind11/tests/test_tagbased_polymorphic.cpp b/pybind11/tests/test_tagbased_polymorphic.cpp deleted file mode 100644 index 272e460..0000000 --- a/pybind11/tests/test_tagbased_polymorphic.cpp +++ /dev/null @@ -1,136 +0,0 @@ -/* - tests/test_tagbased_polymorphic.cpp -- test of polymorphic_type_hook - - Copyright (c) 2018 Hudson River Trading LLC - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include - -struct Animal -{ - enum class Kind { - Unknown = 0, - Dog = 100, Labrador, Chihuahua, LastDog = 199, - Cat = 200, Panther, LastCat = 299 - }; - static const std::type_info* type_of_kind(Kind kind); - static std::string name_of_kind(Kind kind); - - const Kind kind; - const std::string name; - - protected: - Animal(const std::string& _name, Kind _kind) - : kind(_kind), name(_name) - {} -}; - -struct Dog : Animal -{ - Dog(const std::string& _name, Kind _kind = Kind::Dog) : Animal(_name, _kind) {} - std::string bark() const { return name_of_kind(kind) + " " + name + " goes " + sound; } - std::string sound = "WOOF!"; -}; - -struct Labrador : Dog -{ - Labrador(const std::string& _name, int _excitement = 9001) - : Dog(_name, Kind::Labrador), excitement(_excitement) {} - int excitement; -}; - -struct Chihuahua : Dog -{ - Chihuahua(const std::string& _name) : Dog(_name, Kind::Chihuahua) { sound = "iyiyiyiyiyi"; } - std::string bark() const { return Dog::bark() + " and runs in circles"; } -}; - -struct Cat : Animal -{ - Cat(const std::string& _name, Kind _kind = Kind::Cat) : Animal(_name, _kind) {} - std::string purr() const { return "mrowr"; } -}; - -struct Panther : Cat -{ - Panther(const std::string& _name) : Cat(_name, Kind::Panther) {} - std::string purr() const { return "mrrrRRRRRR"; } -}; - -std::vector> create_zoo() -{ - std::vector> ret; - ret.emplace_back(new Labrador("Fido", 15000)); - - // simulate some new type of Dog that the Python bindings - // haven't been updated for; it should still be considered - // a Dog, not just an Animal. - ret.emplace_back(new Dog("Ginger", Dog::Kind(150))); - - ret.emplace_back(new Chihuahua("Hertzl")); - ret.emplace_back(new Cat("Tiger", Cat::Kind::Cat)); - ret.emplace_back(new Panther("Leo")); - return ret; -} - -const std::type_info* Animal::type_of_kind(Kind kind) -{ - switch (kind) { - case Kind::Unknown: break; - - case Kind::Dog: break; - case Kind::Labrador: return &typeid(Labrador); - case Kind::Chihuahua: return &typeid(Chihuahua); - case Kind::LastDog: break; - - case Kind::Cat: break; - case Kind::Panther: return &typeid(Panther); - case Kind::LastCat: break; - } - - if (kind >= Kind::Dog && kind <= Kind::LastDog) return &typeid(Dog); - if (kind >= Kind::Cat && kind <= Kind::LastCat) return &typeid(Cat); - return nullptr; -} - -std::string Animal::name_of_kind(Kind kind) -{ - std::string raw_name = type_of_kind(kind)->name(); - py::detail::clean_type_id(raw_name); - return raw_name; -} - -namespace pybind11 { - template - struct polymorphic_type_hook::value>> - { - static const void *get(const itype *src, const std::type_info*& type) - { type = src ? Animal::type_of_kind(src->kind) : nullptr; return src; } - }; -} - -TEST_SUBMODULE(tagbased_polymorphic, m) { - py::class_(m, "Animal") - .def_readonly("name", &Animal::name); - py::class_(m, "Dog") - .def(py::init()) - .def_readwrite("sound", &Dog::sound) - .def("bark", &Dog::bark); - py::class_(m, "Labrador") - .def(py::init(), "name"_a, "excitement"_a = 9001) - .def_readwrite("excitement", &Labrador::excitement); - py::class_(m, "Chihuahua") - .def(py::init()) - .def("bark", &Chihuahua::bark); - py::class_(m, "Cat") - .def(py::init()) - .def("purr", &Cat::purr); - py::class_(m, "Panther") - .def(py::init()) - .def("purr", &Panther::purr); - m.def("create_zoo", &create_zoo); -}; diff --git a/pybind11/tests/test_tagbased_polymorphic.py b/pybind11/tests/test_tagbased_polymorphic.py deleted file mode 100644 index 2574d7d..0000000 --- a/pybind11/tests/test_tagbased_polymorphic.py +++ /dev/null @@ -1,20 +0,0 @@ -from pybind11_tests import tagbased_polymorphic as m - - -def test_downcast(): - zoo = m.create_zoo() - assert [type(animal) for animal in zoo] == [ - m.Labrador, m.Dog, m.Chihuahua, m.Cat, m.Panther - ] - assert [animal.name for animal in zoo] == [ - "Fido", "Ginger", "Hertzl", "Tiger", "Leo" - ] - zoo[1].sound = "woooooo" - assert [dog.bark() for dog in zoo[:3]] == [ - "Labrador Fido goes WOOF!", - "Dog Ginger goes woooooo", - "Chihuahua Hertzl goes iyiyiyiyiyi and runs in circles" - ] - assert [cat.purr() for cat in zoo[3:]] == ["mrowr", "mrrrRRRRRR"] - zoo[0].excitement -= 1000 - assert zoo[0].excitement == 14000 diff --git a/pybind11/tests/test_union.cpp b/pybind11/tests/test_union.cpp deleted file mode 100644 index 7b98ea2..0000000 --- a/pybind11/tests/test_union.cpp +++ /dev/null @@ -1,22 +0,0 @@ -/* - tests/test_class.cpp -- test py::class_ definitions and basic functionality - - Copyright (c) 2019 Roland Dreier - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" - -TEST_SUBMODULE(union_, m) { - union TestUnion { - int value_int; - unsigned value_uint; - }; - - py::class_(m, "TestUnion") - .def(py::init<>()) - .def_readonly("as_int", &TestUnion::value_int) - .def_readwrite("as_uint", &TestUnion::value_uint); -} diff --git a/pybind11/tests/test_union.py b/pybind11/tests/test_union.py deleted file mode 100644 index e1866e7..0000000 --- a/pybind11/tests/test_union.py +++ /dev/null @@ -1,8 +0,0 @@ -from pybind11_tests import union_ as m - - -def test_union(): - instance = m.TestUnion() - - instance.as_uint = 10 - assert instance.as_int == 10 diff --git a/pybind11/tests/test_virtual_functions.cpp b/pybind11/tests/test_virtual_functions.cpp deleted file mode 100644 index ccf018d..0000000 --- a/pybind11/tests/test_virtual_functions.cpp +++ /dev/null @@ -1,479 +0,0 @@ -/* - tests/test_virtual_functions.cpp -- overriding virtual functions from Python - - Copyright (c) 2016 Wenzel Jakob - - All rights reserved. Use of this source code is governed by a - BSD-style license that can be found in the LICENSE file. -*/ - -#include "pybind11_tests.h" -#include "constructor_stats.h" -#include -#include - -/* This is an example class that we'll want to be able to extend from Python */ -class ExampleVirt { -public: - ExampleVirt(int state) : state(state) { print_created(this, state); } - ExampleVirt(const ExampleVirt &e) : state(e.state) { print_copy_created(this); } - ExampleVirt(ExampleVirt &&e) : state(e.state) { print_move_created(this); e.state = 0; } - virtual ~ExampleVirt() { print_destroyed(this); } - - virtual int run(int value) { - py::print("Original implementation of " - "ExampleVirt::run(state={}, value={}, str1={}, str2={})"_s.format(state, value, get_string1(), *get_string2())); - return state + value; - } - - virtual bool run_bool() = 0; - virtual void pure_virtual() = 0; - - // Returning a reference/pointer to a type converted from python (numbers, strings, etc.) is a - // bit trickier, because the actual int& or std::string& or whatever only exists temporarily, so - // we have to handle it specially in the trampoline class (see below). - virtual const std::string &get_string1() { return str1; } - virtual const std::string *get_string2() { return &str2; } - -private: - int state; - const std::string str1{"default1"}, str2{"default2"}; -}; - -/* This is a wrapper class that must be generated */ -class PyExampleVirt : public ExampleVirt { -public: - using ExampleVirt::ExampleVirt; /* Inherit constructors */ - - int run(int value) override { - /* Generate wrapping code that enables native function overloading */ - PYBIND11_OVERLOAD( - int, /* Return type */ - ExampleVirt, /* Parent class */ - run, /* Name of function */ - value /* Argument(s) */ - ); - } - - bool run_bool() override { - PYBIND11_OVERLOAD_PURE( - bool, /* Return type */ - ExampleVirt, /* Parent class */ - run_bool, /* Name of function */ - /* This function has no arguments. The trailing comma - in the previous line is needed for some compilers */ - ); - } - - void pure_virtual() override { - PYBIND11_OVERLOAD_PURE( - void, /* Return type */ - ExampleVirt, /* Parent class */ - pure_virtual, /* Name of function */ - /* This function has no arguments. The trailing comma - in the previous line is needed for some compilers */ - ); - } - - // We can return reference types for compatibility with C++ virtual interfaces that do so, but - // note they have some significant limitations (see the documentation). - const std::string &get_string1() override { - PYBIND11_OVERLOAD( - const std::string &, /* Return type */ - ExampleVirt, /* Parent class */ - get_string1, /* Name of function */ - /* (no arguments) */ - ); - } - - const std::string *get_string2() override { - PYBIND11_OVERLOAD( - const std::string *, /* Return type */ - ExampleVirt, /* Parent class */ - get_string2, /* Name of function */ - /* (no arguments) */ - ); - } - -}; - -class NonCopyable { -public: - NonCopyable(int a, int b) : value{new int(a*b)} { print_created(this, a, b); } - NonCopyable(NonCopyable &&o) { value = std::move(o.value); print_move_created(this); } - NonCopyable(const NonCopyable &) = delete; - NonCopyable() = delete; - void operator=(const NonCopyable &) = delete; - void operator=(NonCopyable &&) = delete; - std::string get_value() const { - if (value) return std::to_string(*value); else return "(null)"; - } - ~NonCopyable() { print_destroyed(this); } - -private: - std::unique_ptr value; -}; - -// This is like the above, but is both copy and movable. In effect this means it should get moved -// when it is not referenced elsewhere, but copied if it is still referenced. -class Movable { -public: - Movable(int a, int b) : value{a+b} { print_created(this, a, b); } - Movable(const Movable &m) { value = m.value; print_copy_created(this); } - Movable(Movable &&m) { value = std::move(m.value); print_move_created(this); } - std::string get_value() const { return std::to_string(value); } - ~Movable() { print_destroyed(this); } -private: - int value; -}; - -class NCVirt { -public: - virtual ~NCVirt() { } - virtual NonCopyable get_noncopyable(int a, int b) { return NonCopyable(a, b); } - virtual Movable get_movable(int a, int b) = 0; - - std::string print_nc(int a, int b) { return get_noncopyable(a, b).get_value(); } - std::string print_movable(int a, int b) { return get_movable(a, b).get_value(); } -}; -class NCVirtTrampoline : public NCVirt { -#if !defined(__INTEL_COMPILER) - NonCopyable get_noncopyable(int a, int b) override { - PYBIND11_OVERLOAD(NonCopyable, NCVirt, get_noncopyable, a, b); - } -#endif - Movable get_movable(int a, int b) override { - PYBIND11_OVERLOAD_PURE(Movable, NCVirt, get_movable, a, b); - } -}; - -struct Base { - /* for some reason MSVC2015 can't compile this if the function is pure virtual */ - virtual std::string dispatch() const { return {}; }; - virtual ~Base() = default; -}; - -struct DispatchIssue : Base { - virtual std::string dispatch() const { - PYBIND11_OVERLOAD_PURE(std::string, Base, dispatch, /* no arguments */); - } -}; - -static void test_gil() { - { - py::gil_scoped_acquire lock; - py::print("1st lock acquired"); - - } - - { - py::gil_scoped_acquire lock; - py::print("2nd lock acquired"); - } - -} - -static void test_gil_from_thread() { - py::gil_scoped_release release; - - std::thread t(test_gil); - t.join(); -} - - -// Forward declaration (so that we can put the main tests here; the inherited virtual approaches are -// rather long). -void initialize_inherited_virtuals(py::module &m); - -TEST_SUBMODULE(virtual_functions, m) { - // test_override - py::class_(m, "ExampleVirt") - .def(py::init()) - /* Reference original class in function definitions */ - .def("run", &ExampleVirt::run) - .def("run_bool", &ExampleVirt::run_bool) - .def("pure_virtual", &ExampleVirt::pure_virtual); - - py::class_(m, "NonCopyable") - .def(py::init()); - - py::class_(m, "Movable") - .def(py::init()); - - // test_move_support -#if !defined(__INTEL_COMPILER) - py::class_(m, "NCVirt") - .def(py::init<>()) - .def("get_noncopyable", &NCVirt::get_noncopyable) - .def("get_movable", &NCVirt::get_movable) - .def("print_nc", &NCVirt::print_nc) - .def("print_movable", &NCVirt::print_movable); -#endif - - m.def("runExampleVirt", [](ExampleVirt *ex, int value) { return ex->run(value); }); - m.def("runExampleVirtBool", [](ExampleVirt* ex) { return ex->run_bool(); }); - m.def("runExampleVirtVirtual", [](ExampleVirt *ex) { ex->pure_virtual(); }); - - m.def("cstats_debug", &ConstructorStats::get); - initialize_inherited_virtuals(m); - - // test_alias_delay_initialization1 - // don't invoke Python dispatch classes by default when instantiating C++ classes - // that were not extended on the Python side - struct A { - virtual ~A() {} - virtual void f() { py::print("A.f()"); } - }; - - struct PyA : A { - PyA() { py::print("PyA.PyA()"); } - ~PyA() { py::print("PyA.~PyA()"); } - - void f() override { - py::print("PyA.f()"); - // This convolution just gives a `void`, but tests that PYBIND11_TYPE() works to protect - // a type containing a , - PYBIND11_OVERLOAD(PYBIND11_TYPE(typename std::enable_if::type), A, f); - } - }; - - py::class_(m, "A") - .def(py::init<>()) - .def("f", &A::f); - - m.def("call_f", [](A *a) { a->f(); }); - - // test_alias_delay_initialization2 - // ... unless we explicitly request it, as in this example: - struct A2 { - virtual ~A2() {} - virtual void f() { py::print("A2.f()"); } - }; - - struct PyA2 : A2 { - PyA2() { py::print("PyA2.PyA2()"); } - ~PyA2() { py::print("PyA2.~PyA2()"); } - void f() override { - py::print("PyA2.f()"); - PYBIND11_OVERLOAD(void, A2, f); - } - }; - - py::class_(m, "A2") - .def(py::init_alias<>()) - .def(py::init([](int) { return new PyA2(); })) - .def("f", &A2::f); - - m.def("call_f", [](A2 *a2) { a2->f(); }); - - // test_dispatch_issue - // #159: virtual function dispatch has problems with similar-named functions - py::class_(m, "DispatchIssue") - .def(py::init<>()) - .def("dispatch", &Base::dispatch); - - m.def("dispatch_issue_go", [](const Base * b) { return b->dispatch(); }); - - // test_override_ref - // #392/397: overriding reference-returning functions - class OverrideTest { - public: - struct A { std::string value = "hi"; }; - std::string v; - A a; - explicit OverrideTest(const std::string &v) : v{v} {} - virtual std::string str_value() { return v; } - virtual std::string &str_ref() { return v; } - virtual A A_value() { return a; } - virtual A &A_ref() { return a; } - virtual ~OverrideTest() = default; - }; - - class PyOverrideTest : public OverrideTest { - public: - using OverrideTest::OverrideTest; - std::string str_value() override { PYBIND11_OVERLOAD(std::string, OverrideTest, str_value); } - // Not allowed (uncommenting should hit a static_assert failure): we can't get a reference - // to a python numeric value, since we only copy values in the numeric type caster: -// std::string &str_ref() override { PYBIND11_OVERLOAD(std::string &, OverrideTest, str_ref); } - // But we can work around it like this: - private: - std::string _tmp; - std::string str_ref_helper() { PYBIND11_OVERLOAD(std::string, OverrideTest, str_ref); } - public: - std::string &str_ref() override { return _tmp = str_ref_helper(); } - - A A_value() override { PYBIND11_OVERLOAD(A, OverrideTest, A_value); } - A &A_ref() override { PYBIND11_OVERLOAD(A &, OverrideTest, A_ref); } - }; - - py::class_(m, "OverrideTest_A") - .def_readwrite("value", &OverrideTest::A::value); - py::class_(m, "OverrideTest") - .def(py::init()) - .def("str_value", &OverrideTest::str_value) -// .def("str_ref", &OverrideTest::str_ref) - .def("A_value", &OverrideTest::A_value) - .def("A_ref", &OverrideTest::A_ref); -} - - -// Inheriting virtual methods. We do two versions here: the repeat-everything version and the -// templated trampoline versions mentioned in docs/advanced.rst. -// -// These base classes are exactly the same, but we technically need distinct -// classes for this example code because we need to be able to bind them -// properly (pybind11, sensibly, doesn't allow us to bind the same C++ class to -// multiple python classes). -class A_Repeat { -#define A_METHODS \ -public: \ - virtual int unlucky_number() = 0; \ - virtual std::string say_something(unsigned times) { \ - std::string s = ""; \ - for (unsigned i = 0; i < times; ++i) \ - s += "hi"; \ - return s; \ - } \ - std::string say_everything() { \ - return say_something(1) + " " + std::to_string(unlucky_number()); \ - } -A_METHODS - virtual ~A_Repeat() = default; -}; -class B_Repeat : public A_Repeat { -#define B_METHODS \ -public: \ - int unlucky_number() override { return 13; } \ - std::string say_something(unsigned times) override { \ - return "B says hi " + std::to_string(times) + " times"; \ - } \ - virtual double lucky_number() { return 7.0; } -B_METHODS -}; -class C_Repeat : public B_Repeat { -#define C_METHODS \ -public: \ - int unlucky_number() override { return 4444; } \ - double lucky_number() override { return 888; } -C_METHODS -}; -class D_Repeat : public C_Repeat { -#define D_METHODS // Nothing overridden. -D_METHODS -}; - -// Base classes for templated inheritance trampolines. Identical to the repeat-everything version: -class A_Tpl { A_METHODS; virtual ~A_Tpl() = default; }; -class B_Tpl : public A_Tpl { B_METHODS }; -class C_Tpl : public B_Tpl { C_METHODS }; -class D_Tpl : public C_Tpl { D_METHODS }; - - -// Inheritance approach 1: each trampoline gets every virtual method (11 in total) -class PyA_Repeat : public A_Repeat { -public: - using A_Repeat::A_Repeat; - int unlucky_number() override { PYBIND11_OVERLOAD_PURE(int, A_Repeat, unlucky_number, ); } - std::string say_something(unsigned times) override { PYBIND11_OVERLOAD(std::string, A_Repeat, say_something, times); } -}; -class PyB_Repeat : public B_Repeat { -public: - using B_Repeat::B_Repeat; - int unlucky_number() override { PYBIND11_OVERLOAD(int, B_Repeat, unlucky_number, ); } - std::string say_something(unsigned times) override { PYBIND11_OVERLOAD(std::string, B_Repeat, say_something, times); } - double lucky_number() override { PYBIND11_OVERLOAD(double, B_Repeat, lucky_number, ); } -}; -class PyC_Repeat : public C_Repeat { -public: - using C_Repeat::C_Repeat; - int unlucky_number() override { PYBIND11_OVERLOAD(int, C_Repeat, unlucky_number, ); } - std::string say_something(unsigned times) override { PYBIND11_OVERLOAD(std::string, C_Repeat, say_something, times); } - double lucky_number() override { PYBIND11_OVERLOAD(double, C_Repeat, lucky_number, ); } -}; -class PyD_Repeat : public D_Repeat { -public: - using D_Repeat::D_Repeat; - int unlucky_number() override { PYBIND11_OVERLOAD(int, D_Repeat, unlucky_number, ); } - std::string say_something(unsigned times) override { PYBIND11_OVERLOAD(std::string, D_Repeat, say_something, times); } - double lucky_number() override { PYBIND11_OVERLOAD(double, D_Repeat, lucky_number, ); } -}; - -// Inheritance approach 2: templated trampoline classes. -// -// Advantages: -// - we have only 2 (template) class and 4 method declarations (one per virtual method, plus one for -// any override of a pure virtual method), versus 4 classes and 6 methods (MI) or 4 classes and 11 -// methods (repeat). -// - Compared to MI, we also don't have to change the non-trampoline inheritance to virtual, and can -// properly inherit constructors. -// -// Disadvantage: -// - the compiler must still generate and compile 14 different methods (more, even, than the 11 -// required for the repeat approach) instead of the 6 required for MI. (If there was no pure -// method (or no pure method override), the number would drop down to the same 11 as the repeat -// approach). -template -class PyA_Tpl : public Base { -public: - using Base::Base; // Inherit constructors - int unlucky_number() override { PYBIND11_OVERLOAD_PURE(int, Base, unlucky_number, ); } - std::string say_something(unsigned times) override { PYBIND11_OVERLOAD(std::string, Base, say_something, times); } -}; -template -class PyB_Tpl : public PyA_Tpl { -public: - using PyA_Tpl::PyA_Tpl; // Inherit constructors (via PyA_Tpl's inherited constructors) - int unlucky_number() override { PYBIND11_OVERLOAD(int, Base, unlucky_number, ); } - double lucky_number() override { PYBIND11_OVERLOAD(double, Base, lucky_number, ); } -}; -// Since C_Tpl and D_Tpl don't declare any new virtual methods, we don't actually need these (we can -// use PyB_Tpl and PyB_Tpl for the trampoline classes instead): -/* -template class PyC_Tpl : public PyB_Tpl { -public: - using PyB_Tpl::PyB_Tpl; -}; -template class PyD_Tpl : public PyC_Tpl { -public: - using PyC_Tpl::PyC_Tpl; -}; -*/ - -void initialize_inherited_virtuals(py::module &m) { - // test_inherited_virtuals - - // Method 1: repeat - py::class_(m, "A_Repeat") - .def(py::init<>()) - .def("unlucky_number", &A_Repeat::unlucky_number) - .def("say_something", &A_Repeat::say_something) - .def("say_everything", &A_Repeat::say_everything); - py::class_(m, "B_Repeat") - .def(py::init<>()) - .def("lucky_number", &B_Repeat::lucky_number); - py::class_(m, "C_Repeat") - .def(py::init<>()); - py::class_(m, "D_Repeat") - .def(py::init<>()); - - // test_ - // Method 2: Templated trampolines - py::class_>(m, "A_Tpl") - .def(py::init<>()) - .def("unlucky_number", &A_Tpl::unlucky_number) - .def("say_something", &A_Tpl::say_something) - .def("say_everything", &A_Tpl::say_everything); - py::class_>(m, "B_Tpl") - .def(py::init<>()) - .def("lucky_number", &B_Tpl::lucky_number); - py::class_>(m, "C_Tpl") - .def(py::init<>()); - py::class_>(m, "D_Tpl") - .def(py::init<>()); - - - // Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7) - m.def("test_gil", &test_gil); - m.def("test_gil_from_thread", &test_gil_from_thread); -}; diff --git a/pybind11/tests/test_virtual_functions.py b/pybind11/tests/test_virtual_functions.py deleted file mode 100644 index 5ce9abd..0000000 --- a/pybind11/tests/test_virtual_functions.py +++ /dev/null @@ -1,377 +0,0 @@ -import pytest - -from pybind11_tests import virtual_functions as m -from pybind11_tests import ConstructorStats - - -def test_override(capture, msg): - class ExtendedExampleVirt(m.ExampleVirt): - def __init__(self, state): - super(ExtendedExampleVirt, self).__init__(state + 1) - self.data = "Hello world" - - def run(self, value): - print('ExtendedExampleVirt::run(%i), calling parent..' % value) - return super(ExtendedExampleVirt, self).run(value + 1) - - def run_bool(self): - print('ExtendedExampleVirt::run_bool()') - return False - - def get_string1(self): - return "override1" - - def pure_virtual(self): - print('ExtendedExampleVirt::pure_virtual(): %s' % self.data) - - class ExtendedExampleVirt2(ExtendedExampleVirt): - def __init__(self, state): - super(ExtendedExampleVirt2, self).__init__(state + 1) - - def get_string2(self): - return "override2" - - ex12 = m.ExampleVirt(10) - with capture: - assert m.runExampleVirt(ex12, 20) == 30 - assert capture == """ - Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2) - """ # noqa: E501 line too long - - with pytest.raises(RuntimeError) as excinfo: - m.runExampleVirtVirtual(ex12) - assert msg(excinfo.value) == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"' - - ex12p = ExtendedExampleVirt(10) - with capture: - assert m.runExampleVirt(ex12p, 20) == 32 - assert capture == """ - ExtendedExampleVirt::run(20), calling parent.. - Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2) - """ # noqa: E501 line too long - with capture: - assert m.runExampleVirtBool(ex12p) is False - assert capture == "ExtendedExampleVirt::run_bool()" - with capture: - m.runExampleVirtVirtual(ex12p) - assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world" - - ex12p2 = ExtendedExampleVirt2(15) - with capture: - assert m.runExampleVirt(ex12p2, 50) == 68 - assert capture == """ - ExtendedExampleVirt::run(50), calling parent.. - Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2) - """ # noqa: E501 line too long - - cstats = ConstructorStats.get(m.ExampleVirt) - assert cstats.alive() == 3 - del ex12, ex12p, ex12p2 - assert cstats.alive() == 0 - assert cstats.values() == ['10', '11', '17'] - assert cstats.copy_constructions == 0 - assert cstats.move_constructions >= 0 - - -def test_alias_delay_initialization1(capture): - """`A` only initializes its trampoline class when we inherit from it - - If we just create and use an A instance directly, the trampoline initialization is - bypassed and we only initialize an A() instead (for performance reasons). - """ - class B(m.A): - def __init__(self): - super(B, self).__init__() - - def f(self): - print("In python f()") - - # C++ version - with capture: - a = m.A() - m.call_f(a) - del a - pytest.gc_collect() - assert capture == "A.f()" - - # Python version - with capture: - b = B() - m.call_f(b) - del b - pytest.gc_collect() - assert capture == """ - PyA.PyA() - PyA.f() - In python f() - PyA.~PyA() - """ - - -def test_alias_delay_initialization2(capture): - """`A2`, unlike the above, is configured to always initialize the alias - - While the extra initialization and extra class layer has small virtual dispatch - performance penalty, it also allows us to do more things with the trampoline - class such as defining local variables and performing construction/destruction. - """ - class B2(m.A2): - def __init__(self): - super(B2, self).__init__() - - def f(self): - print("In python B2.f()") - - # No python subclass version - with capture: - a2 = m.A2() - m.call_f(a2) - del a2 - pytest.gc_collect() - a3 = m.A2(1) - m.call_f(a3) - del a3 - pytest.gc_collect() - assert capture == """ - PyA2.PyA2() - PyA2.f() - A2.f() - PyA2.~PyA2() - PyA2.PyA2() - PyA2.f() - A2.f() - PyA2.~PyA2() - """ - - # Python subclass version - with capture: - b2 = B2() - m.call_f(b2) - del b2 - pytest.gc_collect() - assert capture == """ - PyA2.PyA2() - PyA2.f() - In python B2.f() - PyA2.~PyA2() - """ - - -# PyPy: Reference count > 1 causes call with noncopyable instance -# to fail in ncv1.print_nc() -@pytest.unsupported_on_pypy -@pytest.mark.skipif(not hasattr(m, "NCVirt"), reason="NCVirt test broken on ICPC") -def test_move_support(): - class NCVirtExt(m.NCVirt): - def get_noncopyable(self, a, b): - # Constructs and returns a new instance: - nc = m.NonCopyable(a * a, b * b) - return nc - - def get_movable(self, a, b): - # Return a referenced copy - self.movable = m.Movable(a, b) - return self.movable - - class NCVirtExt2(m.NCVirt): - def get_noncopyable(self, a, b): - # Keep a reference: this is going to throw an exception - self.nc = m.NonCopyable(a, b) - return self.nc - - def get_movable(self, a, b): - # Return a new instance without storing it - return m.Movable(a, b) - - ncv1 = NCVirtExt() - assert ncv1.print_nc(2, 3) == "36" - assert ncv1.print_movable(4, 5) == "9" - ncv2 = NCVirtExt2() - assert ncv2.print_movable(7, 7) == "14" - # Don't check the exception message here because it differs under debug/non-debug mode - with pytest.raises(RuntimeError): - ncv2.print_nc(9, 9) - - nc_stats = ConstructorStats.get(m.NonCopyable) - mv_stats = ConstructorStats.get(m.Movable) - assert nc_stats.alive() == 1 - assert mv_stats.alive() == 1 - del ncv1, ncv2 - assert nc_stats.alive() == 0 - assert mv_stats.alive() == 0 - assert nc_stats.values() == ['4', '9', '9', '9'] - assert mv_stats.values() == ['4', '5', '7', '7'] - assert nc_stats.copy_constructions == 0 - assert mv_stats.copy_constructions == 1 - assert nc_stats.move_constructions >= 0 - assert mv_stats.move_constructions >= 0 - - -def test_dispatch_issue(msg): - """#159: virtual function dispatch has problems with similar-named functions""" - class PyClass1(m.DispatchIssue): - def dispatch(self): - return "Yay.." - - class PyClass2(m.DispatchIssue): - def dispatch(self): - with pytest.raises(RuntimeError) as excinfo: - super(PyClass2, self).dispatch() - assert msg(excinfo.value) == 'Tried to call pure virtual function "Base::dispatch"' - - p = PyClass1() - return m.dispatch_issue_go(p) - - b = PyClass2() - assert m.dispatch_issue_go(b) == "Yay.." - - -def test_override_ref(): - """#392/397: overriding reference-returning functions""" - o = m.OverrideTest("asdf") - - # Not allowed (see associated .cpp comment) - # i = o.str_ref() - # assert o.str_ref() == "asdf" - assert o.str_value() == "asdf" - - assert o.A_value().value == "hi" - a = o.A_ref() - assert a.value == "hi" - a.value = "bye" - assert a.value == "bye" - - -def test_inherited_virtuals(): - class AR(m.A_Repeat): - def unlucky_number(self): - return 99 - - class AT(m.A_Tpl): - def unlucky_number(self): - return 999 - - obj = AR() - assert obj.say_something(3) == "hihihi" - assert obj.unlucky_number() == 99 - assert obj.say_everything() == "hi 99" - - obj = AT() - assert obj.say_something(3) == "hihihi" - assert obj.unlucky_number() == 999 - assert obj.say_everything() == "hi 999" - - for obj in [m.B_Repeat(), m.B_Tpl()]: - assert obj.say_something(3) == "B says hi 3 times" - assert obj.unlucky_number() == 13 - assert obj.lucky_number() == 7.0 - assert obj.say_everything() == "B says hi 1 times 13" - - for obj in [m.C_Repeat(), m.C_Tpl()]: - assert obj.say_something(3) == "B says hi 3 times" - assert obj.unlucky_number() == 4444 - assert obj.lucky_number() == 888.0 - assert obj.say_everything() == "B says hi 1 times 4444" - - class CR(m.C_Repeat): - def lucky_number(self): - return m.C_Repeat.lucky_number(self) + 1.25 - - obj = CR() - assert obj.say_something(3) == "B says hi 3 times" - assert obj.unlucky_number() == 4444 - assert obj.lucky_number() == 889.25 - assert obj.say_everything() == "B says hi 1 times 4444" - - class CT(m.C_Tpl): - pass - - obj = CT() - assert obj.say_something(3) == "B says hi 3 times" - assert obj.unlucky_number() == 4444 - assert obj.lucky_number() == 888.0 - assert obj.say_everything() == "B says hi 1 times 4444" - - class CCR(CR): - def lucky_number(self): - return CR.lucky_number(self) * 10 - - obj = CCR() - assert obj.say_something(3) == "B says hi 3 times" - assert obj.unlucky_number() == 4444 - assert obj.lucky_number() == 8892.5 - assert obj.say_everything() == "B says hi 1 times 4444" - - class CCT(CT): - def lucky_number(self): - return CT.lucky_number(self) * 1000 - - obj = CCT() - assert obj.say_something(3) == "B says hi 3 times" - assert obj.unlucky_number() == 4444 - assert obj.lucky_number() == 888000.0 - assert obj.say_everything() == "B says hi 1 times 4444" - - class DR(m.D_Repeat): - def unlucky_number(self): - return 123 - - def lucky_number(self): - return 42.0 - - for obj in [m.D_Repeat(), m.D_Tpl()]: - assert obj.say_something(3) == "B says hi 3 times" - assert obj.unlucky_number() == 4444 - assert obj.lucky_number() == 888.0 - assert obj.say_everything() == "B says hi 1 times 4444" - - obj = DR() - assert obj.say_something(3) == "B says hi 3 times" - assert obj.unlucky_number() == 123 - assert obj.lucky_number() == 42.0 - assert obj.say_everything() == "B says hi 1 times 123" - - class DT(m.D_Tpl): - def say_something(self, times): - return "DT says:" + (' quack' * times) - - def unlucky_number(self): - return 1234 - - def lucky_number(self): - return -4.25 - - obj = DT() - assert obj.say_something(3) == "DT says: quack quack quack" - assert obj.unlucky_number() == 1234 - assert obj.lucky_number() == -4.25 - assert obj.say_everything() == "DT says: quack 1234" - - class DT2(DT): - def say_something(self, times): - return "DT2: " + ('QUACK' * times) - - def unlucky_number(self): - return -3 - - class BT(m.B_Tpl): - def say_something(self, times): - return "BT" * times - - def unlucky_number(self): - return -7 - - def lucky_number(self): - return -1.375 - - obj = BT() - assert obj.say_something(3) == "BTBTBT" - assert obj.unlucky_number() == -7 - assert obj.lucky_number() == -1.375 - assert obj.say_everything() == "BT -7" - - -def test_issue_1454(): - # Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7) - m.test_gil() - m.test_gil_from_thread() diff --git a/pybind11/tools/FindCatch.cmake b/pybind11/tools/FindCatch.cmake deleted file mode 100644 index 9d490c5..0000000 --- a/pybind11/tools/FindCatch.cmake +++ /dev/null @@ -1,57 +0,0 @@ -# - Find the Catch test framework or download it (single header) -# -# This is a quick module for internal use. It assumes that Catch is -# REQUIRED and that a minimum version is provided (not EXACT). If -# a suitable version isn't found locally, the single header file -# will be downloaded and placed in the build dir: PROJECT_BINARY_DIR. -# -# This code sets the following variables: -# CATCH_INCLUDE_DIR - path to catch.hpp -# CATCH_VERSION - version number - -if(NOT Catch_FIND_VERSION) - message(FATAL_ERROR "A version number must be specified.") -elseif(Catch_FIND_REQUIRED) - message(FATAL_ERROR "This module assumes Catch is not required.") -elseif(Catch_FIND_VERSION_EXACT) - message(FATAL_ERROR "Exact version numbers are not supported, only minimum.") -endif() - -# Extract the version number from catch.hpp -function(_get_catch_version) - file(STRINGS "${CATCH_INCLUDE_DIR}/catch.hpp" version_line REGEX "Catch v.*" LIMIT_COUNT 1) - if(version_line MATCHES "Catch v([0-9]+)\\.([0-9]+)\\.([0-9]+)") - set(CATCH_VERSION "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}.${CMAKE_MATCH_3}" PARENT_SCOPE) - endif() -endfunction() - -# Download the single-header version of Catch -function(_download_catch version destination_dir) - message(STATUS "Downloading catch v${version}...") - set(url https://github.com/philsquared/Catch/releases/download/v${version}/catch.hpp) - file(DOWNLOAD ${url} "${destination_dir}/catch.hpp" STATUS status) - list(GET status 0 error) - if(error) - message(FATAL_ERROR "Could not download ${url}") - endif() - set(CATCH_INCLUDE_DIR "${destination_dir}" CACHE INTERNAL "") -endfunction() - -# Look for catch locally -find_path(CATCH_INCLUDE_DIR NAMES catch.hpp PATH_SUFFIXES catch) -if(CATCH_INCLUDE_DIR) - _get_catch_version() -endif() - -# Download the header if it wasn't found or if it's outdated -if(NOT CATCH_VERSION OR CATCH_VERSION VERSION_LESS ${Catch_FIND_VERSION}) - if(DOWNLOAD_CATCH) - _download_catch(${Catch_FIND_VERSION} "${PROJECT_BINARY_DIR}/catch/") - _get_catch_version() - else() - set(CATCH_FOUND FALSE) - return() - endif() -endif() - -set(CATCH_FOUND TRUE) diff --git a/pybind11/tools/FindEigen3.cmake b/pybind11/tools/FindEigen3.cmake deleted file mode 100644 index 9c546a0..0000000 --- a/pybind11/tools/FindEigen3.cmake +++ /dev/null @@ -1,81 +0,0 @@ -# - Try to find Eigen3 lib -# -# This module supports requiring a minimum version, e.g. you can do -# find_package(Eigen3 3.1.2) -# to require version 3.1.2 or newer of Eigen3. -# -# Once done this will define -# -# EIGEN3_FOUND - system has eigen lib with correct version -# EIGEN3_INCLUDE_DIR - the eigen include directory -# EIGEN3_VERSION - eigen version - -# Copyright (c) 2006, 2007 Montel Laurent, -# Copyright (c) 2008, 2009 Gael Guennebaud, -# Copyright (c) 2009 Benoit Jacob -# Redistribution and use is allowed according to the terms of the 2-clause BSD license. - -if(NOT Eigen3_FIND_VERSION) - if(NOT Eigen3_FIND_VERSION_MAJOR) - set(Eigen3_FIND_VERSION_MAJOR 2) - endif(NOT Eigen3_FIND_VERSION_MAJOR) - if(NOT Eigen3_FIND_VERSION_MINOR) - set(Eigen3_FIND_VERSION_MINOR 91) - endif(NOT Eigen3_FIND_VERSION_MINOR) - if(NOT Eigen3_FIND_VERSION_PATCH) - set(Eigen3_FIND_VERSION_PATCH 0) - endif(NOT Eigen3_FIND_VERSION_PATCH) - - set(Eigen3_FIND_VERSION "${Eigen3_FIND_VERSION_MAJOR}.${Eigen3_FIND_VERSION_MINOR}.${Eigen3_FIND_VERSION_PATCH}") -endif(NOT Eigen3_FIND_VERSION) - -macro(_eigen3_check_version) - file(READ "${EIGEN3_INCLUDE_DIR}/Eigen/src/Core/util/Macros.h" _eigen3_version_header) - - string(REGEX MATCH "define[ \t]+EIGEN_WORLD_VERSION[ \t]+([0-9]+)" _eigen3_world_version_match "${_eigen3_version_header}") - set(EIGEN3_WORLD_VERSION "${CMAKE_MATCH_1}") - string(REGEX MATCH "define[ \t]+EIGEN_MAJOR_VERSION[ \t]+([0-9]+)" _eigen3_major_version_match "${_eigen3_version_header}") - set(EIGEN3_MAJOR_VERSION "${CMAKE_MATCH_1}") - string(REGEX MATCH "define[ \t]+EIGEN_MINOR_VERSION[ \t]+([0-9]+)" _eigen3_minor_version_match "${_eigen3_version_header}") - set(EIGEN3_MINOR_VERSION "${CMAKE_MATCH_1}") - - set(EIGEN3_VERSION ${EIGEN3_WORLD_VERSION}.${EIGEN3_MAJOR_VERSION}.${EIGEN3_MINOR_VERSION}) - if(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) - set(EIGEN3_VERSION_OK FALSE) - else(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) - set(EIGEN3_VERSION_OK TRUE) - endif(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) - - if(NOT EIGEN3_VERSION_OK) - - message(STATUS "Eigen3 version ${EIGEN3_VERSION} found in ${EIGEN3_INCLUDE_DIR}, " - "but at least version ${Eigen3_FIND_VERSION} is required") - endif(NOT EIGEN3_VERSION_OK) -endmacro(_eigen3_check_version) - -if (EIGEN3_INCLUDE_DIR) - - # in cache already - _eigen3_check_version() - set(EIGEN3_FOUND ${EIGEN3_VERSION_OK}) - -else (EIGEN3_INCLUDE_DIR) - - find_path(EIGEN3_INCLUDE_DIR NAMES signature_of_eigen3_matrix_library - PATHS - ${CMAKE_INSTALL_PREFIX}/include - ${KDE4_INCLUDE_DIR} - PATH_SUFFIXES eigen3 eigen - ) - - if(EIGEN3_INCLUDE_DIR) - _eigen3_check_version() - endif(EIGEN3_INCLUDE_DIR) - - include(FindPackageHandleStandardArgs) - find_package_handle_standard_args(Eigen3 DEFAULT_MSG EIGEN3_INCLUDE_DIR EIGEN3_VERSION_OK) - - mark_as_advanced(EIGEN3_INCLUDE_DIR) - -endif(EIGEN3_INCLUDE_DIR) - diff --git a/pybind11/tools/FindPythonLibsNew.cmake b/pybind11/tools/FindPythonLibsNew.cmake deleted file mode 100644 index e660c5f..0000000 --- a/pybind11/tools/FindPythonLibsNew.cmake +++ /dev/null @@ -1,202 +0,0 @@ -# - Find python libraries -# This module finds the libraries corresponding to the Python interpreter -# FindPythonInterp provides. -# This code sets the following variables: -# -# PYTHONLIBS_FOUND - have the Python libs been found -# PYTHON_PREFIX - path to the Python installation -# PYTHON_LIBRARIES - path to the python library -# PYTHON_INCLUDE_DIRS - path to where Python.h is found -# PYTHON_MODULE_EXTENSION - lib extension, e.g. '.so' or '.pyd' -# PYTHON_MODULE_PREFIX - lib name prefix: usually an empty string -# PYTHON_SITE_PACKAGES - path to installation site-packages -# PYTHON_IS_DEBUG - whether the Python interpreter is a debug build -# -# Thanks to talljimbo for the patch adding the 'LDVERSION' config -# variable usage. - -#============================================================================= -# Copyright 2001-2009 Kitware, Inc. -# Copyright 2012 Continuum Analytics, Inc. -# -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# -# * Neither the names of Kitware, Inc., the Insight Software Consortium, -# nor the names of their contributors may be used to endorse or promote -# products derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#============================================================================= - -# Checking for the extension makes sure that `LibsNew` was found and not just `Libs`. -if(PYTHONLIBS_FOUND AND PYTHON_MODULE_EXTENSION) - return() -endif() - -# Use the Python interpreter to find the libs. -if(PythonLibsNew_FIND_REQUIRED) - find_package(PythonInterp ${PythonLibsNew_FIND_VERSION} REQUIRED) -else() - find_package(PythonInterp ${PythonLibsNew_FIND_VERSION}) -endif() - -if(NOT PYTHONINTERP_FOUND) - set(PYTHONLIBS_FOUND FALSE) - set(PythonLibsNew_FOUND FALSE) - return() -endif() - -# According to http://stackoverflow.com/questions/646518/python-how-to-detect-debug-interpreter -# testing whether sys has the gettotalrefcount function is a reliable, cross-platform -# way to detect a CPython debug interpreter. -# -# The library suffix is from the config var LDVERSION sometimes, otherwise -# VERSION. VERSION will typically be like "2.7" on unix, and "27" on windows. -execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" - "from distutils import sysconfig as s;import sys;import struct; -print('.'.join(str(v) for v in sys.version_info)); -print(sys.prefix); -print(s.get_python_inc(plat_specific=True)); -print(s.get_python_lib(plat_specific=True)); -print(s.get_config_var('SO')); -print(hasattr(sys, 'gettotalrefcount')+0); -print(struct.calcsize('@P')); -print(s.get_config_var('LDVERSION') or s.get_config_var('VERSION')); -print(s.get_config_var('LIBDIR') or ''); -print(s.get_config_var('MULTIARCH') or ''); -" - RESULT_VARIABLE _PYTHON_SUCCESS - OUTPUT_VARIABLE _PYTHON_VALUES - ERROR_VARIABLE _PYTHON_ERROR_VALUE) - -if(NOT _PYTHON_SUCCESS MATCHES 0) - if(PythonLibsNew_FIND_REQUIRED) - message(FATAL_ERROR - "Python config failure:\n${_PYTHON_ERROR_VALUE}") - endif() - set(PYTHONLIBS_FOUND FALSE) - set(PythonLibsNew_FOUND FALSE) - return() -endif() - -# Convert the process output into a list -if(WIN32) - string(REGEX REPLACE "\\\\" "/" _PYTHON_VALUES ${_PYTHON_VALUES}) -endif() -string(REGEX REPLACE ";" "\\\\;" _PYTHON_VALUES ${_PYTHON_VALUES}) -string(REGEX REPLACE "\n" ";" _PYTHON_VALUES ${_PYTHON_VALUES}) -list(GET _PYTHON_VALUES 0 _PYTHON_VERSION_LIST) -list(GET _PYTHON_VALUES 1 PYTHON_PREFIX) -list(GET _PYTHON_VALUES 2 PYTHON_INCLUDE_DIR) -list(GET _PYTHON_VALUES 3 PYTHON_SITE_PACKAGES) -list(GET _PYTHON_VALUES 4 PYTHON_MODULE_EXTENSION) -list(GET _PYTHON_VALUES 5 PYTHON_IS_DEBUG) -list(GET _PYTHON_VALUES 6 PYTHON_SIZEOF_VOID_P) -list(GET _PYTHON_VALUES 7 PYTHON_LIBRARY_SUFFIX) -list(GET _PYTHON_VALUES 8 PYTHON_LIBDIR) -list(GET _PYTHON_VALUES 9 PYTHON_MULTIARCH) - -# Make sure the Python has the same pointer-size as the chosen compiler -# Skip if CMAKE_SIZEOF_VOID_P is not defined -if(CMAKE_SIZEOF_VOID_P AND (NOT "${PYTHON_SIZEOF_VOID_P}" STREQUAL "${CMAKE_SIZEOF_VOID_P}")) - if(PythonLibsNew_FIND_REQUIRED) - math(EXPR _PYTHON_BITS "${PYTHON_SIZEOF_VOID_P} * 8") - math(EXPR _CMAKE_BITS "${CMAKE_SIZEOF_VOID_P} * 8") - message(FATAL_ERROR - "Python config failure: Python is ${_PYTHON_BITS}-bit, " - "chosen compiler is ${_CMAKE_BITS}-bit") - endif() - set(PYTHONLIBS_FOUND FALSE) - set(PythonLibsNew_FOUND FALSE) - return() -endif() - -# The built-in FindPython didn't always give the version numbers -string(REGEX REPLACE "\\." ";" _PYTHON_VERSION_LIST ${_PYTHON_VERSION_LIST}) -list(GET _PYTHON_VERSION_LIST 0 PYTHON_VERSION_MAJOR) -list(GET _PYTHON_VERSION_LIST 1 PYTHON_VERSION_MINOR) -list(GET _PYTHON_VERSION_LIST 2 PYTHON_VERSION_PATCH) - -# Make sure all directory separators are '/' -string(REGEX REPLACE "\\\\" "/" PYTHON_PREFIX ${PYTHON_PREFIX}) -string(REGEX REPLACE "\\\\" "/" PYTHON_INCLUDE_DIR ${PYTHON_INCLUDE_DIR}) -string(REGEX REPLACE "\\\\" "/" PYTHON_SITE_PACKAGES ${PYTHON_SITE_PACKAGES}) - -if(CMAKE_HOST_WIN32 AND NOT (MSYS OR MINGW)) - set(PYTHON_LIBRARY - "${PYTHON_PREFIX}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib") - - # when run in a venv, PYTHON_PREFIX points to it. But the libraries remain in the - # original python installation. They may be found relative to PYTHON_INCLUDE_DIR. - if(NOT EXISTS "${PYTHON_LIBRARY}") - get_filename_component(_PYTHON_ROOT ${PYTHON_INCLUDE_DIR} DIRECTORY) - set(PYTHON_LIBRARY - "${_PYTHON_ROOT}/libs/Python${PYTHON_LIBRARY_SUFFIX}.lib") - endif() - - # raise an error if the python libs are still not found. - if(NOT EXISTS "${PYTHON_LIBRARY}") - message(FATAL_ERROR "Python libraries not found") - endif() - -else() - if(PYTHON_MULTIARCH) - set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}/${PYTHON_MULTIARCH}" "${PYTHON_LIBDIR}") - else() - set(_PYTHON_LIBS_SEARCH "${PYTHON_LIBDIR}") - endif() - #message(STATUS "Searching for Python libs in ${_PYTHON_LIBS_SEARCH}") - # Probably this needs to be more involved. It would be nice if the config - # information the python interpreter itself gave us were more complete. - find_library(PYTHON_LIBRARY - NAMES "python${PYTHON_LIBRARY_SUFFIX}" - PATHS ${_PYTHON_LIBS_SEARCH} - NO_DEFAULT_PATH) - - # If all else fails, just set the name/version and let the linker figure out the path. - if(NOT PYTHON_LIBRARY) - set(PYTHON_LIBRARY python${PYTHON_LIBRARY_SUFFIX}) - endif() -endif() - -MARK_AS_ADVANCED( - PYTHON_LIBRARY - PYTHON_INCLUDE_DIR -) - -# We use PYTHON_INCLUDE_DIR, PYTHON_LIBRARY and PYTHON_DEBUG_LIBRARY for the -# cache entries because they are meant to specify the location of a single -# library. We now set the variables listed by the documentation for this -# module. -SET(PYTHON_INCLUDE_DIRS "${PYTHON_INCLUDE_DIR}") -SET(PYTHON_LIBRARIES "${PYTHON_LIBRARY}") -SET(PYTHON_DEBUG_LIBRARIES "${PYTHON_DEBUG_LIBRARY}") - -find_package_message(PYTHON - "Found PythonLibs: ${PYTHON_LIBRARY}" - "${PYTHON_EXECUTABLE}${PYTHON_VERSION}") - -set(PYTHONLIBS_FOUND TRUE) -set(PythonLibsNew_FOUND TRUE) diff --git a/pybind11/tools/check-style.sh b/pybind11/tools/check-style.sh deleted file mode 100644 index 0a9f7d2..0000000 --- a/pybind11/tools/check-style.sh +++ /dev/null @@ -1,70 +0,0 @@ -#!/bin/bash -# -# Script to check include/test code for common pybind11 code style errors. -# -# This script currently checks for -# -# 1. use of tabs instead of spaces -# 2. MSDOS-style CRLF endings -# 3. trailing spaces -# 4. missing space between keyword and parenthesis, e.g.: for(, if(, while( -# 5. Missing space between right parenthesis and brace, e.g. 'for (...){' -# 6. opening brace on its own line. It should always be on the same line as the -# if/while/for/do statement. -# -# Invoke as: tools/check-style.sh -# - -check_style_errors=0 -IFS=$'\n' - -found="$( GREP_COLORS='mt=41' GREP_COLOR='41' grep $'\t' include tests/*.{cpp,py,h} docs/*.rst -rn --color=always )" -if [ -n "$found" ]; then - # The mt=41 sets a red background for matched tabs: - echo -e '\033[31;01mError: found tab characters in the following files:\033[0m' - check_style_errors=1 - echo "$found" | sed -e 's/^/ /' -fi - - -found="$( grep -IUlr $'\r' include tests/*.{cpp,py,h} docs/*.rst --color=always )" -if [ -n "$found" ]; then - echo -e '\033[31;01mError: found CRLF characters in the following files:\033[0m' - check_style_errors=1 - echo "$found" | sed -e 's/^/ /' -fi - -found="$(GREP_COLORS='mt=41' GREP_COLOR='41' grep '[[:blank:]]\+$' include tests/*.{cpp,py,h} docs/*.rst -rn --color=always )" -if [ -n "$found" ]; then - # The mt=41 sets a red background for matched trailing spaces - echo -e '\033[31;01mError: found trailing spaces in the following files:\033[0m' - check_style_errors=1 - echo "$found" | sed -e 's/^/ /' -fi - -found="$(grep '\<\(if\|for\|while\|catch\)(\|){' include tests/*.{cpp,h} -rn --color=always)" -if [ -n "$found" ]; then - echo -e '\033[31;01mError: found the following coding style problems:\033[0m' - check_style_errors=1 - echo "$found" | sed -e 's/^/ /' -fi - -found="$(awk ' -function prefix(filename, lineno) { - return " \033[35m" filename "\033[36m:\033[32m" lineno "\033[36m:\033[0m" -} -function mark(pattern, string) { sub(pattern, "\033[01;31m&\033[0m", string); return string } -last && /^\s*{/ { - print prefix(FILENAME, FNR-1) mark("\\)\\s*$", last) - print prefix(FILENAME, FNR) mark("^\\s*{", $0) - last="" -} -{ last = /(if|for|while|catch|switch)\s*\(.*\)\s*$/ ? $0 : "" } -' $(find include -type f) tests/*.{cpp,h} docs/*.rst)" -if [ -n "$found" ]; then - check_style_errors=1 - echo -e '\033[31;01mError: braces should occur on the same line as the if/while/.. statement. Found issues in the following files:\033[0m' - echo "$found" -fi - -exit $check_style_errors diff --git a/pybind11/tools/libsize.py b/pybind11/tools/libsize.py deleted file mode 100644 index 5dcb8b0..0000000 --- a/pybind11/tools/libsize.py +++ /dev/null @@ -1,38 +0,0 @@ -from __future__ import print_function, division -import os -import sys - -# Internal build script for generating debugging test .so size. -# Usage: -# python libsize.py file.so save.txt -- displays the size of file.so and, if save.txt exists, compares it to the -# size in it, then overwrites save.txt with the new size for future runs. - -if len(sys.argv) != 3: - sys.exit("Invalid arguments: usage: python libsize.py file.so save.txt") - -lib = sys.argv[1] -save = sys.argv[2] - -if not os.path.exists(lib): - sys.exit("Error: requested file ({}) does not exist".format(lib)) - -libsize = os.path.getsize(lib) - -print("------", os.path.basename(lib), "file size:", libsize, end='') - -if os.path.exists(save): - with open(save) as sf: - oldsize = int(sf.readline()) - - if oldsize > 0: - change = libsize - oldsize - if change == 0: - print(" (no change)") - else: - print(" (change of {:+} bytes = {:+.2%})".format(change, change / oldsize)) -else: - print() - -with open(save, 'w') as sf: - sf.write(str(libsize)) - diff --git a/pybind11/tools/mkdoc.py b/pybind11/tools/mkdoc.py deleted file mode 100644 index 44164af..0000000 --- a/pybind11/tools/mkdoc.py +++ /dev/null @@ -1,379 +0,0 @@ -#!/usr/bin/env python3 -# -# Syntax: mkdoc.py [-I ..] [.. a list of header files ..] -# -# Extract documentation from C++ header files to use it in Python bindings -# - -import os -import sys -import platform -import re -import textwrap - -from clang import cindex -from clang.cindex import CursorKind -from collections import OrderedDict -from glob import glob -from threading import Thread, Semaphore -from multiprocessing import cpu_count - -RECURSE_LIST = [ - CursorKind.TRANSLATION_UNIT, - CursorKind.NAMESPACE, - CursorKind.CLASS_DECL, - CursorKind.STRUCT_DECL, - CursorKind.ENUM_DECL, - CursorKind.CLASS_TEMPLATE -] - -PRINT_LIST = [ - CursorKind.CLASS_DECL, - CursorKind.STRUCT_DECL, - CursorKind.ENUM_DECL, - CursorKind.ENUM_CONSTANT_DECL, - CursorKind.CLASS_TEMPLATE, - CursorKind.FUNCTION_DECL, - CursorKind.FUNCTION_TEMPLATE, - CursorKind.CONVERSION_FUNCTION, - CursorKind.CXX_METHOD, - CursorKind.CONSTRUCTOR, - CursorKind.FIELD_DECL -] - -PREFIX_BLACKLIST = [ - CursorKind.TRANSLATION_UNIT -] - -CPP_OPERATORS = { - '<=': 'le', '>=': 'ge', '==': 'eq', '!=': 'ne', '[]': 'array', - '+=': 'iadd', '-=': 'isub', '*=': 'imul', '/=': 'idiv', '%=': - 'imod', '&=': 'iand', '|=': 'ior', '^=': 'ixor', '<<=': 'ilshift', - '>>=': 'irshift', '++': 'inc', '--': 'dec', '<<': 'lshift', '>>': - 'rshift', '&&': 'land', '||': 'lor', '!': 'lnot', '~': 'bnot', - '&': 'band', '|': 'bor', '+': 'add', '-': 'sub', '*': 'mul', '/': - 'div', '%': 'mod', '<': 'lt', '>': 'gt', '=': 'assign', '()': 'call' -} - -CPP_OPERATORS = OrderedDict( - sorted(CPP_OPERATORS.items(), key=lambda t: -len(t[0]))) - -job_count = cpu_count() -job_semaphore = Semaphore(job_count) - - -class NoFilenamesError(ValueError): - pass - - -def d(s): - return s if isinstance(s, str) else s.decode('utf8') - - -def sanitize_name(name): - name = re.sub(r'type-parameter-0-([0-9]+)', r'T\1', name) - for k, v in CPP_OPERATORS.items(): - name = name.replace('operator%s' % k, 'operator_%s' % v) - name = re.sub('<.*>', '', name) - name = ''.join([ch if ch.isalnum() else '_' for ch in name]) - name = re.sub('_$', '', re.sub('_+', '_', name)) - return '__doc_' + name - - -def process_comment(comment): - result = '' - - # Remove C++ comment syntax - leading_spaces = float('inf') - for s in comment.expandtabs(tabsize=4).splitlines(): - s = s.strip() - if s.startswith('/*'): - s = s[2:].lstrip('*') - elif s.endswith('*/'): - s = s[:-2].rstrip('*') - elif s.startswith('///'): - s = s[3:] - if s.startswith('*'): - s = s[1:] - if len(s) > 0: - leading_spaces = min(leading_spaces, len(s) - len(s.lstrip())) - result += s + '\n' - - if leading_spaces != float('inf'): - result2 = "" - for s in result.splitlines(): - result2 += s[leading_spaces:] + '\n' - result = result2 - - # Doxygen tags - cpp_group = '([\w:]+)' - param_group = '([\[\w:\]]+)' - - s = result - s = re.sub(r'\\c\s+%s' % cpp_group, r'``\1``', s) - s = re.sub(r'\\a\s+%s' % cpp_group, r'*\1*', s) - s = re.sub(r'\\e\s+%s' % cpp_group, r'*\1*', s) - s = re.sub(r'\\em\s+%s' % cpp_group, r'*\1*', s) - s = re.sub(r'\\b\s+%s' % cpp_group, r'**\1**', s) - s = re.sub(r'\\ingroup\s+%s' % cpp_group, r'', s) - s = re.sub(r'\\param%s?\s+%s' % (param_group, cpp_group), - r'\n\n$Parameter ``\2``:\n\n', s) - s = re.sub(r'\\tparam%s?\s+%s' % (param_group, cpp_group), - r'\n\n$Template parameter ``\2``:\n\n', s) - - for in_, out_ in { - 'return': 'Returns', - 'author': 'Author', - 'authors': 'Authors', - 'copyright': 'Copyright', - 'date': 'Date', - 'remark': 'Remark', - 'sa': 'See also', - 'see': 'See also', - 'extends': 'Extends', - 'throw': 'Throws', - 'throws': 'Throws' - }.items(): - s = re.sub(r'\\%s\s*' % in_, r'\n\n$%s:\n\n' % out_, s) - - s = re.sub(r'\\details\s*', r'\n\n', s) - s = re.sub(r'\\brief\s*', r'', s) - s = re.sub(r'\\short\s*', r'', s) - s = re.sub(r'\\ref\s*', r'', s) - - s = re.sub(r'\\code\s?(.*?)\s?\\endcode', - r"```\n\1\n```\n", s, flags=re.DOTALL) - - # HTML/TeX tags - s = re.sub(r'(.*?)', r'``\1``', s, flags=re.DOTALL) - s = re.sub(r'
(.*?)
', r"```\n\1\n```\n", s, flags=re.DOTALL) - s = re.sub(r'(.*?)', r'*\1*', s, flags=re.DOTALL) - s = re.sub(r'(.*?)', r'**\1**', s, flags=re.DOTALL) - s = re.sub(r'\\f\$(.*?)\\f\$', r'$\1$', s, flags=re.DOTALL) - s = re.sub(r'
  • ', r'\n\n* ', s) - s = re.sub(r'', r'', s) - s = re.sub(r'
  • ', r'\n\n', s) - - s = s.replace('``true``', '``True``') - s = s.replace('``false``', '``False``') - - # Re-flow text - wrapper = textwrap.TextWrapper() - wrapper.expand_tabs = True - wrapper.replace_whitespace = True - wrapper.drop_whitespace = True - wrapper.width = 70 - wrapper.initial_indent = wrapper.subsequent_indent = '' - - result = '' - in_code_segment = False - for x in re.split(r'(```)', s): - if x == '```': - if not in_code_segment: - result += '```\n' - else: - result += '\n```\n\n' - in_code_segment = not in_code_segment - elif in_code_segment: - result += x.strip() - else: - for y in re.split(r'(?: *\n *){2,}', x): - wrapped = wrapper.fill(re.sub(r'\s+', ' ', y).strip()) - if len(wrapped) > 0 and wrapped[0] == '$': - result += wrapped[1:] + '\n' - wrapper.initial_indent = \ - wrapper.subsequent_indent = ' ' * 4 - else: - if len(wrapped) > 0: - result += wrapped + '\n\n' - wrapper.initial_indent = wrapper.subsequent_indent = '' - return result.rstrip().lstrip('\n') - - -def extract(filename, node, prefix, output): - if not (node.location.file is None or - os.path.samefile(d(node.location.file.name), filename)): - return 0 - if node.kind in RECURSE_LIST: - sub_prefix = prefix - if node.kind not in PREFIX_BLACKLIST: - if len(sub_prefix) > 0: - sub_prefix += '_' - sub_prefix += d(node.spelling) - for i in node.get_children(): - extract(filename, i, sub_prefix, output) - if node.kind in PRINT_LIST: - comment = d(node.raw_comment) if node.raw_comment is not None else '' - comment = process_comment(comment) - sub_prefix = prefix - if len(sub_prefix) > 0: - sub_prefix += '_' - if len(node.spelling) > 0: - name = sanitize_name(sub_prefix + d(node.spelling)) - output.append((name, filename, comment)) - - -class ExtractionThread(Thread): - def __init__(self, filename, parameters, output): - Thread.__init__(self) - self.filename = filename - self.parameters = parameters - self.output = output - job_semaphore.acquire() - - def run(self): - print('Processing "%s" ..' % self.filename, file=sys.stderr) - try: - index = cindex.Index( - cindex.conf.lib.clang_createIndex(False, True)) - tu = index.parse(self.filename, self.parameters) - extract(self.filename, tu.cursor, '', self.output) - finally: - job_semaphore.release() - - -def read_args(args): - parameters = [] - filenames = [] - if "-x" not in args: - parameters.extend(['-x', 'c++']) - if not any(it.startswith("-std=") for it in args): - parameters.append('-std=c++11') - - if platform.system() == 'Darwin': - dev_path = '/Applications/Xcode.app/Contents/Developer/' - lib_dir = dev_path + 'Toolchains/XcodeDefault.xctoolchain/usr/lib/' - sdk_dir = dev_path + 'Platforms/MacOSX.platform/Developer/SDKs' - libclang = lib_dir + 'libclang.dylib' - - if os.path.exists(libclang): - cindex.Config.set_library_path(os.path.dirname(libclang)) - - if os.path.exists(sdk_dir): - sysroot_dir = os.path.join(sdk_dir, next(os.walk(sdk_dir))[1][0]) - parameters.append('-isysroot') - parameters.append(sysroot_dir) - elif platform.system() == 'Linux': - # clang doesn't find its own base includes by default on Linux, - # but different distros install them in different paths. - # Try to autodetect, preferring the highest numbered version. - def clang_folder_version(d): - return [int(ver) for ver in re.findall(r'(?:${PYBIND11_CPP_STANDARD}>) - endif() - - get_property(_iid TARGET ${PN}::pybind11 PROPERTY INTERFACE_INCLUDE_DIRECTORIES) - get_property(_ill TARGET ${PN}::module PROPERTY INTERFACE_LINK_LIBRARIES) - set(${PN}_INCLUDE_DIRS ${_iid}) - set(${PN}_LIBRARIES ${_ico} ${_ill}) -endif() -endif() diff --git a/pybind11/tools/pybind11Tools.cmake b/pybind11/tools/pybind11Tools.cmake deleted file mode 100644 index c7156c0..0000000 --- a/pybind11/tools/pybind11Tools.cmake +++ /dev/null @@ -1,227 +0,0 @@ -# tools/pybind11Tools.cmake -- Build system for the pybind11 modules -# -# Copyright (c) 2015 Wenzel Jakob -# -# All rights reserved. Use of this source code is governed by a -# BSD-style license that can be found in the LICENSE file. - -cmake_minimum_required(VERSION 2.8.12) - -# Add a CMake parameter for choosing a desired Python version -if(NOT PYBIND11_PYTHON_VERSION) - set(PYBIND11_PYTHON_VERSION "" CACHE STRING "Python version to use for compiling modules") -endif() - -set(Python_ADDITIONAL_VERSIONS 3.7 3.6 3.5 3.4) -find_package(PythonLibsNew ${PYBIND11_PYTHON_VERSION} REQUIRED) - -include(CheckCXXCompilerFlag) -include(CMakeParseArguments) - -if(NOT PYBIND11_CPP_STANDARD AND NOT CMAKE_CXX_STANDARD) - if(NOT MSVC) - check_cxx_compiler_flag("-std=c++14" HAS_CPP14_FLAG) - - if (HAS_CPP14_FLAG) - set(PYBIND11_CPP_STANDARD -std=c++14) - else() - check_cxx_compiler_flag("-std=c++11" HAS_CPP11_FLAG) - if (HAS_CPP11_FLAG) - set(PYBIND11_CPP_STANDARD -std=c++11) - else() - message(FATAL_ERROR "Unsupported compiler -- pybind11 requires C++11 support!") - endif() - endif() - elseif(MSVC) - set(PYBIND11_CPP_STANDARD /std:c++14) - endif() - - set(PYBIND11_CPP_STANDARD ${PYBIND11_CPP_STANDARD} CACHE STRING - "C++ standard flag, e.g. -std=c++11, -std=c++14, /std:c++14. Defaults to C++14 mode." FORCE) -endif() - -# Checks whether the given CXX/linker flags can compile and link a cxx file. cxxflags and -# linkerflags are lists of flags to use. The result variable is a unique variable name for each set -# of flags: the compilation result will be cached base on the result variable. If the flags work, -# sets them in cxxflags_out/linkerflags_out internal cache variables (in addition to ${result}). -function(_pybind11_return_if_cxx_and_linker_flags_work result cxxflags linkerflags cxxflags_out linkerflags_out) - set(CMAKE_REQUIRED_LIBRARIES ${linkerflags}) - check_cxx_compiler_flag("${cxxflags}" ${result}) - if (${result}) - set(${cxxflags_out} "${cxxflags}" CACHE INTERNAL "" FORCE) - set(${linkerflags_out} "${linkerflags}" CACHE INTERNAL "" FORCE) - endif() -endfunction() - -# Internal: find the appropriate link time optimization flags for this compiler -function(_pybind11_add_lto_flags target_name prefer_thin_lto) - if (NOT DEFINED PYBIND11_LTO_CXX_FLAGS) - set(PYBIND11_LTO_CXX_FLAGS "" CACHE INTERNAL "") - set(PYBIND11_LTO_LINKER_FLAGS "" CACHE INTERNAL "") - - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") - set(cxx_append "") - set(linker_append "") - if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND NOT APPLE) - # Clang Gold plugin does not support -Os; append -O3 to MinSizeRel builds to override it - set(linker_append ";$<$:-O3>") - elseif(CMAKE_CXX_COMPILER_ID MATCHES "GNU") - set(cxx_append ";-fno-fat-lto-objects") - endif() - - if (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND prefer_thin_lto) - _pybind11_return_if_cxx_and_linker_flags_work(HAS_FLTO_THIN - "-flto=thin${cxx_append}" "-flto=thin${linker_append}" - PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS) - endif() - - if (NOT HAS_FLTO_THIN) - _pybind11_return_if_cxx_and_linker_flags_work(HAS_FLTO - "-flto${cxx_append}" "-flto${linker_append}" - PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS) - endif() - elseif (CMAKE_CXX_COMPILER_ID MATCHES "Intel") - # Intel equivalent to LTO is called IPO - _pybind11_return_if_cxx_and_linker_flags_work(HAS_INTEL_IPO - "-ipo" "-ipo" PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS) - elseif(MSVC) - # cmake only interprets libraries as linker flags when they start with a - (otherwise it - # converts /LTCG to \LTCG as if it was a Windows path). Luckily MSVC supports passing flags - # with - instead of /, even if it is a bit non-standard: - _pybind11_return_if_cxx_and_linker_flags_work(HAS_MSVC_GL_LTCG - "/GL" "-LTCG" PYBIND11_LTO_CXX_FLAGS PYBIND11_LTO_LINKER_FLAGS) - endif() - - if (PYBIND11_LTO_CXX_FLAGS) - message(STATUS "LTO enabled") - else() - message(STATUS "LTO disabled (not supported by the compiler and/or linker)") - endif() - endif() - - # Enable LTO flags if found, except for Debug builds - if (PYBIND11_LTO_CXX_FLAGS) - target_compile_options(${target_name} PRIVATE "$<$>:${PYBIND11_LTO_CXX_FLAGS}>") - endif() - if (PYBIND11_LTO_LINKER_FLAGS) - target_link_libraries(${target_name} PRIVATE "$<$>:${PYBIND11_LTO_LINKER_FLAGS}>") - endif() -endfunction() - -# Build a Python extension module: -# pybind11_add_module( [MODULE | SHARED] [EXCLUDE_FROM_ALL] -# [NO_EXTRAS] [SYSTEM] [THIN_LTO] source1 [source2 ...]) -# -function(pybind11_add_module target_name) - set(options MODULE SHARED EXCLUDE_FROM_ALL NO_EXTRAS SYSTEM THIN_LTO) - cmake_parse_arguments(ARG "${options}" "" "" ${ARGN}) - - if(ARG_MODULE AND ARG_SHARED) - message(FATAL_ERROR "Can't be both MODULE and SHARED") - elseif(ARG_SHARED) - set(lib_type SHARED) - else() - set(lib_type MODULE) - endif() - - if(ARG_EXCLUDE_FROM_ALL) - set(exclude_from_all EXCLUDE_FROM_ALL) - endif() - - add_library(${target_name} ${lib_type} ${exclude_from_all} ${ARG_UNPARSED_ARGUMENTS}) - - if(ARG_SYSTEM) - set(inc_isystem SYSTEM) - endif() - - target_include_directories(${target_name} ${inc_isystem} - PRIVATE ${PYBIND11_INCLUDE_DIR} # from project CMakeLists.txt - PRIVATE ${pybind11_INCLUDE_DIR} # from pybind11Config - PRIVATE ${PYTHON_INCLUDE_DIRS}) - - # Python debug libraries expose slightly different objects - # https://docs.python.org/3.6/c-api/intro.html#debugging-builds - # https://stackoverflow.com/questions/39161202/how-to-work-around-missing-pymodule-create2-in-amd64-win-python35-d-lib - if(PYTHON_IS_DEBUG) - target_compile_definitions(${target_name} PRIVATE Py_DEBUG) - endif() - - # The prefix and extension are provided by FindPythonLibsNew.cmake - set_target_properties(${target_name} PROPERTIES PREFIX "${PYTHON_MODULE_PREFIX}") - set_target_properties(${target_name} PROPERTIES SUFFIX "${PYTHON_MODULE_EXTENSION}") - - # -fvisibility=hidden is required to allow multiple modules compiled against - # different pybind versions to work properly, and for some features (e.g. - # py::module_local). We force it on everything inside the `pybind11` - # namespace; also turning it on for a pybind module compilation here avoids - # potential warnings or issues from having mixed hidden/non-hidden types. - set_target_properties(${target_name} PROPERTIES CXX_VISIBILITY_PRESET "hidden") - set_target_properties(${target_name} PROPERTIES CUDA_VISIBILITY_PRESET "hidden") - - if(WIN32 OR CYGWIN) - # Link against the Python shared library on Windows - target_link_libraries(${target_name} PRIVATE ${PYTHON_LIBRARIES}) - elseif(APPLE) - # It's quite common to have multiple copies of the same Python version - # installed on one's system. E.g.: one copy from the OS and another copy - # that's statically linked into an application like Blender or Maya. - # If we link our plugin library against the OS Python here and import it - # into Blender or Maya later on, this will cause segfaults when multiple - # conflicting Python instances are active at the same time (even when they - # are of the same version). - - # Windows is not affected by this issue since it handles DLL imports - # differently. The solution for Linux and Mac OS is simple: we just don't - # link against the Python library. The resulting shared library will have - # missing symbols, but that's perfectly fine -- they will be resolved at - # import time. - - target_link_libraries(${target_name} PRIVATE "-undefined dynamic_lookup") - - if(ARG_SHARED) - # Suppress CMake >= 3.0 warning for shared libraries - set_target_properties(${target_name} PROPERTIES MACOSX_RPATH ON) - endif() - endif() - - # Make sure C++11/14 are enabled - if(CMAKE_VERSION VERSION_LESS 3.3) - target_compile_options(${target_name} PUBLIC ${PYBIND11_CPP_STANDARD}) - else() - target_compile_options(${target_name} PUBLIC $<$:${PYBIND11_CPP_STANDARD}>) - endif() - - if(ARG_NO_EXTRAS) - return() - endif() - - _pybind11_add_lto_flags(${target_name} ${ARG_THIN_LTO}) - - if (NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES Debug|RelWithDebInfo) - # Strip unnecessary sections of the binary on Linux/Mac OS - if(CMAKE_STRIP) - if(APPLE) - add_custom_command(TARGET ${target_name} POST_BUILD - COMMAND ${CMAKE_STRIP} -x $) - else() - add_custom_command(TARGET ${target_name} POST_BUILD - COMMAND ${CMAKE_STRIP} $) - endif() - endif() - endif() - - if(MSVC) - # /MP enables multithreaded builds (relevant when there are many files), /bigobj is - # needed for bigger binding projects due to the limit to 64k addressable sections - target_compile_options(${target_name} PRIVATE /bigobj) - if(CMAKE_VERSION VERSION_LESS 3.11) - target_compile_options(${target_name} PRIVATE $<$>:/MP>) - else() - # Only set these options for C++ files. This is important so that, for - # instance, projects that include other types of source files like CUDA - # .cu files don't get these options propagated to nvcc since that would - # cause the build to fail. - target_compile_options(${target_name} PRIVATE $<$>:$<$:/MP>>) - endif() - endif() -endfunction() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1508fa9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,9 @@ + +[build-system] +requires = [ + "setuptools>=42", + "wheel", + "pybind11>=2.9.1", +] + +build-backend = "setuptools.build_meta" diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..0147f31 --- /dev/null +++ b/setup.py @@ -0,0 +1,44 @@ +import os +import platform +from glob import glob +from setuptools import setup +from distutils.sysconfig import get_python_inc +from pybind11.setup_helpers import Pybind11Extension, build_ext + +__version__ = "4.0.0" + +include_dirs = [get_python_inc(), 'pybind11/include', 'SEAL/native/src', 'SEAL/build/native/src'] + +extra_objects = sorted(glob('SEAL/build/lib/*.lib') if platform.system() == "Windows" else glob('SEAL/build/lib/*.a')) + +cpp_args = ['/std:c++latest'] if platform.system() == "Windows" else ['-std=c++17'] + +if len(extra_objects) < 1 or not os.path.exists(extra_objects[0]): + print('Not found the seal lib file, check the `SEAL/build/lib`') + exit(0) + +ext_modules = [ + Pybind11Extension( + "seal", + sorted(glob('src/*.cpp')), + include_dirs=include_dirs, + extra_compile_args=cpp_args, + extra_objects=extra_objects, + define_macros = [('VERSION_INFO', __version__)], + ), +] + +setup( + name="seal", + version=__version__, + author="Huelse", + author_email="topmaxz@protonmail.com", + url="https://github.com/Huelse/SEAL-Python", + description="Python wrapper for the Microsoft SEAL", + long_description="", + ext_modules=ext_modules, + cmdclass={"build_ext": build_ext}, + zip_safe=False, + license='MIT', + python_requires=">=3.6", +) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt deleted file mode 100644 index 405bf19..0000000 --- a/src/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -cmake_minimum_required(VERSION 3.12.0) -project(seal) -set(CMAKE_BUILD_TYPE "Release") -set(PYBIND11_CPP_STANDARD -std=c++1z) -set(LIBRARY_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/../tests) - -include_directories("/usr/include/python3.6") -include_directories("../pybind11/include") -include_directories("../SEAL/native/src") - -LINK_LIBRARIES("../SEAL/native/lib/libseal.a") - -find_package(pybind11 REQUIRED) - -pybind11_add_module(seal wrapper.cpp) -target_sources(seal - PRIVATE - base64.cpp -) \ No newline at end of file diff --git a/src/base64.cpp b/src/base64.cpp deleted file mode 100644 index 252224b..0000000 --- a/src/base64.cpp +++ /dev/null @@ -1,112 +0,0 @@ -/* - base64.cpp and base64.h - base64 encoding and decoding with C++. - Version: 1.01.00 - Copyright (C) 2004-2017 René Nyffenegger - This source code is provided 'as-is', without any express or implied - warranty. In no event will the author be held liable for any damages - arising from the use of this software. - Permission is granted to anyone to use this software for any purpose, - including commercial applications, and to alter it and redistribute it - freely, subject to the following restrictions: - 1. The origin of this source code must not be misrepresented; you must not - claim that you wrote the original source code. If you use this source code - in a product, an acknowledgment in the product documentation would be - appreciated but is not required. - 2. Altered source versions must be plainly marked as such, and must not be - misrepresented as being the original source code. - 3. This notice may not be removed or altered from any source distribution. - René Nyffenegger rene.nyffenegger@adp-gmbh.ch -*/ - -#include "base64.h" -#include - -static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; - - -static inline bool is_base64(unsigned char c) { - return (isalnum(c) || (c == '+') || (c == '/')); -} - -std::string base64_encode(unsigned char const* bytes_to_encode, unsigned int in_len) { - std::string ret; - int i = 0; - int j = 0; - unsigned char char_array_3[3]; - unsigned char char_array_4[4]; - - while (in_len--) { - char_array_3[i++] = *(bytes_to_encode++); - if (i == 3) { - char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); - char_array_4[3] = char_array_3[2] & 0x3f; - - for(i = 0; (i <4) ; i++) - ret += base64_chars[char_array_4[i]]; - i = 0; - } - } - - if (i) - { - for(j = i; j < 3; j++) - char_array_3[j] = '\0'; - - char_array_4[0] = ( char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); - - for (j = 0; (j < i + 1); j++) - ret += base64_chars[char_array_4[j]]; - - while((i++ < 3)) - ret += '='; - - } - - return ret; - -} - -std::string base64_decode(std::string const& encoded_string) { - int in_len = encoded_string.size(); - int i = 0; - int j = 0; - int in_ = 0; - unsigned char char_array_4[4], char_array_3[3]; - std::string ret; - - while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i ==4) { - for (i = 0; i <4; i++) - char_array_4[i] = base64_chars.find(char_array_4[i]); - - char_array_3[0] = ( char_array_4[0] << 2 ) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (i = 0; (i < 3); i++) - ret += char_array_3[i]; - i = 0; - } - } - - if (i) { - for (j = 0; j < i; j++) - char_array_4[j] = base64_chars.find(char_array_4[j]); - - char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - - for (j = 0; (j < i - 1); j++) ret += char_array_3[j]; - } - - return ret; -} diff --git a/src/base64.h b/src/base64.h deleted file mode 100644 index dd1134c..0000000 --- a/src/base64.h +++ /dev/null @@ -1,14 +0,0 @@ -// -// base64 encoding and decoding with C++. -// Version: 1.01.00 -// - -#ifndef BASE64_H_C0CE2A47_D10E_42C9_A27C_C883944E704A -#define BASE64_H_C0CE2A47_D10E_42C9_A27C_C883944E704A - -#include - -std::string base64_encode(unsigned char const* , unsigned int len); -std::string base64_decode(std::string const& s); - -#endif /* BASE64_H_C0CE2A47_D10E_42C9_A27C_C883944E704A */ diff --git a/src/requirements.txt b/src/requirements.txt deleted file mode 100644 index 96edbc5..0000000 --- a/src/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -pytest -pybind11 -numpy diff --git a/src/setup.py b/src/setup.py deleted file mode 100644 index 2f98163..0000000 --- a/src/setup.py +++ /dev/null @@ -1,45 +0,0 @@ -import os -import sys -import platform -from distutils.core import setup, Extension -from distutils.sysconfig import get_python_inc - - -# python include dir -incdir = os.path.join(get_python_inc()) -# cpp flags -cpp_args = ['-std=c++17'] -# include directories -include_dirs = [incdir, '../pybind11/include', '../SEAL/native/src'] -# library path -extra_objects = ['../SEAL/native/lib/libseal.a'] - -if(platform.system() == "Windows"): - cpp_args[0] = '/std:c++latest' - extra_objects[0] = '../SEAL/native/lib/x64/Release/seal.lib' - -if not os.path.exists(extra_objects[0]): - print('Can not find the seal lib') - exit(1) - -ext_modules = [ - Extension( - name='seal', - sources=['base64.cpp', 'wrapper.cpp'], - include_dirs=include_dirs, - language='c++', - extra_compile_args=cpp_args, - extra_objects=extra_objects, - ), -] - -setup( - name='seal', - version='3.3.2', - author='Huelse', - author_email='huelse@oini.top', - description='Python wrapper for SEAL', - url="https://github.com/Huelse/SEAL-Python", - license="MIT", - ext_modules=ext_modules, -) diff --git a/src/wrapper.cpp b/src/wrapper.cpp index e1039d2..270b7d4 100644 --- a/src/wrapper.cpp +++ b/src/wrapper.cpp @@ -1,328 +1,687 @@ #include -#include +#include #include #include -#include #include "seal/seal.h" -#include "base64.h" - -namespace py = pybind11; +#include -using namespace std; using namespace seal; +namespace py = pybind11; -using pt_coeff_type = std::uint64_t; -using size_type = IntArray::size_type; -using parms_id_type = std::array; - -PYBIND11_MAKE_OPAQUE(std::vector>); PYBIND11_MAKE_OPAQUE(std::vector); -PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MAKE_OPAQUE(std::vector); -template -py::tuple serialize(T &c) -{ - std::stringstream output(std::ios::binary | std::ios::out); - c.save(output); - std::string cipherstr = output.str(); - std::string base64_encoded_cipher = base64_encode(reinterpret_cast(cipherstr.c_str()), (unsigned int)cipherstr.length()); - return py::make_tuple(base64_encoded_cipher); -} - -template -T deserialize(py::tuple t) -{ - if (t.size() != 1) - throw std::runtime_error("(Pickle) Invalid input tuple!"); - T c = T(); - std::string cipherstr_encoded = t[0].cast(); - std::string cipherstr_decoded = base64_decode(cipherstr_encoded); - std::stringstream input(std::ios::binary | std::ios::in); - input.str(cipherstr_decoded); - c.unsafe_load(input); - return c; -} - PYBIND11_MODULE(seal, m) { - - m.doc() = "Microsoft SEAL(3.3.2) For Python. From https://github.com/Huelse/SEAL-Python"; - - py::bind_vector>>(m, "ComplexDoubleVector", py::buffer_protocol()); - py::bind_vector>(m, "DoubleVector", py::buffer_protocol()); - py::bind_vector>(m, "uIntVector", py::buffer_protocol()); - py::bind_vector>(m, "IntVector", py::buffer_protocol()); - - // BigUInt - py::class_(m, "BigUInt") - .def(py::init<>()) - .def("to_double", &BigUInt::to_double) - .def("significant_bit_count", (int (BigUInt::*)()) & BigUInt::significant_bit_count); - - // EncryptionParameters - py::class_(m, "EncryptionParameters") - .def(py::init()) - .def("set_poly_modulus_degree", - (void (EncryptionParameters::*)(std::uint64_t)) & EncryptionParameters::set_poly_modulus_degree) - .def("set_coeff_modulus", - (void (EncryptionParameters::*)(const std::vector &)) & EncryptionParameters::set_coeff_modulus) - .def("set_plain_modulus", - (void (EncryptionParameters::*)(const SmallModulus &)) & EncryptionParameters::set_plain_modulus) - .def("set_plain_modulus", - (void (EncryptionParameters::*)(std::uint64_t)) & EncryptionParameters::set_plain_modulus) - .def("scheme", &EncryptionParameters::scheme) - .def("poly_modulus_degree", &EncryptionParameters::poly_modulus_degree) - .def("coeff_modulus", &EncryptionParameters::coeff_modulus) - .def("plain_modulus", &EncryptionParameters::plain_modulus); - - // scheme_type - py::enum_(m, "scheme_type", py::arithmetic()) - .value("BFV", scheme_type::BFV) - .value("CKKS", scheme_type::CKKS); - - // sec_level_type - py::enum_(m, "sec_level_type", py::arithmetic()) - .value("none", sec_level_type::none) - .value("tc128", sec_level_type::tc128) - .value("tc192", sec_level_type::tc192) - .value("tc256", sec_level_type::tc256); - - // EncryptionParameterQualifiers - py::class_>(m, "EncryptionParameterQualifiers") - .def_readwrite("parameters_set", &EncryptionParameterQualifiers::parameters_set) - .def_readwrite("using_fft", &EncryptionParameterQualifiers::using_fft) - .def_readwrite("using_ntt", &EncryptionParameterQualifiers::using_ntt) - .def_readwrite("using_batching", &EncryptionParameterQualifiers::using_batching) - .def_readwrite("using_fast_plain_lift", &EncryptionParameterQualifiers::using_fast_plain_lift) - .def_readwrite("using_descending_modulus_chain", &EncryptionParameterQualifiers::using_descending_modulus_chain) - .def_readwrite("sec_level", &EncryptionParameterQualifiers::sec_level); - - // SEALContext - py::class_>(m, "SEALContext") - .def("Create", [](const EncryptionParameters &parms) { return SEALContext::Create(parms); }) - .def("Create", [](const EncryptionParameters &parms, bool expand_mod_chain) { return SEALContext::Create(parms, expand_mod_chain); }) - .def("get_context_data", &SEALContext::get_context_data, py::return_value_policy::reference) - .def("key_context_data", &SEALContext::key_context_data, py::return_value_policy::reference) - .def("first_context_data", &SEALContext::first_context_data, py::return_value_policy::reference) - .def("first_parms_id", &SEALContext::first_parms_id, py::return_value_policy::reference) - .def("last_parms_id", &SEALContext::last_parms_id, py::return_value_policy::reference) - .def("using_keyswitching", &SEALContext::using_keyswitching); - - // SEALContext::ContextData - py::class_>(m, "SEALContext::ContextData") - .def("parms", &SEALContext::ContextData::parms) - .def("parms_id", &SEALContext::ContextData::parms_id) - .def("qualifiers", &SEALContext::ContextData::qualifiers) - .def("total_coeff_modulus", - (std::uint64_t(SEALContext::ContextData::*)()) & SEALContext::ContextData::total_coeff_modulus) - .def("total_coeff_modulus_bit_count", &SEALContext::ContextData::total_coeff_modulus_bit_count) - .def("next_context_data", &SEALContext::ContextData::next_context_data) - .def("chain_index", &SEALContext::ContextData::chain_index); - - // SmallModulus - py::class_(m, "SmallModulus") - .def(py::init<>()) - .def(py::init()) - .def("bit_count", &SmallModulus::bit_count) - .def("value", (std::uint64_t(SmallModulus::*)()) & SmallModulus::value); - - // CoeffModulus - py::class_(m, "CoeffModulus") - .def("BFVDefault", - [](std::size_t poly_modulus_degree) { return CoeffModulus::BFVDefault(poly_modulus_degree); }) - .def("Create", - [](std::size_t poly_modulus_degree, std::vector bit_sizes) { return CoeffModulus::Create(poly_modulus_degree, bit_sizes); }) - .def("MaxBitCount", - [](std::size_t poly_modulus_degree) { return CoeffModulus::MaxBitCount(poly_modulus_degree); }); - - // PlainModulus - py::class_(m, "PlainModulus") - .def("Batching", - [](std::size_t poly_modulus_degree, int bit_size) { return PlainModulus::Batching(poly_modulus_degree, bit_size); }) - .def("Batching", - [](std::size_t poly_modulus_degree, std::vector bit_sizes) { return PlainModulus::Batching(poly_modulus_degree, bit_sizes); }); - - // SecretKey - py::class_(m, "SecretKey") - .def(py::init<>()) - .def("parms_id", (parms_id_type & (SecretKey::*)()) & SecretKey::parms_id, py::return_value_policy::reference) - .def("save", &SecretKey::python_save) - .def("load", &SecretKey::python_load) - .def(py::pickle(&serialize, &deserialize)); - - // PublicKey - py::class_(m, "PublicKey") - .def(py::init<>()) - .def("parms_id", (parms_id_type & (PublicKey::*)()) & PublicKey::parms_id, py::return_value_policy::reference) - .def("save", &PublicKey::python_save) - .def("load", &PublicKey::python_load) - .def(py::pickle(&serialize, &deserialize)); - - // KSwitchKeys - py::class_(m, "KSwitchKeys") - .def(py::init<>()) - .def("parms_id", (parms_id_type & (KSwitchKeys::*)()) & KSwitchKeys::parms_id, py::return_value_policy::reference) - .def("save", &KSwitchKeys::python_save) - .def("load", &KSwitchKeys::python_load) - .def(py::pickle(&serialize, &deserialize)); - - // RelinKeys - py::class_(m, "RelinKeys") - .def(py::init<>()) - .def("parms_id", (parms_id_type & (RelinKeys::KSwitchKeys::*)()) & RelinKeys::KSwitchKeys::parms_id, py::return_value_policy::reference) - .def("save", &KSwitchKeys::python_save) - .def("load", &KSwitchKeys::python_load) - .def(py::pickle(&serialize, &deserialize)); - - // GaloisKeys - py::class_(m, "GaloisKeys") - .def(py::init<>()) - .def("parms_id", (parms_id_type & (GaloisKeys::KSwitchKeys::*)()) & GaloisKeys::KSwitchKeys::parms_id, py::return_value_policy::reference) - .def("save", &KSwitchKeys::python_save) - .def("load", &KSwitchKeys::python_load) - .def(py::pickle(&serialize, &deserialize)); - - // KeyGenerator - py::class_(m, "KeyGenerator") - .def(py::init>()) - .def(py::init, const SecretKey &>()) - .def(py::init, const SecretKey &, const PublicKey &>()) - .def("secret_key", &KeyGenerator::secret_key) - .def("public_key", &KeyGenerator::public_key) - .def("galois_keys", (GaloisKeys(KeyGenerator::*)(const std::vector &)) & KeyGenerator::galois_keys) - .def("galois_keys", (GaloisKeys(KeyGenerator::*)(const std::vector &)) & KeyGenerator::galois_keys) - .def("galois_keys", (GaloisKeys(KeyGenerator::*)()) & KeyGenerator::galois_keys) - .def("relin_keys", (RelinKeys(KeyGenerator::*)()) & KeyGenerator::relin_keys); - - // MemoryPoolHandle - py::class_(m, "MemoryPoolHandle") - .def(py::init<>()) - .def(py::init>()) - .def_static("New", &MemoryPoolHandle::New) - .def_static("Global", &MemoryPoolHandle::Global); - - // MemoryManager - py::class_(m, "MemoryManager") - .def("GetPool", []() { return MemoryManager::GetPool(); }); - - // Ciphertext - py::class_(m, "Ciphertext") - .def(py::init<>()) - .def(py::init>()) - .def(py::init, parms_id_type>()) - .def(py::init()) - .def("size", &Ciphertext::size) - .def("scale", (double &(Ciphertext::*)()) & Ciphertext::scale) - .def("reserve", (void (Ciphertext::*)(size_type)) & Ciphertext::reserve) - .def("set_scale", (void (Ciphertext::*)(double)) & Ciphertext::set_scale) - .def("parms_id", (parms_id_type & (Ciphertext::*)()) & Ciphertext::parms_id) - .def("save", &Ciphertext::python_save) - .def("load", &Ciphertext::python_load) - .def(py::pickle(&serialize, &deserialize)); - - // Plaintext - py::class_(m, "Plaintext") - .def(py::init<>()) - .def(py::init<size_type>()) - .def(py::init<size_type, size_type>()) - .def(py::init<const std::string &>()) - .def(py::init<const Plaintext &>()) - .def("data", (pt_coeff_type * (Plaintext::*)(size_type)) & Plaintext::data) - .def("significant_coeff_count", &Plaintext::significant_coeff_count) - .def("to_string", &Plaintext::to_string) - .def("coeff_count", &Plaintext::coeff_count) - .def("save", (void (Plaintext::*)(std::ostream &)) & Plaintext::save) - .def("load", (void (Plaintext::*)(std::shared_ptr<SEALContext>, std::istream &)) & Plaintext::load) - .def("scale", (double &(Plaintext::*)()) & Plaintext::scale) - .def("parms_id", (parms_id_type & (Plaintext::*)()) & Plaintext::parms_id) - .def("save", &Plaintext::python_save) - .def("load", &Plaintext::python_load) - .def(py::pickle(&serialize<Plaintext>, &deserialize<Plaintext>)); - - // Encryptor - py::class_<Encryptor>(m, "Encryptor") - .def(py::init<std::shared_ptr<SEALContext>, const PublicKey &>()) - .def("encrypt", (void (Encryptor::*)(const Plaintext &, Ciphertext &, MemoryPoolHandle)) & Encryptor::encrypt, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()); - - // Evaluator - py::class_<Evaluator>(m, "Evaluator") - .def(py::init<std::shared_ptr<SEALContext>>()) - .def("negate_inplace", (void (Evaluator::*)(Ciphertext &)) & Evaluator::negate_inplace) - .def("negate", (void (Evaluator::*)(const Ciphertext &, Ciphertext &)) & Evaluator::negate) - .def("add_inplace", (void (Evaluator::*)(Ciphertext &, const Ciphertext &)) & Evaluator::add_inplace) - .def("add", (void (Evaluator::*)(const Ciphertext &, const Ciphertext &, Ciphertext &)) & Evaluator::add) - .def("add_many", (void (Evaluator::*)(const std::vector<Ciphertext> &, Ciphertext &)) & Evaluator::add_many) - .def("sub_inplace", (void (Evaluator::*)(Ciphertext &, const Ciphertext &)) & Evaluator::sub_inplace) - .def("sub", (void (Evaluator::*)(const Ciphertext &, const Ciphertext &, Ciphertext &)) & Evaluator::sub) - .def("multiply_inplace", (void (Evaluator::*)(Ciphertext &, const Ciphertext &, MemoryPoolHandle)) & Evaluator::multiply_inplace, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("multiply", (void (Evaluator::*)(Ciphertext &, const Ciphertext &, Ciphertext &, MemoryPoolHandle)) & Evaluator::multiply, py::arg(), py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("square_inplace", (void (Evaluator::*)(Ciphertext &, MemoryPoolHandle)) & Evaluator::square_inplace, py::arg(), py::arg() = MemoryManager::GetPool()) - .def("square", (void (Evaluator::*)(const Ciphertext &, Ciphertext &, MemoryPoolHandle)) & Evaluator::square, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("relinearize_inplace", (void (Evaluator::*)(Ciphertext &, const RelinKeys &, MemoryPoolHandle)) & Evaluator::relinearize_inplace, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("relinearize", (void (Evaluator::*)(const Ciphertext &, const RelinKeys &, Ciphertext &)) & Evaluator::relinearize) - .def("mod_switch_to_next_inplace", (void (Evaluator::*)(Ciphertext &, MemoryPoolHandle)) & Evaluator::mod_switch_to_next_inplace, py::arg(), py::arg() = MemoryManager::GetPool()) - .def("mod_switch_to_next_inplace", (void (Evaluator::*)(Plaintext &)) & Evaluator::mod_switch_to_next_inplace) - .def("mod_switch_to_inplace", (void (Evaluator::*)(Ciphertext &, parms_id_type, MemoryPoolHandle)) & Evaluator::mod_switch_to_inplace, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("mod_switch_to_inplace", (void (Evaluator::*)(Plaintext &, parms_id_type)) & Evaluator::mod_switch_to_inplace) - .def("rescale_to_next_inplace", (void (Evaluator::*)(Ciphertext &, MemoryPoolHandle)) & Evaluator::rescale_to_next_inplace, py::arg(), py::arg() = MemoryManager::GetPool()) - .def("multiply_many", (void (Evaluator::*)(std::vector<Ciphertext> &, const RelinKeys &, Ciphertext &)) & Evaluator::multiply_many) - .def("exponentiate_inplace", (void (Evaluator::*)(Ciphertext &, std::uint64_t, const RelinKeys &)) & Evaluator::exponentiate_inplace) - .def("exponentiate", (void (Evaluator::*)(const Ciphertext &, std::uint64_t, const RelinKeys &, Ciphertext &)) & Evaluator::exponentiate) - .def("add_plain_inplace", (void (Evaluator::*)(Ciphertext &, const Plaintext &)) & Evaluator::add_plain_inplace) - .def("add_plain", (void (Evaluator::*)(const Ciphertext &, const Plaintext &, Ciphertext &)) & Evaluator::add_plain) - .def("sub_plain_inplace", (void (Evaluator::*)(Ciphertext &, const Plaintext &)) & Evaluator::sub_plain_inplace) - .def("sub_plain", (void (Evaluator::*)(const Ciphertext &, const Plaintext &)) & Evaluator::sub_plain) - .def("multiply_plain_inplace", (void (Evaluator::*)(Ciphertext &, const Plaintext &, MemoryPoolHandle)) & Evaluator::multiply_plain_inplace, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("multiply_plain", (void (Evaluator::*)(const Ciphertext &, const Plaintext &, Ciphertext &, MemoryPoolHandle)) & Evaluator::multiply_plain, py::arg(), py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("transform_to_ntt_inplace", (void (Evaluator::*)(Ciphertext &)) & Evaluator::transform_to_ntt_inplace) - .def("transform_to_ntt", (void (Evaluator::*)(const Ciphertext &, Ciphertext &)) & Evaluator::transform_to_ntt) - .def("transform_from_ntt_inplace", (void (Evaluator::*)(Ciphertext &)) & Evaluator::transform_from_ntt_inplace) - .def("transform_from_ntt", (void (Evaluator::*)(const Ciphertext &, Ciphertext &)) & Evaluator::transform_from_ntt) - .def("apply_galois_inplace", (void (Evaluator::*)(Ciphertext &, std::uint64_t, const GaloisKeys &)) & Evaluator::apply_galois_inplace) - .def("apply_galois", (void (Evaluator::*)(const Ciphertext &, std::uint64_t, const GaloisKeys &, Ciphertext &)) & Evaluator::apply_galois) - .def("rotate_rows_inplace", (void (Evaluator::*)(Ciphertext &, int, GaloisKeys, MemoryPoolHandle)) & Evaluator::rotate_rows_inplace, py::arg(), py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("rotate_rows", (void (Evaluator::*)(const Ciphertext &, int, const GaloisKeys &, Ciphertext &)) & Evaluator::rotate_rows) - .def("rotate_columns_inplace", (void (Evaluator::*)(Ciphertext &, const GaloisKeys &, MemoryPoolHandle)) & Evaluator::rotate_columns_inplace, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("rotate_columns", (void (Evaluator::*)(const Ciphertext &, const GaloisKeys &, Ciphertext &)) & Evaluator::rotate_columns) - .def("rotate_vector_inplace", (void (Evaluator::*)(Ciphertext &, int, const GaloisKeys &, MemoryPoolHandle)) & Evaluator::rotate_vector_inplace, py::arg(), py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("rotate_vector", (void (Evaluator::*)(const Ciphertext &, int, const GaloisKeys &, Ciphertext &, MemoryPoolHandle)) & Evaluator::rotate_vector, py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("complex_conjugate_inplace", (void (Evaluator::*)(Ciphertext &, const GaloisKeys &, MemoryPoolHandle)) & Evaluator::complex_conjugate_inplace, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("complex_conjugate", (void (Evaluator::*)(const Ciphertext &, const GaloisKeys &, Ciphertext &)) & Evaluator::complex_conjugate); - - // CKKSEncoder - py::class_<CKKSEncoder>(m, "CKKSEncoder") - .def(py::init<std::shared_ptr<SEALContext>>()) - .def("encode", (void (CKKSEncoder::*)(const std::vector<double> &, double, Plaintext &, MemoryPoolHandle)) & CKKSEncoder::encode, py::arg(), py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("encode", (void (CKKSEncoder::*)(const std::vector<std::complex<double>> &, double, Plaintext &, MemoryPoolHandle)) & CKKSEncoder::encode, py::arg(), py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("encode", (void (CKKSEncoder::*)(double, parms_id_type, double, Plaintext &, MemoryPoolHandle)) & CKKSEncoder::encode, py::arg(), py::arg(), py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("encode", (void (CKKSEncoder::*)(double, double, Plaintext &, MemoryPoolHandle)) & CKKSEncoder::encode, py::arg(), py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("encode", (void (CKKSEncoder::*)(std::int64_t, Plaintext &)) & CKKSEncoder::encode) - .def("decode", (void (CKKSEncoder::*)(const Plaintext &, std::vector<double> &, MemoryPoolHandle)) & CKKSEncoder::decode, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("decode", (void (CKKSEncoder::*)(const Plaintext &, std::vector<std::complex<double>> &, MemoryPoolHandle)) & CKKSEncoder::decode, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("slot_count", &CKKSEncoder::slot_count); - - // Decryptor - py::class_<Decryptor>(m, "Decryptor") - .def(py::init<std::shared_ptr<SEALContext>, const SecretKey &>()) - .def("decrypt", (int (Decryptor::*)(const Ciphertext &, Plaintext &)) & Decryptor::decrypt) - .def("invariant_noise_budget", (int (Decryptor::*)(const Ciphertext &)) & Decryptor::invariant_noise_budget); - - // IntegerEncoder - py::class_<IntegerEncoder>(m, "IntegerEncoder") - .def(py::init<std::shared_ptr<SEALContext>>()) - .def("encode", (Plaintext(IntegerEncoder::*)(std::uint64_t)) & IntegerEncoder::encode) - .def("encode", (Plaintext(IntegerEncoder::*)(std::int64_t)) & IntegerEncoder::encode) - .def("encode", (void (IntegerEncoder::*)(std::uint64_t, Plaintext &)) & IntegerEncoder::encode) - .def("decode_int32", (std::int32_t(IntegerEncoder::*)(const Plaintext &)) & IntegerEncoder::decode_int32); - - // BatchEncoder - py::class_<BatchEncoder>(m, "BatchEncoder") - .def(py::init<std::shared_ptr<SEALContext>>()) - .def("encode", (void (BatchEncoder::*)(const std::vector<std::uint64_t> &, Plaintext &)) & BatchEncoder::encode) - .def("encode", (void (BatchEncoder::*)(const std::vector<std::int64_t> &, Plaintext &)) & BatchEncoder::encode) - .def("decode", (void (BatchEncoder::*)(const Plaintext &, std::vector<std::uint64_t> &, MemoryPoolHandle)) & BatchEncoder::decode, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("decode", (void (BatchEncoder::*)(const Plaintext &, std::vector<std::int64_t> &, MemoryPoolHandle)) & BatchEncoder::decode, py::arg(), py::arg(), py::arg() = MemoryManager::GetPool()) - .def("slot_count", &BatchEncoder::slot_count); + m.doc() = "Microsoft SEAL for Python, from https://github.com/Huelse/SEAL-Python"; + m.attr("__version__") = "4.0.0"; + + py::bind_vector<std::vector<double>>(m, "VectorDouble", py::buffer_protocol()); + py::bind_vector<std::vector<std::int64_t>>(m, "VectorInt", py::buffer_protocol()); + + // encryptionparams.h + py::enum_<scheme_type>(m, "scheme_type") + .value("none", scheme_type::none) + .value("bfv", scheme_type::bfv) + .value("ckks", scheme_type::ckks) + .value("bgv", scheme_type::bgv); + + // encryptionparams.h + py::class_<EncryptionParameters>(m, "EncryptionParameters") + .def(py::init<scheme_type>()) + .def(py::init<EncryptionParameters>()) + .def("set_poly_modulus_degree", &EncryptionParameters::set_poly_modulus_degree) + .def("set_coeff_modulus", &EncryptionParameters::set_coeff_modulus) + .def("set_plain_modulus", py::overload_cast<const Modulus &>(&EncryptionParameters::set_plain_modulus)) + .def("set_plain_modulus", py::overload_cast<std::uint64_t>(&EncryptionParameters::set_plain_modulus)) + .def("scheme", &EncryptionParameters::scheme) + .def("poly_modulus_degree", &EncryptionParameters::poly_modulus_degree) + .def("coeff_modulus", &EncryptionParameters::coeff_modulus) + .def("plain_modulus", &EncryptionParameters::plain_modulus) + .def("save", [](const EncryptionParameters &parms, std::string &path){ + std::ofstream out(path, std::ios::binary); + parms.save(out); + out.close(); + }) + .def("load", [](EncryptionParameters &parms, std::string &path){ + std::ifstream in(path, std::ios::binary); + parms.load(in); + in.close(); + }) + .def(py::pickle( + [](const EncryptionParameters &parms){ + std::stringstream out(std::ios::binary | std::ios::out); + parms.save(out); + return py::make_tuple(py::bytes(out.str())); + }, + [](py::tuple t){ + if (t.size() != 1) + throw std::runtime_error("(Pickle) Invalid input tuple!"); + std::string str = t[0].cast<std::string>(); + std::stringstream in(std::ios::binary | std::ios::in); + EncryptionParameters parms; + parms.load(in); + return parms; + } + )); + + // modulus.h + py::enum_<sec_level_type>(m, "sec_level_type") + .value("none", sec_level_type::none) + .value("tc128", sec_level_type::tc128) + .value("tc192", sec_level_type::tc192) + .value("tc256", sec_level_type::tc256); + + // context.h + py::enum_<EncryptionParameterQualifiers::error_type>(m, "error_type") + .value("none", EncryptionParameterQualifiers::error_type::none) + .value("success", EncryptionParameterQualifiers::error_type::success) + .value("invalid_scheme", EncryptionParameterQualifiers::error_type::invalid_scheme) + .value("invalid_coeff_modulus_size", EncryptionParameterQualifiers::error_type::invalid_coeff_modulus_size) + .value("invalid_coeff_modulus_bit_count", EncryptionParameterQualifiers::error_type::invalid_coeff_modulus_bit_count) + .value("invalid_coeff_modulus_no_ntt", EncryptionParameterQualifiers::error_type::invalid_coeff_modulus_no_ntt) + .value("invalid_poly_modulus_degree", EncryptionParameterQualifiers::error_type::invalid_poly_modulus_degree) + .value("invalid_poly_modulus_degree_non_power_of_two", EncryptionParameterQualifiers::error_type::invalid_poly_modulus_degree_non_power_of_two) + .value("invalid_parameters_too_large", EncryptionParameterQualifiers::error_type::invalid_parameters_too_large) + .value("invalid_parameters_insecure", EncryptionParameterQualifiers::error_type::invalid_parameters_insecure) + .value("failed_creating_rns_base", EncryptionParameterQualifiers::error_type::failed_creating_rns_base) + .value("invalid_plain_modulus_bit_count", EncryptionParameterQualifiers::error_type::invalid_plain_modulus_bit_count) + .value("invalid_plain_modulus_coprimality", EncryptionParameterQualifiers::error_type::invalid_plain_modulus_coprimality) + .value("invalid_plain_modulus_too_large", EncryptionParameterQualifiers::error_type::invalid_plain_modulus_too_large) + .value("invalid_plain_modulus_nonzero", EncryptionParameterQualifiers::error_type::invalid_plain_modulus_nonzero) + .value("failed_creating_rns_tool", EncryptionParameterQualifiers::error_type::failed_creating_rns_tool); + + // context.h + py::class_<EncryptionParameterQualifiers, std::unique_ptr<EncryptionParameterQualifiers, py::nodelete>>(m, "EncryptionParameterQualifiers") + .def("parameters_set", &EncryptionParameterQualifiers::parameters_set) + .def_readwrite("using_fft", &EncryptionParameterQualifiers::using_fft) + .def_readwrite("using_ntt", &EncryptionParameterQualifiers::using_ntt) + .def_readwrite("using_batching", &EncryptionParameterQualifiers::using_batching) + .def_readwrite("using_fast_plain_lift", &EncryptionParameterQualifiers::using_fast_plain_lift) + .def_readwrite("using_descending_modulus_chain", &EncryptionParameterQualifiers::using_descending_modulus_chain) + .def_readwrite("sec_level", &EncryptionParameterQualifiers::sec_level); + + // context.h + py::class_<SEALContext::ContextData, std::shared_ptr<SEALContext::ContextData>>(m, "ContextData") + .def("parms", &SEALContext::ContextData::parms) + .def("parms_id", &SEALContext::ContextData::parms_id) + .def("qualifiers", &SEALContext::ContextData::qualifiers) + .def("total_coeff_modulus", &SEALContext::ContextData::total_coeff_modulus) + .def("total_coeff_modulus_bit_count", &SEALContext::ContextData::total_coeff_modulus_bit_count) + .def("next_context_data", &SEALContext::ContextData::next_context_data) + .def("chain_index", &SEALContext::ContextData::chain_index); + + // context.h + py::class_<SEALContext, std::shared_ptr<SEALContext>>(m, "SEALContext") + .def(py::init<const EncryptionParameters &, bool, sec_level_type>(), py::arg(), py::arg()=true, py::arg()=sec_level_type::tc128) + .def("get_context_data", &SEALContext::get_context_data) + .def("key_context_data", &SEALContext::key_context_data) + .def("first_context_data", &SEALContext::first_context_data) + .def("last_context_data", &SEALContext::last_context_data) + .def("parameters_set", &SEALContext::parameters_set) + .def("first_parms_id", &SEALContext::first_parms_id) + .def("last_parms_id", &SEALContext::last_parms_id) + .def("using_keyswitching", &SEALContext::using_keyswitching) + .def("from_cipher_str", [](const SEALContext &context, const std::string &str){ + Ciphertext cipher; + std::stringstream in(std::ios::binary | std::ios::in); + in.str(str); + cipher.load(context, in); + return cipher; + }) + .def("from_plain_str", [](const SEALContext &context, const std::string &str){ + Plaintext plain; + std::stringstream in(std::ios::binary | std::ios::in); + in.str(str); + plain.load(context, in); + return plain; + }) + .def("from_secret_str", [](const SEALContext &context, const std::string &str){ + SecretKey secret; + std::stringstream in(std::ios::binary | std::ios::in); + in.str(str); + secret.load(context, in); + return secret; + }) + .def("from_public_str", [](const SEALContext &context, const std::string &str){ + PublicKey public_; + std::stringstream in(std::ios::binary | std::ios::in); + in.str(str); + public_.load(context, in); + return public_; + }) + .def("from_relin_str", [](const SEALContext &context, const std::string &str){ + RelinKeys relin; + std::stringstream in(std::ios::binary | std::ios::in); + in.str(str); + relin.load(context, in); + return relin; + }) + .def("from_galois_str", [](const SEALContext &context, const std::string &str){ + GaloisKeys galois; + std::stringstream in(std::ios::binary | std::ios::in); + in.str(str); + galois.load(context, in); + return galois; + }); + + // modulus.h + py::class_<Modulus>(m, "Modulus") + .def(py::init<std::uint64_t>()) + .def("bit_count", &Modulus::bit_count) + .def("value", &Modulus::value) + .def("is_zero", &Modulus::is_zero) + .def("is_prime", &Modulus::is_prime); + //save & load + + // modulus.h + py::class_<CoeffModulus>(m, "CoeffModulus") + .def_static("MaxBitCount", &CoeffModulus::MaxBitCount, py::arg(), py::arg()=sec_level_type::tc128) + .def_static("BFVDefault", &CoeffModulus::BFVDefault, py::arg(), py::arg()=sec_level_type::tc128) + .def_static("Create", py::overload_cast<std::size_t, std::vector<int>>(&CoeffModulus::Create)) + .def_static("Create", py::overload_cast<std::size_t, const Modulus &, std::vector<int>>(&CoeffModulus::Create)); + + // modulus.h + py::class_<PlainModulus>(m, "PlainModulus") + .def_static("Batching", py::overload_cast<std::size_t, int>(&PlainModulus::Batching)) + .def_static("Batching", py::overload_cast<std::size_t, std::vector<int>>(&PlainModulus::Batching)); + + // plaintext.h + py::class_<Plaintext>(m, "Plaintext") + .def(py::init<>()) + .def(py::init<std::size_t>()) + .def(py::init<std::size_t, std::size_t>()) + .def(py::init<const std::string &>()) + .def(py::init<const Plaintext &>()) + .def("set_zero", py::overload_cast<std::size_t, std::size_t>(&Plaintext::set_zero)) + .def("set_zero", py::overload_cast<std::size_t>(&Plaintext::set_zero)) + .def("set_zero", py::overload_cast<>(&Plaintext::set_zero)) + .def("is_zero", &Plaintext::is_zero) + .def("capacity", &Plaintext::capacity) + .def("coeff_count", &Plaintext::coeff_count) + .def("significant_coeff_count", &Plaintext::significant_coeff_count) + .def("nonzero_coeff_count", &Plaintext::nonzero_coeff_count) + .def("to_string", &Plaintext::to_string) + .def("is_ntt_form", &Plaintext::is_ntt_form) + .def("parms_id", py::overload_cast<>(&Plaintext::parms_id, py::const_)) + .def("scale", py::overload_cast<>(&Plaintext::scale, py::const_)) + .def("scale", [](Plaintext &plain, double scale){ + plain.scale() = scale; + }) + .def("save", [](const Plaintext &plain, const std::string &path){ + std::ofstream out(path, std::ios::binary); + plain.save(out); + out.close(); + }) + .def("load", [](Plaintext &plain, const SEALContext &context, const std::string &path){ + std::ifstream in(path, std::ios::binary); + plain.load(context, in); + in.close(); + }) + .def("save_size", [](const Plaintext &plain){ + return plain.save_size(); + }) + .def("to_string", [](const Plaintext &plain){ + std::stringstream out(std::ios::binary | std::ios::out); + plain.save(out); + return py::bytes(out.str()); + }); + + // ciphertext.h + py::class_<Ciphertext>(m, "Ciphertext") + .def(py::init<>()) + .def(py::init<const SEALContext &>()) + .def(py::init<const SEALContext &, parms_id_type>()) + .def(py::init<const SEALContext &, parms_id_type, std::size_t>()) + .def(py::init<const Ciphertext &>()) + .def("coeff_modulus_size", &Ciphertext::coeff_modulus_size) + .def("poly_modulus_degree", &Ciphertext::poly_modulus_degree) + .def("size", &Ciphertext::size) + .def("size_capacity", &Ciphertext::size_capacity) + .def("is_transparent", &Ciphertext::is_transparent) + .def("is_ntt_form", py::overload_cast<>(&Ciphertext::is_ntt_form, py::const_)) + .def("parms_id", py::overload_cast<>(&Ciphertext::parms_id, py::const_)) + .def("scale", py::overload_cast<>(&Ciphertext::scale, py::const_)) + .def("scale", [](Ciphertext &cipher, double scale){ + cipher.scale() = scale; + }) + .def("save", [](const Ciphertext &cipher, const std::string &path){ + std::ofstream out(path, std::ios::binary); + cipher.save(out); + out.close(); + }) + .def("load", [](Ciphertext &cipher, const SEALContext &context, const std::string &path){ + std::ifstream in(path, std::ios::binary); + cipher.load(context, in); + in.close(); + }) + .def("save_size", [](const Ciphertext &cipher){ + return cipher.save_size(); + }) + .def("to_string", [](const Ciphertext &cipher){ + std::stringstream out(std::ios::binary | std::ios::out); + cipher.save(out); + return py::bytes(out.str()); + }); + + // secretkey.h + py::class_<SecretKey>(m, "SecretKey") + .def(py::init<>()) + .def(py::init<const SecretKey &>()) + .def("parms_id", py::overload_cast<>(&SecretKey::parms_id, py::const_)) + .def("save", [](const SecretKey &sk, const std::string &path){ + std::ofstream out(path, std::ios::binary); + sk.save(out); + out.close(); + }) + .def("load", [](SecretKey &sk, const SEALContext &context, const std::string &path){ + std::ifstream in(path, std::ios::binary); + sk.load(context, in); + in.close(); + }) + .def("to_string", [](const SecretKey &secret){ + std::stringstream out(std::ios::binary | std::ios::out); + secret.save(out); + return py::bytes(out.str()); + }); + + // publickey.h + py::class_<PublicKey>(m, "PublicKey") + .def(py::init<>()) + .def(py::init<const PublicKey &>()) + .def("parms_id", py::overload_cast<>(&PublicKey::parms_id, py::const_)) + .def("save", [](const PublicKey &pk, const std::string &path){ + std::ofstream out(path, std::ios::binary); + pk.save(out); + out.close(); + }) + .def("load", [](PublicKey &pk, const SEALContext &context, const std::string &path){ + std::ifstream in(path, std::ios::binary); + pk.load(context, in); + in.close(); + }) + .def("to_string", [](const PublicKey &public_){ + std::stringstream out(std::ios::binary | std::ios::out); + public_.save(out); + return py::bytes(out.str()); + }); + + // kswitchkeys.h + py::class_<KSwitchKeys>(m, "KSwitchKeys") + .def(py::init<>()) + .def(py::init<const KSwitchKeys &>()) + .def("size", &KSwitchKeys::size) + .def("parms_id", py::overload_cast<>(&KSwitchKeys::parms_id, py::const_)) + .def("save", [](const KSwitchKeys &ksk, const std::string &path){ + std::ofstream out(path, std::ios::binary); + ksk.save(out); + out.close(); + }) + .def("load", [](KSwitchKeys &ksk, const SEALContext &context, const std::string &path){ + std::ifstream in(path, std::ios::binary); + ksk.load(context, in); + in.close(); + }); + + // relinkeys.h + py::class_<RelinKeys, KSwitchKeys>(m, "RelinKeys") + .def(py::init<>()) + .def(py::init<const RelinKeys::KSwitchKeys &>()) + .def("size", &RelinKeys::KSwitchKeys::size) + .def("parms_id", py::overload_cast<>(&RelinKeys::KSwitchKeys::parms_id, py::const_)) + .def_static("get_index", &RelinKeys::get_index) + .def("has_key", &RelinKeys::has_key) + .def("save", [](const RelinKeys &rk, const std::string &path){ + std::ofstream out(path, std::ios::binary); + rk.save(out); + out.close(); + }) + .def("load", [](RelinKeys &rk, const SEALContext &context, const std::string &path){ + std::ifstream in(path, std::ios::binary); + rk.load(context, in); + in.close(); + }) + .def("to_string", [](const RelinKeys &relin){ + std::stringstream out(std::ios::binary | std::ios::out); + relin.save(out); + return py::bytes(out.str()); + }); + + // galoiskeys.h + py::class_<GaloisKeys, KSwitchKeys>(m, "GaloisKeys") + .def(py::init<>()) + .def(py::init<const GaloisKeys::KSwitchKeys &>()) + .def("size", &GaloisKeys::KSwitchKeys::size) + .def("parms_id", py::overload_cast<>(&GaloisKeys::KSwitchKeys::parms_id, py::const_)) + .def_static("get_index", &GaloisKeys::get_index) + .def("has_key", &GaloisKeys::has_key) + .def("save", [](const GaloisKeys &gk, const std::string &path){ + std::ofstream out(path, std::ios::binary); + gk.save(out); + out.close(); + }) + .def("load", [](GaloisKeys &gk, const SEALContext &context, const std::string &path){ + std::ifstream in(path, std::ios::binary); + gk.load(context, in); + in.close(); + }) + .def("to_string", [](const GaloisKeys &galois){ + std::stringstream out(std::ios::binary | std::ios::out); + galois.save(out); + return py::bytes(out.str()); + }); + + // keygenerator.h + py::class_<KeyGenerator>(m, "KeyGenerator") + .def(py::init<const SEALContext &>()) + .def(py::init<const SEALContext &, const SecretKey &>()) + .def("secret_key", &KeyGenerator::secret_key) + .def("create_public_key", py::overload_cast<PublicKey &>(&KeyGenerator::create_public_key, py::const_)) + .def("create_relin_keys", py::overload_cast<RelinKeys &>(&KeyGenerator::create_relin_keys)) + .def("create_galois_keys", py::overload_cast<const std::vector<int> &, GaloisKeys &>(&KeyGenerator::create_galois_keys)) + .def("create_galois_keys", py::overload_cast<GaloisKeys &>(&KeyGenerator::create_galois_keys)) + .def("create_public_key", [](KeyGenerator &keygen){ + PublicKey pk; + keygen.create_public_key(pk); + return pk; + }) + .def("create_relin_keys", [](KeyGenerator &keygen){ + RelinKeys rk; + keygen.create_relin_keys(rk); + return rk; + }) + .def("create_galois_keys", [](KeyGenerator &keygen){ + GaloisKeys gk; + keygen.create_galois_keys(gk); + return gk; + }); + + // encryptor.h + py::class_<Encryptor>(m, "Encryptor") + .def(py::init<const SEALContext &, const PublicKey &>()) + .def(py::init<const SEALContext &, const SecretKey &>()) + .def(py::init<const SEALContext &, const PublicKey &, const SecretKey &>()) + .def("set_public_key", &Encryptor::set_public_key) + .def("set_secret_key", &Encryptor::set_secret_key) + .def("encrypt_zero", [](const Encryptor &encryptor){ + Ciphertext encrypted; + encryptor.encrypt_zero(encrypted); + return encrypted; + }) + .def("encrypt", [](const Encryptor &encryptor, const Plaintext &plain){ + Ciphertext encrypted; + encryptor.encrypt(plain, encrypted); + return encrypted; + }); + // symmetric + + // evaluator.h + py::class_<Evaluator>(m, "Evaluator") + .def(py::init<const SEALContext &>()) + .def("negate_inplace", &Evaluator::negate_inplace) + .def("negate", [](Evaluator &evaluator, const Ciphertext &encrypted1){ + Ciphertext destination; + evaluator.negate(encrypted1, destination); + return destination; + }) + .def("add_inplace", &Evaluator::add_inplace) + .def("add", [](Evaluator &evaluator, const Ciphertext &encrypted1, const Ciphertext &encrypted2){ + Ciphertext destination; + evaluator.add(encrypted1, encrypted2, destination); + return destination; + }) + .def("add_many", [](Evaluator &evaluator, const std::vector<Ciphertext> &encrypteds){ + Ciphertext destination; + evaluator.add_many(encrypteds, destination); + return destination; + }) + .def("sub_inplace", &Evaluator::sub_inplace) + .def("sub", [](Evaluator &evaluator, const Ciphertext &encrypted1, const Ciphertext &encrypted2){ + Ciphertext destination; + evaluator.sub(encrypted1, encrypted2, destination); + return destination; + }) + .def("multiply_inplace", [](Evaluator &evaluator, Ciphertext &encrypted1, const Ciphertext &encrypted2){ + evaluator.multiply_inplace(encrypted1, encrypted2); + }) + .def("multiply", [](Evaluator &evaluator, const Ciphertext &encrypted1, const Ciphertext &encrypted2){ + Ciphertext destination; + evaluator.multiply(encrypted1, encrypted2, destination); + return destination; + }) + .def("square_inplace", [](Evaluator &evaluator, Ciphertext &encrypted1){ + evaluator.square_inplace(encrypted1); + }) + .def("square", [](Evaluator &evaluator, const Ciphertext &encrypted1){ + Ciphertext destination; + evaluator.square(encrypted1, destination); + return destination; + }) + .def("relinearize_inplace", [](Evaluator &evaluator, Ciphertext &encrypted1, const RelinKeys &relin_keys){ + evaluator.relinearize_inplace(encrypted1, relin_keys); + }) + .def("relinearize", [](Evaluator &evaluator, const Ciphertext &encrypted1, const RelinKeys &relin_keys){ + Ciphertext destination; + evaluator.relinearize(encrypted1, relin_keys, destination); + return destination; + }) + .def("mod_switch_to_next", [](Evaluator &evaluator, const Ciphertext &encrypted){ + Ciphertext destination; + evaluator.mod_switch_to_next(encrypted, destination); + return destination; + }) + .def("mod_switch_to_next_inplace", [](Evaluator &evaluator, Ciphertext &encrypted){ + evaluator.mod_switch_to_next_inplace(encrypted); + }) + .def("mod_switch_to_next_inplace", py::overload_cast<Plaintext &>(&Evaluator::mod_switch_to_next_inplace, py::const_)) + .def("mod_switch_to_next", [](Evaluator &evaluator, const Plaintext &plain){ + Plaintext destination; + evaluator.mod_switch_to_next(plain, destination); + return destination; + }) + .def("mod_switch_to_inplace", [](Evaluator &evaluator, Ciphertext &encrypted, parms_id_type parms_id){ + evaluator.mod_switch_to_inplace(encrypted, parms_id); + }) + .def("mod_switch_to", [](Evaluator &evaluator, const Ciphertext &encrypted, parms_id_type parms_id){ + Ciphertext destination; + evaluator.mod_switch_to(encrypted, parms_id, destination); + return destination; + }) + .def("mod_switch_to_inplace", py::overload_cast<Plaintext &, parms_id_type>(&Evaluator::mod_switch_to_inplace, py::const_)) + .def("mod_switch_to", [](Evaluator &evaluator, const Plaintext &plain, parms_id_type parms_id){ + Plaintext destination; + evaluator.mod_switch_to(plain, parms_id, destination); + return destination; + }) + .def("rescale_to_next", [](Evaluator &evaluator, const Ciphertext &encrypted){ + Ciphertext destination; + evaluator.rescale_to_next(encrypted, destination); + return destination; + }) + .def("rescale_to_next_inplace", [](Evaluator &evaluator, Ciphertext &encrypted){ + evaluator.rescale_to_next_inplace(encrypted); + }) + .def("rescale_to_inplace", [](Evaluator &evaluator, Ciphertext &encrypted, parms_id_type parms_id){ + evaluator.rescale_to_inplace(encrypted, parms_id); + }) + .def("rescale_to", [](Evaluator &evaluator, const Ciphertext &encrypted, parms_id_type parms_id){ + Ciphertext destination; + evaluator.rescale_to(encrypted, parms_id, destination); + return destination; + }) + .def("multiply_many", [](Evaluator &evaluator, const std::vector<Ciphertext> &encrypteds, const RelinKeys &relin_keys){ + Ciphertext destination; + evaluator.multiply_many(encrypteds, relin_keys, destination); + return destination; + }) + .def("exponentiate_inplace", [](Evaluator &evaluator, Ciphertext &encrypted, std::uint64_t exponent, const RelinKeys &relin_keys){ + evaluator.exponentiate_inplace(encrypted, exponent, relin_keys); + }) + .def("exponentiate", [](Evaluator &evaluator, const Ciphertext &encrypted, std::uint64_t exponent, const RelinKeys &relin_keys){ + Ciphertext destination; + evaluator.exponentiate(encrypted, exponent, relin_keys, destination); + return destination; + }) + .def("add_plain_inplace", &Evaluator::add_plain_inplace) + .def("add_plain", [](Evaluator &evaluator, const Ciphertext &encrypted, const Plaintext &plain){ + Ciphertext destination; + evaluator.add_plain(encrypted, plain, destination); + return destination; + }) + .def("sub_plain_inplace", &Evaluator::sub_plain_inplace) + .def("sub_plain", [](Evaluator &evaluator, const Ciphertext &encrypted, const Plaintext &plain){ + Ciphertext destination; + evaluator.sub_plain(encrypted, plain, destination); + return destination; + }) + .def("multiply_plain_inplace", [](Evaluator &evaluator, Ciphertext &encrypted, const Plaintext &plain){ + evaluator.multiply_plain_inplace(encrypted, plain); + }) + .def("multiply_plain", [](Evaluator &evaluator, const Ciphertext &encrypted, const Plaintext &plain){ + Ciphertext destination; + evaluator.multiply_plain(encrypted, plain, destination); + return destination; + }) + .def("transform_to_ntt_inplace", [](Evaluator &evaluator, Plaintext &plain, parms_id_type parms_id){ + evaluator.transform_to_ntt_inplace(plain,parms_id); + }) + .def("transform_to_ntt", [](Evaluator &evaluator, const Plaintext &plain, parms_id_type parms_id){ + Plaintext destination_ntt; + evaluator.transform_to_ntt(plain, parms_id, destination_ntt); + return destination_ntt; + }) + .def("transform_to_ntt_inplace", py::overload_cast<Ciphertext &>(&Evaluator::transform_to_ntt_inplace, py::const_)) + .def("transform_to_ntt", [](Evaluator &evaluator, const Ciphertext &encrypted){ + Ciphertext destination_ntt; + evaluator.transform_to_ntt(encrypted, destination_ntt); + return destination_ntt; + }) + .def("transform_from_ntt_inplace", &Evaluator::transform_from_ntt_inplace) + .def("transform_from_ntt", [](Evaluator &evaluator, const Ciphertext &encrypted_ntt){ + Ciphertext destination; + evaluator.transform_from_ntt(encrypted_ntt, destination); + return destination; + }) + .def("apply_galois_inplace", [](Evaluator &evaluator, Ciphertext &encrypted, std::uint32_t galois_elt, const GaloisKeys &galois_keys){ + evaluator.apply_galois_inplace(encrypted, galois_elt, galois_keys); + }) + .def("apply_galois", [](Evaluator &evaluator, const Ciphertext &encrypted, std::uint32_t galois_elt, const GaloisKeys &galois_keys){ + Ciphertext destination; + evaluator.apply_galois(encrypted, galois_elt, galois_keys, destination); + return destination; + }) + .def("rotate_rows_inplace", [](Evaluator &evaluator, Ciphertext &encrypted, int steps, const GaloisKeys &galois_keys){ + evaluator.rotate_rows_inplace(encrypted, steps, galois_keys); + }) + .def("rotate_rows", [](Evaluator &evaluator, const Ciphertext &encrypted, int steps, const GaloisKeys &galois_keys){ + Ciphertext destination; + evaluator.rotate_rows(encrypted, steps, galois_keys, destination); + return destination; + }) + .def("rotate_columns_inplace", [](Evaluator &evaluator, Ciphertext &encrypted, const GaloisKeys &galois_keys){ + evaluator.rotate_columns_inplace(encrypted, galois_keys); + }) + .def("rotate_columns", [](Evaluator &evaluator, const Ciphertext &encrypted, const GaloisKeys &galois_keys){ + Ciphertext destination; + evaluator.rotate_columns(encrypted, galois_keys, destination); + return destination; + }) + .def("rotate_vector_inplace", [](Evaluator &evaluator, Ciphertext &encrypted, int steps, const GaloisKeys &galois_keys){ + evaluator.rotate_vector_inplace(encrypted, steps, galois_keys); + }) + .def("rotate_vector", [](Evaluator &evaluator, const Ciphertext &encrypted, int steps, const GaloisKeys &galois_keys){ + Ciphertext destination; + evaluator.rotate_vector(encrypted, steps, galois_keys, destination); + return destination; + }) + .def("complex_conjugate_inplace", [](Evaluator &evaluator, Ciphertext &encrypted, const GaloisKeys &galois_keys){ + evaluator.complex_conjugate_inplace(encrypted, galois_keys); + }) + .def("complex_conjugate", [](Evaluator &evaluator, const Ciphertext &encrypted, const GaloisKeys &galois_keys){ + Ciphertext destination; + evaluator.complex_conjugate(encrypted, galois_keys, destination); + return destination; + }); + + // ckks.h + py::class_<CKKSEncoder>(m, "CKKSEncoder") + .def(py::init<const SEALContext &>()) + .def("slot_count", &CKKSEncoder::slot_count) + .def("encode", [](CKKSEncoder &encoder, py::array_t<double> values, double scale){ + py::buffer_info buf = values.request(); + if (buf.ndim != 1) + throw std::runtime_error("E101: Number of dimensions must be one"); + + double *ptr = (double *)buf.ptr; + std::vector<double> vec(buf.shape[0]); + + for (auto i = 0; i < buf.shape[0]; i++) + vec[i] = ptr[i]; + + Plaintext pt; + encoder.encode(vec, scale, pt); + return pt; + }) + .def("encode", [](CKKSEncoder &encoder, double value, double scale){ + Plaintext pt; + encoder.encode(value, scale, pt); + return pt; + }) + .def("decode", [](CKKSEncoder &encoder, const Plaintext &plain){ + std::vector<double> destination; + encoder.decode(plain, destination); + + py::array_t<double> values(destination.size()); + py::buffer_info buf = values.request(); + double *ptr = (double *)buf.ptr; + + for (auto i = 0; i < buf.shape[0]; i++) + ptr[i] = destination[i]; + + return values; + }); + + // decryptor.h + py::class_<Decryptor>(m, "Decryptor") + .def(py::init<const SEALContext &, const SecretKey &>()) + .def("decrypt", &Decryptor::decrypt) + .def("invariant_noise_budget", &Decryptor::invariant_noise_budget) + .def("decrypt", [](Decryptor &decryptor, const Ciphertext &encrypted){ + Plaintext pt; + decryptor.decrypt(encrypted, pt); + return pt; + }); + + // batchencoder.h + py::class_<BatchEncoder>(m, "BatchEncoder") + .def(py::init<const SEALContext &>()) + .def("slot_count", &BatchEncoder::slot_count) + .def("encode", [](BatchEncoder &encoder, py::array_t<std::int64_t> values){ + py::buffer_info buf = values.request(); + if (buf.ndim != 1) + throw std::runtime_error("E101: Number of dimensions must be one"); + + std::int64_t *ptr = (std::int64_t *)buf.ptr; + std::vector<std::int64_t> vec(buf.shape[0]); + + for (auto i = 0; i < buf.shape[0]; i++) + vec[i] = ptr[i]; + + Plaintext pt; + encoder.encode(vec, pt); + return pt; + }) + .def("decode", [](BatchEncoder &encoder, const Plaintext &plain){ + std::vector<std::int64_t> destination; + encoder.decode(plain, destination); + + py::array_t<std::int64_t> values(destination.size()); + py::buffer_info buf = values.request(); + std::int64_t *ptr = (std::int64_t *)buf.ptr; + + for (auto i = 0; i < buf.shape[0]; i++) + ptr[i] = destination[i]; + + return values; + }); } diff --git a/tests/0_data_type.py b/tests/0_data_type.py deleted file mode 100644 index 3a50632..0000000 --- a/tests/0_data_type.py +++ /dev/null @@ -1,80 +0,0 @@ -from seal import * -from seal_helper import * -import numpy as np -import pickle - - -def example_data_type(): - a = [0.1, 0.3, 1.01, 0.2] - b = DoubleVector(a) - print(a) # [0.1, 0.3, 1.01, 0.2] - c = np.array(b) - print(c) # [0.1 0.3 1.01 0.2 ] - - d = IntVector([0]*10) - print(len(d)) # 10 - d[4] = 1 - e = np.array(d) - print(e) # [0 0 0 0 1 0 0 0 0 0] - - -def example_pickle_save(out, path="temp"): - with open(path, "wb") as write: - pickle.dump(out, write) - - -def example_pickle_load(path="temp"): - with open(path, "rb") as read: - return pickle.load(read) - - -def example_serialize(): - print_example_banner("Example: pickle & save & load") - parms = EncryptionParameters(scheme_type.BFV) - - poly_modulus_degree = 4096 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - parms.set_plain_modulus(256) - - context = SEALContext.Create(parms) - - print("-" * 50) - print("Set encryption parameters and print") - print_parameters(context) - - keygen = KeyGenerator(context) - public_key = keygen.public_key() - secret_key = keygen.secret_key() - - encryptor = Encryptor(context, public_key) - evaluator = Evaluator(context) - decryptor = Decryptor(context, secret_key) - - print("-" * 50) - x = "6" - x_plain = Plaintext(x) - print("Express x = " + x + " as a plaintext polynomial 0x" + - x_plain.to_string() + ".") - - print("-" * 50) - x_save = Ciphertext() - print("Encrypt x_plain to x_save.") - encryptor.encrypt(x_plain, x_save) - - print("\nx_save scale: %.1f" % x_save.scale()) - print("x_save parms_id: ", end="") - print(x_save.parms_id()) - - x_save.save("temp") - x_read = Ciphertext() - x_read.load(context, "temp") - - print("\nx_read scale: %.1f" % x_read.scale()) - print("x_read parms_id: ", end="") - print(x_read.parms_id()) - - -if __name__ == '__main__': - example_data_type() - example_serialize() diff --git a/tests/1_bfv_basics.py b/tests/1_bfv_basics.py deleted file mode 100644 index f5be447..0000000 --- a/tests/1_bfv_basics.py +++ /dev/null @@ -1,155 +0,0 @@ -from seal import * -from seal_helper import * - - -def example_bfv_basics(): - print_example_banner("Example: BFV Basics") - parms = EncryptionParameters(scheme_type.BFV) - - poly_modulus_degree = 4096 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - parms.set_plain_modulus(256) - - context = SEALContext.Create(parms) - - print("-" * 50) - print("Set encryption parameters and print") - print_parameters(context) - print("~~~~~~ A naive way to calculate 2(x^2+1)(x+1)^2. ~~~~~~") - - keygen = KeyGenerator(context) - public_key = keygen.public_key() - secret_key = keygen.secret_key() - - encryptor = Encryptor(context, public_key) - - evaluator = Evaluator(context) - - decryptor = Decryptor(context, secret_key) - - print("-" * 50) - x = "6" - x_plain = Plaintext(x) - print("Express x = " + x + " as a plaintext polynomial 0x" + - x_plain.to_string() + ".") - - print("-" * 50) - x_encrypted = Ciphertext() - print("Encrypt x_plain to x_encrypted.") - encryptor.encrypt(x_plain, x_encrypted) - - print(" + size of freshly encrypted x: " + str(x_encrypted.size())) - - print(" + noise budget in freshly encrypted x: " + - str(decryptor.invariant_noise_budget(x_encrypted)) + " bits") - - x_decrypted = Plaintext() - print(" + decryption of x_encrypted: ", end="") - decryptor.decrypt(x_encrypted, x_decrypted) - print("0x" + x_decrypted.to_string() + " ...... Correct.") - - print("-"*50) - print("Compute x_sq_plus_one (x^2+1).") - - x_sq_plus_one = Ciphertext() - evaluator.square(x_encrypted, x_sq_plus_one) - plain_one = Plaintext("1") - evaluator.add_plain_inplace(x_sq_plus_one, plain_one) - - print(" + size of x_sq_plus_one: " + str(x_sq_plus_one.size())) - print(" + noise budget in x_sq_plus_one: " + - str(decryptor.invariant_noise_budget(x_sq_plus_one)) + " bits") - - decrypted_result = Plaintext() - print(" + decryption of x_sq_plus_one: ", end="") - decryptor.decrypt(x_sq_plus_one, decrypted_result) - print("0x" + decrypted_result.to_string() + " ...... Correct.") - - ''' - Next, we compute (x + 1)^2. - ''' - print("-"*50) - print("Compute x_plus_one_sq ((x+1)^2).") - x_plus_one_sq = Ciphertext() - evaluator.add_plain(x_encrypted, plain_one, x_plus_one_sq) - evaluator.square_inplace(x_plus_one_sq) - print(" + size of x_plus_one_sq: " + str(x_plus_one_sq.size())) - print(" + noise budget in x_plus_one_sq: " + - str(decryptor.invariant_noise_budget(x_plus_one_sq)) + " bits") - decryptor.decrypt(x_plus_one_sq, decrypted_result) - print(" + decryption of x_plus_one_sq: 0x" + - decrypted_result.to_string() + " ...... Correct.") - - ''' - Finally, we multiply (x^2 + 1) * (x + 1)^2 * 2. - ''' - print("-"*50) - print("Compute encrypted_result (2(x^2+1)(x+1)^2).") - encrypted_result = Ciphertext() - plain_two = Plaintext("2") - evaluator.multiply_plain_inplace(x_sq_plus_one, plain_two) - evaluator.multiply(x_sq_plus_one, x_plus_one_sq, encrypted_result) - print(" + size of encrypted_result: " + str(encrypted_result.size())) - print(" + noise budget in encrypted_result: " + - str(decryptor.invariant_noise_budget(encrypted_result)) + " bits") - print("NOTE: Decryption can be incorrect if noise budget is zero.") - print("\n~~~~~~ A better way to calculate 2(x^2+1)(x+1)^2. ~~~~~~") - - print("-"*50) - print("Generate relinearization keys.") - relin_keys = keygen.relin_keys() - - ''' - We now repeat the computation relinearizing after each multiplication. - ''' - print("-"*50) - print("Compute and relinearize x_squared (x^2),") - print(" "*13 + "then compute x_sq_plus_one (x^2+1)") - x_squared = Ciphertext() - evaluator.square(x_encrypted, x_squared) - print(" + size of x_squared: " + str(x_squared.size())) - evaluator.relinearize_inplace(x_squared, relin_keys) - print(" + size of x_squared (after relinearization): " + str(x_squared.size())) - evaluator.add_plain(x_squared, plain_one, x_sq_plus_one) - print(" + noise budget in x_sq_plus_one: " + - str(decryptor.invariant_noise_budget(x_sq_plus_one)) + " bits") - decryptor.decrypt(x_sq_plus_one, decrypted_result) - print(" + decryption of x_sq_plus_one: 0x" + - decrypted_result.to_string() + " ...... Correct.") - - print("-"*50) - x_plus_one = Ciphertext() - print("Compute x_plus_one (x+1),") - print(" "*13 + "then compute and relinearize x_plus_one_sq ((x+1)^2).") - evaluator.add_plain(x_encrypted, plain_one, x_plus_one) - evaluator.square(x_plus_one, x_plus_one_sq) - print(" + size of x_plus_one_sq: " + str(x_plus_one_sq.size())) - evaluator.relinearize_inplace(x_plus_one_sq, relin_keys) - print(" + noise budget in x_plus_one_sq: " + - str(decryptor.invariant_noise_budget(x_plus_one_sq)) + " bits") - decryptor.decrypt(x_plus_one_sq, decrypted_result) - print(" + decryption of x_plus_one_sq: 0x" + - decrypted_result.to_string() + " ...... Correct.") - - print("-"*50) - print("Compute and relinearize encrypted_result (2(x^2+1)(x+1)^2).") - evaluator.multiply_plain_inplace(x_sq_plus_one, plain_two) - evaluator.multiply(x_sq_plus_one, x_plus_one_sq, encrypted_result) - print(" + size of encrypted_result: " + str(encrypted_result.size())) - evaluator.relinearize_inplace(encrypted_result, relin_keys) - print(" + size of encrypted_result (after relinearization): " + - str(encrypted_result.size())) - print(" + noise budget in encrypted_result: " + - str(decryptor.invariant_noise_budget(encrypted_result)) + " bits") - print("\nNOTE: Notice the increase in remaining noise budget.") - - print("-"*50) - print("Decrypt encrypted_result (2(x^2+1)(x+1)^2).") - decryptor.decrypt(encrypted_result, decrypted_result) - print(" + decryption of 2(x^2+1)(x+1)^2 = 0x" + - decrypted_result.to_string() + " ...... Correct.") - - -if __name__ == '__main__': - example_bfv_basics() diff --git a/tests/2_encoders.py b/tests/2_encoders.py deleted file mode 100644 index d270b2f..0000000 --- a/tests/2_encoders.py +++ /dev/null @@ -1,307 +0,0 @@ -import math -from seal import * -from seal_helper import * - - -def example_integer_encoder(): - print_example_banner("Example: Encoders / Integer Encoder") - parms = EncryptionParameters(scheme_type.BFV) - poly_modulus_degree = 4096 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - parms.set_plain_modulus(512) - context = SEALContext.Create(parms) - print_parameters(context) - - keygen = KeyGenerator(context) - public_key = keygen.public_key() - secret_key = keygen.secret_key() - encryptor = Encryptor(context, public_key) - evaluator = Evaluator(context) - decryptor = Decryptor(context, secret_key) - encoder = IntegerEncoder(context) - value1 = 5 - plain1 = Plaintext(encoder.encode(value1)) - print("-" * 50) - print("Encode " + str(value1) + " as polynomial " + - plain1.to_string() + " (plain1),") - value2 = -7 - plain2 = Plaintext(encoder.encode(value2)) - print("encode " + str(value2) + " as polynomial " + - plain2.to_string() + " (plain2).") - - encrypted1 = Ciphertext() - encrypted2 = Ciphertext() - print("-" * 50) - print("Encrypt plain1 to encrypted1 and plain2 to encrypted2.") - encryptor.encrypt(plain1, encrypted1) - encryptor.encrypt(plain2, encrypted2) - print(" + Noise budget in encrypted1: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted1) + " bits") - print(" + Noise budget in encrypted2: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted2) + " bits") - - encryptor.encrypt(plain2, encrypted2) - encrypted_result = Ciphertext() - print("-" * 50) - print("Compute encrypted_result = (-encrypted1 + encrypted2) * encrypted2.") - evaluator.negate(encrypted1, encrypted_result) - evaluator.add_inplace(encrypted_result, encrypted2) - evaluator.multiply_inplace(encrypted_result, encrypted2) - print(" + Noise budget in encrypted_result: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted_result) + " bits") - plain_result = Plaintext() - print("-" * 50) - print("Decrypt encrypted_result to plain_result.") - decryptor.decrypt(encrypted_result, plain_result) - print(" + Plaintext polynomial: " + plain_result.to_string()) - print("-" * 50) - print("Decode plain_result.") - print(" + Decoded integer: " + - str(encoder.decode_int32(plain_result)) + "...... Correct.") - - -def example_batch_encoder(): - print_example_banner("Example: Encoders / Batch Encoder") - parms = EncryptionParameters(scheme_type.BFV) - poly_modulus_degree = 8192 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - parms.set_plain_modulus(PlainModulus.Batching(poly_modulus_degree, 20)) - context = SEALContext.Create(parms) - print_parameters(context) - - qualifiers = context.first_context_data().qualifiers() - print("Batching enabled: " + str(qualifiers.using_batching)) - - keygen = KeyGenerator(context) - public_key = keygen.public_key() - secret_key = keygen.secret_key() - relin_keys = keygen.relin_keys() - encryptor = Encryptor(context, public_key) - evaluator = Evaluator(context) - decryptor = Decryptor(context, secret_key) - - batch_encoder = BatchEncoder(context) - slot_count = batch_encoder.slot_count() - row_size = int(slot_count / 2) - print("Plaintext matrix row size: " + str(row_size)) - - pod_matrixs = [0] * slot_count - pod_matrixs[0] = 0 - pod_matrixs[1] = 1 - pod_matrixs[2] = 2 - pod_matrixs[3] = 3 - pod_matrixs[row_size] = 4 - pod_matrixs[row_size + 1] = 5 - pod_matrixs[row_size + 2] = 6 - pod_matrixs[row_size + 3] = 7 - - pod_matrix = uIntVector(pod_matrixs) - - print("Input plaintext matrix:") - print_matrix(pod_matrix, row_size) - - plain_matrix = Plaintext() - print("-" * 50) - print("Encode plaintext matrix:") - batch_encoder.encode(pod_matrix, plain_matrix) - - pod_result = uIntVector() - print(" + Decode plaintext matrix ...... Correct.") - - batch_encoder.decode(plain_matrix, pod_result) - print_matrix(pod_result, row_size) - - encrypted_matrix = Ciphertext() - print("-" * 50) - print("Encrypt plain_matrix to encrypted_matrix.") - encryptor.encrypt(plain_matrix, encrypted_matrix) - print(" + Noise budget in encrypted_matrix: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted_matrix) + " bits") - - pod_matrix2 = uIntVector() - for i in range(slot_count): - pod_matrix2.append((i % 2) + 1) - - plain_matrix2 = Plaintext() - batch_encoder.encode(pod_matrix2, plain_matrix2) - print("Second input plaintext matrix:") - print_matrix(pod_matrix2, row_size) - - print("-" * 50) - print("Sum, square, and relinearize.") - evaluator.add_plain_inplace(encrypted_matrix, plain_matrix2) - evaluator.square_inplace(encrypted_matrix) - evaluator.relinearize_inplace(encrypted_matrix, relin_keys) - - print(" + Noise budget in result: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted_matrix) + " bits") - - plain_result = Plaintext() - print("-" * 50) - print("Decrypt and decode result.") - decryptor.decrypt(encrypted_matrix, plain_result) - batch_encoder.decode(plain_result, pod_result) - print(" + Result plaintext matrix ...... Correct.") - print_matrix(pod_result, row_size) - - -def example_ckks_encoder(): - print_example_banner("Example: Encoders / CKKS Encoder") - - ''' - [CKKSEncoder] (For CKKS scheme only) - - In this example we demonstrate the Cheon-Kim-Kim-Song (CKKS) scheme for - computing on encrypted real or complex numbers. We start by creating - encryption parameters for the CKKS scheme. There are two important - differences compared to the BFV scheme: - - (1) CKKS does not use the plain_modulus encryption parameter; - (2) Selecting the coeff_modulus in a specific way can be very important - when using the CKKS scheme. We will explain this further in the file - `ckks_basics.cpp'. In this example we use CoeffModulus::Create to - generate 5 40-bit prime numbers. - ''' - - parms = EncryptionParameters(scheme_type.CKKS) - poly_modulus_degree = 8192 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.Create( - poly_modulus_degree, [40, 40, 40, 40, 40])) - - ''' - We create the SEALContext as usual and print the parameters. - ''' - - context = SEALContext.Create(parms) - print_parameters(context) - - ''' - Keys are created the same way as for the BFV scheme. - ''' - - keygen = KeyGenerator(context) - public_key = keygen.public_key() - secret_key = keygen.secret_key() - relin_keys = keygen.relin_keys() - - ''' - We also set up an Encryptor, Evaluator, and Decryptor as usual. - ''' - - encryptor = Encryptor(context, public_key) - evaluator = Evaluator(context) - decryptor = Decryptor(context, secret_key) - - ''' - To create CKKS plaintexts we need a special encoder: there is no other way - to create them. The IntegerEncoder and BatchEncoder cannot be used with the - CKKS scheme. The CKKSEncoder encodes vectors of real or complex numbers into - Plaintext objects, which can subsequently be encrypted. At a high level this - looks a lot like what BatchEncoder does for the BFV scheme, but the theory - behind it is completely different. - ''' - - encoder = CKKSEncoder(context) - - ''' - In CKKS the number of slots is poly_modulus_degree / 2 and each slot encodes - one real or complex number. This should be contrasted with BatchEncoder in - the BFV scheme, where the number of slots is equal to poly_modulus_degree - and they are arranged into a matrix with two rows. - ''' - - slot_count = encoder.slot_count() - print("Number of slots: " + str(slot_count)) - - ''' - We create a small vector to encode; the CKKSEncoder will implicitly pad it - with zeros to full size (poly_modulus_degree / 2) when encoding. - ''' - - inputs = DoubleVector([0.0, 1.1, 2.2, 3.3]) - - print("Input vector: ") - print_vector(inputs) - - ''' - Now we encode it with CKKSEncoder. The floating-point coefficients of `input' - will be scaled up by the parameter `scale'. This is necessary since even in - the CKKS scheme the plaintext elements are fundamentally polynomials with - integer coefficients. It is instructive to think of the scale as determining - the bit-precision of the encoding; naturally it will affect the precision of - the result. - - In CKKS the message is stored modulo coeff_modulus (in BFV it is stored modulo - plain_modulus), so the scaled message must not get too close to the total size - of coeff_modulus. In this case our coeff_modulus is quite large (218 bits) so - we have little to worry about in this regard. For this simple example a 30-bit - scale is more than enough. - ''' - - plain = Plaintext() - scale = pow(2.0, 30) - print("-" * 50) - - print("Encode input vector.") - encoder.encode(inputs, scale, plain) - - ''' - We can instantly decode to check the correctness of encoding. - ''' - - output = DoubleVector() - print(" + Decode input vector ...... Correct.") - encoder.decode(plain, output) - print_vector(output) - - ''' - The vector is encrypted the same was as in BFV. - ''' - - encrypted = Ciphertext() - print("-" * 50) - print("Encrypt input vector, square, and relinearize.") - encryptor.encrypt(plain, encrypted) - - ''' - Basic operations on the ciphertexts are still easy to do. Here we square the - ciphertext, decrypt, decode, and print the result. We note also that decoding - returns a vector of full size (poly_modulus_degree / 2); this is because of - the implicit zero-padding mentioned above. - ''' - - evaluator.square_inplace(encrypted) - evaluator.relinearize_inplace(encrypted, relin_keys) - - ''' - We notice that the scale in the result has increased. In fact, it is now the - square of the original scale: 2^60. - ''' - - print(" + Scale in squared input: " + str(encrypted.scale()), end="") - print(" (" + "%.0f" % math.log(encrypted.scale(), 2) + " bits)") - - print("-" * 50) - print("Decrypt and decode.") - decryptor.decrypt(encrypted, plain) - encoder.decode(plain, output) - print(" + Result vector ...... Correct.") - print_vector(output) - - ''' - The CKKS scheme allows the scale to be reduced between encrypted computations. - This is a fundamental and critical feature that makes CKKS very powerful and - flexible. We will discuss it in great detail in `3_levels.cpp' and later in - `4_ckks_basics.cpp'. - ''' - - -if __name__ == '__main__': - print_example_banner("Example: Encoders") - - example_integer_encoder() - example_batch_encoder() - example_ckks_encoder() diff --git a/tests/3_levels.py b/tests/3_levels.py deleted file mode 100644 index 3601cc6..0000000 --- a/tests/3_levels.py +++ /dev/null @@ -1,183 +0,0 @@ -from seal import * -from seal_helper import * - - -def print_parms_id(parms_id): - for item in parms_id: - print(str(hex(item)) + " ", end="") - print() - - -def example_levels(): - print_example_banner("Example: Levels") - parms = EncryptionParameters(scheme_type.BFV) - poly_modulus_degree = 8192 - parms.set_poly_modulus_degree(poly_modulus_degree) - - parms.set_coeff_modulus(CoeffModulus.Create( - poly_modulus_degree, [50, 30, 30, 50, 50])) - parms.set_plain_modulus(1 << 20) - context = SEALContext.Create(parms) - print_parameters(context) - - print("-" * 50) - print("Print the modulus switching chain.") - - ''' - First print the key level parameter information. - ''' - context_data = context.key_context_data() - print("----> Level (chain index): " + - str(context_data.chain_index()), end="") - print(" ...... key_context_data()") - print(" parms_id: ", end="") - print_parms_id(context_data.parms_id()) - print(" coeff_modulus primes: ", end="") - for item in context_data.parms().coeff_modulus(): - print(str(hex(item.value())) + " ", end="") - print("\n\\\n \\-->", end="") - - ''' - Next iterate over the remaining (data) levels. - ''' - context_data = context.first_context_data() - while context_data: - print(" Level (chain index): " + str(context_data.chain_index()), end="") - if context_data.parms_id() == context.first_parms_id(): - print(" ...... first_context_data()") - elif context_data.parms_id() == context.last_parms_id(): - print(" ...... last_context_data()") - else: - print() - print(" parms_id: ", end="") - print_parms_id(context_data.parms_id()) - print(" coeff_modulus primes: ", end="") - for item in context_data.parms().coeff_modulus(): - print(str(hex(item.value())) + " ", end="") - print("\n\\\n \\-->", end="") - # Step forward in the chain. - context_data = context_data.next_context_data() - print(" End of chain reached\n") - - ''' - We create some keys and check that indeed they appear at the highest level. - ''' - keygen = KeyGenerator(context) - public_key = keygen.public_key() - secret_key = keygen.secret_key() - relin_keys = keygen.relin_keys() - galois_keys = keygen.galois_keys() - print("-" * 50) - print("Print the parameter IDs of generated elements.") - print(" + public_key: ", end="") - print_parms_id(public_key.parms_id()) - print(" + secret_key: ", end="") - print_parms_id(secret_key.parms_id()) - print(" + relin_keys: ", end="") - print_parms_id(relin_keys.parms_id()) - print(" + galois_keys: ", end="") - print_parms_id(galois_keys.parms_id()) - - encryptor = Encryptor(context, public_key) - evaluator = Evaluator(context) - decryptor = Decryptor(context, secret_key) - - ''' - In the BFV scheme plaintexts do not carry a parms_id, but ciphertexts do. Note - how the freshly encrypted ciphertext is at the highest data level. - ''' - plain = Plaintext("1x^3 + 2x^2 + 3x^1 + 4") - encrypted = Ciphertext() - encryptor.encrypt(plain, encrypted) - print(" + plain: ", end="") - print_parms_id(plain.parms_id()) - print(" (not set in BFV)") - print(" + encrypted: ", end="") - print_parms_id(encrypted.parms_id()) - - print("-" * 50) - print("Perform modulus switching on encrypted and print.") - context_data = context.first_context_data() - print("---->", end="") - - while context_data.next_context_data(): - print(" Level (chain index): " + str(context_data.chain_index())) - print(" parms_id of encrypted: ", end="") - print_parms_id(encrypted.parms_id()) - print(" Noise budget at this level: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted) + " bits") - print("\\\n \\-->", end="") - evaluator.mod_switch_to_next_inplace(encrypted) - context_data = context_data.next_context_data() - print(" Level (chain index): " + str(context_data.chain_index())) - print(" parms_id of encrypted: ", end="") - print_parms_id(encrypted.parms_id()) - print(" Noise budget at this level: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted) + " bits") - print("\\\n \\--> End of chain reached\n") - - ''' - At this point it is hard to see any benefit in doing this: we lost a huge - amount of noise budget (i.e., computational power) at each switch and seemed - to get nothing in return. Decryption still works. - ''' - print("-" * 50) - print("Decrypt still works after modulus switching.") - decryptor.decrypt(encrypted, plain) - print(" + Decryption of encrypted: " + - plain.to_string() + " ...... Correct.\n") - - print("Computation is more efficient with modulus switching.") - print("-" * 50) - print("Compute the fourth power.") - encryptor.encrypt(plain, encrypted) - print(" + Noise budget before squaring: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted) + " bits") - evaluator.square_inplace(encrypted) - evaluator.relinearize_inplace(encrypted, relin_keys) - print(" + Noise budget after squaring: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted) + " bits") - - ''' - Surprisingly, in this case modulus switching has no effect at all on the - noise budget. - ''' - evaluator.mod_switch_to_next_inplace(encrypted) - print(" + Noise budget after modulus switching: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted) + " bits") - - evaluator.square_inplace(encrypted) - print(" + Noise budget after squaring: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted) + " bits") - evaluator.mod_switch_to_next_inplace(encrypted) - print(" + Noise budget after modulus switching: " + - "%.0f" % decryptor.invariant_noise_budget(encrypted) + " bits") - decryptor.decrypt(encrypted, plain) - - print(" + Decryption of fourth power (hexadecimal) ...... Correct.") - print(" " + plain.to_string() + "\n") - - ''' - In BFV modulus switching is not necessary and in some cases the user might - not want to create the modulus switching chain, except for the highest two - levels. This can be done by passing a bool `false' to SEALContext::Create. - ''' - context = SEALContext.Create(parms, False) - print("Optionally disable modulus switching chain expansion.") - print("-" * 50) - print("Print the modulus switching chain.\n---->", end="") - context_data = context.key_context_data() - while context_data: - print(" Level (chain index): " + str(context_data.chain_index())) - print(" parms_id: ", end="") - print_parms_id(context_data.parms_id()) - print(" coeff_modulus primes: ", end="") - for item in context_data.parms().coeff_modulus(): - print(str(hex(item.value())) + " ", end="") - print("\n\\\n \\-->", end="") - context_data = context_data.next_context_data() - print(" End of chain reached") - - -if __name__ == '__main__': - example_levels() diff --git a/tests/4_ckks_basics.py b/tests/4_ckks_basics.py deleted file mode 100644 index 5897a69..0000000 --- a/tests/4_ckks_basics.py +++ /dev/null @@ -1,183 +0,0 @@ -import math -from seal import * -from seal_helper import * - - -def example_ckks_basics(): - print_example_banner("Example: CKKS Basics") - - parms = EncryptionParameters(scheme_type.CKKS) - - poly_modulus_degree = 8192 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.Create( - poly_modulus_degree, [60, 40, 40, 60])) - - scale = pow(2.0, 40) - context = SEALContext.Create(parms) - print_parameters(context) - - keygen = KeyGenerator(context) - public_key = keygen.public_key() - secret_key = keygen.secret_key() - relin_keys = keygen.relin_keys() - - encryptor = Encryptor(context, public_key) - evaluator = Evaluator(context) - decryptor = Decryptor(context, secret_key) - - encoder = CKKSEncoder(context) - slot_count = encoder.slot_count() - print("Number of slots: " + str(slot_count)) - - inputs = DoubleVector() - curr_point = 0.0 - step_size = 1.0 / (slot_count - 1) - - for i in range(slot_count): - inputs.append(curr_point) - curr_point += step_size - - print("Input vector: ") - print_vector(inputs, 3, 7) - - print("Evaluating polynomial PI*x^3 + 0.4x + 1 ...") - - ''' - We create plaintexts for PI, 0.4, and 1 using an overload of CKKSEncoder::encode - that encodes the given floating-point value to every slot in the vector. - ''' - plain_coeff3 = Plaintext() - plain_coeff1 = Plaintext() - plain_coeff0 = Plaintext() - encoder.encode(3.14159265, scale, plain_coeff3) - encoder.encode(0.4, scale, plain_coeff1) - encoder.encode(1.0, scale, plain_coeff0) - - x_plain = Plaintext() - print("-" * 50) - print("Encode input vectors.") - encoder.encode(inputs, scale, x_plain) - x1_encrypted = Ciphertext() - encryptor.encrypt(x_plain, x1_encrypted) - - x3_encrypted = Ciphertext() - print("-" * 50) - print("Compute x^2 and relinearize:") - evaluator.square(x1_encrypted, x3_encrypted) - evaluator.relinearize_inplace(x3_encrypted, relin_keys) - print(" + Scale of x^2 before rescale: " + - "%.0f" % math.log(x3_encrypted.scale(), 2) + " bits") - - print("-" * 50) - print("Rescale x^2.") - evaluator.rescale_to_next_inplace(x3_encrypted) - print(" + Scale of x^2 after rescale: " + - "%.0f" % math.log(x3_encrypted.scale(), 2) + " bits") - - print("-" * 50) - print("Compute and rescale PI*x.") - x1_encrypted_coeff3 = Ciphertext() - evaluator.multiply_plain(x1_encrypted, plain_coeff3, x1_encrypted_coeff3) - print(" + Scale of PI*x before rescale: " + - "%.0f" % math.log(x1_encrypted_coeff3.scale(), 2) + " bits") - evaluator.rescale_to_next_inplace(x1_encrypted_coeff3) - print(" + Scale of PI*x after rescale: " + - "%.0f" % math.log(x1_encrypted_coeff3.scale(), 2) + " bits") - - print("-" * 50) - print("Compute, relinearize, and rescale (PI*x)*x^2.") - evaluator.multiply_inplace(x3_encrypted, x1_encrypted_coeff3) - evaluator.relinearize_inplace(x3_encrypted, relin_keys) - print(" + Scale of PI*x^3 before rescale: " + - "%.0f" % math.log(x3_encrypted.scale(), 2) + " bits") - evaluator.rescale_to_next_inplace(x3_encrypted) - print(" + Scale of PI*x^3 after rescale: " + - "%.0f" % math.log(x3_encrypted.scale(), 2) + " bits") - - print("-" * 50) - print("Compute and rescale 0.4*x.") - evaluator.multiply_plain_inplace(x1_encrypted, plain_coeff1) - print(" + Scale of 0.4*x before rescale: " + - "%.0f" % math.log(x1_encrypted.scale(), 2) + " bits") - evaluator.rescale_to_next_inplace(x1_encrypted) - print(" + Scale of 0.4*x after rescale: " + - "%.0f" % math.log(x1_encrypted.scale(), 2) + " bits") - print() - - print("-" * 50) - print("Parameters used by all three terms are different.") - print(" + Modulus chain index for x3_encrypted: " + - str(context.get_context_data(x3_encrypted.parms_id()).chain_index())) - print(" + Modulus chain index for x1_encrypted: " + - str(context.get_context_data(x1_encrypted.parms_id()).chain_index())) - print(" + Modulus chain index for x1_encrypted: " + - str(context.get_context_data(plain_coeff0.parms_id()).chain_index())) - print() - - print("-" * 50) - print("The exact scales of all three terms are different:") - print(" + Exact scale in PI*x^3: " + "%.10f" % x3_encrypted.scale()) - print(" + Exact scale in 0.4*x: " + "%.10f" % x1_encrypted.scale()) - print(" + Exact scale in 1: " + "%.10f" % plain_coeff0.scale()) - - print("-" * 50) - print("Normalize scales to 2^40.") - - # set_scale() this function should be add to seal/ciphertext.h line 632 - x3_encrypted.set_scale(pow(2.0, 40)) - x1_encrypted.set_scale(pow(2.0, 40)) - - ''' - We still have a problem with mismatching encryption parameters. This is easy - to fix by using traditional modulus switching (no rescaling). CKKS supports - modulus switching just like the BFV scheme, allowing us to switch away parts - of the coefficient modulus when it is simply not needed. - ''' - print("-" * 50) - print("Normalize encryption parameters to the lowest level.") - last_parms_id = x3_encrypted.parms_id() - evaluator.mod_switch_to_inplace(x1_encrypted, last_parms_id) - evaluator.mod_switch_to_inplace(plain_coeff0, last_parms_id) - - ''' - All three ciphertexts are now compatible and can be added. - ''' - print("-" * 50) - print("Compute PI*x^3 + 0.4*x + 1.") - - encrypted_result = Ciphertext() - evaluator.add(x3_encrypted, x1_encrypted, encrypted_result) - evaluator.add_plain_inplace(encrypted_result, plain_coeff0) - - ''' - First print the true result. - ''' - plain_result = Plaintext() - print("-" * 50) - print("Decrypt and decode PI*x^3 + 0.4x + 1.") - print(" + Expected result:") - true_result = [] - for x in inputs: - true_result.append((3.14159265 * x * x + 0.4) * x + 1) - print_vector(true_result, 3, 7) - - ''' - Decrypt, decode, and print the result. - ''' - - decryptor.decrypt(encrypted_result, plain_result) - result = DoubleVector() - encoder.decode(plain_result, result) - print(" + Computed result ...... Correct.") - print_vector(result, 3, 7) - - ''' - While we did not show any computations on complex numbers in these examples, - the CKKSEncoder would allow us to have done that just as easily. Additions - and multiplications of complex numbers behave just as one would expect. - ''' - - -if __name__ == '__main__': - example_ckks_basics() diff --git a/tests/5_rotation.py b/tests/5_rotation.py deleted file mode 100644 index 49bc8d1..0000000 --- a/tests/5_rotation.py +++ /dev/null @@ -1,181 +0,0 @@ -from seal import * -from seal_helper import * - - -def example_rotation_bfv(): - print_example_banner("Example: Rotation / Rotation in BFV") - parms = EncryptionParameters(scheme_type.BFV) - - poly_modulus_degree = 8192 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - parms.set_plain_modulus(PlainModulus.Batching(poly_modulus_degree, 20)) - - context = SEALContext.Create(parms) - print_parameters(context) - print("-" * 50) - - keygen = KeyGenerator(context) - public_key = keygen.public_key() - secret_key = keygen.secret_key() - relin_keys = keygen.relin_keys() - - encryptor = Encryptor(context, public_key) - evaluator = Evaluator(context) - decryptor = Decryptor(context, secret_key) - - batch_encoder = BatchEncoder(context) - slot_count = batch_encoder.slot_count() - row_size = int(slot_count / 2) - print("Plaintext matrix row size: " + str(row_size)) - - pod_matrixs = [0] * slot_count - pod_matrixs[0] = 0 - pod_matrixs[1] = 1 - pod_matrixs[2] = 2 - pod_matrixs[3] = 3 - pod_matrixs[row_size] = 4 - pod_matrixs[row_size + 1] = 5 - pod_matrixs[row_size + 2] = 6 - pod_matrixs[row_size + 3] = 7 - - pod_matrix = uIntVector(pod_matrixs) - - print("Input plaintext matrix:") - print_matrix(pod_matrix, row_size) - - ''' - First we use BatchEncoder to encode the matrix into a plaintext. We encrypt - the plaintext as usual. - ''' - plain_matrix = Plaintext() - print("-" * 50) - print("Encode and encrypt.") - batch_encoder.encode(pod_matrix, plain_matrix) - encrypted_matrix = Ciphertext() - encryptor.encrypt(plain_matrix, encrypted_matrix) - print(" + Noise budget in fresh encryption: " + - str(decryptor.invariant_noise_budget(encrypted_matrix)) + " bits") - - ''' - Rotations require yet another type of special key called `Galois keys'. These - are easily obtained from the KeyGenerator. - ''' - gal_keys = keygen.galois_keys() - ''' - Now rotate both matrix rows 3 steps to the left, decrypt, decode, and print. - ''' - print("-" * 50) - print("Rotate rows 3 steps left.") - - evaluator.rotate_rows_inplace(encrypted_matrix, 3, gal_keys) - plain_result = Plaintext() - print(" + Noise budget after rotation: " + - str(decryptor.invariant_noise_budget(encrypted_matrix)) + " bits") - print(" + Decrypt and decode ...... Correct.") - decryptor.decrypt(encrypted_matrix, plain_result) - batch_encoder.decode(plain_result, pod_matrix) - print_matrix(pod_matrix, row_size) - - ''' - We can also rotate the columns, i.e., swap the rows. - ''' - print("-" * 50) - print("Rotate columns.") - evaluator.rotate_columns_inplace(encrypted_matrix, gal_keys) - print(" + Noise budget after rotation: " + - str(decryptor.invariant_noise_budget(encrypted_matrix)) + " bits") - print(" + Decrypt and decode ...... Correct.") - decryptor.decrypt(encrypted_matrix, plain_result) - batch_encoder.decode(plain_result, pod_matrix) - print_matrix(pod_matrix, row_size) - - ''' - Finally, we rotate the rows 4 steps to the right, decrypt, decode, and print. - ''' - print("-" * 50) - print("Rotate rows 4 steps right.") - evaluator.rotate_rows_inplace(encrypted_matrix, -4, gal_keys) - print(" + Noise budget after rotation: " + - str(decryptor.invariant_noise_budget(encrypted_matrix)) + " bits") - print(" + Decrypt and decode ...... Correct.") - decryptor.decrypt(encrypted_matrix, plain_result) - batch_encoder.decode(plain_result, pod_matrix) - print_matrix(pod_matrix, row_size) - - ''' - Note that rotations do not consume any noise budget. However, this is only - the case when the special prime is at least as large as the other primes. The - same holds for relinearization. Microsoft SEAL does not require that the - special prime is of any particular size, so ensuring this is the case is left - for the user to do. - ''' - - -def example_rotation_ckks(): - print_example_banner("Example: Rotation / Rotation in CKKS") - parms = EncryptionParameters(scheme_type.CKKS) - poly_modulus_degree = 8192 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.Create( - poly_modulus_degree, [40, 40, 40, 40, 40])) - context = SEALContext.Create(parms) - print_parameters(context) - - keygen = KeyGenerator(context) - public_key = keygen.public_key() - secret_key = keygen.secret_key() - relin_keys = keygen.relin_keys() - gal_keys = keygen.galois_keys() - - encryptor = Encryptor(context, public_key) - evaluator = Evaluator(context) - decryptor = Decryptor(context, secret_key) - - ckks_encoder = CKKSEncoder(context) - slot_count = ckks_encoder.slot_count() - print("Number of slots: " + str(slot_count)) - - inputs = DoubleVector() - curr_point = 0.0 - step_size = 1.0 / (slot_count - 1) - - for i in range(slot_count): - inputs.append(curr_point) - curr_point += step_size - - print("Input vector:") - print_vector(inputs, 3, 7) - - scale = pow(2.0, 50) - - print("-" * 50) - print("Encode and encrypt.") - plain = Plaintext() - - ckks_encoder.encode(inputs, scale, plain) - encrypted = Ciphertext() - encryptor.encrypt(plain, encrypted) - - rotated = Ciphertext() - print("-" * 50) - print("Rotate 2 steps left.") - evaluator.rotate_vector(encrypted, 2, gal_keys, rotated) - print(" + Decrypt and decode ...... Correct.") - decryptor.decrypt(rotated, plain) - result = DoubleVector() - ckks_encoder.decode(plain, result) - print_vector(result, 3, 7) - - ''' - With the CKKS scheme it is also possible to evaluate a complex conjugation on - a vector of encrypted complex numbers, using Evaluator::complex_conjugate. - This is in fact a kind of rotation, and requires also Galois keys. - ''' - - -if __name__ == '__main__': - print_example_banner("Example: Rotation") - - example_rotation_bfv() - example_rotation_ckks() diff --git a/tests/6_performance.py b/tests/6_performance.py deleted file mode 100644 index 2693c7f..0000000 --- a/tests/6_performance.py +++ /dev/null @@ -1,585 +0,0 @@ -import time -import math -import random -from seal import * -from seal_helper import * - - -def rand_int(): - return int(random.random()*(10**10)) - - -def bfv_performance_test(context): - print_parameters(context) - - parms = context.first_context_data().parms() - plain_modulus = parms.plain_modulus() - poly_modulus_degree = parms.poly_modulus_degree() - - print("Generating secret/public keys: ", end="") - keygen = KeyGenerator(context) - print("Done") - - secret_key = keygen.secret_key() - public_key = keygen.public_key() - relin_keys = RelinKeys() - gal_keys = GaloisKeys() - - if context.using_keyswitching(): - # Generate relinearization keys. - print("Generating relinearization keys: ", end="") - time_start = time.time() - relin_keys = keygen.relin_keys() - time_end = time.time() - print("Done [" + "%.0f" % - ((time_end-time_start)*1000000) + " microseconds]") - - if not context.key_context_data().qualifiers().using_batching: - print("Given encryption parameters do not support batching.") - return 0 - - print("Generating Galois keys: ", end="") - time_start = time.time() - gal_keys = keygen.galois_keys() - time_end = time.time() - print("Done [" + "%.0f" % - ((time_end-time_start)*1000000) + " microseconds]") - - encryptor = Encryptor(context, public_key) - decryptor = Decryptor(context, secret_key) - evaluator = Evaluator(context) - batch_encoder = BatchEncoder(context) - encoder = IntegerEncoder(context) - - # These will hold the total times used by each operation. - time_batch_sum = 0 - time_unbatch_sum = 0 - time_encrypt_sum = 0 - time_decrypt_sum = 0 - time_add_sum = 0 - time_multiply_sum = 0 - time_multiply_plain_sum = 0 - time_square_sum = 0 - time_relinearize_sum = 0 - time_rotate_rows_one_step_sum = 0 - time_rotate_rows_random_sum = 0 - time_rotate_columns_sum = 0 - - # How many times to run the test? - count = 10 - - # Populate a vector of values to batch. - slot_count = batch_encoder.slot_count() - pod_vector = uIntVector() - for i in range(slot_count): - pod_vector.append(rand_int() % plain_modulus.value()) - print("Running tests ", end="") - - for i in range(count): - ''' - [Batching] - There is nothing unusual here. We batch our random plaintext matrix - into the polynomial. Note how the plaintext we create is of the exactly - right size so unnecessary reallocations are avoided. - ''' - plain = Plaintext(parms.poly_modulus_degree(), 0) - time_start = time.time() - batch_encoder.encode(pod_vector, plain) - time_end = time.time() - time_batch_sum += (time_end-time_start)*1000000 - - ''' - [Unbatching] - We unbatch what we just batched. - ''' - pod_vector2 = uIntVector() - time_start = time.time() - batch_encoder.decode(plain, pod_vector2) - time_end = time.time() - time_unbatch_sum += (time_end-time_start)*1000000 - for j in range(slot_count): - if pod_vector[j] != pod_vector2[j]: - raise Exception("Batch/unbatch failed. Something is wrong.") - - ''' - [Encryption] - We make sure our ciphertext is already allocated and large enough - to hold the encryption with these encryption parameters. We encrypt - our random batched matrix here. - ''' - encrypted = Ciphertext() - time_start = time.time() - encryptor.encrypt(plain, encrypted) - time_end = time.time() - time_encrypt_sum += (time_end-time_start)*1000000 - - ''' - [Decryption] - We decrypt what we just encrypted. - ''' - plain2 = Plaintext(poly_modulus_degree, 0) - time_start = time.time() - decryptor.decrypt(encrypted, plain2) - time_end = time.time() - time_decrypt_sum += (time_end-time_start)*1000000 - if plain.to_string() != plain2.to_string(): - raise Exception("Encrypt/decrypt failed. Something is wrong.") - - ''' - [Add] - We create two ciphertexts and perform a few additions with them. - ''' - encrypted1 = Ciphertext() - encryptor.encrypt(encoder.encode(i), encrypted1) - encrypted2 = Ciphertext(context) - encryptor.encrypt(encoder.encode(i + 1), encrypted2) - time_start = time.time() - evaluator.add_inplace(encrypted1, encrypted1) - evaluator.add_inplace(encrypted2, encrypted2) - evaluator.add_inplace(encrypted1, encrypted2) - time_end = time.time() - time_add_sum += (time_end-time_start)*1000000 - - ''' - [Multiply] - We multiply two ciphertexts. Since the size of the result will be 3, - and will overwrite the first argument, we reserve first enough memory - to avoid reallocating during multiplication. - ''' - encrypted1.reserve(3) - time_start = time.time() - evaluator.multiply_inplace(encrypted1, encrypted2) - time_end = time.time() - time_multiply_sum += (time_end-time_start)*1000000 - - ''' - [Multiply Plain] - We multiply a ciphertext with a random plaintext. Recall that - multiply_plain does not change the size of the ciphertext so we use - encrypted2 here. - ''' - time_start = time.time() - evaluator.multiply_plain_inplace(encrypted2, plain) - time_end = time.time() - time_multiply_plain_sum += (time_end-time_start)*1000000 - - ''' - [Square] - We continue to use encrypted2. Now we square it; this should be - faster than generic homomorphic multiplication. - ''' - time_start = time.time() - evaluator.square_inplace(encrypted2) - time_end = time.time() - time_square_sum += (time_end-time_start)*1000000 - - if context.using_keyswitching(): - ''' - [Relinearize] - Time to get back to encrypted1. We now relinearize it back - to size 2. Since the allocation is currently big enough to - contain a ciphertext of size 3, no costly reallocations are - needed in the process. - ''' - time_start = time.time() - evaluator.relinearize_inplace(encrypted1, relin_keys) - time_end = time.time() - time_relinearize_sum += (time_end-time_start)*1000000 - - ''' - [Rotate Rows One Step] - We rotate matrix rows by one step left and measure the time. - ''' - time_start = time.time() - evaluator.rotate_rows_inplace(encrypted, 1, gal_keys) - evaluator.rotate_rows_inplace(encrypted, -1, gal_keys) - time_end = time.time() - time_rotate_rows_one_step_sum += (time_end-time_start)*1000000 - - ''' - [Rotate Rows Random] - We rotate matrix rows by a random number of steps. This is much more - expensive than rotating by just one step. - ''' - row_size = batch_encoder.slot_count() / 2 - random_rotation = int(rand_int() % row_size) - time_start = time.time() - evaluator.rotate_rows_inplace( - encrypted, random_rotation, gal_keys) - time_end = time.time() - time_rotate_rows_random_sum += (time_end-time_start)*1000000 - - ''' - [Rotate Columns] - Nothing surprising here. - ''' - time_start = time.time() - evaluator.rotate_columns_inplace(encrypted, gal_keys) - time_end = time.time() - time_rotate_columns_sum += (time_end-time_start)*1000000 - - # Print a dot to indicate progress. - print(".", end="", flush=True) - print(" Done", flush=True) - - avg_batch = time_batch_sum / count - avg_unbatch = time_unbatch_sum / count - avg_encrypt = time_encrypt_sum / count - avg_decrypt = time_decrypt_sum / count - avg_add = time_add_sum / (3 * count) - avg_multiply = time_multiply_sum / count - avg_multiply_plain = time_multiply_plain_sum / count - avg_square = time_square_sum / count - avg_relinearize = time_relinearize_sum / count - avg_rotate_rows_one_step = time_rotate_rows_one_step_sum / (2 * count) - avg_rotate_rows_random = time_rotate_rows_random_sum / count - avg_rotate_columns = time_rotate_columns_sum / count - - print("Average batch: " + "%.0f" % avg_batch + " microseconds", flush=True) - print("Average unbatch: " + "%.0f" % - avg_unbatch + " microseconds", flush=True) - print("Average encrypt: " + "%.0f" % - avg_encrypt + " microseconds", flush=True) - print("Average decrypt: " + "%.0f" % - avg_decrypt + " microseconds", flush=True) - print("Average add: " + "%.0f" % avg_add + " microseconds", flush=True) - print("Average multiply: " + "%.0f" % - avg_multiply + " microseconds", flush=True) - print("Average multiply plain: " + "%.0f" % - avg_multiply_plain + " microseconds", flush=True) - print("Average square: " + "%.0f" % - avg_square + " microseconds", flush=True) - if context.using_keyswitching(): - print("Average relinearize: " + "%.0f" % - avg_relinearize + " microseconds", flush=True) - print("Average rotate rows one step: " + "%.0f" % - avg_rotate_rows_one_step + " microseconds", flush=True) - print("Average rotate rows random: " + "%.0f" % - avg_rotate_rows_random + " microseconds", flush=True) - print("Average rotate columns: " + "%.0f" % - avg_rotate_columns + " microseconds", flush=True) - - -def ckks_performance_test(context): - print_parameters(context) - - parms = context.first_context_data().parms() - plain_modulus = parms.plain_modulus() - poly_modulus_degree = parms.poly_modulus_degree() - - print("Generating secret/public keys: ", end="") - keygen = KeyGenerator(context) - print("Done") - - secret_key = keygen.secret_key() - public_key = keygen.public_key() - relin_keys = RelinKeys() - gal_keys = GaloisKeys() - - if context.using_keyswitching(): - print("Generating relinearization keys: ", end="") - time_start = time.time() - relin_keys = keygen.relin_keys() - time_end = time.time() - print("Done [" + "%.0f" % - ((time_end-time_start)*1000000) + " microseconds]") - - if not context.key_context_data().qualifiers().using_batching: - print("Given encryption parameters do not support batching.") - return 0 - - print("Generating Galois keys: ", end="") - time_start = time.time() - gal_keys = keygen.galois_keys() - time_end = time.time() - print("Done [" + "%.0f" % - ((time_end-time_start)*1000000) + " microseconds]") - - encryptor = Encryptor(context, public_key) - decryptor = Decryptor(context, secret_key) - evaluator = Evaluator(context) - ckks_encoder = CKKSEncoder(context) - - time_encode_sum = 0 - time_decode_sum = 0 - time_encrypt_sum = 0 - time_decrypt_sum = 0 - time_add_sum = 0 - time_multiply_sum = 0 - time_multiply_plain_sum = 0 - time_square_sum = 0 - time_relinearize_sum = 0 - time_rescale_sum = 0 - time_rotate_one_step_sum = 0 - time_rotate_random_sum = 0 - time_conjugate_sum = 0 - - # How many times to run the test? - count = 10 - - # Populate a vector of floating-point values to batch. - pod_vector = DoubleVector() - slot_count = ckks_encoder.slot_count() - for i in range(slot_count): - pod_vector.append(1.001 * float(i)) - - print("Running tests ", end="") - for i in range(count): - ''' - [Encoding] - For scale we use the square root of the last coeff_modulus prime - from parms. - ''' - plain = Plaintext(parms.poly_modulus_degree() * - len(parms.coeff_modulus()), 0) - - # [Encoding] - scale = math.sqrt(parms.coeff_modulus()[-1].value()) - time_start = time.time() - ckks_encoder.encode(pod_vector, scale, plain) - time_end = time.time() - time_encode_sum += (time_end-time_start)*1000000 - - # [Decoding] - pod_vector2 = DoubleVector() - time_start = time.time() - ckks_encoder.decode(plain, pod_vector2) - time_end = time.time() - time_decode_sum += (time_end-time_start)*1000000 - - # [Encryption] - encrypted = Ciphertext(context) - time_start = time.time() - encryptor.encrypt(plain, encrypted) - time_end = time.time() - time_encrypt_sum += (time_end-time_start)*1000000 - - # [Decryption] - plain2 = Plaintext(poly_modulus_degree, 0) - time_start = time.time() - decryptor.decrypt(encrypted, plain2) - time_end = time.time() - time_decrypt_sum += (time_end-time_start)*1000000 - - # [Add] - encrypted1 = Ciphertext(context) - ckks_encoder.encode(i + 1, plain) - encryptor.encrypt(plain, encrypted1) - encrypted2 = Ciphertext(context) - ckks_encoder.encode(i + 1, plain2) - encryptor.encrypt(plain2, encrypted2) - time_start = time.time() - evaluator.add_inplace(encrypted1, encrypted1) - evaluator.add_inplace(encrypted2, encrypted2) - evaluator.add_inplace(encrypted1, encrypted2) - time_end = time.time() - time_add_sum += (time_end-time_start)*1000000 - - # [Multiply] - encrypted1.reserve(3) - time_start = time.time() - evaluator.multiply_inplace(encrypted1, encrypted2) - time_end = time.time() - time_multiply_sum += (time_end-time_start)*1000000 - - # [Multiply Plain] - time_start = time.time() - evaluator.multiply_plain_inplace(encrypted2, plain) - time_end = time.time() - time_multiply_plain_sum += (time_end-time_start)*1000000 - - # [Square] - time_start = time.time() - evaluator.square_inplace(encrypted2) - time_end = time.time() - time_square_sum += (time_end-time_start)*1000000 - - if context.using_keyswitching(): - - # [Relinearize] - time_start = time.time() - evaluator.relinearize_inplace(encrypted1, relin_keys) - time_end = time.time() - time_relinearize_sum += (time_end-time_start)*1000000 - - # [Rescale] - time_start = time.time() - evaluator.rescale_to_next_inplace(encrypted1) - time_end = time.time() - time_rescale_sum += (time_end-time_start)*1000000 - - # [Rotate Vector] - time_start = time.time() - evaluator.rotate_vector_inplace(encrypted, 1, gal_keys) - evaluator.rotate_vector_inplace(encrypted, -1, gal_keys) - time_end = time.time() - time_rotate_one_step_sum += (time_end-time_start)*1000000 - - # [Rotate Vector Random] - random_rotation = int(rand_int() % ckks_encoder.slot_count()) - time_start = time.time() - evaluator.rotate_vector_inplace( - encrypted, random_rotation, gal_keys) - time_end = time.time() - time_rotate_random_sum += (time_end-time_start)*1000000 - - # [Complex Conjugate] - time_start = time.time() - evaluator.complex_conjugate_inplace(encrypted, gal_keys) - time_end = time.time() - time_conjugate_sum += (time_end-time_start)*1000000 - print(".", end="", flush=True) - - print(" Done\n", flush=True) - - avg_encode = time_encode_sum / count - avg_decode = time_decode_sum / count - avg_encrypt = time_encrypt_sum / count - avg_decrypt = time_decrypt_sum / count - avg_add = time_add_sum / (3 * count) - avg_multiply = time_multiply_sum / count - avg_multiply_plain = time_multiply_plain_sum / count - avg_square = time_square_sum / count - avg_relinearize = time_relinearize_sum / count - avg_rescale = time_rescale_sum / count - avg_rotate_one_step = time_rotate_one_step_sum / (2 * count) - avg_rotate_random = time_rotate_random_sum / count - avg_conjugate = time_conjugate_sum / count - - print("Average encode: " + "%.0f" % - avg_encode + " microseconds", flush=True) - print("Average decode: " + "%.0f" % - avg_decode + " microseconds", flush=True) - print("Average encrypt: " + "%.0f" % - avg_encrypt + " microseconds", flush=True) - print("Average decrypt: " + "%.0f" % - avg_decrypt + " microseconds", flush=True) - print("Average add: " + "%.0f" % avg_add + " microseconds", flush=True) - print("Average multiply: " + "%.0f" % - avg_multiply + " microseconds", flush=True) - print("Average multiply plain: " + "%.0f" % - avg_multiply_plain + " microseconds", flush=True) - print("Average square: " + "%.0f" % - avg_square + " microseconds", flush=True) - if context.using_keyswitching(): - print("Average relinearize: " + "%.0f" % - avg_relinearize + " microseconds", flush=True) - print("Average rescale: " + "%.0f" % - avg_rescale + " microseconds", flush=True) - print("Average rotate vector one step: " + "%.0f" % - avg_rotate_one_step + " microseconds", flush=True) - print("Average rotate vector random: " + "%.0f" % - avg_rotate_random + " microseconds", flush=True) - print("Average complex conjugate: " + "%.0f" % - avg_conjugate + " microseconds", flush=True) - - -def example_bfv_performance_default(): - print_example_banner( - "BFV Performance Test with Degrees: 4096, 8192, and 16384") - - parms = EncryptionParameters(scheme_type.BFV) - poly_modulus_degree = 4096 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - parms.set_plain_modulus(786433) - bfv_performance_test(SEALContext.Create(parms)) - - print() - poly_modulus_degree = 8192 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - parms.set_plain_modulus(786433) - bfv_performance_test(SEALContext.Create(parms)) - - print() - poly_modulus_degree = 16384 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - parms.set_plain_modulus(786433) - bfv_performance_test(SEALContext.Create(parms)) - - # Comment out the following to run the biggest example. - # poly_modulus_degree = 32768 - - -def example_bfv_performance_custom(): - print("\nSet poly_modulus_degree (1024, 2048, 4096, 8192, 16384, or 32768): ") - poly_modulus_degree = input("Input the poly_modulus_degree: ").strip() - - if len(poly_modulus_degree) < 4 or not poly_modulus_degree.isdigit(): - print("Invalid option.") - return 0 - - poly_modulus_degree = int(poly_modulus_degree) - - if poly_modulus_degree < 1024 or poly_modulus_degree > 32768 or (poly_modulus_degree & (poly_modulus_degree - 1) != 0): - print("Invalid option.") - return 0 - - print("BFV Performance Test with Degree: " + str(poly_modulus_degree)) - - parms = EncryptionParameters(scheme_type.BFV) - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - if poly_modulus_degree == 1024: - parms.set_plain_modulus(12289) - else: - parms.set_plain_modulus(786433) - bfv_performance_test(SEALContext.Create(parms)) - - -def example_ckks_performance_default(): - print_example_banner( - "CKKS Performance Test with Degrees: 4096, 8192, and 16384") - - parms = EncryptionParameters(scheme_type.CKKS) - poly_modulus_degree = 4096 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - ckks_performance_test(SEALContext.Create(parms)) - - print() - poly_modulus_degree = 8192 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - ckks_performance_test(SEALContext.Create(parms)) - - poly_modulus_degree = 16384 - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - ckks_performance_test(SEALContext.Create(parms)) - - # Comment out the following to run the biggest example. - # poly_modulus_degree = 32768 - - -def example_ckks_performance_custom(): - print("\nSet poly_modulus_degree (1024, 2048, 4096, 8192, 16384, or 32768): ") - poly_modulus_degree = input("Input the poly_modulus_degree: ").strip() - - if len(poly_modulus_degree) < 4 or not poly_modulus_degree.isdigit(): - print("Invalid option.") - return 0 - - poly_modulus_degree = int(poly_modulus_degree) - - if poly_modulus_degree < 1024 or poly_modulus_degree > 32768 or (poly_modulus_degree & (poly_modulus_degree - 1) != 0): - print("Invalid option.") - return 0 - - print("CKKS Performance Test with Degree: " + str(poly_modulus_degree)) - - parms = EncryptionParameters(scheme_type.CKKS) - parms.set_poly_modulus_degree(poly_modulus_degree) - parms.set_coeff_modulus(CoeffModulus.BFVDefault(poly_modulus_degree)) - ckks_performance_test(SEALContext.Create(parms)) - - -if __name__ == '__main__': - print_example_banner("Example: Performance Test") - - example_bfv_performance_default() - example_bfv_performance_custom() - example_ckks_performance_default() - example_ckks_performance_custom() diff --git a/tests/seal_helper.py b/tests/seal_helper.py deleted file mode 100644 index 777da45..0000000 --- a/tests/seal_helper.py +++ /dev/null @@ -1,89 +0,0 @@ -# coding: utf-8 -# author: Huelse - -from seal import scheme_type - - -def print_example_banner(title): - title_length = len(title) - banner_length = title_length + 2 * 10 - banner_top = "+" + "-" * (banner_length - 2) + "+" - banner_middle = "|" + ' ' * 9 + title + ' ' * 9 + "|" - print(banner_top) - print(banner_middle) - print(banner_top) - - -def print_parameters(context): - context_data = context.key_context_data() - if context_data.parms().scheme() == scheme_type.BFV: - scheme_name = "BFV" - elif context_data.parms().scheme() == scheme_type.CKKS: - scheme_name = "CKKS" - else: - scheme_name = "unsupported scheme" - print("/") - print("| Encryption parameters:") - print("| scheme: " + scheme_name) - print("| poly_modulus_degree: " + - str(context_data.parms().poly_modulus_degree())) - print("| coeff_modulus size: ", end="") - coeff_modulus = context_data.parms().coeff_modulus() - coeff_modulus_sum = 0 - for j in coeff_modulus: - coeff_modulus_sum += j.bit_count() - print(str(coeff_modulus_sum) + "(", end="") - for i in range(len(coeff_modulus) - 1): - print(str(coeff_modulus[i].bit_count()) + " + ", end="") - print(str(coeff_modulus[-1].bit_count()) + ") bits") - if context_data.parms().scheme() == scheme_type.BFV: - print("| plain_modulus: " + - str(context_data.parms().plain_modulus().value())) - print("\\") - - -def print_matrix(matrix, row_size): - print() - print_size = 5 - current_line = " [ " - for i in range(print_size): - current_line += ((str)(matrix[i]) + ", ") - current_line += ("..., ") - for i in range(row_size - print_size, row_size): - current_line += ((str)(matrix[i])) - if i != row_size-1: - current_line += ", " - else: - current_line += " ]" - print(current_line) - - current_line = " [ " - for i in range(row_size, row_size + print_size): - current_line += ((str)(matrix[i]) + ", ") - current_line += ("..., ") - for i in range(2*row_size - print_size, 2*row_size): - current_line += ((str)(matrix[i])) - if i != 2*row_size-1: - current_line += ", " - else: - current_line += " ]" - print(current_line) - print() - - -def print_vector(vec, print_size=4, prec=3): - slot_count = len(vec) - print() - if slot_count <= 2*print_size: - print(" [", end="") - for i in range(slot_count): - print(" " + (f"%.{prec}f" % vec[i]) + ("," if (i != slot_count - 1) else " ]\n"), end="") - else: - print(" [", end="") - for i in range(print_size): - print(" " + (f"%.{prec}f" % vec[i]) + ",", end="") - if len(vec) > 2*print_size: - print(" ...,", end="") - for i in range(slot_count - print_size, slot_count): - print(" " + (f"%.{prec}f" % vec[i]) + ("," if (i != slot_count - 1) else " ]\n"), end="") - print()