diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index d3770591d..000000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,39 +0,0 @@ -restore_registry: &RESTORE_REGISTRY - restore_cache: - key: registry -save_registry: &SAVE_REGISTRY - save_cache: - key: registry-{{ .BuildNum }} - paths: - - /usr/local/cargo/registry/index -deps_key: &DEPS_KEY - key: deps-{{ checksum "~/rust-version" }}-{{ checksum "Cargo.lock" }} -restore_deps: &RESTORE_DEPS - restore_cache: - <<: *DEPS_KEY -save_deps: &SAVE_DEPS - save_cache: - <<: *DEPS_KEY - paths: - - target - - /usr/local/cargo/registry/cache - -version: 2 -jobs: - build: - docker: - - image: rust:1.26.2 - environment: - RUSTFLAGS: -D warnings - - image: sfackler/rust-postgres-test:4 - steps: - - checkout - - *RESTORE_REGISTRY - - run: cargo generate-lockfile - - *SAVE_REGISTRY - - run: rustc --version > ~/rust-version - - *RESTORE_DEPS - - run: cargo test --all - - run: cargo test -p postgres --all-features - - run: cargo test -p tokio-postgres --all-features - - *SAVE_DEPS diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..1332f8eb5 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,8 @@ +version: 2 +updates: +- package-ecosystem: cargo + directory: "/" + schedule: + interval: daily + time: "13:00" + open-pull-requests-limit: 10 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..3426d624b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,108 @@ +name: CI + +on: + pull_request: + branches: + - master + push: + branches: + - master + +env: + RUSTFLAGS: -Dwarnings + RUST_BACKTRACE: 1 + +jobs: + rustfmt: + name: rustfmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: sfackler/actions/rustup@master + - uses: sfackler/actions/rustfmt@master + + clippy: + name: clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: sfackler/actions/rustup@master + - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT + id: rust-version + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/index + key: index-${{ runner.os }}-${{ github.run_number }} + restore-keys: | + index-${{ runner.os }}- + - run: cargo generate-lockfile + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/cache + key: registry-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo fetch + - uses: actions/cache@v3 + with: + path: target + key: clippy-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y + - run: cargo clippy --all --all-targets + + check-wasm32: + name: check-wasm32 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: sfackler/actions/rustup@master + - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT + id: rust-version + - run: rustup target add wasm32-unknown-unknown + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/index + key: index-${{ runner.os }}-${{ github.run_number }} + restore-keys: | + index-${{ runner.os }}- + - run: cargo generate-lockfile + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/cache + key: registry-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo fetch + - uses: actions/cache@v3 + with: + path: target + key: check-wasm32-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo check --target wasm32-unknown-unknown --manifest-path tokio-postgres/Cargo.toml --no-default-features --features js + env: + RUSTFLAGS: --cfg getrandom_backend="wasm_js" + + test: + name: test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: docker compose up -d + - uses: sfackler/actions/rustup@master + with: + version: 1.81.0 + - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT + id: rust-version + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/index + key: index-${{ runner.os }}-${{ github.run_number }} + restore-keys: | + index-${{ runner.os }}- + - run: cargo generate-lockfile + - uses: actions/cache@v3 + with: + path: ~/.cargo/registry/cache + key: registry-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }} + - run: cargo fetch + - uses: actions/cache@v3 + with: + path: target + key: test-target-${{ runner.os }}-${{ steps.rust-version.outputs.version }}-${{ hashFiles('Cargo.lock') }}y + - run: cargo test --all + - run: cargo test --manifest-path tokio-postgres/Cargo.toml --no-default-features + - run: cargo test --manifest-path tokio-postgres/Cargo.toml --all-features diff --git a/Cargo.toml b/Cargo.toml index d7a9186a8..16e3739dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,10 +2,15 @@ members = [ "codegen", "postgres", - "postgres-protocol", - "postgres-shared", - "postgres-openssl", + "postgres-derive", + "postgres-derive-test", "postgres-native-tls", + "postgres-openssl", + "postgres-protocol", + "postgres-types", "tokio-postgres", - "tokio-postgres-openssl", ] +resolver = "2" + +[profile.release] +debug = 2 diff --git a/LICENSE b/LICENSE deleted file mode 100644 index c7e577c00..000000000 --- a/LICENSE +++ /dev/null @@ -1,20 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2013-2017 Steven Fackler - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software is furnished to do so, -subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 000000000..16fe87b06 --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 000000000..71803aea1 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2016 Steven Fackler + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/README.md b/README.md index f0090e510..b81a6716f 100644 --- a/README.md +++ b/README.md @@ -1,361 +1,46 @@ # Rust-Postgres -[![CircleCI](https://circleci.com/gh/sfackler/rust-postgres.svg?style=shield)](https://circleci.com/gh/sfackler/rust-postgres) [![Latest Version](https://img.shields.io/crates/v/postgres.svg)](https://crates.io/crates/postgres) -A native PostgreSQL driver for Rust. +PostgreSQL support for Rust. -[Documentation](https://docs.rs/postgres) - -You can integrate Rust-Postgres into your project through the [releases on crates.io](https://crates.io/crates/postgres): -```toml -[dependencies] -postgres = "0.15" -``` - -## Overview -Rust-Postgres is a pure-Rust frontend for the popular PostgreSQL database. -```rust -extern crate postgres; - -use postgres::{Connection, TlsMode}; - -struct Person { - id: i32, - name: String, - data: Option>, -} - -fn main() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.execute("CREATE TABLE person ( - id SERIAL PRIMARY KEY, - name VARCHAR NOT NULL, - data BYTEA - )", &[]).unwrap(); - let me = Person { - id: 0, - name: "Steven".to_string(), - data: None, - }; - conn.execute("INSERT INTO person (name, data) VALUES ($1, $2)", - &[&me.name, &me.data]).unwrap(); - for row in &conn.query("SELECT id, name, data FROM person", &[]).unwrap() { - let person = Person { - id: row.get(0), - name: row.get(1), - data: row.get(2), - }; - println!("Found person {}: {}", person.id, person.name); - } -} -``` - -## Requirements -* **Rust** - Rust-Postgres is developed against the 1.18 release of Rust - available on http://www.rust-lang.org. It should also compile against more - recent releases. - -* **PostgreSQL 7.4 or later** - Rust-Postgres speaks version 3 of the - PostgreSQL protocol, which corresponds to versions 7.4 and later. If your - version of Postgres was compiled in the last decade, you should be okay. - -## Usage - -### Connecting -Connect to a Postgres server using the standard URI format: -```rust -let conn = Connection::connect("postgres://user:pass@host:port/database?arg1=val1&arg2=val2", - TlsMode::None)?; -``` -`pass` may be omitted if not needed. `port` defaults to `5432` and `database` -defaults to the value of `user` if not specified. The driver supports `trust`, -`password`, and `md5` authentication. - -Unix domain sockets can be used as well. The `host` portion of the URI should -be set to the absolute path to the directory containing the socket file. Since -`/` is a reserved character in URLs, the path should be URL encoded. If Postgres -stored its socket files in `/run/postgres`, the connection would then look like: -```rust -let conn = Connection::connect("postgres://postgres@%2Frun%2Fpostgres", TlsMode::None)?; -``` -Paths which contain non-UTF8 characters can be handled in a different manner; -see the documentation for details. - -### Querying -SQL statements can be executed with the `query` and `execute` methods. Both -methods take a query string as well as a slice of parameters to bind to the -query. The `i`th query parameter is specified in the query string by `$i`. Note -that query parameters are 1-indexed rather than the more common 0-indexing. - -`execute` returns the number of rows affected by the query (or 0 if not -applicable): -```rust -let updates = conn.execute("UPDATE foo SET bar = $1 WHERE baz = $2", &[&1i32, &"biz"])?; -println!("{} rows were updated", updates); -``` - -`query` returns an iterable object holding the rows returned from the database. -The fields in a row can be accessed either by their indices or their column -names, though access by index is more efficient. Unlike statement parameters, -result columns are zero-indexed. -```rust -for row in &conn.query("SELECT bar, baz FROM foo WHERE buz = $1", &[&1i32])? { - let bar: i32 = row.get(0); - let baz: String = row.get("baz"); - println!("bar: {}, baz: {}", bar, baz); -} -``` - -### Statement Preparation -If the same statement will be executed repeatedly (possibly with different -parameters), explicitly preparing it can improve performance: - -```rust -let stmt = conn.prepare("UPDATE foo SET bar = $1 WHERE baz = $2")?; -for (bar, baz) in updates { - stmt.execute(&[bar, baz])?; -} -``` +## postgres [![Latest Version](https://img.shields.io/crates/v/postgres.svg)](https://crates.io/crates/postgres) -### Transactions -The `transaction` method will start a new transaction. It returns a -`Transaction` object which has the functionality of a -`Connection` as well as methods to control the result of the -transaction: -```rust -let trans = conn.transaction()?; - -trans.execute(...)?; -let stmt = trans.prepare(...)?; -// ... - -trans.commit()?; -``` -The transaction will be active until the `Transaction` object falls out of -scope. A transaction will roll back by default. Nested transactions are -supported via savepoints. - -### Type Correspondence -Rust-Postgres enforces a strict correspondence between Rust types and Postgres -types. The driver currently supports the following conversions: - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Rust TypePostgres Type
boolBOOL
i8"char"
i16SMALLINT, SMALLSERIAL
i32INT, SERIAL
u32OID
i64BIGINT, BIGSERIAL
f32REAL
f64DOUBLE PRECISION
str/StringVARCHAR, CHAR(n), TEXT, CITEXT, NAME
[u8]/Vec<u8>BYTEA
- serialize::json::Json - and - serde_json::Value - (optional) - JSON, JSONB
- time::Timespec - and - chrono::NaiveDateTime - (optional) - TIMESTAMP
- time::Timespec, - chrono::DateTime<Utc>, - chrono::DateTime<Local>, - and - chrono::DateTime<FixedOffset> - (optional) - TIMESTAMP WITH TIME ZONE
- chrono::NaiveDate - (optional) - DATE
- chrono::NaiveTime - (optional) - TIME
- uuid::Uuid - (optional) - UUID
- bit_vec::BitVec - (optional) - BIT, VARBIT
HashMap<String, Option<String>>HSTORE
- eui48::MacAddress - (optional) - MACADDR
- geo::Point<f64> - (optional) - POINT
- geo::Bbox<f64> - (optional) - BOX
- geo::LineString<f64> - (optional) - PATH
- -`Option` implements `FromSql` where `T: FromSql` and `ToSql` where `T: -ToSql`, and represents nullable Postgres values. - -`&[T]` and `Vec` implement `ToSql` where `T: ToSql`, and `Vec` -additionally implements `FromSql` where `T: FromSql`, which represent -one-dimensional Postgres arrays. - -More conversions can be defined by implementing the `ToSql` and `FromSql` -traits. - -The [postgres-derive](https://github.com/sfackler/rust-postgres-derive) -crate will synthesize `ToSql` and `FromSql` implementations for enum, domain, -and composite Postgres types. - -Full support for array types is located in the -[postgres-array](https://github.com/sfackler/rust-postgres-array) crate. - -Support for range types is located in the -[postgres-range](https://github.com/sfackler/rust-postgres-range) crate. - -Support for the large object API is located in the -[postgres-large-object](https://github.com/sfackler/rust-postgres-large-object) -crate. - -## Optional features - -### UUID type - -[UUID](http://www.postgresql.org/docs/9.4/static/datatype-uuid.html) support is -provided optionally by the `with-uuid` feature, which adds `ToSql` and `FromSql` -implementations for `uuid`'s `Uuid` type. Requires `uuid` version 0.5. +[Documentation](https://docs.rs/postgres) -### JSON/JSONB types +A native, synchronous PostgreSQL client. -[JSON and JSONB](http://www.postgresql.org/docs/9.4/static/datatype-json.html) -support is provided optionally by the `with-rustc-serialize` feature, which adds -`ToSql` and `FromSql` implementations for `rustc-serialize`'s `Json` type, and -the `with-serde_json` feature, which adds implementations for `serde_json`'s -`Value` type. Requires `serde_json` version 1.0, `rustc-serialize` version 0.3. +## tokio-postgres [![Latest Version](https://img.shields.io/crates/v/tokio-postgres.svg)](https://crates.io/crates/tokio-postgres) -### TIMESTAMP/TIMESTAMPTZ/DATE/TIME types +[Documentation](https://docs.rs/tokio-postgres) -[Date and Time](http://www.postgresql.org/docs/9.1/static/datatype-datetime.html) -support is provided optionally by the `with-time` feature, which adds `ToSql` -and `FromSql` implementations for `time`'s `Timespec` type, or the `with-chrono` -feature, which adds `ToSql` and `FromSql` implementations for `chrono`'s -`DateTime`, `NaiveDateTime`, `NaiveDate` and `NaiveTime` types. Requires `time` version 0.1.14. +A native, asynchronous PostgreSQL client. -### BIT/VARBIT types +## postgres-types [![Latest Version](https://img.shields.io/crates/v/postgres-types.svg)](https://crates.io/crates/postgres-types) -[BIT and VARBIT](http://www.postgresql.org/docs/9.4/static/datatype-bit.html) -support is provided optionally by the `with-bit-vec` feature, which adds `ToSql` -and `FromSql` implementations for `bit-vec`'s `BitVec` type. Requires `bit-vec` version 0.4. +[Documentation](https://docs.rs/postgres-types) -### MACADDR type +Conversions between Rust and Postgres types. -[MACADDR](http://www.postgresql.org/docs/9.4/static/datatype-net-types.html#DATATYPE-MACADDR) -support is provided optionally by the `with-eui48` feature, which adds `ToSql` -and `FromSql` implementations for `eui48`'s `MacAddress` type. Requires `eui48` version 0.3. +## postgres-native-tls [![Latest Version](https://img.shields.io/crates/v/postgres-native-tls.svg)](https://crates.io/crates/postgres-native-tls) -### POINT type +[Documentation](https://docs.rs/postgres-native-tls) -[POINT](https://www.postgresql.org/docs/9.4/static/datatype-geometric.html#AEN6799) -support is provided optionally by the `with-geo` feature, which adds `ToSql` and `FromSql` implementations for `geo`'s `Point` type. Requires `geo` version 0.4. +TLS support for postgres and tokio-postgres via native-tls. -### BOX type +## postgres-openssl [![Latest Version](https://img.shields.io/crates/v/postgres-openssl.svg)](https://crates.io/crates/postgres-openssl) -[BOX](https://www.postgresql.org/docs/9.4/static/datatype-geometric.html#AEN6883) -support is provided optionally by the `with-geo` feature, which adds `ToSql` and `FromSql` implementations for `geo`'s `Bbox` type. Requires `geo` version 0.4. +[Documentation](https://docs.rs/postgres-openssl) -### PATH type +TLS support for postgres and tokio-postgres via openssl. -[PATH](https://www.postgresql.org/docs/9.4/static/datatype-geometric.html#AEN6912) -support is provided optionally by the `with-geo` feature, which adds `ToSql` and `FromSql` implementations for `geo`'s `LineString` type. -Paths converted from LineString are always treated as "open" paths. Requires `geo` version 0.4. Use the -[pclose](https://www.postgresql.org/docs/8.2/static/functions-geometry.html#FUNCTIONS-GEOMETRY-FUNC-TABLE) -geometric function to insert a closed path. +# Running test suite -## See Also +The test suite requires postgres to be running in the correct configuration. The easiest way to do this is with docker: -- [r2d2-postgres](https://github.com/sfackler/r2d2-postgres) for connection pool support. +1. Install `docker` and `docker-compose`. + 1. On ubuntu: `sudo apt install docker.io docker-compose`. +1. Make sure your user has permissions for docker. + 1. On ubuntu: ``sudo usermod -aG docker $USER`` +1. Change to top-level directory of `rust-postgres` repo. +1. Run `docker-compose up -d`. +1. Run `cargo test`. +1. Run `docker-compose stop`. diff --git a/THIRD_PARTY b/THIRD_PARTY index 80336ea0f..05e5ac435 100644 --- a/THIRD_PARTY +++ b/THIRD_PARTY @@ -27,33 +27,3 @@ BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE SOFTWARE PROVIDED HEREUNDER IS ON AN "AS IS" BASIS, AND THE UNIVERSITY OF CALIFORNIA HAS NO OBLIGATIONS TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. - -------------------------------------------------------------------------------- - -* src/url.rs has been copied from Rust - -Copyright (c) 2014 The Rust Project Developers - -Permission is hereby granted, free of charge, to any -person obtaining a copy of this software and associated -documentation files (the "Software"), to deal in the -Software without restriction, including without -limitation the rights to use, copy, modify, merge, -publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software -is furnished to do so, subject to the following -conditions: - -The above copyright notice and this permission notice -shall be included in all copies or substantial portions -of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. diff --git a/codegen/Cargo.toml b/codegen/Cargo.toml index 2ce54a7c0..bbe6b789c 100644 --- a/codegen/Cargo.toml +++ b/codegen/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" authors = ["Steven Fackler "] [dependencies] -phf_codegen = "=0.7.22" -regex = "0.1" +phf_codegen = "0.11" +regex = "1.0" marksman_escape = "0.1" -linked-hash-map = "0.4" +linked-hash-map = "0.5" diff --git a/codegen/src/errcodes.txt b/codegen/src/errcodes.txt index 4f3547176..62418a051 100644 --- a/codegen/src/errcodes.txt +++ b/codegen/src/errcodes.txt @@ -2,7 +2,7 @@ # errcodes.txt # PostgreSQL error codes # -# Copyright (c) 2003-2017, PostgreSQL Global Development Group +# Copyright (c) 2003-2022, PostgreSQL Global Development Group # # This list serves as the basis for generating source files containing error # codes. It is kept in a common format to make sure all these source files have @@ -18,7 +18,7 @@ # src/pl/tcl/pltclerrcodes.h # the same, for PL/Tcl # -# doc/src/sgml/errcodes-list.sgml +# doc/src/sgml/errcodes-table.sgml # a SGML table of error codes for inclusion in the documentation # # The format of this file is one error code per line, with the following @@ -177,6 +177,7 @@ Section: Class 22 - Data Exception 22P06 E ERRCODE_NONSTANDARD_USE_OF_ESCAPE_CHARACTER nonstandard_use_of_escape_character 22010 E ERRCODE_INVALID_INDICATOR_PARAMETER_VALUE invalid_indicator_parameter_value 22023 E ERRCODE_INVALID_PARAMETER_VALUE invalid_parameter_value +22013 E ERRCODE_INVALID_PRECEDING_OR_FOLLOWING_SIZE invalid_preceding_or_following_size 2201B E ERRCODE_INVALID_REGULAR_EXPRESSION invalid_regular_expression 2201W E ERRCODE_INVALID_ROW_COUNT_IN_LIMIT_CLAUSE invalid_row_count_in_limit_clause 2201X E ERRCODE_INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE invalid_row_count_in_result_offset_clause @@ -205,6 +206,23 @@ Section: Class 22 - Data Exception 2200N E ERRCODE_INVALID_XML_CONTENT invalid_xml_content 2200S E ERRCODE_INVALID_XML_COMMENT invalid_xml_comment 2200T E ERRCODE_INVALID_XML_PROCESSING_INSTRUCTION invalid_xml_processing_instruction +22030 E ERRCODE_DUPLICATE_JSON_OBJECT_KEY_VALUE duplicate_json_object_key_value +22031 E ERRCODE_INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION invalid_argument_for_sql_json_datetime_function +22032 E ERRCODE_INVALID_JSON_TEXT invalid_json_text +22033 E ERRCODE_INVALID_SQL_JSON_SUBSCRIPT invalid_sql_json_subscript +22034 E ERRCODE_MORE_THAN_ONE_SQL_JSON_ITEM more_than_one_sql_json_item +22035 E ERRCODE_NO_SQL_JSON_ITEM no_sql_json_item +22036 E ERRCODE_NON_NUMERIC_SQL_JSON_ITEM non_numeric_sql_json_item +22037 E ERRCODE_NON_UNIQUE_KEYS_IN_A_JSON_OBJECT non_unique_keys_in_a_json_object +22038 E ERRCODE_SINGLETON_SQL_JSON_ITEM_REQUIRED singleton_sql_json_item_required +22039 E ERRCODE_SQL_JSON_ARRAY_NOT_FOUND sql_json_array_not_found +2203A E ERRCODE_SQL_JSON_MEMBER_NOT_FOUND sql_json_member_not_found +2203B E ERRCODE_SQL_JSON_NUMBER_NOT_FOUND sql_json_number_not_found +2203C E ERRCODE_SQL_JSON_OBJECT_NOT_FOUND sql_json_object_not_found +2203D E ERRCODE_TOO_MANY_JSON_ARRAY_ELEMENTS too_many_json_array_elements +2203E E ERRCODE_TOO_MANY_JSON_OBJECT_MEMBERS too_many_json_object_members +2203F E ERRCODE_SQL_JSON_SCALAR_REQUIRED sql_json_scalar_required +2203G E ERRCODE_SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE sql_json_item_cannot_be_cast_to_target_type Section: Class 23 - Integrity Constraint Violation @@ -411,6 +429,7 @@ Section: Class 57 - Operator Intervention 57P02 E ERRCODE_CRASH_SHUTDOWN crash_shutdown 57P03 E ERRCODE_CANNOT_CONNECT_NOW cannot_connect_now 57P04 E ERRCODE_DATABASE_DROPPED database_dropped +57P05 E ERRCODE_IDLE_SESSION_TIMEOUT idle_session_timeout Section: Class 58 - System Error (errors external to PostgreSQL itself) diff --git a/codegen/src/main.rs b/codegen/src/main.rs index 9aa9a9744..d0a74f2dd 100644 --- a/codegen/src/main.rs +++ b/codegen/src/main.rs @@ -1,17 +1,17 @@ +#![warn(clippy::all)] +#![allow(clippy::write_with_newline)] + extern crate linked_hash_map; extern crate marksman_escape; extern crate phf_codegen; extern crate regex; -use std::path::Path; - mod sqlstate; mod type_gen; fn main() { - let path = Path::new("../postgres-shared/src"); - sqlstate::build(path); - type_gen::build(path); + sqlstate::build(); + type_gen::build(); } fn snake_to_camel(s: &str) -> String { diff --git a/codegen/src/pg_range.dat b/codegen/src/pg_range.dat new file mode 100644 index 000000000..74d6de0cf --- /dev/null +++ b/codegen/src/pg_range.dat @@ -0,0 +1,34 @@ +#---------------------------------------------------------------------- +# +# pg_range.dat +# Initial contents of the pg_range system catalog. +# +# Portions Copyright (c) 1996-2022, PostgreSQL Global Development Group +# Portions Copyright (c) 1994, Regents of the University of California +# +# src/include/catalog/pg_range.dat +# +#---------------------------------------------------------------------- + +[ + +{ rngtypid => 'int4range', rngsubtype => 'int4', + rngmultitypid => 'int4multirange', rngsubopc => 'btree/int4_ops', + rngcanonical => 'int4range_canonical', rngsubdiff => 'int4range_subdiff' }, +{ rngtypid => 'numrange', rngsubtype => 'numeric', + rngmultitypid => 'nummultirange', rngsubopc => 'btree/numeric_ops', + rngcanonical => '-', rngsubdiff => 'numrange_subdiff' }, +{ rngtypid => 'tsrange', rngsubtype => 'timestamp', + rngmultitypid => 'tsmultirange', rngsubopc => 'btree/timestamp_ops', + rngcanonical => '-', rngsubdiff => 'tsrange_subdiff' }, +{ rngtypid => 'tstzrange', rngsubtype => 'timestamptz', + rngmultitypid => 'tstzmultirange', rngsubopc => 'btree/timestamptz_ops', + rngcanonical => '-', rngsubdiff => 'tstzrange_subdiff' }, +{ rngtypid => 'daterange', rngsubtype => 'date', + rngmultitypid => 'datemultirange', rngsubopc => 'btree/date_ops', + rngcanonical => 'daterange_canonical', rngsubdiff => 'daterange_subdiff' }, +{ rngtypid => 'int8range', rngsubtype => 'int8', + rngmultitypid => 'int8multirange', rngsubopc => 'btree/int8_ops', + rngcanonical => 'int8range_canonical', rngsubdiff => 'int8range_subdiff' }, + +] diff --git a/codegen/src/pg_range.h b/codegen/src/pg_range.h deleted file mode 100644 index 4ed57fe2e..000000000 --- a/codegen/src/pg_range.h +++ /dev/null @@ -1,85 +0,0 @@ -/*------------------------------------------------------------------------- - * - * pg_range.h - * definition of the system "range" relation (pg_range) - * along with the relation's initial contents. - * - * - * Portions Copyright (c) 1996-2017, PostgreSQL Global Development Group - * Portions Copyright (c) 1994, Regents of the University of California - * - * src/include/catalog/pg_range.h - * - * NOTES - * the genbki.pl script reads this file and generates .bki - * information from the DATA() statements. - * - * XXX do NOT break up DATA() statements into multiple lines! - * the scripts are not as smart as you might think... - * - *------------------------------------------------------------------------- - */ -#ifndef PG_RANGE_H -#define PG_RANGE_H - -#include "catalog/genbki.h" - -/* ---------------- - * pg_range definition. cpp turns this into - * typedef struct FormData_pg_range - * ---------------- - */ -#define RangeRelationId 3541 - -CATALOG(pg_range,3541) BKI_WITHOUT_OIDS -{ - Oid rngtypid; /* OID of owning range type */ - Oid rngsubtype; /* OID of range's element type (subtype) */ - Oid rngcollation; /* collation for this range type, or 0 */ - Oid rngsubopc; /* subtype's btree opclass */ - regproc rngcanonical; /* canonicalize range, or 0 */ - regproc rngsubdiff; /* subtype difference as a float8, or 0 */ -} FormData_pg_range; - -/* ---------------- - * Form_pg_range corresponds to a pointer to a tuple with - * the format of pg_range relation. - * ---------------- - */ -typedef FormData_pg_range *Form_pg_range; - -/* ---------------- - * compiler constants for pg_range - * ---------------- - */ -#define Natts_pg_range 6 -#define Anum_pg_range_rngtypid 1 -#define Anum_pg_range_rngsubtype 2 -#define Anum_pg_range_rngcollation 3 -#define Anum_pg_range_rngsubopc 4 -#define Anum_pg_range_rngcanonical 5 -#define Anum_pg_range_rngsubdiff 6 - - -/* ---------------- - * initial contents of pg_range - * ---------------- - */ -DATA(insert ( 3904 23 0 1978 int4range_canonical int4range_subdiff)); -DATA(insert ( 3906 1700 0 3125 - numrange_subdiff)); -DATA(insert ( 3908 1114 0 3128 - tsrange_subdiff)); -DATA(insert ( 3910 1184 0 3127 - tstzrange_subdiff)); -DATA(insert ( 3912 1082 0 3122 daterange_canonical daterange_subdiff)); -DATA(insert ( 3926 20 0 3124 int8range_canonical int8range_subdiff)); - - -/* - * prototypes for functions in pg_range.c - */ - -extern void RangeCreate(Oid rangeTypeOid, Oid rangeSubType, Oid rangeCollation, - Oid rangeSubOpclass, RegProcedure rangeCanonical, - RegProcedure rangeSubDiff); -extern void RangeDelete(Oid rangeTypeOid); - -#endif /* PG_RANGE_H */ diff --git a/codegen/src/pg_type.dat b/codegen/src/pg_type.dat new file mode 100644 index 000000000..df4587946 --- /dev/null +++ b/codegen/src/pg_type.dat @@ -0,0 +1,695 @@ +#---------------------------------------------------------------------- +# +# pg_type.dat +# Initial contents of the pg_type system catalog. +# +# Portions Copyright (c) 1996-2022, PostgreSQL Global Development Group +# Portions Copyright (c) 1994, Regents of the University of California +# +# src/include/catalog/pg_type.dat +# +#---------------------------------------------------------------------- + +[ + +# For types used in the system catalogs, make sure the values here match +# TypInfo[] in bootstrap.c. + +# OID symbol macro names for pg_type OIDs are not specified here because +# they are generated by genbki.pl according to the following rule: +# foo_bar -> FOO_BAROID +# _foo_bar -> FOO_BARARRAYOID + +# To autogenerate an array type, add 'array_type_oid => 'nnnn' to the element +# type, which will instruct genbki.pl to generate a BKI entry for it. +# In a few cases, the array type's properties don't match the normal pattern +# so it can't be autogenerated; in such cases do not write array_type_oid. + +# Once upon a time these entries were ordered by OID. Lately it's often +# been the custom to insert new entries adjacent to related older entries. +# Try to do one or the other though, don't just insert entries at random. + +# OIDS 1 - 99 + +{ oid => '16', array_type_oid => '1000', + descr => 'boolean, \'true\'/\'false\'', + typname => 'bool', typlen => '1', typbyval => 't', typcategory => 'B', + typispreferred => 't', typinput => 'boolin', typoutput => 'boolout', + typreceive => 'boolrecv', typsend => 'boolsend', typalign => 'c' }, +{ oid => '17', array_type_oid => '1001', + descr => 'variable-length string, binary values escaped', + typname => 'bytea', typlen => '-1', typbyval => 'f', typcategory => 'U', + typinput => 'byteain', typoutput => 'byteaout', typreceive => 'bytearecv', + typsend => 'byteasend', typalign => 'i', typstorage => 'x' }, +{ oid => '18', array_type_oid => '1002', descr => 'single character', + typname => 'char', typlen => '1', typbyval => 't', typcategory => 'Z', + typinput => 'charin', typoutput => 'charout', typreceive => 'charrecv', + typsend => 'charsend', typalign => 'c' }, +{ oid => '19', array_type_oid => '1003', + descr => '63-byte type for storing system identifiers', + typname => 'name', typlen => 'NAMEDATALEN', typbyval => 'f', + typcategory => 'S', typsubscript => 'raw_array_subscript_handler', + typelem => 'char', typinput => 'namein', typoutput => 'nameout', + typreceive => 'namerecv', typsend => 'namesend', typalign => 'c', + typcollation => 'C' }, +{ oid => '20', array_type_oid => '1016', + descr => '~18 digit integer, 8-byte storage', + typname => 'int8', typlen => '8', typbyval => 'FLOAT8PASSBYVAL', + typcategory => 'N', typinput => 'int8in', typoutput => 'int8out', + typreceive => 'int8recv', typsend => 'int8send', typalign => 'd' }, +{ oid => '21', array_type_oid => '1005', + descr => '-32 thousand to 32 thousand, 2-byte storage', + typname => 'int2', typlen => '2', typbyval => 't', typcategory => 'N', + typinput => 'int2in', typoutput => 'int2out', typreceive => 'int2recv', + typsend => 'int2send', typalign => 's' }, +{ oid => '22', array_type_oid => '1006', + descr => 'array of int2, used in system tables', + typname => 'int2vector', typlen => '-1', typbyval => 'f', typcategory => 'A', + typsubscript => 'array_subscript_handler', typelem => 'int2', + typinput => 'int2vectorin', typoutput => 'int2vectorout', + typreceive => 'int2vectorrecv', typsend => 'int2vectorsend', + typalign => 'i' }, +{ oid => '23', array_type_oid => '1007', + descr => '-2 billion to 2 billion integer, 4-byte storage', + typname => 'int4', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'int4in', typoutput => 'int4out', typreceive => 'int4recv', + typsend => 'int4send', typalign => 'i' }, +{ oid => '24', array_type_oid => '1008', descr => 'registered procedure', + typname => 'regproc', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'regprocin', typoutput => 'regprocout', + typreceive => 'regprocrecv', typsend => 'regprocsend', typalign => 'i' }, +{ oid => '25', array_type_oid => '1009', + descr => 'variable-length string, no limit specified', + typname => 'text', typlen => '-1', typbyval => 'f', typcategory => 'S', + typispreferred => 't', typinput => 'textin', typoutput => 'textout', + typreceive => 'textrecv', typsend => 'textsend', typalign => 'i', + typstorage => 'x', typcollation => 'default' }, +{ oid => '26', array_type_oid => '1028', + descr => 'object identifier(oid), maximum 4 billion', + typname => 'oid', typlen => '4', typbyval => 't', typcategory => 'N', + typispreferred => 't', typinput => 'oidin', typoutput => 'oidout', + typreceive => 'oidrecv', typsend => 'oidsend', typalign => 'i' }, +{ oid => '27', array_type_oid => '1010', + descr => '(block, offset), physical location of tuple', + typname => 'tid', typlen => '6', typbyval => 'f', typcategory => 'U', + typinput => 'tidin', typoutput => 'tidout', typreceive => 'tidrecv', + typsend => 'tidsend', typalign => 's' }, +{ oid => '28', array_type_oid => '1011', descr => 'transaction id', + typname => 'xid', typlen => '4', typbyval => 't', typcategory => 'U', + typinput => 'xidin', typoutput => 'xidout', typreceive => 'xidrecv', + typsend => 'xidsend', typalign => 'i' }, +{ oid => '29', array_type_oid => '1012', + descr => 'command identifier type, sequence in transaction id', + typname => 'cid', typlen => '4', typbyval => 't', typcategory => 'U', + typinput => 'cidin', typoutput => 'cidout', typreceive => 'cidrecv', + typsend => 'cidsend', typalign => 'i' }, +{ oid => '30', array_type_oid => '1013', + descr => 'array of oids, used in system tables', + typname => 'oidvector', typlen => '-1', typbyval => 'f', typcategory => 'A', + typsubscript => 'array_subscript_handler', typelem => 'oid', + typinput => 'oidvectorin', typoutput => 'oidvectorout', + typreceive => 'oidvectorrecv', typsend => 'oidvectorsend', typalign => 'i' }, + +# hand-built rowtype entries for bootstrapped catalogs +# NB: OIDs assigned here must match the BKI_ROWTYPE_OID declarations +{ oid => '71', array_type_oid => '210', + typname => 'pg_type', typlen => '-1', typbyval => 'f', typtype => 'c', + typcategory => 'C', typrelid => 'pg_type', typinput => 'record_in', + typoutput => 'record_out', typreceive => 'record_recv', + typsend => 'record_send', typalign => 'd', typstorage => 'x' }, +{ oid => '75', array_type_oid => '270', + typname => 'pg_attribute', typlen => '-1', typbyval => 'f', typtype => 'c', + typcategory => 'C', typrelid => 'pg_attribute', typinput => 'record_in', + typoutput => 'record_out', typreceive => 'record_recv', + typsend => 'record_send', typalign => 'd', typstorage => 'x' }, +{ oid => '81', array_type_oid => '272', + typname => 'pg_proc', typlen => '-1', typbyval => 'f', typtype => 'c', + typcategory => 'C', typrelid => 'pg_proc', typinput => 'record_in', + typoutput => 'record_out', typreceive => 'record_recv', + typsend => 'record_send', typalign => 'd', typstorage => 'x' }, +{ oid => '83', array_type_oid => '273', + typname => 'pg_class', typlen => '-1', typbyval => 'f', typtype => 'c', + typcategory => 'C', typrelid => 'pg_class', typinput => 'record_in', + typoutput => 'record_out', typreceive => 'record_recv', + typsend => 'record_send', typalign => 'd', typstorage => 'x' }, + +# OIDS 100 - 199 + +{ oid => '114', array_type_oid => '199', descr => 'JSON stored as text', + typname => 'json', typlen => '-1', typbyval => 'f', typcategory => 'U', + typinput => 'json_in', typoutput => 'json_out', typreceive => 'json_recv', + typsend => 'json_send', typalign => 'i', typstorage => 'x' }, +{ oid => '142', array_type_oid => '143', descr => 'XML content', + typname => 'xml', typlen => '-1', typbyval => 'f', typcategory => 'U', + typinput => 'xml_in', typoutput => 'xml_out', typreceive => 'xml_recv', + typsend => 'xml_send', typalign => 'i', typstorage => 'x' }, +{ oid => '194', descr => 'string representing an internal node tree', + typname => 'pg_node_tree', typlen => '-1', typbyval => 'f', + typcategory => 'Z', typinput => 'pg_node_tree_in', + typoutput => 'pg_node_tree_out', typreceive => 'pg_node_tree_recv', + typsend => 'pg_node_tree_send', typalign => 'i', typstorage => 'x', + typcollation => 'default' }, +{ oid => '3361', descr => 'multivariate ndistinct coefficients', + typname => 'pg_ndistinct', typlen => '-1', typbyval => 'f', + typcategory => 'Z', typinput => 'pg_ndistinct_in', + typoutput => 'pg_ndistinct_out', typreceive => 'pg_ndistinct_recv', + typsend => 'pg_ndistinct_send', typalign => 'i', typstorage => 'x', + typcollation => 'default' }, +{ oid => '3402', descr => 'multivariate dependencies', + typname => 'pg_dependencies', typlen => '-1', typbyval => 'f', + typcategory => 'Z', typinput => 'pg_dependencies_in', + typoutput => 'pg_dependencies_out', typreceive => 'pg_dependencies_recv', + typsend => 'pg_dependencies_send', typalign => 'i', typstorage => 'x', + typcollation => 'default' }, +{ oid => '5017', descr => 'multivariate MCV list', + typname => 'pg_mcv_list', typlen => '-1', typbyval => 'f', typcategory => 'Z', + typinput => 'pg_mcv_list_in', typoutput => 'pg_mcv_list_out', + typreceive => 'pg_mcv_list_recv', typsend => 'pg_mcv_list_send', + typalign => 'i', typstorage => 'x', typcollation => 'default' }, +{ oid => '32', descr => 'internal type for passing CollectedCommand', + typname => 'pg_ddl_command', typlen => 'SIZEOF_POINTER', typbyval => 't', + typtype => 'p', typcategory => 'P', typinput => 'pg_ddl_command_in', + typoutput => 'pg_ddl_command_out', typreceive => 'pg_ddl_command_recv', + typsend => 'pg_ddl_command_send', typalign => 'ALIGNOF_POINTER' }, +{ oid => '5069', array_type_oid => '271', descr => 'full transaction id', + typname => 'xid8', typlen => '8', typbyval => 'FLOAT8PASSBYVAL', + typcategory => 'U', typinput => 'xid8in', typoutput => 'xid8out', + typreceive => 'xid8recv', typsend => 'xid8send', typalign => 'd' }, + +# OIDS 600 - 699 + +{ oid => '600', array_type_oid => '1017', + descr => 'geometric point \'(x, y)\'', + typname => 'point', typlen => '16', typbyval => 'f', typcategory => 'G', + typsubscript => 'raw_array_subscript_handler', typelem => 'float8', + typinput => 'point_in', typoutput => 'point_out', typreceive => 'point_recv', + typsend => 'point_send', typalign => 'd' }, +{ oid => '601', array_type_oid => '1018', + descr => 'geometric line segment \'(pt1,pt2)\'', + typname => 'lseg', typlen => '32', typbyval => 'f', typcategory => 'G', + typsubscript => 'raw_array_subscript_handler', typelem => 'point', + typinput => 'lseg_in', typoutput => 'lseg_out', typreceive => 'lseg_recv', + typsend => 'lseg_send', typalign => 'd' }, +{ oid => '602', array_type_oid => '1019', + descr => 'geometric path \'(pt1,...)\'', + typname => 'path', typlen => '-1', typbyval => 'f', typcategory => 'G', + typinput => 'path_in', typoutput => 'path_out', typreceive => 'path_recv', + typsend => 'path_send', typalign => 'd', typstorage => 'x' }, +{ oid => '603', array_type_oid => '1020', + descr => 'geometric box \'(lower left,upper right)\'', + typname => 'box', typlen => '32', typbyval => 'f', typcategory => 'G', + typdelim => ';', typsubscript => 'raw_array_subscript_handler', + typelem => 'point', typinput => 'box_in', typoutput => 'box_out', + typreceive => 'box_recv', typsend => 'box_send', typalign => 'd' }, +{ oid => '604', array_type_oid => '1027', + descr => 'geometric polygon \'(pt1,...)\'', + typname => 'polygon', typlen => '-1', typbyval => 'f', typcategory => 'G', + typinput => 'poly_in', typoutput => 'poly_out', typreceive => 'poly_recv', + typsend => 'poly_send', typalign => 'd', typstorage => 'x' }, +{ oid => '628', array_type_oid => '629', descr => 'geometric line', + typname => 'line', typlen => '24', typbyval => 'f', typcategory => 'G', + typsubscript => 'raw_array_subscript_handler', typelem => 'float8', + typinput => 'line_in', typoutput => 'line_out', typreceive => 'line_recv', + typsend => 'line_send', typalign => 'd' }, + +# OIDS 700 - 799 + +{ oid => '700', array_type_oid => '1021', + descr => 'single-precision floating point number, 4-byte storage', + typname => 'float4', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'float4in', typoutput => 'float4out', typreceive => 'float4recv', + typsend => 'float4send', typalign => 'i' }, +{ oid => '701', array_type_oid => '1022', + descr => 'double-precision floating point number, 8-byte storage', + typname => 'float8', typlen => '8', typbyval => 'FLOAT8PASSBYVAL', + typcategory => 'N', typispreferred => 't', typinput => 'float8in', + typoutput => 'float8out', typreceive => 'float8recv', typsend => 'float8send', + typalign => 'd' }, +{ oid => '705', descr => 'pseudo-type representing an undetermined type', + typname => 'unknown', typlen => '-2', typbyval => 'f', typtype => 'p', + typcategory => 'X', typinput => 'unknownin', typoutput => 'unknownout', + typreceive => 'unknownrecv', typsend => 'unknownsend', typalign => 'c' }, +{ oid => '718', array_type_oid => '719', + descr => 'geometric circle \'(center,radius)\'', + typname => 'circle', typlen => '24', typbyval => 'f', typcategory => 'G', + typinput => 'circle_in', typoutput => 'circle_out', + typreceive => 'circle_recv', typsend => 'circle_send', typalign => 'd' }, +{ oid => '790', array_type_oid => '791', + descr => 'monetary amounts, $d,ddd.cc', + typname => 'money', typlen => '8', typbyval => 'FLOAT8PASSBYVAL', + typcategory => 'N', typinput => 'cash_in', typoutput => 'cash_out', + typreceive => 'cash_recv', typsend => 'cash_send', typalign => 'd' }, + +# OIDS 800 - 899 + +{ oid => '829', array_type_oid => '1040', + descr => 'XX:XX:XX:XX:XX:XX, MAC address', + typname => 'macaddr', typlen => '6', typbyval => 'f', typcategory => 'U', + typinput => 'macaddr_in', typoutput => 'macaddr_out', + typreceive => 'macaddr_recv', typsend => 'macaddr_send', typalign => 'i' }, +{ oid => '869', array_type_oid => '1041', + descr => 'IP address/netmask, host address, netmask optional', + typname => 'inet', typlen => '-1', typbyval => 'f', typcategory => 'I', + typispreferred => 't', typinput => 'inet_in', typoutput => 'inet_out', + typreceive => 'inet_recv', typsend => 'inet_send', typalign => 'i', + typstorage => 'm' }, +{ oid => '650', array_type_oid => '651', + descr => 'network IP address/netmask, network address', + typname => 'cidr', typlen => '-1', typbyval => 'f', typcategory => 'I', + typinput => 'cidr_in', typoutput => 'cidr_out', typreceive => 'cidr_recv', + typsend => 'cidr_send', typalign => 'i', typstorage => 'm' }, +{ oid => '774', array_type_oid => '775', + descr => 'XX:XX:XX:XX:XX:XX:XX:XX, MAC address', + typname => 'macaddr8', typlen => '8', typbyval => 'f', typcategory => 'U', + typinput => 'macaddr8_in', typoutput => 'macaddr8_out', + typreceive => 'macaddr8_recv', typsend => 'macaddr8_send', typalign => 'i' }, + +# OIDS 1000 - 1099 + +{ oid => '1033', array_type_oid => '1034', descr => 'access control list', + typname => 'aclitem', typlen => '12', typbyval => 'f', typcategory => 'U', + typinput => 'aclitemin', typoutput => 'aclitemout', typreceive => '-', + typsend => '-', typalign => 'i' }, +{ oid => '1042', array_type_oid => '1014', + descr => 'char(length), blank-padded string, fixed storage length', + typname => 'bpchar', typlen => '-1', typbyval => 'f', typcategory => 'S', + typinput => 'bpcharin', typoutput => 'bpcharout', typreceive => 'bpcharrecv', + typsend => 'bpcharsend', typmodin => 'bpchartypmodin', + typmodout => 'bpchartypmodout', typalign => 'i', typstorage => 'x', + typcollation => 'default' }, +{ oid => '1043', array_type_oid => '1015', + descr => 'varchar(length), non-blank-padded string, variable storage length', + typname => 'varchar', typlen => '-1', typbyval => 'f', typcategory => 'S', + typinput => 'varcharin', typoutput => 'varcharout', + typreceive => 'varcharrecv', typsend => 'varcharsend', + typmodin => 'varchartypmodin', typmodout => 'varchartypmodout', + typalign => 'i', typstorage => 'x', typcollation => 'default' }, +{ oid => '1082', array_type_oid => '1182', descr => 'date', + typname => 'date', typlen => '4', typbyval => 't', typcategory => 'D', + typinput => 'date_in', typoutput => 'date_out', typreceive => 'date_recv', + typsend => 'date_send', typalign => 'i' }, +{ oid => '1083', array_type_oid => '1183', descr => 'time of day', + typname => 'time', typlen => '8', typbyval => 'FLOAT8PASSBYVAL', + typcategory => 'D', typinput => 'time_in', typoutput => 'time_out', + typreceive => 'time_recv', typsend => 'time_send', typmodin => 'timetypmodin', + typmodout => 'timetypmodout', typalign => 'd' }, + +# OIDS 1100 - 1199 + +{ oid => '1114', array_type_oid => '1115', descr => 'date and time', + typname => 'timestamp', typlen => '8', typbyval => 'FLOAT8PASSBYVAL', + typcategory => 'D', typinput => 'timestamp_in', typoutput => 'timestamp_out', + typreceive => 'timestamp_recv', typsend => 'timestamp_send', + typmodin => 'timestamptypmodin', typmodout => 'timestamptypmodout', + typalign => 'd' }, +{ oid => '1184', array_type_oid => '1185', + descr => 'date and time with time zone', + typname => 'timestamptz', typlen => '8', typbyval => 'FLOAT8PASSBYVAL', + typcategory => 'D', typispreferred => 't', typinput => 'timestamptz_in', + typoutput => 'timestamptz_out', typreceive => 'timestamptz_recv', + typsend => 'timestamptz_send', typmodin => 'timestamptztypmodin', + typmodout => 'timestamptztypmodout', typalign => 'd' }, +{ oid => '1186', array_type_oid => '1187', + descr => '@ , time interval', + typname => 'interval', typlen => '16', typbyval => 'f', typcategory => 'T', + typispreferred => 't', typinput => 'interval_in', typoutput => 'interval_out', + typreceive => 'interval_recv', typsend => 'interval_send', + typmodin => 'intervaltypmodin', typmodout => 'intervaltypmodout', + typalign => 'd' }, + +# OIDS 1200 - 1299 + +{ oid => '1266', array_type_oid => '1270', + descr => 'time of day with time zone', + typname => 'timetz', typlen => '12', typbyval => 'f', typcategory => 'D', + typinput => 'timetz_in', typoutput => 'timetz_out', + typreceive => 'timetz_recv', typsend => 'timetz_send', + typmodin => 'timetztypmodin', typmodout => 'timetztypmodout', + typalign => 'd' }, + +# OIDS 1500 - 1599 + +{ oid => '1560', array_type_oid => '1561', descr => 'fixed-length bit string', + typname => 'bit', typlen => '-1', typbyval => 'f', typcategory => 'V', + typinput => 'bit_in', typoutput => 'bit_out', typreceive => 'bit_recv', + typsend => 'bit_send', typmodin => 'bittypmodin', typmodout => 'bittypmodout', + typalign => 'i', typstorage => 'x' }, +{ oid => '1562', array_type_oid => '1563', + descr => 'variable-length bit string', + typname => 'varbit', typlen => '-1', typbyval => 'f', typcategory => 'V', + typispreferred => 't', typinput => 'varbit_in', typoutput => 'varbit_out', + typreceive => 'varbit_recv', typsend => 'varbit_send', + typmodin => 'varbittypmodin', typmodout => 'varbittypmodout', typalign => 'i', + typstorage => 'x' }, + +# OIDS 1700 - 1799 + +{ oid => '1700', array_type_oid => '1231', + descr => 'numeric(precision, decimal), arbitrary precision number', + typname => 'numeric', typlen => '-1', typbyval => 'f', typcategory => 'N', + typinput => 'numeric_in', typoutput => 'numeric_out', + typreceive => 'numeric_recv', typsend => 'numeric_send', + typmodin => 'numerictypmodin', typmodout => 'numerictypmodout', + typalign => 'i', typstorage => 'm' }, + +{ oid => '1790', array_type_oid => '2201', + descr => 'reference to cursor (portal name)', + typname => 'refcursor', typlen => '-1', typbyval => 'f', typcategory => 'U', + typinput => 'textin', typoutput => 'textout', typreceive => 'textrecv', + typsend => 'textsend', typalign => 'i', typstorage => 'x' }, + +# OIDS 2200 - 2299 + +{ oid => '2202', array_type_oid => '2207', + descr => 'registered procedure (with args)', + typname => 'regprocedure', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'regprocedurein', typoutput => 'regprocedureout', + typreceive => 'regprocedurerecv', typsend => 'regproceduresend', + typalign => 'i' }, +{ oid => '2203', array_type_oid => '2208', descr => 'registered operator', + typname => 'regoper', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'regoperin', typoutput => 'regoperout', + typreceive => 'regoperrecv', typsend => 'regopersend', typalign => 'i' }, +{ oid => '2204', array_type_oid => '2209', + descr => 'registered operator (with args)', + typname => 'regoperator', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'regoperatorin', typoutput => 'regoperatorout', + typreceive => 'regoperatorrecv', typsend => 'regoperatorsend', + typalign => 'i' }, +{ oid => '2205', array_type_oid => '2210', descr => 'registered class', + typname => 'regclass', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'regclassin', typoutput => 'regclassout', + typreceive => 'regclassrecv', typsend => 'regclasssend', typalign => 'i' }, +{ oid => '4191', array_type_oid => '4192', descr => 'registered collation', + typname => 'regcollation', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'regcollationin', typoutput => 'regcollationout', + typreceive => 'regcollationrecv', typsend => 'regcollationsend', + typalign => 'i' }, +{ oid => '2206', array_type_oid => '2211', descr => 'registered type', + typname => 'regtype', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'regtypein', typoutput => 'regtypeout', + typreceive => 'regtyperecv', typsend => 'regtypesend', typalign => 'i' }, +{ oid => '4096', array_type_oid => '4097', descr => 'registered role', + typname => 'regrole', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'regrolein', typoutput => 'regroleout', + typreceive => 'regrolerecv', typsend => 'regrolesend', typalign => 'i' }, +{ oid => '4089', array_type_oid => '4090', descr => 'registered namespace', + typname => 'regnamespace', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'regnamespacein', typoutput => 'regnamespaceout', + typreceive => 'regnamespacerecv', typsend => 'regnamespacesend', + typalign => 'i' }, + +# uuid +{ oid => '2950', array_type_oid => '2951', descr => 'UUID datatype', + typname => 'uuid', typlen => '16', typbyval => 'f', typcategory => 'U', + typinput => 'uuid_in', typoutput => 'uuid_out', typreceive => 'uuid_recv', + typsend => 'uuid_send', typalign => 'c' }, + +# pg_lsn +{ oid => '3220', array_type_oid => '3221', descr => 'PostgreSQL LSN datatype', + typname => 'pg_lsn', typlen => '8', typbyval => 'FLOAT8PASSBYVAL', + typcategory => 'U', typinput => 'pg_lsn_in', typoutput => 'pg_lsn_out', + typreceive => 'pg_lsn_recv', typsend => 'pg_lsn_send', typalign => 'd' }, + +# text search +{ oid => '3614', array_type_oid => '3643', + descr => 'text representation for text search', + typname => 'tsvector', typlen => '-1', typbyval => 'f', typcategory => 'U', + typinput => 'tsvectorin', typoutput => 'tsvectorout', + typreceive => 'tsvectorrecv', typsend => 'tsvectorsend', + typanalyze => 'ts_typanalyze', typalign => 'i', typstorage => 'x' }, +{ oid => '3642', array_type_oid => '3644', + descr => 'GiST index internal text representation for text search', + typname => 'gtsvector', typlen => '-1', typbyval => 'f', typcategory => 'U', + typinput => 'gtsvectorin', typoutput => 'gtsvectorout', typreceive => '-', + typsend => '-', typalign => 'i' }, +{ oid => '3615', array_type_oid => '3645', + descr => 'query representation for text search', + typname => 'tsquery', typlen => '-1', typbyval => 'f', typcategory => 'U', + typinput => 'tsqueryin', typoutput => 'tsqueryout', + typreceive => 'tsqueryrecv', typsend => 'tsquerysend', typalign => 'i' }, +{ oid => '3734', array_type_oid => '3735', + descr => 'registered text search configuration', + typname => 'regconfig', typlen => '4', typbyval => 't', typcategory => 'N', + typinput => 'regconfigin', typoutput => 'regconfigout', + typreceive => 'regconfigrecv', typsend => 'regconfigsend', typalign => 'i' }, +{ oid => '3769', array_type_oid => '3770', + descr => 'registered text search dictionary', + typname => 'regdictionary', typlen => '4', typbyval => 't', + typcategory => 'N', typinput => 'regdictionaryin', + typoutput => 'regdictionaryout', typreceive => 'regdictionaryrecv', + typsend => 'regdictionarysend', typalign => 'i' }, + +# jsonb +{ oid => '3802', array_type_oid => '3807', descr => 'Binary JSON', + typname => 'jsonb', typlen => '-1', typbyval => 'f', typcategory => 'U', + typsubscript => 'jsonb_subscript_handler', typinput => 'jsonb_in', + typoutput => 'jsonb_out', typreceive => 'jsonb_recv', typsend => 'jsonb_send', + typalign => 'i', typstorage => 'x' }, +{ oid => '4072', array_type_oid => '4073', descr => 'JSON path', + typname => 'jsonpath', typlen => '-1', typbyval => 'f', typcategory => 'U', + typinput => 'jsonpath_in', typoutput => 'jsonpath_out', + typreceive => 'jsonpath_recv', typsend => 'jsonpath_send', typalign => 'i', + typstorage => 'x' }, + +{ oid => '2970', array_type_oid => '2949', descr => 'txid snapshot', + typname => 'txid_snapshot', typlen => '-1', typbyval => 'f', + typcategory => 'U', typinput => 'txid_snapshot_in', + typoutput => 'txid_snapshot_out', typreceive => 'txid_snapshot_recv', + typsend => 'txid_snapshot_send', typalign => 'd', typstorage => 'x' }, +{ oid => '5038', array_type_oid => '5039', descr => 'snapshot', + typname => 'pg_snapshot', typlen => '-1', typbyval => 'f', typcategory => 'U', + typinput => 'pg_snapshot_in', typoutput => 'pg_snapshot_out', + typreceive => 'pg_snapshot_recv', typsend => 'pg_snapshot_send', + typalign => 'd', typstorage => 'x' }, + +# range types +{ oid => '3904', array_type_oid => '3905', descr => 'range of integers', + typname => 'int4range', typlen => '-1', typbyval => 'f', typtype => 'r', + typcategory => 'R', typinput => 'range_in', typoutput => 'range_out', + typreceive => 'range_recv', typsend => 'range_send', + typanalyze => 'range_typanalyze', typalign => 'i', typstorage => 'x' }, +{ oid => '3906', array_type_oid => '3907', descr => 'range of numerics', + typname => 'numrange', typlen => '-1', typbyval => 'f', typtype => 'r', + typcategory => 'R', typinput => 'range_in', typoutput => 'range_out', + typreceive => 'range_recv', typsend => 'range_send', + typanalyze => 'range_typanalyze', typalign => 'i', typstorage => 'x' }, +{ oid => '3908', array_type_oid => '3909', + descr => 'range of timestamps without time zone', + typname => 'tsrange', typlen => '-1', typbyval => 'f', typtype => 'r', + typcategory => 'R', typinput => 'range_in', typoutput => 'range_out', + typreceive => 'range_recv', typsend => 'range_send', + typanalyze => 'range_typanalyze', typalign => 'd', typstorage => 'x' }, +{ oid => '3910', array_type_oid => '3911', + descr => 'range of timestamps with time zone', + typname => 'tstzrange', typlen => '-1', typbyval => 'f', typtype => 'r', + typcategory => 'R', typinput => 'range_in', typoutput => 'range_out', + typreceive => 'range_recv', typsend => 'range_send', + typanalyze => 'range_typanalyze', typalign => 'd', typstorage => 'x' }, +{ oid => '3912', array_type_oid => '3913', descr => 'range of dates', + typname => 'daterange', typlen => '-1', typbyval => 'f', typtype => 'r', + typcategory => 'R', typinput => 'range_in', typoutput => 'range_out', + typreceive => 'range_recv', typsend => 'range_send', + typanalyze => 'range_typanalyze', typalign => 'i', typstorage => 'x' }, +{ oid => '3926', array_type_oid => '3927', descr => 'range of bigints', + typname => 'int8range', typlen => '-1', typbyval => 'f', typtype => 'r', + typcategory => 'R', typinput => 'range_in', typoutput => 'range_out', + typreceive => 'range_recv', typsend => 'range_send', + typanalyze => 'range_typanalyze', typalign => 'd', typstorage => 'x' }, + +# multirange types +{ oid => '4451', array_type_oid => '6150', descr => 'multirange of integers', + typname => 'int4multirange', typlen => '-1', typbyval => 'f', typtype => 'm', + typcategory => 'R', typinput => 'multirange_in', + typoutput => 'multirange_out', typreceive => 'multirange_recv', + typsend => 'multirange_send', typanalyze => 'multirange_typanalyze', + typalign => 'i', typstorage => 'x' }, +{ oid => '4532', array_type_oid => '6151', descr => 'multirange of numerics', + typname => 'nummultirange', typlen => '-1', typbyval => 'f', typtype => 'm', + typcategory => 'R', typinput => 'multirange_in', + typoutput => 'multirange_out', typreceive => 'multirange_recv', + typsend => 'multirange_send', typanalyze => 'multirange_typanalyze', + typalign => 'i', typstorage => 'x' }, +{ oid => '4533', array_type_oid => '6152', + descr => 'multirange of timestamps without time zone', + typname => 'tsmultirange', typlen => '-1', typbyval => 'f', typtype => 'm', + typcategory => 'R', typinput => 'multirange_in', + typoutput => 'multirange_out', typreceive => 'multirange_recv', + typsend => 'multirange_send', typanalyze => 'multirange_typanalyze', + typalign => 'd', typstorage => 'x' }, +{ oid => '4534', array_type_oid => '6153', + descr => 'multirange of timestamps with time zone', + typname => 'tstzmultirange', typlen => '-1', typbyval => 'f', typtype => 'm', + typcategory => 'R', typinput => 'multirange_in', + typoutput => 'multirange_out', typreceive => 'multirange_recv', + typsend => 'multirange_send', typanalyze => 'multirange_typanalyze', + typalign => 'd', typstorage => 'x' }, +{ oid => '4535', array_type_oid => '6155', descr => 'multirange of dates', + typname => 'datemultirange', typlen => '-1', typbyval => 'f', typtype => 'm', + typcategory => 'R', typinput => 'multirange_in', + typoutput => 'multirange_out', typreceive => 'multirange_recv', + typsend => 'multirange_send', typanalyze => 'multirange_typanalyze', + typalign => 'i', typstorage => 'x' }, +{ oid => '4536', array_type_oid => '6157', descr => 'multirange of bigints', + typname => 'int8multirange', typlen => '-1', typbyval => 'f', typtype => 'm', + typcategory => 'R', typinput => 'multirange_in', + typoutput => 'multirange_out', typreceive => 'multirange_recv', + typsend => 'multirange_send', typanalyze => 'multirange_typanalyze', + typalign => 'd', typstorage => 'x' }, + +# pseudo-types +# types with typtype='p' represent various special cases in the type system. +# These cannot be used to define table columns, but are valid as function +# argument and result types (if supported by the function's implementation +# language). +# Note: cstring is a borderline case; it is still considered a pseudo-type, +# but there is now support for it in records and arrays. Perhaps we should +# just treat it as a regular base type? + +{ oid => '2249', descr => 'pseudo-type representing any composite type', + typname => 'record', typlen => '-1', typbyval => 'f', typtype => 'p', + typcategory => 'P', typarray => '_record', typinput => 'record_in', + typoutput => 'record_out', typreceive => 'record_recv', + typsend => 'record_send', typalign => 'd', typstorage => 'x' }, +# Arrays of records have typcategory P, so they can't be autogenerated. +{ oid => '2287', + typname => '_record', typlen => '-1', typbyval => 'f', typtype => 'p', + typcategory => 'P', typsubscript => 'array_subscript_handler', + typelem => 'record', typinput => 'array_in', typoutput => 'array_out', + typreceive => 'array_recv', typsend => 'array_send', + typanalyze => 'array_typanalyze', typalign => 'd', typstorage => 'x' }, +{ oid => '2275', array_type_oid => '1263', descr => 'C-style string', + typname => 'cstring', typlen => '-2', typbyval => 'f', typtype => 'p', + typcategory => 'P', typinput => 'cstring_in', typoutput => 'cstring_out', + typreceive => 'cstring_recv', typsend => 'cstring_send', typalign => 'c' }, +{ oid => '2276', descr => 'pseudo-type representing any type', + typname => 'any', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'any_in', typoutput => 'any_out', + typreceive => '-', typsend => '-', typalign => 'i' }, +{ oid => '2277', descr => 'pseudo-type representing a polymorphic array type', + typname => 'anyarray', typlen => '-1', typbyval => 'f', typtype => 'p', + typcategory => 'P', typinput => 'anyarray_in', typoutput => 'anyarray_out', + typreceive => 'anyarray_recv', typsend => 'anyarray_send', typalign => 'd', + typstorage => 'x' }, +{ oid => '2278', + descr => 'pseudo-type for the result of a function with no real result', + typname => 'void', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'void_in', typoutput => 'void_out', + typreceive => 'void_recv', typsend => 'void_send', typalign => 'i' }, +{ oid => '2279', descr => 'pseudo-type for the result of a trigger function', + typname => 'trigger', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'trigger_in', typoutput => 'trigger_out', + typreceive => '-', typsend => '-', typalign => 'i' }, +{ oid => '3838', + descr => 'pseudo-type for the result of an event trigger function', + typname => 'event_trigger', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'event_trigger_in', + typoutput => 'event_trigger_out', typreceive => '-', typsend => '-', + typalign => 'i' }, +{ oid => '2280', + descr => 'pseudo-type for the result of a language handler function', + typname => 'language_handler', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'language_handler_in', + typoutput => 'language_handler_out', typreceive => '-', typsend => '-', + typalign => 'i' }, +{ oid => '2281', + descr => 'pseudo-type representing an internal data structure', + typname => 'internal', typlen => 'SIZEOF_POINTER', typbyval => 't', + typtype => 'p', typcategory => 'P', typinput => 'internal_in', + typoutput => 'internal_out', typreceive => '-', typsend => '-', + typalign => 'ALIGNOF_POINTER' }, +{ oid => '2283', descr => 'pseudo-type representing a polymorphic base type', + typname => 'anyelement', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'anyelement_in', + typoutput => 'anyelement_out', typreceive => '-', typsend => '-', + typalign => 'i' }, +{ oid => '2776', + descr => 'pseudo-type representing a polymorphic base type that is not an array', + typname => 'anynonarray', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'anynonarray_in', + typoutput => 'anynonarray_out', typreceive => '-', typsend => '-', + typalign => 'i' }, +{ oid => '3500', + descr => 'pseudo-type representing a polymorphic base type that is an enum', + typname => 'anyenum', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'anyenum_in', typoutput => 'anyenum_out', + typreceive => '-', typsend => '-', typalign => 'i' }, +{ oid => '3115', + descr => 'pseudo-type for the result of an FDW handler function', + typname => 'fdw_handler', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'fdw_handler_in', + typoutput => 'fdw_handler_out', typreceive => '-', typsend => '-', + typalign => 'i' }, +{ oid => '325', + descr => 'pseudo-type for the result of an index AM handler function', + typname => 'index_am_handler', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'index_am_handler_in', + typoutput => 'index_am_handler_out', typreceive => '-', typsend => '-', + typalign => 'i' }, +{ oid => '3310', + descr => 'pseudo-type for the result of a tablesample method function', + typname => 'tsm_handler', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'tsm_handler_in', + typoutput => 'tsm_handler_out', typreceive => '-', typsend => '-', + typalign => 'i' }, +{ oid => '269', + typname => 'table_am_handler', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'table_am_handler_in', + typoutput => 'table_am_handler_out', typreceive => '-', typsend => '-', + typalign => 'i' }, +{ oid => '3831', + descr => 'pseudo-type representing a range over a polymorphic base type', + typname => 'anyrange', typlen => '-1', typbyval => 'f', typtype => 'p', + typcategory => 'P', typinput => 'anyrange_in', typoutput => 'anyrange_out', + typreceive => '-', typsend => '-', typalign => 'd', typstorage => 'x' }, +{ oid => '5077', + descr => 'pseudo-type representing a polymorphic common type', + typname => 'anycompatible', typlen => '4', typbyval => 't', typtype => 'p', + typcategory => 'P', typinput => 'anycompatible_in', + typoutput => 'anycompatible_out', typreceive => '-', typsend => '-', + typalign => 'i' }, +{ oid => '5078', + descr => 'pseudo-type representing an array of polymorphic common type elements', + typname => 'anycompatiblearray', typlen => '-1', typbyval => 'f', + typtype => 'p', typcategory => 'P', typinput => 'anycompatiblearray_in', + typoutput => 'anycompatiblearray_out', + typreceive => 'anycompatiblearray_recv', typsend => 'anycompatiblearray_send', + typalign => 'd', typstorage => 'x' }, +{ oid => '5079', + descr => 'pseudo-type representing a polymorphic common type that is not an array', + typname => 'anycompatiblenonarray', typlen => '4', typbyval => 't', + typtype => 'p', typcategory => 'P', typinput => 'anycompatiblenonarray_in', + typoutput => 'anycompatiblenonarray_out', typreceive => '-', typsend => '-', + typalign => 'i' }, +{ oid => '5080', + descr => 'pseudo-type representing a range over a polymorphic common type', + typname => 'anycompatiblerange', typlen => '-1', typbyval => 'f', + typtype => 'p', typcategory => 'P', typinput => 'anycompatiblerange_in', + typoutput => 'anycompatiblerange_out', typreceive => '-', typsend => '-', + typalign => 'd', typstorage => 'x' }, +{ oid => '4537', + descr => 'pseudo-type representing a polymorphic base type that is a multirange', + typname => 'anymultirange', typlen => '-1', typbyval => 'f', typtype => 'p', + typcategory => 'P', typinput => 'anymultirange_in', + typoutput => 'anymultirange_out', typreceive => '-', typsend => '-', + typalign => 'd', typstorage => 'x' }, +{ oid => '4538', + descr => 'pseudo-type representing a multirange over a polymorphic common type', + typname => 'anycompatiblemultirange', typlen => '-1', typbyval => 'f', + typtype => 'p', typcategory => 'P', typinput => 'anycompatiblemultirange_in', + typoutput => 'anycompatiblemultirange_out', typreceive => '-', typsend => '-', + typalign => 'd', typstorage => 'x' }, +{ oid => '4600', descr => 'BRIN bloom summary', + typname => 'pg_brin_bloom_summary', typlen => '-1', typbyval => 'f', + typcategory => 'Z', typinput => 'brin_bloom_summary_in', + typoutput => 'brin_bloom_summary_out', + typreceive => 'brin_bloom_summary_recv', typsend => 'brin_bloom_summary_send', + typalign => 'i', typstorage => 'x', typcollation => 'default' }, +{ oid => '4601', descr => 'BRIN minmax-multi summary', + typname => 'pg_brin_minmax_multi_summary', typlen => '-1', typbyval => 'f', + typcategory => 'Z', typinput => 'brin_minmax_multi_summary_in', + typoutput => 'brin_minmax_multi_summary_out', + typreceive => 'brin_minmax_multi_summary_recv', + typsend => 'brin_minmax_multi_summary_send', typalign => 'i', + typstorage => 'x', typcollation => 'default' }, +] diff --git a/codegen/src/pg_type.h b/codegen/src/pg_type.h deleted file mode 100644 index 345e91640..000000000 --- a/codegen/src/pg_type.h +++ /dev/null @@ -1,752 +0,0 @@ -/*------------------------------------------------------------------------- - * - * pg_type.h - * definition of the system "type" relation (pg_type) - * along with the relation's initial contents. - * - * - * Portions Copyright (c) 1996-2017, PostgreSQL Global Development Group - * Portions Copyright (c) 1994, Regents of the University of California - * - * src/include/catalog/pg_type.h - * - * NOTES - * the genbki.pl script reads this file and generates .bki - * information from the DATA() statements. - * - *------------------------------------------------------------------------- - */ -#ifndef PG_TYPE_H -#define PG_TYPE_H - -#include "catalog/genbki.h" - -/* ---------------- - * pg_type definition. cpp turns this into - * typedef struct FormData_pg_type - * - * Some of the values in a pg_type instance are copied into - * pg_attribute instances. Some parts of Postgres use the pg_type copy, - * while others use the pg_attribute copy, so they must match. - * See struct FormData_pg_attribute for details. - * ---------------- - */ -#define TypeRelationId 1247 -#define TypeRelation_Rowtype_Id 71 - -CATALOG(pg_type,1247) BKI_BOOTSTRAP BKI_ROWTYPE_OID(71) BKI_SCHEMA_MACRO -{ - NameData typname; /* type name */ - Oid typnamespace; /* OID of namespace containing this type */ - Oid typowner; /* type owner */ - - /* - * For a fixed-size type, typlen is the number of bytes we use to - * represent a value of this type, e.g. 4 for an int4. But for a - * variable-length type, typlen is negative. We use -1 to indicate a - * "varlena" type (one that has a length word), -2 to indicate a - * null-terminated C string. - */ - int16 typlen; - - /* - * typbyval determines whether internal Postgres routines pass a value of - * this type by value or by reference. typbyval had better be FALSE if - * the length is not 1, 2, or 4 (or 8 on 8-byte-Datum machines). - * Variable-length types are always passed by reference. Note that - * typbyval can be false even if the length would allow pass-by-value; - * this is currently true for type float4, for example. - */ - bool typbyval; - - /* - * typtype is 'b' for a base type, 'c' for a composite type (e.g., a - * table's rowtype), 'd' for a domain, 'e' for an enum type, 'p' for a - * pseudo-type, or 'r' for a range type. (Use the TYPTYPE macros below.) - * - * If typtype is 'c', typrelid is the OID of the class' entry in pg_class. - */ - char typtype; - - /* - * typcategory and typispreferred help the parser distinguish preferred - * and non-preferred coercions. The category can be any single ASCII - * character (but not \0). The categories used for built-in types are - * identified by the TYPCATEGORY macros below. - */ - char typcategory; /* arbitrary type classification */ - - bool typispreferred; /* is type "preferred" within its category? */ - - /* - * If typisdefined is false, the entry is only a placeholder (forward - * reference). We know the type name, but not yet anything else about it. - */ - bool typisdefined; - - char typdelim; /* delimiter for arrays of this type */ - - Oid typrelid; /* 0 if not a composite type */ - - /* - * If typelem is not 0 then it identifies another row in pg_type. The - * current type can then be subscripted like an array yielding values of - * type typelem. A non-zero typelem does not guarantee this type to be a - * "real" array type; some ordinary fixed-length types can also be - * subscripted (e.g., name, point). Variable-length types can *not* be - * turned into pseudo-arrays like that. Hence, the way to determine - * whether a type is a "true" array type is if: - * - * typelem != 0 and typlen == -1. - */ - Oid typelem; - - /* - * If there is a "true" array type having this type as element type, - * typarray links to it. Zero if no associated "true" array type. - */ - Oid typarray; - - /* - * I/O conversion procedures for the datatype. - */ - regproc typinput; /* text format (required) */ - regproc typoutput; - regproc typreceive; /* binary format (optional) */ - regproc typsend; - - /* - * I/O functions for optional type modifiers. - */ - regproc typmodin; - regproc typmodout; - - /* - * Custom ANALYZE procedure for the datatype (0 selects the default). - */ - regproc typanalyze; - - /* ---------------- - * typalign is the alignment required when storing a value of this - * type. It applies to storage on disk as well as most - * representations of the value inside Postgres. When multiple values - * are stored consecutively, such as in the representation of a - * complete row on disk, padding is inserted before a datum of this - * type so that it begins on the specified boundary. The alignment - * reference is the beginning of the first datum in the sequence. - * - * 'c' = CHAR alignment, ie no alignment needed. - * 's' = SHORT alignment (2 bytes on most machines). - * 'i' = INT alignment (4 bytes on most machines). - * 'd' = DOUBLE alignment (8 bytes on many machines, but by no means all). - * - * See include/access/tupmacs.h for the macros that compute these - * alignment requirements. Note also that we allow the nominal alignment - * to be violated when storing "packed" varlenas; the TOAST mechanism - * takes care of hiding that from most code. - * - * NOTE: for types used in system tables, it is critical that the - * size and alignment defined in pg_type agree with the way that the - * compiler will lay out the field in a struct representing a table row. - * ---------------- - */ - char typalign; - - /* ---------------- - * typstorage tells if the type is prepared for toasting and what - * the default strategy for attributes of this type should be. - * - * 'p' PLAIN type not prepared for toasting - * 'e' EXTERNAL external storage possible, don't try to compress - * 'x' EXTENDED try to compress and store external if required - * 'm' MAIN like 'x' but try to keep in main tuple - * ---------------- - */ - char typstorage; - - /* - * This flag represents a "NOT NULL" constraint against this datatype. - * - * If true, the attnotnull column for a corresponding table column using - * this datatype will always enforce the NOT NULL constraint. - * - * Used primarily for domain types. - */ - bool typnotnull; - - /* - * Domains use typbasetype to show the base (or domain) type that the - * domain is based on. Zero if the type is not a domain. - */ - Oid typbasetype; - - /* - * Domains use typtypmod to record the typmod to be applied to their base - * type (-1 if base type does not use a typmod). -1 if this type is not a - * domain. - */ - int32 typtypmod; - - /* - * typndims is the declared number of dimensions for an array domain type - * (i.e., typbasetype is an array type). Otherwise zero. - */ - int32 typndims; - - /* - * Collation: 0 if type cannot use collations, DEFAULT_COLLATION_OID for - * collatable base types, possibly other OID for domains - */ - Oid typcollation; - -#ifdef CATALOG_VARLEN /* variable-length fields start here */ - - /* - * If typdefaultbin is not NULL, it is the nodeToString representation of - * a default expression for the type. Currently this is only used for - * domains. - */ - pg_node_tree typdefaultbin; - - /* - * typdefault is NULL if the type has no associated default value. If - * typdefaultbin is not NULL, typdefault must contain a human-readable - * version of the default expression represented by typdefaultbin. If - * typdefaultbin is NULL and typdefault is not, then typdefault is the - * external representation of the type's default value, which may be fed - * to the type's input converter to produce a constant. - */ - text typdefault; - - /* - * Access permissions - */ - aclitem typacl[1]; -#endif -} FormData_pg_type; - -/* ---------------- - * Form_pg_type corresponds to a pointer to a row with - * the format of pg_type relation. - * ---------------- - */ -typedef FormData_pg_type *Form_pg_type; - -/* ---------------- - * compiler constants for pg_type - * ---------------- - */ -#define Natts_pg_type 30 -#define Anum_pg_type_typname 1 -#define Anum_pg_type_typnamespace 2 -#define Anum_pg_type_typowner 3 -#define Anum_pg_type_typlen 4 -#define Anum_pg_type_typbyval 5 -#define Anum_pg_type_typtype 6 -#define Anum_pg_type_typcategory 7 -#define Anum_pg_type_typispreferred 8 -#define Anum_pg_type_typisdefined 9 -#define Anum_pg_type_typdelim 10 -#define Anum_pg_type_typrelid 11 -#define Anum_pg_type_typelem 12 -#define Anum_pg_type_typarray 13 -#define Anum_pg_type_typinput 14 -#define Anum_pg_type_typoutput 15 -#define Anum_pg_type_typreceive 16 -#define Anum_pg_type_typsend 17 -#define Anum_pg_type_typmodin 18 -#define Anum_pg_type_typmodout 19 -#define Anum_pg_type_typanalyze 20 -#define Anum_pg_type_typalign 21 -#define Anum_pg_type_typstorage 22 -#define Anum_pg_type_typnotnull 23 -#define Anum_pg_type_typbasetype 24 -#define Anum_pg_type_typtypmod 25 -#define Anum_pg_type_typndims 26 -#define Anum_pg_type_typcollation 27 -#define Anum_pg_type_typdefaultbin 28 -#define Anum_pg_type_typdefault 29 -#define Anum_pg_type_typacl 30 - - -/* ---------------- - * initial contents of pg_type - * ---------------- - */ - -/* - * Keep the following ordered by OID so that later changes can be made more - * easily. - * - * For types used in the system catalogs, make sure the values here match - * TypInfo[] in bootstrap.c. - */ - -/* OIDS 1 - 99 */ -DATA(insert OID = 16 ( bool PGNSP PGUID 1 t b B t t \054 0 0 1000 boolin boolout boolrecv boolsend - - - c p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("boolean, 'true'/'false'"); -#define BOOLOID 16 - -DATA(insert OID = 17 ( bytea PGNSP PGUID -1 f b U f t \054 0 0 1001 byteain byteaout bytearecv byteasend - - - i x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("variable-length string, binary values escaped"); -#define BYTEAOID 17 - -DATA(insert OID = 18 ( char PGNSP PGUID 1 t b S f t \054 0 0 1002 charin charout charrecv charsend - - - c p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("single character"); -#define CHAROID 18 - -DATA(insert OID = 19 ( name PGNSP PGUID NAMEDATALEN f b S f t \054 0 18 1003 namein nameout namerecv namesend - - - c p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("63-byte type for storing system identifiers"); -#define NAMEOID 19 - -DATA(insert OID = 20 ( int8 PGNSP PGUID 8 FLOAT8PASSBYVAL b N f t \054 0 0 1016 int8in int8out int8recv int8send - - - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("~18 digit integer, 8-byte storage"); -#define INT8OID 20 - -DATA(insert OID = 21 ( int2 PGNSP PGUID 2 t b N f t \054 0 0 1005 int2in int2out int2recv int2send - - - s p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("-32 thousand to 32 thousand, 2-byte storage"); -#define INT2OID 21 - -DATA(insert OID = 22 ( int2vector PGNSP PGUID -1 f b A f t \054 0 21 1006 int2vectorin int2vectorout int2vectorrecv int2vectorsend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("array of int2, used in system tables"); -#define INT2VECTOROID 22 - -DATA(insert OID = 23 ( int4 PGNSP PGUID 4 t b N f t \054 0 0 1007 int4in int4out int4recv int4send - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("-2 billion to 2 billion integer, 4-byte storage"); -#define INT4OID 23 - -DATA(insert OID = 24 ( regproc PGNSP PGUID 4 t b N f t \054 0 0 1008 regprocin regprocout regprocrecv regprocsend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("registered procedure"); -#define REGPROCOID 24 - -DATA(insert OID = 25 ( text PGNSP PGUID -1 f b S t t \054 0 0 1009 textin textout textrecv textsend - - - i x f 0 -1 0 100 _null_ _null_ _null_ )); -DESCR("variable-length string, no limit specified"); -#define TEXTOID 25 - -DATA(insert OID = 26 ( oid PGNSP PGUID 4 t b N t t \054 0 0 1028 oidin oidout oidrecv oidsend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("object identifier(oid), maximum 4 billion"); -#define OIDOID 26 - -DATA(insert OID = 27 ( tid PGNSP PGUID 6 f b U f t \054 0 0 1010 tidin tidout tidrecv tidsend - - - s p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("(block, offset), physical location of tuple"); -#define TIDOID 27 - -DATA(insert OID = 28 ( xid PGNSP PGUID 4 t b U f t \054 0 0 1011 xidin xidout xidrecv xidsend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("transaction id"); -#define XIDOID 28 - -DATA(insert OID = 29 ( cid PGNSP PGUID 4 t b U f t \054 0 0 1012 cidin cidout cidrecv cidsend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("command identifier type, sequence in transaction id"); -#define CIDOID 29 - -DATA(insert OID = 30 ( oidvector PGNSP PGUID -1 f b A f t \054 0 26 1013 oidvectorin oidvectorout oidvectorrecv oidvectorsend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("array of oids, used in system tables"); -#define OIDVECTOROID 30 - -/* hand-built rowtype entries for bootstrapped catalogs */ -/* NB: OIDs assigned here must match the BKI_ROWTYPE_OID declarations */ - -DATA(insert OID = 71 ( pg_type PGNSP PGUID -1 f c C f t \054 1247 0 0 record_in record_out record_recv record_send - - - d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 75 ( pg_attribute PGNSP PGUID -1 f c C f t \054 1249 0 0 record_in record_out record_recv record_send - - - d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 81 ( pg_proc PGNSP PGUID -1 f c C f t \054 1255 0 0 record_in record_out record_recv record_send - - - d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 83 ( pg_class PGNSP PGUID -1 f c C f t \054 1259 0 0 record_in record_out record_recv record_send - - - d x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* OIDS 100 - 199 */ -DATA(insert OID = 114 ( json PGNSP PGUID -1 f b U f t \054 0 0 199 json_in json_out json_recv json_send - - - i x f 0 -1 0 0 _null_ _null_ _null_ )); -#define JSONOID 114 -DATA(insert OID = 142 ( xml PGNSP PGUID -1 f b U f t \054 0 0 143 xml_in xml_out xml_recv xml_send - - - i x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("XML content"); -#define XMLOID 142 -DATA(insert OID = 143 ( _xml PGNSP PGUID -1 f b A f t \054 0 142 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 199 ( _json PGNSP PGUID -1 f b A f t \054 0 114 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); - -DATA(insert OID = 194 ( pg_node_tree PGNSP PGUID -1 f b S f t \054 0 0 0 pg_node_tree_in pg_node_tree_out pg_node_tree_recv pg_node_tree_send - - - i x f 0 -1 0 100 _null_ _null_ _null_ )); -DESCR("string representing an internal node tree"); -#define PGNODETREEOID 194 - -DATA(insert OID = 3361 ( pg_ndistinct PGNSP PGUID -1 f b S f t \054 0 0 0 pg_ndistinct_in pg_ndistinct_out pg_ndistinct_recv pg_ndistinct_send - - - i x f 0 -1 0 100 _null_ _null_ _null_ )); -DESCR("multivariate ndistinct coefficients"); -#define PGNDISTINCTOID 3361 - -DATA(insert OID = 3402 ( pg_dependencies PGNSP PGUID -1 f b S f t \054 0 0 0 pg_dependencies_in pg_dependencies_out pg_dependencies_recv pg_dependencies_send - - - i x f 0 -1 0 100 _null_ _null_ _null_ )); -DESCR("multivariate dependencies"); -#define PGDEPENDENCIESOID 3402 - -DATA(insert OID = 32 ( pg_ddl_command PGNSP PGUID SIZEOF_POINTER t p P f t \054 0 0 0 pg_ddl_command_in pg_ddl_command_out pg_ddl_command_recv pg_ddl_command_send - - - ALIGNOF_POINTER p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("internal type for passing CollectedCommand"); -#define PGDDLCOMMANDOID 32 - -/* OIDS 200 - 299 */ - -DATA(insert OID = 210 ( smgr PGNSP PGUID 2 t b U f t \054 0 0 0 smgrin smgrout - - - - - s p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("storage manager"); - -/* OIDS 300 - 399 */ - -/* OIDS 400 - 499 */ - -/* OIDS 500 - 599 */ - -/* OIDS 600 - 699 */ -DATA(insert OID = 600 ( point PGNSP PGUID 16 f b G f t \054 0 701 1017 point_in point_out point_recv point_send - - - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("geometric point '(x, y)'"); -#define POINTOID 600 -DATA(insert OID = 601 ( lseg PGNSP PGUID 32 f b G f t \054 0 600 1018 lseg_in lseg_out lseg_recv lseg_send - - - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("geometric line segment '(pt1,pt2)'"); -#define LSEGOID 601 -DATA(insert OID = 602 ( path PGNSP PGUID -1 f b G f t \054 0 0 1019 path_in path_out path_recv path_send - - - d x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("geometric path '(pt1,...)'"); -#define PATHOID 602 -DATA(insert OID = 603 ( box PGNSP PGUID 32 f b G f t \073 0 600 1020 box_in box_out box_recv box_send - - - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("geometric box '(lower left,upper right)'"); -#define BOXOID 603 -DATA(insert OID = 604 ( polygon PGNSP PGUID -1 f b G f t \054 0 0 1027 poly_in poly_out poly_recv poly_send - - - d x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("geometric polygon '(pt1,...)'"); -#define POLYGONOID 604 - -DATA(insert OID = 628 ( line PGNSP PGUID 24 f b G f t \054 0 701 629 line_in line_out line_recv line_send - - - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("geometric line"); -#define LINEOID 628 -DATA(insert OID = 629 ( _line PGNSP PGUID -1 f b A f t \054 0 628 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* OIDS 700 - 799 */ - -DATA(insert OID = 700 ( float4 PGNSP PGUID 4 FLOAT4PASSBYVAL b N f t \054 0 0 1021 float4in float4out float4recv float4send - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("single-precision floating point number, 4-byte storage"); -#define FLOAT4OID 700 -DATA(insert OID = 701 ( float8 PGNSP PGUID 8 FLOAT8PASSBYVAL b N t t \054 0 0 1022 float8in float8out float8recv float8send - - - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("double-precision floating point number, 8-byte storage"); -#define FLOAT8OID 701 -DATA(insert OID = 702 ( abstime PGNSP PGUID 4 t b D f t \054 0 0 1023 abstimein abstimeout abstimerecv abstimesend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("absolute, limited-range date and time (Unix system time)"); -#define ABSTIMEOID 702 -DATA(insert OID = 703 ( reltime PGNSP PGUID 4 t b T f t \054 0 0 1024 reltimein reltimeout reltimerecv reltimesend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("relative, limited-range time interval (Unix delta time)"); -#define RELTIMEOID 703 -DATA(insert OID = 704 ( tinterval PGNSP PGUID 12 f b T f t \054 0 0 1025 tintervalin tintervalout tintervalrecv tintervalsend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("(abstime,abstime), time interval"); -#define TINTERVALOID 704 -DATA(insert OID = 705 ( unknown PGNSP PGUID -2 f p X f t \054 0 0 0 unknownin unknownout unknownrecv unknownsend - - - c p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR(""); -#define UNKNOWNOID 705 - -DATA(insert OID = 718 ( circle PGNSP PGUID 24 f b G f t \054 0 0 719 circle_in circle_out circle_recv circle_send - - - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("geometric circle '(center,radius)'"); -#define CIRCLEOID 718 -DATA(insert OID = 719 ( _circle PGNSP PGUID -1 f b A f t \054 0 718 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 790 ( money PGNSP PGUID 8 FLOAT8PASSBYVAL b N f t \054 0 0 791 cash_in cash_out cash_recv cash_send - - - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("monetary amounts, $d,ddd.cc"); -#define CASHOID 790 -DATA(insert OID = 791 ( _money PGNSP PGUID -1 f b A f t \054 0 790 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* OIDS 800 - 899 */ -DATA(insert OID = 829 ( macaddr PGNSP PGUID 6 f b U f t \054 0 0 1040 macaddr_in macaddr_out macaddr_recv macaddr_send - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("XX:XX:XX:XX:XX:XX, MAC address"); -#define MACADDROID 829 -DATA(insert OID = 869 ( inet PGNSP PGUID -1 f b I t t \054 0 0 1041 inet_in inet_out inet_recv inet_send - - - i m f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("IP address/netmask, host address, netmask optional"); -#define INETOID 869 -DATA(insert OID = 650 ( cidr PGNSP PGUID -1 f b I f t \054 0 0 651 cidr_in cidr_out cidr_recv cidr_send - - - i m f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("network IP address/netmask, network address"); -#define CIDROID 650 -DATA(insert OID = 774 ( macaddr8 PGNSP PGUID 8 f b U f t \054 0 0 775 macaddr8_in macaddr8_out macaddr8_recv macaddr8_send - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("XX:XX:XX:XX:XX:XX:XX:XX, MAC address"); -#define MACADDR8OID 774 - -/* OIDS 900 - 999 */ - -/* OIDS 1000 - 1099 */ -DATA(insert OID = 1000 ( _bool PGNSP PGUID -1 f b A f t \054 0 16 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1001 ( _bytea PGNSP PGUID -1 f b A f t \054 0 17 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1002 ( _char PGNSP PGUID -1 f b A f t \054 0 18 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1003 ( _name PGNSP PGUID -1 f b A f t \054 0 19 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1005 ( _int2 PGNSP PGUID -1 f b A f t \054 0 21 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -#define INT2ARRAYOID 1005 -DATA(insert OID = 1006 ( _int2vector PGNSP PGUID -1 f b A f t \054 0 22 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1007 ( _int4 PGNSP PGUID -1 f b A f t \054 0 23 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -#define INT4ARRAYOID 1007 -DATA(insert OID = 1008 ( _regproc PGNSP PGUID -1 f b A f t \054 0 24 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1009 ( _text PGNSP PGUID -1 f b A f t \054 0 25 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 100 _null_ _null_ _null_ )); -#define TEXTARRAYOID 1009 -DATA(insert OID = 1028 ( _oid PGNSP PGUID -1 f b A f t \054 0 26 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -#define OIDARRAYOID 1028 -DATA(insert OID = 1010 ( _tid PGNSP PGUID -1 f b A f t \054 0 27 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1011 ( _xid PGNSP PGUID -1 f b A f t \054 0 28 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1012 ( _cid PGNSP PGUID -1 f b A f t \054 0 29 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1013 ( _oidvector PGNSP PGUID -1 f b A f t \054 0 30 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1014 ( _bpchar PGNSP PGUID -1 f b A f t \054 0 1042 0 array_in array_out array_recv array_send bpchartypmodin bpchartypmodout array_typanalyze i x f 0 -1 0 100 _null_ _null_ _null_ )); -DATA(insert OID = 1015 ( _varchar PGNSP PGUID -1 f b A f t \054 0 1043 0 array_in array_out array_recv array_send varchartypmodin varchartypmodout array_typanalyze i x f 0 -1 0 100 _null_ _null_ _null_ )); -DATA(insert OID = 1016 ( _int8 PGNSP PGUID -1 f b A f t \054 0 20 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1017 ( _point PGNSP PGUID -1 f b A f t \054 0 600 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1018 ( _lseg PGNSP PGUID -1 f b A f t \054 0 601 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1019 ( _path PGNSP PGUID -1 f b A f t \054 0 602 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1020 ( _box PGNSP PGUID -1 f b A f t \073 0 603 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1021 ( _float4 PGNSP PGUID -1 f b A f t \054 0 700 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -#define FLOAT4ARRAYOID 1021 -DATA(insert OID = 1022 ( _float8 PGNSP PGUID -1 f b A f t \054 0 701 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1023 ( _abstime PGNSP PGUID -1 f b A f t \054 0 702 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1024 ( _reltime PGNSP PGUID -1 f b A f t \054 0 703 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1025 ( _tinterval PGNSP PGUID -1 f b A f t \054 0 704 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1027 ( _polygon PGNSP PGUID -1 f b A f t \054 0 604 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1033 ( aclitem PGNSP PGUID 12 f b U f t \054 0 0 1034 aclitemin aclitemout - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("access control list"); -#define ACLITEMOID 1033 -DATA(insert OID = 1034 ( _aclitem PGNSP PGUID -1 f b A f t \054 0 1033 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1040 ( _macaddr PGNSP PGUID -1 f b A f t \054 0 829 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 775 ( _macaddr8 PGNSP PGUID -1 f b A f t \054 0 774 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1041 ( _inet PGNSP PGUID -1 f b A f t \054 0 869 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 651 ( _cidr PGNSP PGUID -1 f b A f t \054 0 650 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1263 ( _cstring PGNSP PGUID -1 f b A f t \054 0 2275 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -#define CSTRINGARRAYOID 1263 - -DATA(insert OID = 1042 ( bpchar PGNSP PGUID -1 f b S f t \054 0 0 1014 bpcharin bpcharout bpcharrecv bpcharsend bpchartypmodin bpchartypmodout - i x f 0 -1 0 100 _null_ _null_ _null_ )); -DESCR("char(length), blank-padded string, fixed storage length"); -#define BPCHAROID 1042 -DATA(insert OID = 1043 ( varchar PGNSP PGUID -1 f b S f t \054 0 0 1015 varcharin varcharout varcharrecv varcharsend varchartypmodin varchartypmodout - i x f 0 -1 0 100 _null_ _null_ _null_ )); -DESCR("varchar(length), non-blank-padded string, variable storage length"); -#define VARCHAROID 1043 - -DATA(insert OID = 1082 ( date PGNSP PGUID 4 t b D f t \054 0 0 1182 date_in date_out date_recv date_send - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("date"); -#define DATEOID 1082 -DATA(insert OID = 1083 ( time PGNSP PGUID 8 FLOAT8PASSBYVAL b D f t \054 0 0 1183 time_in time_out time_recv time_send timetypmodin timetypmodout - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("time of day"); -#define TIMEOID 1083 - -/* OIDS 1100 - 1199 */ -DATA(insert OID = 1114 ( timestamp PGNSP PGUID 8 FLOAT8PASSBYVAL b D f t \054 0 0 1115 timestamp_in timestamp_out timestamp_recv timestamp_send timestamptypmodin timestamptypmodout - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("date and time"); -#define TIMESTAMPOID 1114 -DATA(insert OID = 1115 ( _timestamp PGNSP PGUID -1 f b A f t \054 0 1114 0 array_in array_out array_recv array_send timestamptypmodin timestamptypmodout array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1182 ( _date PGNSP PGUID -1 f b A f t \054 0 1082 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1183 ( _time PGNSP PGUID -1 f b A f t \054 0 1083 0 array_in array_out array_recv array_send timetypmodin timetypmodout array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1184 ( timestamptz PGNSP PGUID 8 FLOAT8PASSBYVAL b D t t \054 0 0 1185 timestamptz_in timestamptz_out timestamptz_recv timestamptz_send timestamptztypmodin timestamptztypmodout - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("date and time with time zone"); -#define TIMESTAMPTZOID 1184 -DATA(insert OID = 1185 ( _timestamptz PGNSP PGUID -1 f b A f t \054 0 1184 0 array_in array_out array_recv array_send timestamptztypmodin timestamptztypmodout array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1186 ( interval PGNSP PGUID 16 f b T t t \054 0 0 1187 interval_in interval_out interval_recv interval_send intervaltypmodin intervaltypmodout - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("@ , time interval"); -#define INTERVALOID 1186 -DATA(insert OID = 1187 ( _interval PGNSP PGUID -1 f b A f t \054 0 1186 0 array_in array_out array_recv array_send intervaltypmodin intervaltypmodout array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* OIDS 1200 - 1299 */ -DATA(insert OID = 1231 ( _numeric PGNSP PGUID -1 f b A f t \054 0 1700 0 array_in array_out array_recv array_send numerictypmodin numerictypmodout array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1266 ( timetz PGNSP PGUID 12 f b D f t \054 0 0 1270 timetz_in timetz_out timetz_recv timetz_send timetztypmodin timetztypmodout - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("time of day with time zone"); -#define TIMETZOID 1266 -DATA(insert OID = 1270 ( _timetz PGNSP PGUID -1 f b A f t \054 0 1266 0 array_in array_out array_recv array_send timetztypmodin timetztypmodout array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* OIDS 1500 - 1599 */ -DATA(insert OID = 1560 ( bit PGNSP PGUID -1 f b V f t \054 0 0 1561 bit_in bit_out bit_recv bit_send bittypmodin bittypmodout - i x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("fixed-length bit string"); -#define BITOID 1560 -DATA(insert OID = 1561 ( _bit PGNSP PGUID -1 f b A f t \054 0 1560 0 array_in array_out array_recv array_send bittypmodin bittypmodout array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 1562 ( varbit PGNSP PGUID -1 f b V t t \054 0 0 1563 varbit_in varbit_out varbit_recv varbit_send varbittypmodin varbittypmodout - i x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("variable-length bit string"); -#define VARBITOID 1562 -DATA(insert OID = 1563 ( _varbit PGNSP PGUID -1 f b A f t \054 0 1562 0 array_in array_out array_recv array_send varbittypmodin varbittypmodout array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* OIDS 1600 - 1699 */ - -/* OIDS 1700 - 1799 */ -DATA(insert OID = 1700 ( numeric PGNSP PGUID -1 f b N f t \054 0 0 1231 numeric_in numeric_out numeric_recv numeric_send numerictypmodin numerictypmodout - i m f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("numeric(precision, decimal), arbitrary precision number"); -#define NUMERICOID 1700 - -DATA(insert OID = 1790 ( refcursor PGNSP PGUID -1 f b U f t \054 0 0 2201 textin textout textrecv textsend - - - i x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("reference to cursor (portal name)"); -#define REFCURSOROID 1790 - -/* OIDS 2200 - 2299 */ -DATA(insert OID = 2201 ( _refcursor PGNSP PGUID -1 f b A f t \054 0 1790 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); - -DATA(insert OID = 2202 ( regprocedure PGNSP PGUID 4 t b N f t \054 0 0 2207 regprocedurein regprocedureout regprocedurerecv regproceduresend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("registered procedure (with args)"); -#define REGPROCEDUREOID 2202 - -DATA(insert OID = 2203 ( regoper PGNSP PGUID 4 t b N f t \054 0 0 2208 regoperin regoperout regoperrecv regopersend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("registered operator"); -#define REGOPEROID 2203 - -DATA(insert OID = 2204 ( regoperator PGNSP PGUID 4 t b N f t \054 0 0 2209 regoperatorin regoperatorout regoperatorrecv regoperatorsend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("registered operator (with args)"); -#define REGOPERATOROID 2204 - -DATA(insert OID = 2205 ( regclass PGNSP PGUID 4 t b N f t \054 0 0 2210 regclassin regclassout regclassrecv regclasssend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("registered class"); -#define REGCLASSOID 2205 - -DATA(insert OID = 2206 ( regtype PGNSP PGUID 4 t b N f t \054 0 0 2211 regtypein regtypeout regtyperecv regtypesend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("registered type"); -#define REGTYPEOID 2206 - -DATA(insert OID = 4096 ( regrole PGNSP PGUID 4 t b N f t \054 0 0 4097 regrolein regroleout regrolerecv regrolesend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("registered role"); -#define REGROLEOID 4096 - -DATA(insert OID = 4089 ( regnamespace PGNSP PGUID 4 t b N f t \054 0 0 4090 regnamespacein regnamespaceout regnamespacerecv regnamespacesend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("registered namespace"); -#define REGNAMESPACEOID 4089 - -DATA(insert OID = 2207 ( _regprocedure PGNSP PGUID -1 f b A f t \054 0 2202 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 2208 ( _regoper PGNSP PGUID -1 f b A f t \054 0 2203 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 2209 ( _regoperator PGNSP PGUID -1 f b A f t \054 0 2204 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 2210 ( _regclass PGNSP PGUID -1 f b A f t \054 0 2205 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 2211 ( _regtype PGNSP PGUID -1 f b A f t \054 0 2206 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -#define REGTYPEARRAYOID 2211 -DATA(insert OID = 4097 ( _regrole PGNSP PGUID -1 f b A f t \054 0 4096 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 4090 ( _regnamespace PGNSP PGUID -1 f b A f t \054 0 4089 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* uuid */ -DATA(insert OID = 2950 ( uuid PGNSP PGUID 16 f b U f t \054 0 0 2951 uuid_in uuid_out uuid_recv uuid_send - - - c p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("UUID datatype"); -#define UUIDOID 2950 -DATA(insert OID = 2951 ( _uuid PGNSP PGUID -1 f b A f t \054 0 2950 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* pg_lsn */ -DATA(insert OID = 3220 ( pg_lsn PGNSP PGUID 8 FLOAT8PASSBYVAL b U f t \054 0 0 3221 pg_lsn_in pg_lsn_out pg_lsn_recv pg_lsn_send - - - d p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("PostgreSQL LSN datatype"); -#define LSNOID 3220 -DATA(insert OID = 3221 ( _pg_lsn PGNSP PGUID -1 f b A f t \054 0 3220 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* text search */ -DATA(insert OID = 3614 ( tsvector PGNSP PGUID -1 f b U f t \054 0 0 3643 tsvectorin tsvectorout tsvectorrecv tsvectorsend - - ts_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("text representation for text search"); -#define TSVECTOROID 3614 -DATA(insert OID = 3642 ( gtsvector PGNSP PGUID -1 f b U f t \054 0 0 3644 gtsvectorin gtsvectorout - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("GiST index internal text representation for text search"); -#define GTSVECTOROID 3642 -DATA(insert OID = 3615 ( tsquery PGNSP PGUID -1 f b U f t \054 0 0 3645 tsqueryin tsqueryout tsqueryrecv tsquerysend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("query representation for text search"); -#define TSQUERYOID 3615 -DATA(insert OID = 3734 ( regconfig PGNSP PGUID 4 t b N f t \054 0 0 3735 regconfigin regconfigout regconfigrecv regconfigsend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("registered text search configuration"); -#define REGCONFIGOID 3734 -DATA(insert OID = 3769 ( regdictionary PGNSP PGUID 4 t b N f t \054 0 0 3770 regdictionaryin regdictionaryout regdictionaryrecv regdictionarysend - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("registered text search dictionary"); -#define REGDICTIONARYOID 3769 - -DATA(insert OID = 3643 ( _tsvector PGNSP PGUID -1 f b A f t \054 0 3614 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 3644 ( _gtsvector PGNSP PGUID -1 f b A f t \054 0 3642 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 3645 ( _tsquery PGNSP PGUID -1 f b A f t \054 0 3615 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 3735 ( _regconfig PGNSP PGUID -1 f b A f t \054 0 3734 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 3770 ( _regdictionary PGNSP PGUID -1 f b A f t \054 0 3769 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* jsonb */ -DATA(insert OID = 3802 ( jsonb PGNSP PGUID -1 f b U f t \054 0 0 3807 jsonb_in jsonb_out jsonb_recv jsonb_send - - - i x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("Binary JSON"); -#define JSONBOID 3802 -DATA(insert OID = 3807 ( _jsonb PGNSP PGUID -1 f b A f t \054 0 3802 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); - -DATA(insert OID = 2970 ( txid_snapshot PGNSP PGUID -1 f b U f t \054 0 0 2949 txid_snapshot_in txid_snapshot_out txid_snapshot_recv txid_snapshot_send - - - d x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("txid snapshot"); -DATA(insert OID = 2949 ( _txid_snapshot PGNSP PGUID -1 f b A f t \054 0 2970 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* range types */ -DATA(insert OID = 3904 ( int4range PGNSP PGUID -1 f r R f t \054 0 0 3905 range_in range_out range_recv range_send - - range_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("range of integers"); -#define INT4RANGEOID 3904 -DATA(insert OID = 3905 ( _int4range PGNSP PGUID -1 f b A f t \054 0 3904 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 3906 ( numrange PGNSP PGUID -1 f r R f t \054 0 0 3907 range_in range_out range_recv range_send - - range_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("range of numerics"); -DATA(insert OID = 3907 ( _numrange PGNSP PGUID -1 f b A f t \054 0 3906 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 3908 ( tsrange PGNSP PGUID -1 f r R f t \054 0 0 3909 range_in range_out range_recv range_send - - range_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("range of timestamps without time zone"); -DATA(insert OID = 3909 ( _tsrange PGNSP PGUID -1 f b A f t \054 0 3908 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 3910 ( tstzrange PGNSP PGUID -1 f r R f t \054 0 0 3911 range_in range_out range_recv range_send - - range_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("range of timestamps with time zone"); -DATA(insert OID = 3911 ( _tstzrange PGNSP PGUID -1 f b A f t \054 0 3910 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 3912 ( daterange PGNSP PGUID -1 f r R f t \054 0 0 3913 range_in range_out range_recv range_send - - range_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("range of dates"); -DATA(insert OID = 3913 ( _daterange PGNSP PGUID -1 f b A f t \054 0 3912 0 array_in array_out array_recv array_send - - array_typanalyze i x f 0 -1 0 0 _null_ _null_ _null_ )); -DATA(insert OID = 3926 ( int8range PGNSP PGUID -1 f r R f t \054 0 0 3927 range_in range_out range_recv range_send - - range_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -DESCR("range of bigints"); -DATA(insert OID = 3927 ( _int8range PGNSP PGUID -1 f b A f t \054 0 3926 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); - -/* - * pseudo-types - * - * types with typtype='p' represent various special cases in the type system. - * - * These cannot be used to define table columns, but are valid as function - * argument and result types (if supported by the function's implementation - * language). - * - * Note: cstring is a borderline case; it is still considered a pseudo-type, - * but there is now support for it in records and arrays. Perhaps we should - * just treat it as a regular base type? - */ -DATA(insert OID = 2249 ( record PGNSP PGUID -1 f p P f t \054 0 0 2287 record_in record_out record_recv record_send - - - d x f 0 -1 0 0 _null_ _null_ _null_ )); -#define RECORDOID 2249 -DATA(insert OID = 2287 ( _record PGNSP PGUID -1 f p P f t \054 0 2249 0 array_in array_out array_recv array_send - - array_typanalyze d x f 0 -1 0 0 _null_ _null_ _null_ )); -#define RECORDARRAYOID 2287 -DATA(insert OID = 2275 ( cstring PGNSP PGUID -2 f p P f t \054 0 0 1263 cstring_in cstring_out cstring_recv cstring_send - - - c p f 0 -1 0 0 _null_ _null_ _null_ )); -#define CSTRINGOID 2275 -DATA(insert OID = 2276 ( any PGNSP PGUID 4 t p P f t \054 0 0 0 any_in any_out - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define ANYOID 2276 -DATA(insert OID = 2277 ( anyarray PGNSP PGUID -1 f p P f t \054 0 0 0 anyarray_in anyarray_out anyarray_recv anyarray_send - - - d x f 0 -1 0 0 _null_ _null_ _null_ )); -#define ANYARRAYOID 2277 -DATA(insert OID = 2278 ( void PGNSP PGUID 4 t p P f t \054 0 0 0 void_in void_out void_recv void_send - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define VOIDOID 2278 -DATA(insert OID = 2279 ( trigger PGNSP PGUID 4 t p P f t \054 0 0 0 trigger_in trigger_out - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define TRIGGEROID 2279 -DATA(insert OID = 3838 ( event_trigger PGNSP PGUID 4 t p P f t \054 0 0 0 event_trigger_in event_trigger_out - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define EVTTRIGGEROID 3838 -DATA(insert OID = 2280 ( language_handler PGNSP PGUID 4 t p P f t \054 0 0 0 language_handler_in language_handler_out - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define LANGUAGE_HANDLEROID 2280 -DATA(insert OID = 2281 ( internal PGNSP PGUID SIZEOF_POINTER t p P f t \054 0 0 0 internal_in internal_out - - - - - ALIGNOF_POINTER p f 0 -1 0 0 _null_ _null_ _null_ )); -#define INTERNALOID 2281 -DATA(insert OID = 2282 ( opaque PGNSP PGUID 4 t p P f t \054 0 0 0 opaque_in opaque_out - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define OPAQUEOID 2282 -DATA(insert OID = 2283 ( anyelement PGNSP PGUID 4 t p P f t \054 0 0 0 anyelement_in anyelement_out - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define ANYELEMENTOID 2283 -DATA(insert OID = 2776 ( anynonarray PGNSP PGUID 4 t p P f t \054 0 0 0 anynonarray_in anynonarray_out - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define ANYNONARRAYOID 2776 -DATA(insert OID = 3500 ( anyenum PGNSP PGUID 4 t p P f t \054 0 0 0 anyenum_in anyenum_out - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define ANYENUMOID 3500 -DATA(insert OID = 3115 ( fdw_handler PGNSP PGUID 4 t p P f t \054 0 0 0 fdw_handler_in fdw_handler_out - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define FDW_HANDLEROID 3115 -DATA(insert OID = 325 ( index_am_handler PGNSP PGUID 4 t p P f t \054 0 0 0 index_am_handler_in index_am_handler_out - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define INDEX_AM_HANDLEROID 325 -DATA(insert OID = 3310 ( tsm_handler PGNSP PGUID 4 t p P f t \054 0 0 0 tsm_handler_in tsm_handler_out - - - - - i p f 0 -1 0 0 _null_ _null_ _null_ )); -#define TSM_HANDLEROID 3310 -DATA(insert OID = 3831 ( anyrange PGNSP PGUID -1 f p P f t \054 0 0 0 anyrange_in anyrange_out - - - - - d x f 0 -1 0 0 _null_ _null_ _null_ )); -#define ANYRANGEOID 3831 - - -/* - * macros - */ -#define TYPTYPE_BASE 'b' /* base type (ordinary scalar type) */ -#define TYPTYPE_COMPOSITE 'c' /* composite (e.g., table's rowtype) */ -#define TYPTYPE_DOMAIN 'd' /* domain over another type */ -#define TYPTYPE_ENUM 'e' /* enumerated type */ -#define TYPTYPE_PSEUDO 'p' /* pseudo-type */ -#define TYPTYPE_RANGE 'r' /* range type */ - -#define TYPCATEGORY_INVALID '\0' /* not an allowed category */ -#define TYPCATEGORY_ARRAY 'A' -#define TYPCATEGORY_BOOLEAN 'B' -#define TYPCATEGORY_COMPOSITE 'C' -#define TYPCATEGORY_DATETIME 'D' -#define TYPCATEGORY_ENUM 'E' -#define TYPCATEGORY_GEOMETRIC 'G' -#define TYPCATEGORY_NETWORK 'I' /* think INET */ -#define TYPCATEGORY_NUMERIC 'N' -#define TYPCATEGORY_PSEUDOTYPE 'P' -#define TYPCATEGORY_RANGE 'R' -#define TYPCATEGORY_STRING 'S' -#define TYPCATEGORY_TIMESPAN 'T' -#define TYPCATEGORY_USER 'U' -#define TYPCATEGORY_BITSTRING 'V' /* er ... "varbit"? */ -#define TYPCATEGORY_UNKNOWN 'X' - -/* Is a type OID a polymorphic pseudotype? (Beware of multiple evaluation) */ -#define IsPolymorphicType(typid) \ - ((typid) == ANYELEMENTOID || \ - (typid) == ANYARRAYOID || \ - (typid) == ANYNONARRAYOID || \ - (typid) == ANYENUMOID || \ - (typid) == ANYRANGEOID) - -#endif /* PG_TYPE_H */ diff --git a/codegen/src/sqlstate.rs b/codegen/src/sqlstate.rs index 791e47b11..d21b92eec 100644 --- a/codegen/src/sqlstate.rs +++ b/codegen/src/sqlstate.rs @@ -1,18 +1,18 @@ use linked_hash_map::LinkedHashMap; -use phf_codegen; use std::fs::File; use std::io::{BufWriter, Write}; -use std::path::Path; -const ERRCODES_TXT: &'static str = include_str!("errcodes.txt"); +const ERRCODES_TXT: &str = include_str!("errcodes.txt"); -pub fn build(path: &Path) { - let mut file = BufWriter::new(File::create(path.join("error/sqlstate.rs")).unwrap()); +pub fn build() { + let mut file = BufWriter::new(File::create("../tokio-postgres/src/error/sqlstate.rs").unwrap()); let codes = parse_codes(); make_type(&mut file); + make_code(&codes, &mut file); make_consts(&codes, &mut file); + make_inner(&codes, &mut file); make_map(&codes, &mut file); } @@ -20,7 +20,7 @@ fn parse_codes() -> LinkedHashMap> { let mut codes = LinkedHashMap::new(); for line in ERRCODES_TXT.lines() { - if line.starts_with("#") || line.starts_with("Section") || line.trim().is_empty() { + if line.starts_with('#') || line.starts_with("Section") || line.trim().is_empty() { continue; } @@ -39,28 +39,53 @@ fn make_type(file: &mut BufWriter) { write!( file, "// Autogenerated file - DO NOT EDIT -use phf; -use std::borrow::Cow; /// A SQLSTATE error code #[derive(PartialEq, Eq, Clone, Debug)] -pub struct SqlState(Cow<'static, str>); +pub struct SqlState(Inner); impl SqlState {{ /// Creates a `SqlState` from its error code. pub fn from_code(s: &str) -> SqlState {{ match SQLSTATE_MAP.get(s) {{ Some(state) => state.clone(), - None => SqlState(Cow::Owned(s.to_string())), + None => SqlState(Inner::Other(s.into())), }} }} +" + ) + .unwrap(); +} +fn make_code(codes: &LinkedHashMap>, file: &mut BufWriter) { + write!( + file, + r#" /// Returns the error code corresponding to the `SqlState`. pub fn code(&self) -> &str {{ - &self.0 + match &self.0 {{"#, + ) + .unwrap(); + + for code in codes.keys() { + write!( + file, + r#" + Inner::E{code} => "{code}","#, + code = code, + ) + .unwrap(); + } + + write!( + file, + r#" + Inner::Other(code) => code, + }} }} -" - ).unwrap(); + "# + ) + .unwrap(); } fn make_consts(codes: &LinkedHashMap>, file: &mut BufWriter) { @@ -70,28 +95,58 @@ fn make_consts(codes: &LinkedHashMap>, file: &mut BufWriter< file, r#" /// {code} - pub const {name}: SqlState = SqlState(Cow::Borrowed("{code}")); + pub const {name}: SqlState = SqlState(Inner::E{code}); "#, name = name, code = code, - ).unwrap(); + ) + .unwrap(); } } write!(file, "}}").unwrap(); } -fn make_map(codes: &LinkedHashMap>, file: &mut BufWriter) { +fn make_inner(codes: &LinkedHashMap>, file: &mut BufWriter) { write!( file, - " -#[cfg_attr(rustfmt, rustfmt_skip)] -static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = " - ).unwrap(); + r#" + +#[derive(PartialEq, Eq, Clone, Debug)] +#[allow(clippy::upper_case_acronyms)] +enum Inner {{"#, + ) + .unwrap(); + for code in codes.keys() { + write!( + file, + r#" + E{},"#, + code, + ) + .unwrap(); + } + write!( + file, + r#" + Other(Box), +}} + "#, + ) + .unwrap(); +} + +fn make_map(codes: &LinkedHashMap>, file: &mut BufWriter) { let mut builder = phf_codegen::Map::new(); for (code, names) in codes { builder.entry(&**code, &format!("SqlState::{}", &names[0])); } - builder.build(file).unwrap(); - write!(file, ";\n").unwrap(); + write!( + file, + " +#[rustfmt::skip] +static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = \n{};\n", + builder.build() + ) + .unwrap(); } diff --git a/codegen/src/type_gen.rs b/codegen/src/type_gen.rs index 0f742d696..fd7a56450 100644 --- a/codegen/src/type_gen.rs +++ b/codegen/src/type_gen.rs @@ -1,29 +1,30 @@ use marksman_escape::Escape; use regex::Regex; -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; +use std::fmt::Write as _; use std::fs::File; use std::io::{BufWriter, Write}; -use std::path::Path; +use std::iter; +use std::str; -use snake_to_camel; +use crate::snake_to_camel; -const PG_TYPE_H: &'static str = include_str!("pg_type.h"); -const PG_RANGE_H: &'static str = include_str!("pg_range.h"); +const PG_TYPE_DAT: &str = include_str!("pg_type.dat"); +const PG_RANGE_DAT: &str = include_str!("pg_range.dat"); struct Type { - name: &'static str, + name: String, variant: String, ident: String, - kind: &'static str, + kind: String, + typtype: Option, element: u32, doc: String, } -pub fn build(path: &Path) { - let mut file = BufWriter::new(File::create(path.join("types/type_gen.rs")).unwrap()); - - let ranges = parse_ranges(); - let types = parse_types(&ranges); +pub fn build() { + let mut file = BufWriter::new(File::create("../postgres-types/src/type_gen.rs").unwrap()); + let types = parse_types(); make_header(&mut file); make_enum(&mut file, &types); @@ -31,85 +32,245 @@ pub fn build(path: &Path) { make_consts(&mut file, &types); } -fn parse_ranges() -> BTreeMap { - let mut ranges = BTreeMap::new(); +struct DatParser<'a> { + it: iter::Peekable>, + s: &'a str, +} - for line in PG_RANGE_H.lines() { - if !line.starts_with("DATA") { - continue; +impl<'a> DatParser<'a> { + fn new(s: &'a str) -> DatParser<'a> { + DatParser { + it: s.char_indices().peekable(), + s, } + } - let split = line.split_whitespace().collect::>(); + fn parse_array(&mut self) -> Vec> { + self.eat('['); + let mut vec = vec![]; + while !self.try_eat(']') { + let object = self.parse_object(); + vec.push(object); + } + self.eof(); - let oid = split[2].parse().unwrap(); - let element = split[3].parse().unwrap(); + vec + } + + fn parse_object(&mut self) -> HashMap { + let mut object = HashMap::new(); + + self.eat('{'); + loop { + let key = self.parse_ident(); + self.eat('='); + self.eat('>'); + let value = self.parse_string(); + object.insert(key, value); + if !self.try_eat(',') { + break; + } + } + self.eat('}'); + self.eat(','); - ranges.insert(oid, element); + object + } + + fn parse_ident(&mut self) -> String { + self.skip_ws(); + + let start = match self.it.peek() { + Some((i, _)) => *i, + None => return "".to_string(), + }; + + loop { + match self.it.peek() { + Some((_, 'a'..='z')) | Some((_, '_')) => { + self.it.next(); + } + Some((i, _)) => return self.s[start..*i].to_string(), + None => return self.s[start..].to_string(), + } + } + } + + fn parse_string(&mut self) -> String { + self.skip_ws(); + + let mut s = String::new(); + + self.eat('\''); + loop { + match self.it.next() { + Some((_, '\'')) => return s, + Some((_, '\\')) => { + let (_, ch) = self.it.next().expect("unexpected eof"); + s.push(ch); + } + Some((_, ch)) => s.push(ch), + None => panic!("unexpected eof"), + } + } } - ranges + fn eat(&mut self, target: char) { + self.skip_ws(); + + match self.it.next() { + Some((_, ch)) if ch == target => {} + Some((_, ch)) => panic!("expected {} but got {}", target, ch), + None => panic!("expected {} but got eof", target), + } + } + + fn try_eat(&mut self, target: char) -> bool { + if self.peek(target) { + self.eat(target); + true + } else { + false + } + } + + fn peek(&mut self, target: char) -> bool { + self.skip_ws(); + + matches!(self.it.peek(), Some((_, ch)) if *ch == target) + } + + fn eof(&mut self) { + self.skip_ws(); + if let Some((_, ch)) = self.it.next() { + panic!("expected eof but got {}", ch); + } + } + + fn skip_ws(&mut self) { + loop { + match self.it.peek() { + Some(&(_, '#')) => self.skip_to('\n'), + Some(&(_, '\n')) | Some(&(_, ' ')) | Some(&(_, '\t')) => { + self.it.next(); + } + _ => break, + } + } + } + + fn skip_to(&mut self, target: char) { + for (_, ch) in &mut self.it { + if ch == target { + break; + } + } + } } -fn parse_types(ranges: &BTreeMap) -> BTreeMap { - let doc_re = Regex::new(r#"DESCR\("([^"]+)"\)"#).unwrap(); +fn parse_types() -> BTreeMap { + let raw_types = DatParser::new(PG_TYPE_DAT).parse_array(); + let raw_ranges = DatParser::new(PG_RANGE_DAT).parse_array(); + + let oids_by_name = raw_types + .iter() + .map(|m| (m["typname"].clone(), m["oid"].parse::().unwrap())) + .collect::>(); + + let range_elements = raw_ranges + .iter() + .map(|m| { + ( + oids_by_name[&*m["rngtypid"]], + oids_by_name[&*m["rngsubtype"]], + ) + }) + .collect::>(); + let multi_range_elements = raw_ranges + .iter() + .map(|m| { + ( + oids_by_name[&*m["rngmultitypid"]], + oids_by_name[&*m["rngsubtype"]], + ) + }) + .collect::>(); + let range_vector_re = Regex::new("(range|vector)$").unwrap(); let array_re = Regex::new("^_(.*)").unwrap(); let mut types = BTreeMap::new(); - let mut lines = PG_TYPE_H.lines().peekable(); - while let Some(line) = lines.next() { - if !line.starts_with("DATA") { - continue; - } - - let split = line.split_whitespace().collect::>(); - - let oid = split[3].parse().unwrap(); + for raw_type in raw_types { + let oid = raw_type["oid"].parse::().unwrap(); - let name = split[5]; + let name = raw_type["typname"].clone(); - let ident = range_vector_re.replace(name, "_$1"); - let ident = array_re.replace(&ident, "$1_array"); + let ident = range_vector_re.replace(&name, "_$1"); + let ident = array_re.replace(&ident, "${1}_array"); let variant = snake_to_camel(&ident); let ident = ident.to_ascii_uppercase(); - let kind = split[11]; + let kind = raw_type["typcategory"].clone(); // we need to be able to pull composite fields and enum variants at runtime if kind == "C" || kind == "E" { continue; } - let element = if let Some(&element) = ranges.get(&oid) { - element - } else { - split[16].parse().unwrap() + let typtype = raw_type.get("typtype").cloned(); + + let element = match &*kind { + "R" => match typtype + .as_ref() + .expect("range type must have typtype") + .as_str() + { + "r" => range_elements[&oid], + "m" => multi_range_elements[&oid], + typtype => panic!("invalid range typtype {}", typtype), + }, + "A" => oids_by_name[&raw_type["typelem"]], + _ => 0, }; - let doc = array_re.replace(name, "$1[]"); - let mut doc = doc.to_ascii_uppercase(); - - let descr = lines - .peek() - .and_then(|line| doc_re.captures(line)) - .and_then(|captures| captures.at(1)); - if let Some(descr) = descr { - doc.push_str(" - "); - doc.push_str(descr); + let doc_name = array_re.replace(&name, "$1[]").to_ascii_uppercase(); + let mut doc = doc_name.clone(); + if let Some(descr) = raw_type.get("descr") { + write!(doc, " - {}", descr).unwrap(); } let doc = Escape::new(doc.as_bytes().iter().cloned()).collect(); let doc = String::from_utf8(doc).unwrap(); + if let Some(array_type_oid) = raw_type.get("array_type_oid") { + let array_type_oid = array_type_oid.parse::().unwrap(); + + let name = format!("_{}", name); + let variant = format!("{}Array", variant); + let doc = format!("{}[]", doc_name); + let ident = format!("{}_ARRAY", ident); + + let type_ = Type { + name, + variant, + ident, + kind: "A".to_string(), + typtype: None, + element: oid, + doc, + }; + types.insert(array_type_oid, type_); + } + let type_ = Type { name, variant, ident, kind, + typtype, element, doc, }; - types.insert(oid, type_); } @@ -122,9 +283,9 @@ fn make_header(w: &mut BufWriter) { "// Autogenerated file - DO NOT EDIT use std::sync::Arc; -use types::{{Type, Oid, Kind}}; +use crate::{{Type, Oid, Kind}}; -#[derive(PartialEq, Eq, Debug)] +#[derive(PartialEq, Eq, Debug, Hash)] pub struct Other {{ pub name: String, pub oid: Oid, @@ -132,16 +293,18 @@ pub struct Other {{ pub schema: String, }} " - ).unwrap(); + ) + .unwrap(); } fn make_enum(w: &mut BufWriter, types: &BTreeMap) { write!( w, " -#[derive(PartialEq, Eq, Clone, Debug)] +#[derive(PartialEq, Eq, Clone, Debug, Hash)] pub enum Inner {{" - ).unwrap(); + ) + .unwrap(); for type_ in types.values() { write!( @@ -149,7 +312,8 @@ pub enum Inner {{" " {},", type_.variant - ).unwrap(); + ) + .unwrap(); } write!( @@ -159,7 +323,8 @@ pub enum Inner {{" }} " - ).unwrap(); + ) + .unwrap(); } fn make_impl(w: &mut BufWriter, types: &BTreeMap) { @@ -169,110 +334,109 @@ fn make_impl(w: &mut BufWriter, types: &BTreeMap) { pub fn from_oid(oid: Oid) -> Option {{ match oid {{ ", - ).unwrap(); + ) + .unwrap(); for (oid, type_) in types { - write!( - w, - " {} => Some(Inner::{}), -", - oid, type_.variant - ).unwrap(); + writeln!(w, " {} => Some(Inner::{}),", oid, type_.variant).unwrap(); } - write!( + writeln!( w, " _ => None, }} }} pub fn oid(&self) -> Oid {{ - match *self {{ -", - ).unwrap(); + match *self {{", + ) + .unwrap(); for (oid, type_) in types { - write!( - w, - " Inner::{} => {}, -", - type_.variant, oid - ).unwrap(); + writeln!(w, " Inner::{} => {},", type_.variant, oid).unwrap(); } - write!( + writeln!( w, " Inner::Other(ref u) => u.oid, }} }} pub fn kind(&self) -> &Kind {{ - match *self {{ -", - ).unwrap(); + match *self {{", + ) + .unwrap(); for type_ in types.values() { - let kind = match type_.kind { + let kind = match &*type_.kind { "P" => "Pseudo".to_owned(), "A" => format!("Array(Type(Inner::{}))", types[&type_.element].variant), - "R" => format!("Range(Type(Inner::{}))", types[&type_.element].variant), + "R" => match type_ + .typtype + .as_ref() + .expect("range type must have typtype") + .as_str() + { + "r" => format!("Range(Type(Inner::{}))", types[&type_.element].variant), + "m" => format!("Multirange(Type(Inner::{}))", types[&type_.element].variant), + typtype => panic!("invalid range typtype {}", typtype), + }, _ => "Simple".to_owned(), }; - write!( + writeln!( w, " Inner::{} => {{ - const V: &'static Kind = &Kind::{}; - V - }} -", + &Kind::{} + }}", type_.variant, kind - ).unwrap(); + ) + .unwrap(); } - write!( + writeln!( w, r#" Inner::Other(ref u) => &u.kind, }} }} pub fn name(&self) -> &str {{ - match *self {{ -"#, - ).unwrap(); + match *self {{"#, + ) + .unwrap(); for type_ in types.values() { - write!( + writeln!( w, - r#" Inner::{} => "{}", -"#, + r#" Inner::{} => "{}","#, type_.variant, type_.name - ).unwrap(); + ) + .unwrap(); } - write!( + writeln!( w, " Inner::Other(ref u) => &u.name, }} }} -}} -" - ).unwrap(); +}}" + ) + .unwrap(); } fn make_consts(w: &mut BufWriter, types: &BTreeMap) { write!(w, "impl Type {{").unwrap(); for type_ in types.values() { - write!( + writeln!( w, " /// {docs} - pub const {ident}: Type = Type(Inner::{variant}); -", + pub const {ident}: Type = Type(Inner::{variant});", docs = type_.doc, ident = type_.ident, variant = type_.variant - ).unwrap(); + ) + .unwrap(); } write!(w, "}}").unwrap(); diff --git a/docker-compose.yml b/docker-compose.yml index c149834e2..991df2d01 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,6 +1,10 @@ version: '2' services: postgres: - image: "sfackler/rust-postgres-test:4" + image: docker.io/postgres:17 ports: - - 5433:5433 + - 5433:5433 + volumes: + - ./docker/sql_setup.sh:/docker-entrypoint-initdb.d/sql_setup.sh + environment: + POSTGRES_PASSWORD: postgres diff --git a/docker/Dockerfile b/docker/Dockerfile deleted file mode 100644 index 9e2642ba1..000000000 --- a/docker/Dockerfile +++ /dev/null @@ -1,3 +0,0 @@ -FROM postgres:11-beta1 - -COPY sql_setup.sh /docker-entrypoint-initdb.d/ diff --git a/docker/sql_setup.sh b/docker/sql_setup.sh index 422dcbda9..0315ac805 100755 --- a/docker/sql_setup.sh +++ b/docker/sql_setup.sh @@ -96,4 +96,5 @@ psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" <<-EOSQL CREATE ROLE ssl_user LOGIN; CREATE EXTENSION hstore; CREATE EXTENSION citext; + CREATE EXTENSION ltree; EOSQL diff --git a/postgres-derive-test/Cargo.toml b/postgres-derive-test/Cargo.toml new file mode 100644 index 000000000..24fd1614f --- /dev/null +++ b/postgres-derive-test/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "postgres-derive-test" +version = "0.1.0" +authors = ["Steven Fackler "] +edition = "2018" + +[dev-dependencies] +trybuild = "1.0" + +postgres-types = { path = "../postgres-types", features = ["derive"] } +postgres = { path = "../postgres" } diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs new file mode 100644 index 000000000..52d0ba8f6 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs @@ -0,0 +1,31 @@ +use postgres_types::{FromSql, ToSql}; + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchStruct { + a: i32, +} + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchStruct { + a: i32, +} + +#[derive(ToSql, Debug)] +#[postgres(allow_mismatch)] +struct ToSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch)] +struct FromSqlAllowMismatchTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(transparent, allow_mismatch)] +struct TransparentFromSqlAllowMismatchStruct(i32); + +#[derive(FromSql, Debug)] +#[postgres(allow_mismatch, transparent)] +struct AllowMismatchFromSqlTransparentStruct(i32); + +fn main() {} diff --git a/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr new file mode 100644 index 000000000..a8e573248 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-allow-mismatch.stderr @@ -0,0 +1,43 @@ +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:4:1 + | +4 | / #[postgres(allow_mismatch)] +5 | | struct ToSqlAllowMismatchStruct { +6 | | a: i32, +7 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:10:1 + | +10 | / #[postgres(allow_mismatch)] +11 | | struct FromSqlAllowMismatchStruct { +12 | | a: i32, +13 | | } + | |_^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:16:1 + | +16 | / #[postgres(allow_mismatch)] +17 | | struct ToSqlAllowMismatchTupleStruct(i32, i32); + | |_______________________________________________^ + +error: #[postgres(allow_mismatch)] may only be applied to enums + --> src/compile-fail/invalid-allow-mismatch.rs:20:1 + | +20 | / #[postgres(allow_mismatch)] +21 | | struct FromSqlAllowMismatchTupleStruct(i32, i32); + | |_________________________________________________^ + +error: #[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)] + --> src/compile-fail/invalid-allow-mismatch.rs:24:25 + | +24 | #[postgres(transparent, allow_mismatch)] + | ^^^^^^^^^^^^^^ + +error: #[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)] + --> src/compile-fail/invalid-allow-mismatch.rs:28:28 + | +28 | #[postgres(allow_mismatch, transparent)] + | ^^^^^^^^^^^ diff --git a/postgres-derive-test/src/compile-fail/invalid-transparent.rs b/postgres-derive-test/src/compile-fail/invalid-transparent.rs new file mode 100644 index 000000000..43bd48266 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-transparent.rs @@ -0,0 +1,35 @@ +use postgres_types::{FromSql, ToSql}; + +#[derive(ToSql, Debug)] +#[postgres(transparent)] +struct ToSqlTransparentStruct { + a: i32 +} + +#[derive(FromSql, Debug)] +#[postgres(transparent)] +struct FromSqlTransparentStruct { + a: i32 +} + +#[derive(ToSql, Debug)] +#[postgres(transparent)] +enum ToSqlTransparentEnum { + Foo +} + +#[derive(FromSql, Debug)] +#[postgres(transparent)] +enum FromSqlTransparentEnum { + Foo +} + +#[derive(ToSql, Debug)] +#[postgres(transparent)] +struct ToSqlTransparentTwoFieldTupleStruct(i32, i32); + +#[derive(FromSql, Debug)] +#[postgres(transparent)] +struct FromSqlTransparentTwoFieldTupleStruct(i32, i32); + +fn main() {} diff --git a/postgres-derive-test/src/compile-fail/invalid-transparent.stderr b/postgres-derive-test/src/compile-fail/invalid-transparent.stderr new file mode 100644 index 000000000..42e49f874 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-transparent.stderr @@ -0,0 +1,49 @@ +error: #[postgres(transparent)] may only be applied to single field tuple structs + --> src/compile-fail/invalid-transparent.rs:4:1 + | +4 | / #[postgres(transparent)] +5 | | struct ToSqlTransparentStruct { +6 | | a: i32 +7 | | } + | |_^ + +error: #[postgres(transparent)] may only be applied to single field tuple structs + --> src/compile-fail/invalid-transparent.rs:10:1 + | +10 | / #[postgres(transparent)] +11 | | struct FromSqlTransparentStruct { +12 | | a: i32 +13 | | } + | |_^ + +error: #[postgres(transparent)] may only be applied to single field tuple structs + --> src/compile-fail/invalid-transparent.rs:16:1 + | +16 | / #[postgres(transparent)] +17 | | enum ToSqlTransparentEnum { +18 | | Foo +19 | | } + | |_^ + +error: #[postgres(transparent)] may only be applied to single field tuple structs + --> src/compile-fail/invalid-transparent.rs:22:1 + | +22 | / #[postgres(transparent)] +23 | | enum FromSqlTransparentEnum { +24 | | Foo +25 | | } + | |_^ + +error: #[postgres(transparent)] may only be applied to single field tuple structs + --> src/compile-fail/invalid-transparent.rs:28:1 + | +28 | / #[postgres(transparent)] +29 | | struct ToSqlTransparentTwoFieldTupleStruct(i32, i32); + | |_____________________________________________________^ + +error: #[postgres(transparent)] may only be applied to single field tuple structs + --> src/compile-fail/invalid-transparent.rs:32:1 + | +32 | / #[postgres(transparent)] +33 | | struct FromSqlTransparentTwoFieldTupleStruct(i32, i32); + | |_______________________________________________________^ diff --git a/postgres-derive-test/src/compile-fail/invalid-types.rs b/postgres-derive-test/src/compile-fail/invalid-types.rs new file mode 100644 index 000000000..ef41ac820 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-types.rs @@ -0,0 +1,25 @@ +use postgres_types::{FromSql, ToSql}; + +#[derive(ToSql)] +struct ToSqlUnit; + +#[derive(FromSql)] +struct FromSqlUnit; + +#[derive(ToSql)] +struct ToSqlTuple(i32, i32); + +#[derive(FromSql)] +struct FromSqlTuple(i32, i32); + +#[derive(ToSql)] +enum ToSqlEnum { + Foo(i32), +} + +#[derive(FromSql)] +enum FromSqlEnum { + Foo(i32), +} + +fn main() {} diff --git a/postgres-derive-test/src/compile-fail/invalid-types.stderr b/postgres-derive-test/src/compile-fail/invalid-types.stderr new file mode 100644 index 000000000..9b563d58b --- /dev/null +++ b/postgres-derive-test/src/compile-fail/invalid-types.stderr @@ -0,0 +1,35 @@ +error: #[derive(ToSql)] may only be applied to structs, single field tuple structs, and enums + --> $DIR/invalid-types.rs:4:1 + | +4 | struct ToSqlUnit; + | ^^^^^^^^^^^^^^^^^ + +error: #[derive(FromSql)] may only be applied to structs, single field tuple structs, and enums + --> $DIR/invalid-types.rs:7:1 + | +7 | struct FromSqlUnit; + | ^^^^^^^^^^^^^^^^^^^ + +error: #[derive(ToSql)] may only be applied to structs, single field tuple structs, and enums + --> $DIR/invalid-types.rs:10:1 + | +10 | struct ToSqlTuple(i32, i32); + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error: #[derive(FromSql)] may only be applied to structs, single field tuple structs, and enums + --> $DIR/invalid-types.rs:13:1 + | +13 | struct FromSqlTuple(i32, i32); + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error: non-C-like enums are not supported + --> $DIR/invalid-types.rs:17:5 + | +17 | Foo(i32), + | ^^^^^^^^ + +error: non-C-like enums are not supported + --> $DIR/invalid-types.rs:22:5 + | +22 | Foo(i32), + | ^^^^^^^^ diff --git a/postgres-derive-test/src/compile-fail/unknown-override.rs b/postgres-derive-test/src/compile-fail/unknown-override.rs new file mode 100644 index 000000000..e4fffd540 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/unknown-override.rs @@ -0,0 +1,15 @@ +use postgres_types::{FromSql, ToSql}; + +#[derive(FromSql)] +#[postgres(foo = "bar")] +struct Foo { + a: i32, +} + +#[derive(ToSql)] +#[postgres(foo = "bar")] +struct Bar { + a: i32, +} + +fn main() {} diff --git a/postgres-derive-test/src/compile-fail/unknown-override.stderr b/postgres-derive-test/src/compile-fail/unknown-override.stderr new file mode 100644 index 000000000..b7719e3c2 --- /dev/null +++ b/postgres-derive-test/src/compile-fail/unknown-override.stderr @@ -0,0 +1,11 @@ +error: unknown override + --> $DIR/unknown-override.rs:4:12 + | +4 | #[postgres(foo = "bar")] + | ^^^ + +error: unknown override + --> $DIR/unknown-override.rs:10:12 + | +10 | #[postgres(foo = "bar")] + | ^^^ diff --git a/postgres-derive-test/src/composites.rs b/postgres-derive-test/src/composites.rs new file mode 100644 index 000000000..50a22790d --- /dev/null +++ b/postgres-derive-test/src/composites.rs @@ -0,0 +1,348 @@ +use crate::{test_type, test_type_asymmetric}; +use postgres::{Client, NoTls}; +use postgres_types::{FromSql, ToSql, WrongType}; +use std::error::Error; + +#[test] +fn defaults() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + struct InventoryItem { + name: String, + supplier_id: i32, + price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.\"InventoryItem\" AS ( + name TEXT, + supplier_id INT, + price DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: Some(15.50), + }; + + let item_null = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: None, + }; + + test_type( + &mut conn, + "\"InventoryItem\"", + &[ + (item, "ROW('foobar', 100, 15.50)"), + (item_null, "ROW('foobar', 100, NULL)"), + ], + ); +} + +#[test] +fn name_overrides() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "inventory_item")] + struct InventoryItem { + #[postgres(name = "name")] + _name: String, + #[postgres(name = "supplier_id")] + _supplier_id: i32, + #[postgres(name = "price")] + _price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + name TEXT, + supplier_id INT, + price DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + _name: "foobar".to_owned(), + _supplier_id: 100, + _price: Some(15.50), + }; + + let item_null = InventoryItem { + _name: "foobar".to_owned(), + _supplier_id: 100, + _price: None, + }; + + test_type( + &mut conn, + "inventory_item", + &[ + (item, "ROW('foobar', 100, 15.50)"), + (item_null, "ROW('foobar', 100, NULL)"), + ], + ); +} + +#[test] +fn rename_all_overrides() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "inventory_item", rename_all = "SCREAMING_SNAKE_CASE")] + struct InventoryItem { + name: String, + supplier_id: i32, + #[postgres(name = "Price")] + price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + \"NAME\" TEXT, + \"SUPPLIER_ID\" INT, + \"Price\" DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: Some(15.50), + }; + + let item_null = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: None, + }; + + test_type( + &mut conn, + "inventory_item", + &[ + (item, "ROW('foobar', 100, 15.50)"), + (item_null, "ROW('foobar', 100, NULL)"), + ], + ); +} + +#[test] +fn wrong_name() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + struct InventoryItem { + name: String, + supplier_id: i32, + price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + name TEXT, + supplier_id INT, + price DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: Some(15.50), + }; + + let err = conn + .execute("SELECT $1::inventory_item", &[&item]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn extra_field() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "inventory_item")] + struct InventoryItem { + name: String, + supplier_id: i32, + price: Option, + foo: i32, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + name TEXT, + supplier_id INT, + price DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: Some(15.50), + foo: 0, + }; + + let err = conn + .execute("SELECT $1::inventory_item", &[&item]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn missing_field() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "inventory_item")] + struct InventoryItem { + name: String, + supplier_id: i32, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + name TEXT, + supplier_id INT, + price DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + }; + + let err = conn + .execute("SELECT $1::inventory_item", &[&item]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn wrong_type() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "inventory_item")] + struct InventoryItem { + name: String, + supplier_id: i32, + price: i32, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + name TEXT, + supplier_id INT, + price DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: 0, + }; + + let err = conn + .execute("SELECT $1::inventory_item", &[&item]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn raw_ident_field() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "inventory_item")] + struct InventoryItem { + r#type: String, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + type TEXT + )", + ) + .unwrap(); + + let item = InventoryItem { + r#type: "foo".to_owned(), + }; + + test_type(&mut conn, "inventory_item", &[(item, "ROW('foo')")]); +} + +#[test] +fn generics() { + #[derive(FromSql, Debug, PartialEq)] + struct InventoryItem + where + U: Clone, + { + name: String, + supplier_id: T, + price: Option, + } + + // doesn't make sense to implement derived FromSql on a type with borrows + #[derive(ToSql, Debug, PartialEq)] + #[postgres(name = "InventoryItem")] + struct InventoryItemRef<'a, T: 'a + Clone, U> + where + U: 'a + Clone, + { + name: &'a str, + supplier_id: &'a T, + price: Option<&'a U>, + } + + const NAME: &str = "foobar"; + const SUPPLIER_ID: i32 = 100; + const PRICE: f64 = 15.50; + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.\"InventoryItem\" AS ( + name TEXT, + supplier_id INT, + price DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItemRef { + name: NAME, + supplier_id: &SUPPLIER_ID, + price: Some(&PRICE), + }; + + let item_null = InventoryItemRef { + name: NAME, + supplier_id: &SUPPLIER_ID, + price: None, + }; + + test_type_asymmetric( + &mut conn, + "\"InventoryItem\"", + &[ + (item, "ROW('foobar', 100, 15.50)"), + (item_null, "ROW('foobar', 100, NULL)"), + ], + |t: &InventoryItemRef, f: &InventoryItem| { + t.name == f.name.as_str() + && t.supplier_id == &f.supplier_id + && t.price == f.price.as_ref() + }, + ); +} diff --git a/postgres-derive-test/src/domains.rs b/postgres-derive-test/src/domains.rs new file mode 100644 index 000000000..25674f75e --- /dev/null +++ b/postgres-derive-test/src/domains.rs @@ -0,0 +1,121 @@ +use crate::test_type; +use postgres::{Client, NoTls}; +use postgres_types::{FromSql, ToSql, WrongType}; +use std::error::Error; + +#[test] +fn defaults() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + struct SessionId(Vec); + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE DOMAIN pg_temp.\"SessionId\" AS bytea CHECK(octet_length(VALUE) = 16);", + &[], + ) + .unwrap(); + + test_type( + &mut conn, + "\"SessionId\"", + &[( + SessionId(b"0123456789abcdef".to_vec()), + "'0123456789abcdef'", + )], + ); +} + +#[test] +fn name_overrides() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "session_id")] + struct SessionId(Vec); + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);", + &[], + ) + .unwrap(); + + test_type( + &mut conn, + "session_id", + &[( + SessionId(b"0123456789abcdef".to_vec()), + "'0123456789abcdef'", + )], + ); +} + +#[test] +fn wrong_name() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + struct SessionId(Vec); + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);", + &[], + ) + .unwrap(); + + let err = conn + .execute("SELECT $1::session_id", &[&SessionId(vec![])]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn wrong_type() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "session_id")] + struct SessionId(i32); + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16);", + &[], + ) + .unwrap(); + + let err = conn + .execute("SELECT $1::session_id", &[&SessionId(0)]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn domain_in_composite() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "domain")] + struct Domain(String); + + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "composite")] + struct Composite { + domain: Domain, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + " + CREATE DOMAIN pg_temp.domain AS TEXT;\ + CREATE TYPE pg_temp.composite AS ( + domain domain + ); + ", + ) + .unwrap(); + + test_type( + &mut conn, + "composite", + &[( + Composite { + domain: Domain("hello".to_string()), + }, + "ROW('hello')", + )], + ); +} diff --git a/postgres-derive-test/src/enums.rs b/postgres-derive-test/src/enums.rs new file mode 100644 index 000000000..f3e6c488c --- /dev/null +++ b/postgres-derive-test/src/enums.rs @@ -0,0 +1,203 @@ +use crate::test_type; +use postgres::{error::DbError, Client, NoTls}; +use postgres_types::{FromSql, ToSql, WrongType}; +use std::error::Error; + +#[test] +fn defaults() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + enum Foo { + Bar, + Baz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + test_type( + &mut conn, + "\"Foo\"", + &[(Foo::Bar, "'Bar'"), (Foo::Baz, "'Baz'")], + ); +} + +#[test] +fn name_overrides() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "mood")] + enum Mood { + #[postgres(name = "sad")] + Sad, + #[postgres(name = "ok")] + Ok, + #[postgres(name = "happy")] + Happy, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy')", + &[], + ) + .unwrap(); + + test_type( + &mut conn, + "mood", + &[ + (Mood::Sad, "'sad'"), + (Mood::Ok, "'ok'"), + (Mood::Happy, "'happy'"), + ], + ); +} + +#[test] +fn rename_all_overrides() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "mood", rename_all = "snake_case")] + enum Mood { + VerySad, + #[postgres(name = "okay")] + Ok, + VeryHappy, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE TYPE pg_temp.mood AS ENUM ('very_sad', 'okay', 'very_happy')", + &[], + ) + .unwrap(); + + test_type( + &mut conn, + "mood", + &[ + (Mood::VerySad, "'very_sad'"), + (Mood::Ok, "'okay'"), + (Mood::VeryHappy, "'very_happy'"), + ], + ); +} + +#[test] +fn wrong_name() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + enum Foo { + Bar, + Baz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn extra_variant() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "foo")] + enum Foo { + Bar, + Baz, + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn missing_variant() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "foo")] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn allow_mismatch_enums() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::\"Foo\"", &[&Foo::Bar]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Bar); +} + +#[test] +fn missing_enum_variant() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn + .query_one("SELECT $1::\"Foo\"", &[&Foo::Buz]) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn allow_mismatch_and_renaming() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "foo", allow_mismatch)] + enum Foo { + #[postgres(name = "bar")] + Bar, + #[postgres(name = "buz")] + Buz, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('bar', 'baz', 'buz')", &[]) + .unwrap(); + + let row = conn.query_one("SELECT $1::foo", &[&Foo::Buz]).unwrap(); + assert_eq!(row.get::<_, Foo>(0), Foo::Buz); +} + +#[test] +fn wrong_name_and_allow_mismatch() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(allow_mismatch)] + enum Foo { + Bar, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[]) + .unwrap(); + + let err = conn.query_one("SELECT $1::foo", &[&Foo::Bar]).unwrap_err(); + assert!(err.source().unwrap().is::()); +} diff --git a/postgres-derive-test/src/lib.rs b/postgres-derive-test/src/lib.rs new file mode 100644 index 000000000..f0534f32c --- /dev/null +++ b/postgres-derive-test/src/lib.rs @@ -0,0 +1,57 @@ +#![cfg(test)] + +use postgres::Client; +use postgres_types::{FromSqlOwned, ToSql}; +use std::fmt; + +mod composites; +mod domains; +mod enums; +mod transparent; + +pub fn test_type(conn: &mut Client, sql_type: &str, checks: &[(T, S)]) +where + T: PartialEq + FromSqlOwned + ToSql + Sync, + S: fmt::Display, +{ + for (val, repr) in checks.iter() { + let stmt = conn + .prepare(&format!("SELECT {}::{}", *repr, sql_type)) + .unwrap(); + let result = conn.query_one(&stmt, &[]).unwrap().get(0); + assert_eq!(val, &result); + + let stmt = conn.prepare(&format!("SELECT $1::{}", sql_type)).unwrap(); + let result = conn.query_one(&stmt, &[val]).unwrap().get(0); + assert_eq!(val, &result); + } +} + +pub fn test_type_asymmetric( + conn: &mut Client, + sql_type: &str, + checks: &[(T, S)], + cmp: C, +) where + T: ToSql + Sync, + F: FromSqlOwned, + S: fmt::Display, + C: Fn(&T, &F) -> bool, +{ + for (val, repr) in checks.iter() { + let stmt = conn + .prepare(&format!("SELECT {}::{}", *repr, sql_type)) + .unwrap(); + let result: F = conn.query_one(&stmt, &[]).unwrap().get(0); + assert!(cmp(val, &result)); + + let stmt = conn.prepare(&format!("SELECT $1::{}", sql_type)).unwrap(); + let result: F = conn.query_one(&stmt, &[val]).unwrap().get(0); + assert!(cmp(val, &result)); + } +} + +#[test] +fn compile_fail() { + trybuild::TestCases::new().compile_fail("src/compile-fail/*.rs"); +} diff --git a/postgres-derive-test/src/transparent.rs b/postgres-derive-test/src/transparent.rs new file mode 100644 index 000000000..1614553d2 --- /dev/null +++ b/postgres-derive-test/src/transparent.rs @@ -0,0 +1,18 @@ +use postgres::{Client, NoTls}; +use postgres_types::{FromSql, ToSql}; + +#[test] +fn round_trip() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(transparent)] + struct UserId(i32); + + assert_eq!( + Client::connect("user=postgres host=localhost port=5433", NoTls) + .unwrap() + .query_one("SELECT $1::integer", &[&UserId(123)]) + .unwrap() + .get::<_, UserId>(0), + UserId(123) + ); +} diff --git a/postgres-derive/CHANGELOG.md b/postgres-derive/CHANGELOG.md new file mode 100644 index 000000000..1532b307c --- /dev/null +++ b/postgres-derive/CHANGELOG.md @@ -0,0 +1,46 @@ +# Change Log + +## v0.4.6 - 2024-09-15 + +### Changed + +* Upgraded `heck`. + +## v0.4.5 - 2023-08-19 + +### Added + +* Added a `rename_all` option for enum and struct derives. +* Added an `allow_mismatch` option to disable strict enum variant checks against the Postgres type. + +## v0.4.4 - 2023-03-27 + +### Changed + +* Upgraded `syn`. + +## v0.4.3 - 2022-09-07 + +### Added + +* Added support for parameterized structs. + +## v0.4.2 - 2022-04-30 + +### Added + +* Added support for transparent wrapper types. + +## v0.4.1 - 2021-11-23 + +### Fixed + +* Fixed handling of struct fields using raw identifiers. + +## v0.4.0 - 2019-12-23 + +No changes + +## v0.4.0-alpha.1 - 2019-10-14 + +* Initial release diff --git a/postgres-derive/Cargo.toml b/postgres-derive/Cargo.toml new file mode 100644 index 000000000..96600f124 --- /dev/null +++ b/postgres-derive/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "postgres-derive" +version = "0.4.6" +authors = ["Steven Fackler "] +license = "MIT OR Apache-2.0" +edition = "2018" +description = "An internal crate used by postgres-types" +repository = "https://github.com/sfackler/rust-postgres" + +[lib] +proc-macro = true +test = false + +[dependencies] +syn = "2.0" +proc-macro2 = "1.0" +quote = "1.0" +heck = "0.5" diff --git a/postgres-derive/LICENSE-APACHE b/postgres-derive/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/postgres-derive/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/postgres-derive/LICENSE-MIT b/postgres-derive/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/postgres-derive/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/postgres-derive/src/accepts.rs b/postgres-derive/src/accepts.rs new file mode 100644 index 000000000..a68538dcc --- /dev/null +++ b/postgres-derive/src/accepts.rs @@ -0,0 +1,101 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use std::iter; +use syn::Ident; + +use crate::composites::Field; +use crate::enums::Variant; + +pub fn transparent_body(field: &syn::Field) -> TokenStream { + let ty = &field.ty; + + quote! { + <#ty as ::postgres_types::ToSql>::accepts(type_) + } +} + +pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream { + let ty = &field.ty; + + quote! { + if type_.name() != #name { + return false; + } + + match *type_.kind() { + ::postgres_types::Kind::Domain(ref type_) => { + <#ty as ::postgres_types::ToSql>::accepts(type_) + } + _ => false, + } + } +} + +pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> TokenStream { + let num_variants = variants.len(); + let variant_names = variants.iter().map(|v| &v.name); + + if allow_mismatch { + quote! { + type_.name() == #name + } + } else { + quote! { + if type_.name() != #name { + return false; + } + + match *type_.kind() { + ::postgres_types::Kind::Enum(ref variants) => { + if variants.len() != #num_variants { + return false; + } + + variants.iter().all(|v| { + match &**v { + #( + #variant_names => true, + )* + _ => false, + } + }) + } + _ => false, + } + } + } +} + +pub fn composite_body(name: &str, trait_: &str, fields: &[Field]) -> TokenStream { + let num_fields = fields.len(); + let trait_ = Ident::new(trait_, Span::call_site()); + let traits = iter::repeat(&trait_); + let field_names = fields.iter().map(|f| &f.name); + let field_types = fields.iter().map(|f| &f.type_); + + quote! { + if type_.name() != #name { + return false; + } + + match *type_.kind() { + ::postgres_types::Kind::Composite(ref fields) => { + if fields.len() != #num_fields { + return false; + } + + fields.iter().all(|f| { + match f.name() { + #( + #field_names => { + <#field_types as ::postgres_types::#traits>::accepts(f.type_()) + } + )* + _ => false, + } + }) + } + _ => false, + } + } +} diff --git a/postgres-derive/src/case.rs b/postgres-derive/src/case.rs new file mode 100644 index 000000000..20ecc8eed --- /dev/null +++ b/postgres-derive/src/case.rs @@ -0,0 +1,110 @@ +#[allow(deprecated, unused_imports)] +use std::ascii::AsciiExt; + +use heck::{ + ToKebabCase, ToLowerCamelCase, ToShoutyKebabCase, ToShoutySnakeCase, ToSnakeCase, ToTrainCase, + ToUpperCamelCase, +}; + +use self::RenameRule::*; + +/// The different possible ways to change case of fields in a struct, or variants in an enum. +#[allow(clippy::enum_variant_names)] +#[derive(Copy, Clone, PartialEq)] +pub enum RenameRule { + /// Rename direct children to "lowercase" style. + LowerCase, + /// Rename direct children to "UPPERCASE" style. + UpperCase, + /// Rename direct children to "PascalCase" style, as typically used for + /// enum variants. + PascalCase, + /// Rename direct children to "camelCase" style. + CamelCase, + /// Rename direct children to "snake_case" style, as commonly used for + /// fields. + SnakeCase, + /// Rename direct children to "SCREAMING_SNAKE_CASE" style, as commonly + /// used for constants. + ScreamingSnakeCase, + /// Rename direct children to "kebab-case" style. + KebabCase, + /// Rename direct children to "SCREAMING-KEBAB-CASE" style. + ScreamingKebabCase, + + /// Rename direct children to "Train-Case" style. + TrainCase, +} + +pub const RENAME_RULES: &[&str] = &[ + "lowercase", + "UPPERCASE", + "PascalCase", + "camelCase", + "snake_case", + "SCREAMING_SNAKE_CASE", + "kebab-case", + "SCREAMING-KEBAB-CASE", + "Train-Case", +]; + +impl RenameRule { + pub fn from_str(rule: &str) -> Option { + match rule { + "lowercase" => Some(LowerCase), + "UPPERCASE" => Some(UpperCase), + "PascalCase" => Some(PascalCase), + "camelCase" => Some(CamelCase), + "snake_case" => Some(SnakeCase), + "SCREAMING_SNAKE_CASE" => Some(ScreamingSnakeCase), + "kebab-case" => Some(KebabCase), + "SCREAMING-KEBAB-CASE" => Some(ScreamingKebabCase), + "Train-Case" => Some(TrainCase), + _ => None, + } + } + /// Apply a renaming rule to an enum or struct field, returning the version expected in the source. + pub fn apply_to_field(&self, variant: &str) -> String { + match *self { + LowerCase => variant.to_lowercase(), + UpperCase => variant.to_uppercase(), + PascalCase => variant.to_upper_camel_case(), + CamelCase => variant.to_lower_camel_case(), + SnakeCase => variant.to_snake_case(), + ScreamingSnakeCase => variant.to_shouty_snake_case(), + KebabCase => variant.to_kebab_case(), + ScreamingKebabCase => variant.to_shouty_kebab_case(), + TrainCase => variant.to_train_case(), + } + } +} + +#[test] +fn rename_field() { + for &(original, lower, upper, camel, snake, screaming, kebab, screaming_kebab) in &[ + ( + "Outcome", "outcome", "OUTCOME", "outcome", "outcome", "OUTCOME", "outcome", "OUTCOME", + ), + ( + "VeryTasty", + "verytasty", + "VERYTASTY", + "veryTasty", + "very_tasty", + "VERY_TASTY", + "very-tasty", + "VERY-TASTY", + ), + ("A", "a", "A", "a", "a", "A", "a", "A"), + ("Z42", "z42", "Z42", "z42", "z42", "Z42", "z42", "Z42"), + ] { + assert_eq!(LowerCase.apply_to_field(original), lower); + assert_eq!(UpperCase.apply_to_field(original), upper); + assert_eq!(PascalCase.apply_to_field(original), original); + assert_eq!(CamelCase.apply_to_field(original), camel); + assert_eq!(SnakeCase.apply_to_field(original), snake); + assert_eq!(ScreamingSnakeCase.apply_to_field(original), screaming); + assert_eq!(KebabCase.apply_to_field(original), kebab); + assert_eq!(ScreamingKebabCase.apply_to_field(original), screaming_kebab); + } +} diff --git a/postgres-derive/src/composites.rs b/postgres-derive/src/composites.rs new file mode 100644 index 000000000..b6aad8ab3 --- /dev/null +++ b/postgres-derive/src/composites.rs @@ -0,0 +1,60 @@ +use proc_macro2::Span; +use syn::{ + punctuated::Punctuated, Error, GenericParam, Generics, Ident, Path, PathSegment, Type, + TypeParamBound, +}; + +use crate::{case::RenameRule, overrides::Overrides}; + +pub struct Field { + pub name: String, + pub ident: Ident, + pub type_: Type, +} + +impl Field { + pub fn parse(raw: &syn::Field, rename_all: Option) -> Result { + let overrides = Overrides::extract(&raw.attrs, false)?; + let ident = raw.ident.as_ref().unwrap().clone(); + + // field level name override takes precendence over container level rename_all override + let name = match overrides.name { + Some(n) => n, + None => { + let name = ident.to_string(); + let stripped = name.strip_prefix("r#").map(String::from).unwrap_or(name); + + match rename_all { + Some(rule) => rule.apply_to_field(&stripped), + None => stripped, + } + } + }; + + Ok(Field { + name, + ident, + type_: raw.ty.clone(), + }) + } +} + +pub(crate) fn append_generic_bound(mut generics: Generics, bound: &TypeParamBound) -> Generics { + for param in &mut generics.params { + if let GenericParam::Type(param) = param { + param.bounds.push(bound.to_owned()) + } + } + generics +} + +pub(crate) fn new_derive_path(last: PathSegment) -> Path { + let mut path = Path { + leading_colon: None, + segments: Punctuated::new(), + }; + path.segments + .push(Ident::new("postgres_types", Span::call_site()).into()); + path.segments.push(last); + path +} diff --git a/postgres-derive/src/enums.rs b/postgres-derive/src/enums.rs new file mode 100644 index 000000000..9a6dfa926 --- /dev/null +++ b/postgres-derive/src/enums.rs @@ -0,0 +1,33 @@ +use syn::{Error, Fields, Ident}; + +use crate::{case::RenameRule, overrides::Overrides}; + +pub struct Variant { + pub ident: Ident, + pub name: String, +} + +impl Variant { + pub fn parse(raw: &syn::Variant, rename_all: Option) -> Result { + match raw.fields { + Fields::Unit => {} + _ => { + return Err(Error::new_spanned( + raw, + "non-C-like enums are not supported", + )) + } + } + let overrides = Overrides::extract(&raw.attrs, false)?; + + // variant level name override takes precendence over container level rename_all override + let name = overrides.name.unwrap_or_else(|| match rename_all { + Some(rule) => rule.apply_to_field(&raw.ident.to_string()), + None => raw.ident.to_string(), + }); + Ok(Variant { + ident: raw.ident.clone(), + name, + }) + } +} diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs new file mode 100644 index 000000000..d3ac47f4f --- /dev/null +++ b/postgres-derive/src/fromsql.rs @@ -0,0 +1,266 @@ +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote}; +use std::iter; +use syn::{ + punctuated::Punctuated, token, AngleBracketedGenericArguments, Data, DataStruct, DeriveInput, + Error, Fields, GenericArgument, GenericParam, Generics, Ident, Lifetime, PathArguments, + PathSegment, +}; +use syn::{LifetimeParam, TraitBound, TraitBoundModifier, TypeParamBound}; + +use crate::accepts; +use crate::composites::Field; +use crate::composites::{append_generic_bound, new_derive_path}; +use crate::enums::Variant; +use crate::overrides::Overrides; + +pub fn expand_derive_fromsql(input: DeriveInput) -> Result { + let overrides = Overrides::extract(&input.attrs, true)?; + + if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent { + return Err(Error::new_spanned( + &input, + "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]", + )); + } + + let name = overrides + .name + .clone() + .unwrap_or_else(|| input.ident.to_string()); + + let (accepts_body, to_sql_body) = if overrides.transparent { + match input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(ref fields), + .. + }) if fields.unnamed.len() == 1 => { + let field = fields.unnamed.first().unwrap(); + ( + accepts::transparent_body(field), + transparent_body(&input.ident, field), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(transparent)] may only be applied to single field tuple structs", + )) + } + } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } + } else { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + Data::Struct(DataStruct { + fields: Fields::Unnamed(ref fields), + .. + }) if fields.unnamed.len() == 1 => { + let field = fields.unnamed.first().unwrap(); + ( + domain_accepts_body(&name, field), + domain_body(&input.ident, field), + ) + } + Data::Struct(DataStruct { + fields: Fields::Named(ref fields), + .. + }) => { + let fields = fields + .named + .iter() + .map(|field| Field::parse(field, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::composite_body(&name, "FromSql", &fields), + composite_body(&input.ident, &fields), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[derive(FromSql)] may only be applied to structs, single field tuple structs, and enums", + )) + } + } + }; + + let ident = &input.ident; + let (generics, lifetime) = build_generics(&input.generics); + let (impl_generics, _, _) = generics.split_for_impl(); + let (_, ty_generics, where_clause) = input.generics.split_for_impl(); + let out = quote! { + impl #impl_generics postgres_types::FromSql<#lifetime> for #ident #ty_generics #where_clause { + fn from_sql(_type: &postgres_types::Type, buf: &#lifetime [u8]) + -> std::result::Result<#ident #ty_generics, + std::boxed::Box> { + #to_sql_body + } + + fn accepts(type_: &postgres_types::Type) -> bool { + #accepts_body + } + } + }; + + Ok(out) +} + +fn transparent_body(ident: &Ident, field: &syn::Field) -> TokenStream { + let ty = &field.ty; + quote! { + <#ty as postgres_types::FromSql>::from_sql(_type, buf).map(#ident) + } +} + +fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream { + let variant_names = variants.iter().map(|v| &v.name); + let idents = iter::repeat(ident); + let variant_idents = variants.iter().map(|v| &v.ident); + + quote! { + match std::str::from_utf8(buf)? { + #( + #variant_names => std::result::Result::Ok(#idents::#variant_idents), + )* + s => { + std::result::Result::Err( + std::convert::Into::into(format!("invalid variant `{}`", s))) + } + } + } +} + +// Domains are sometimes but not always just represented by the bare type (!?) +fn domain_accepts_body(name: &str, field: &syn::Field) -> TokenStream { + let ty = &field.ty; + let normal_body = accepts::domain_body(name, field); + + quote! { + if <#ty as postgres_types::FromSql>::accepts(type_) { + return true; + } + + #normal_body + } +} + +fn domain_body(ident: &Ident, field: &syn::Field) -> TokenStream { + let ty = &field.ty; + quote! { + <#ty as postgres_types::FromSql>::from_sql(_type, buf).map(#ident) + } +} + +fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream { + let temp_vars = &fields + .iter() + .map(|f| format_ident!("__{}", f.ident)) + .collect::>(); + let field_names = &fields.iter().map(|f| &f.name).collect::>(); + let field_idents = &fields.iter().map(|f| &f.ident).collect::>(); + + quote! { + let fields = match *_type.kind() { + postgres_types::Kind::Composite(ref fields) => fields, + _ => unreachable!(), + }; + + let mut buf = buf; + let num_fields = postgres_types::private::read_be_i32(&mut buf)?; + if num_fields as usize != fields.len() { + return std::result::Result::Err( + std::convert::Into::into(format!("invalid field count: {} vs {}", num_fields, fields.len()))); + } + + #( + let mut #temp_vars = std::option::Option::None; + )* + + for field in fields { + let oid = postgres_types::private::read_be_i32(&mut buf)? as u32; + if oid != field.type_().oid() { + return std::result::Result::Err(std::convert::Into::into("unexpected OID")); + } + + match field.name() { + #( + #field_names => { + #temp_vars = std::option::Option::Some( + postgres_types::private::read_value(field.type_(), &mut buf)?); + } + )* + _ => unreachable!(), + } + } + + std::result::Result::Ok(#ident { + #( + #field_idents: #temp_vars.unwrap(), + )* + }) + } +} + +fn build_generics(source: &Generics) -> (Generics, Lifetime) { + // don't worry about lifetime name collisions, it doesn't make sense to derive FromSql on a struct with a lifetime + let lifetime = Lifetime::new("'a", Span::call_site()); + + let mut out = append_generic_bound(source.to_owned(), &new_fromsql_bound(&lifetime)); + out.params.insert( + 0, + GenericParam::Lifetime(LifetimeParam::new(lifetime.to_owned())), + ); + + (out, lifetime) +} + +fn new_fromsql_bound(lifetime: &Lifetime) -> TypeParamBound { + let mut path_segment: PathSegment = Ident::new("FromSql", Span::call_site()).into(); + let mut seg_args = Punctuated::new(); + seg_args.push(GenericArgument::Lifetime(lifetime.to_owned())); + path_segment.arguments = PathArguments::AngleBracketed(AngleBracketedGenericArguments { + colon2_token: None, + lt_token: token::Lt::default(), + args: seg_args, + gt_token: token::Gt::default(), + }); + + TypeParamBound::Trait(TraitBound { + lifetimes: None, + modifier: TraitBoundModifier::None, + paren_token: None, + path: new_derive_path(path_segment), + }) +} diff --git a/postgres-derive/src/lib.rs b/postgres-derive/src/lib.rs new file mode 100644 index 000000000..b849096c9 --- /dev/null +++ b/postgres-derive/src/lib.rs @@ -0,0 +1,33 @@ +//! An internal crate for `postgres-types`. + +#![recursion_limit = "256"] +extern crate proc_macro; + +use proc_macro::TokenStream; +use syn::parse_macro_input; + +mod accepts; +mod case; +mod composites; +mod enums; +mod fromsql; +mod overrides; +mod tosql; + +#[proc_macro_derive(ToSql, attributes(postgres))] +pub fn derive_tosql(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input); + + tosql::expand_derive_tosql(input) + .unwrap_or_else(|e| e.to_compile_error()) + .into() +} + +#[proc_macro_derive(FromSql, attributes(postgres))] +pub fn derive_fromsql(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input); + + fromsql::expand_derive_fromsql(input) + .unwrap_or_else(|e| e.to_compile_error()) + .into() +} diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs new file mode 100644 index 000000000..d50550bee --- /dev/null +++ b/postgres-derive/src/overrides.rs @@ -0,0 +1,106 @@ +use syn::punctuated::Punctuated; +use syn::{Attribute, Error, Expr, ExprLit, Lit, Meta, Token}; + +use crate::case::{RenameRule, RENAME_RULES}; + +pub struct Overrides { + pub name: Option, + pub rename_all: Option, + pub transparent: bool, + pub allow_mismatch: bool, +} + +impl Overrides { + pub fn extract(attrs: &[Attribute], container_attr: bool) -> Result { + let mut overrides = Overrides { + name: None, + rename_all: None, + transparent: false, + allow_mismatch: false, + }; + + for attr in attrs { + if !attr.path().is_ident("postgres") { + continue; + } + + let list = match &attr.meta { + Meta::List(ref list) => list, + bad => return Err(Error::new_spanned(bad, "expected a #[postgres(...)]")), + }; + + let nested = list.parse_args_with(Punctuated::::parse_terminated)?; + + for item in nested { + match item { + Meta::NameValue(meta) => { + let name_override = meta.path.is_ident("name"); + let rename_all_override = meta.path.is_ident("rename_all"); + if !container_attr && rename_all_override { + return Err(Error::new_spanned( + &meta.path, + "rename_all is a container attribute", + )); + } + if !name_override && !rename_all_override { + return Err(Error::new_spanned(&meta.path, "unknown override")); + } + + let value = match &meta.value { + Expr::Lit(ExprLit { + lit: Lit::Str(lit), .. + }) => lit.value(), + bad => { + return Err(Error::new_spanned(bad, "expected a string literal")) + } + }; + + if name_override { + overrides.name = Some(value); + } else if rename_all_override { + let rename_rule = RenameRule::from_str(&value).ok_or_else(|| { + Error::new_spanned( + &meta.value, + format!( + "invalid rename_all rule, expected one of: {}", + RENAME_RULES + .iter() + .map(|rule| format!("\"{}\"", rule)) + .collect::>() + .join(", ") + ), + ) + })?; + + overrides.rename_all = Some(rename_rule); + } + } + Meta::Path(path) => { + if path.is_ident("transparent") { + if overrides.allow_mismatch { + return Err(Error::new_spanned( + path, + "#[postgres(allow_mismatch)] is not allowed with #[postgres(transparent)]", + )); + } + overrides.transparent = true; + } else if path.is_ident("allow_mismatch") { + if overrides.transparent { + return Err(Error::new_spanned( + path, + "#[postgres(transparent)] is not allowed with #[postgres(allow_mismatch)]", + )); + } + overrides.allow_mismatch = true; + } else { + return Err(Error::new_spanned(path, "unknown override")); + } + } + bad => return Err(Error::new_spanned(bad, "unknown attribute")), + } + } + } + + Ok(overrides) + } +} diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs new file mode 100644 index 000000000..81d4834bf --- /dev/null +++ b/postgres-derive/src/tosql.rs @@ -0,0 +1,221 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use std::iter; +use syn::{ + Data, DataStruct, DeriveInput, Error, Fields, Ident, TraitBound, TraitBoundModifier, + TypeParamBound, +}; + +use crate::accepts; +use crate::composites::Field; +use crate::composites::{append_generic_bound, new_derive_path}; +use crate::enums::Variant; +use crate::overrides::Overrides; + +pub fn expand_derive_tosql(input: DeriveInput) -> Result { + let overrides = Overrides::extract(&input.attrs, true)?; + + if (overrides.name.is_some() || overrides.rename_all.is_some()) && overrides.transparent { + return Err(Error::new_spanned( + &input, + "#[postgres(transparent)] is not allowed with #[postgres(name = \"...\")] or #[postgres(rename_all = \"...\")]", + )); + } + + let name = overrides + .name + .clone() + .unwrap_or_else(|| input.ident.to_string()); + + let (accepts_body, to_sql_body) = if overrides.transparent { + match input.data { + Data::Struct(DataStruct { + fields: Fields::Unnamed(ref fields), + .. + }) if fields.unnamed.len() == 1 => { + let field = fields.unnamed.first().unwrap(); + + (accepts::transparent_body(field), transparent_body()) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(transparent)] may only be applied to single field tuple structs", + )); + } + } + } else if overrides.allow_mismatch { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[postgres(allow_mismatch)] may only be applied to enums", + )); + } + } + } else { + match input.data { + Data::Enum(ref data) => { + let variants = data + .variants + .iter() + .map(|variant| Variant::parse(variant, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::enum_body(&name, &variants, overrides.allow_mismatch), + enum_body(&input.ident, &variants), + ) + } + Data::Struct(DataStruct { + fields: Fields::Unnamed(ref fields), + .. + }) if fields.unnamed.len() == 1 => { + let field = fields.unnamed.first().unwrap(); + + (accepts::domain_body(&name, field), domain_body()) + } + Data::Struct(DataStruct { + fields: Fields::Named(ref fields), + .. + }) => { + let fields = fields + .named + .iter() + .map(|field| Field::parse(field, overrides.rename_all)) + .collect::, _>>()?; + ( + accepts::composite_body(&name, "ToSql", &fields), + composite_body(&fields), + ) + } + _ => { + return Err(Error::new_spanned( + input, + "#[derive(ToSql)] may only be applied to structs, single field tuple structs, and enums", + )); + } + } + }; + + let ident = &input.ident; + let generics = append_generic_bound(input.generics.to_owned(), &new_tosql_bound()); + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let out = quote! { + impl#impl_generics postgres_types::ToSql for #ident#ty_generics #where_clause { + fn to_sql(&self, + _type: &postgres_types::Type, + buf: &mut postgres_types::private::BytesMut) + -> std::result::Result> { + #to_sql_body + } + + fn accepts(type_: &postgres_types::Type) -> bool { + #accepts_body + } + + postgres_types::to_sql_checked!(); + } + }; + + Ok(out) +} + +fn transparent_body() -> TokenStream { + quote! { + postgres_types::ToSql::to_sql(&self.0, _type, buf) + } +} + +fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream { + let idents = iter::repeat(ident); + let variant_idents = variants.iter().map(|v| &v.ident); + let variant_names = variants.iter().map(|v| &v.name); + + quote! { + let s = match *self { + #( + #idents::#variant_idents => #variant_names, + )* + }; + + buf.extend_from_slice(s.as_bytes()); + std::result::Result::Ok(postgres_types::IsNull::No) + } +} + +fn domain_body() -> TokenStream { + quote! { + let type_ = match *_type.kind() { + postgres_types::Kind::Domain(ref type_) => type_, + _ => unreachable!(), + }; + + postgres_types::ToSql::to_sql(&self.0, type_, buf) + } +} + +fn composite_body(fields: &[Field]) -> TokenStream { + let field_names = fields.iter().map(|f| &f.name); + let field_idents = fields.iter().map(|f| &f.ident); + + quote! { + let fields = match *_type.kind() { + postgres_types::Kind::Composite(ref fields) => fields, + _ => unreachable!(), + }; + + buf.extend_from_slice(&(fields.len() as i32).to_be_bytes()); + + for field in fields { + buf.extend_from_slice(&field.type_().oid().to_be_bytes()); + + let base = buf.len(); + buf.extend_from_slice(&[0; 4]); + let r = match field.name() { + #( + #field_names => postgres_types::ToSql::to_sql(&self.#field_idents, field.type_(), buf), + )* + _ => unreachable!(), + }; + + let count = match r? { + postgres_types::IsNull::Yes => -1, + postgres_types::IsNull::No => { + let len = buf.len() - base - 4; + if len > i32::max_value() as usize { + return std::result::Result::Err( + std::convert::Into::into("value too large to transmit")); + } + len as i32 + } + }; + + buf[base..base + 4].copy_from_slice(&count.to_be_bytes()); + } + + std::result::Result::Ok(postgres_types::IsNull::No) + } +} + +fn new_tosql_bound() -> TypeParamBound { + TypeParamBound::Trait(TraitBound { + lifetimes: None, + modifier: TraitBoundModifier::None, + paren_token: None, + path: new_derive_path(Ident::new("ToSql", Span::call_site()).into()), + }) +} diff --git a/postgres-native-tls/CHANGELOG.md b/postgres-native-tls/CHANGELOG.md new file mode 100644 index 000000000..5fe0a9c7a --- /dev/null +++ b/postgres-native-tls/CHANGELOG.md @@ -0,0 +1,43 @@ +# Change Log + +## v0.5.1 - 2025-02-02 + +### Added + +* Added `set_postgresql_alpn`. + +## v0.5.0 - 2020-12-25 + +### Changed + +* Upgraded to `tokio-postgres` 0.7. + +## v0.4.0 - 2020-10-17 + +### Changed + +* Upgraded to `tokio-postgres` 0.6. + +## v0.3.0 - 2019-12-23 + +### Changed + +* Upgraded to `tokio-postgres` 0.5. + +## v0.3.0-alpha.2 - 2019-11-27 + +### Changed + +* Upgraded to `tokio-postgres` v0.5.0-alpha.2. + +## v0.3.0-alpha.1 - 2019-10-14 + +### Changed + +* Updated to `tokio-postgres` v0.5.0-alpha.1. + +## v0.2.0-rc.1 - 2019-06-29 + +### Changed + +* Updated to `tokio-postgres` v0.4.0-rc. diff --git a/postgres-native-tls/Cargo.toml b/postgres-native-tls/Cargo.toml index ada50ccff..f79ae5491 100644 --- a/postgres-native-tls/Cargo.toml +++ b/postgres-native-tls/Cargo.toml @@ -1,9 +1,27 @@ [package] name = "postgres-native-tls" -version = "0.1.0" +version = "0.5.1" authors = ["Steven Fackler "] +edition = "2018" +license = "MIT OR Apache-2.0" +description = "TLS support for tokio-postgres via native-tls" +repository = "https://github.com/sfackler/rust-postgres" +readme = "../README.md" + +[badges] +circle-ci = { repository = "sfackler/rust-postgres" } + +[features] +default = ["runtime"] +runtime = ["tokio-postgres/runtime"] [dependencies] -native-tls = "0.2" +native-tls = { version = "0.2", features = ["alpn"] } +tokio = "1.0" +tokio-native-tls = "0.3" +tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-features = false } -postgres = { version = "0.15", path = "../postgres" } +[dev-dependencies] +futures-util = "0.3" +tokio = { version = "1.0", features = ["macros", "net", "rt"] } +postgres = { version = "0.19.8", path = "../postgres" } diff --git a/postgres-native-tls/LICENSE-APACHE b/postgres-native-tls/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/postgres-native-tls/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/postgres-native-tls/LICENSE-MIT b/postgres-native-tls/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/postgres-native-tls/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/postgres-native-tls/src/lib.rs b/postgres-native-tls/src/lib.rs index ea887ed70..9ee7da653 100644 --- a/postgres-native-tls/src/lib.rs +++ b/postgres-native-tls/src/lib.rs @@ -1,76 +1,190 @@ -pub extern crate native_tls; -extern crate postgres; +//! TLS support for `tokio-postgres` and `postgres` via `native-tls`. +//! +//! # Examples +//! +//! ```no_run +//! use native_tls::{Certificate, TlsConnector}; +//! # #[cfg(feature = "runtime")] +//! use postgres_native_tls::MakeTlsConnector; +//! use std::fs; +//! +//! # fn main() -> Result<(), Box> { +//! # #[cfg(feature = "runtime")] { +//! let cert = fs::read("database_cert.pem")?; +//! let cert = Certificate::from_pem(&cert)?; +//! let connector = TlsConnector::builder() +//! .add_root_certificate(cert) +//! .build()?; +//! let connector = MakeTlsConnector::new(connector); +//! +//! let connect_future = tokio_postgres::connect( +//! "host=localhost user=postgres sslmode=require", +//! connector, +//! ); +//! # } +//! +//! // ... +//! # Ok(()) +//! # } +//! ``` +//! +//! ```no_run +//! use native_tls::{Certificate, TlsConnector}; +//! # #[cfg(feature = "runtime")] +//! use postgres_native_tls::MakeTlsConnector; +//! use std::fs; +//! +//! # fn main() -> Result<(), Box> { +//! # #[cfg(feature = "runtime")] { +//! let cert = fs::read("database_cert.pem")?; +//! let cert = Certificate::from_pem(&cert)?; +//! let connector = TlsConnector::builder() +//! .add_root_certificate(cert) +//! .build()?; +//! let connector = MakeTlsConnector::new(connector); +//! +//! let client = postgres::Client::connect( +//! "host=localhost user=postgres sslmode=require", +//! connector, +//! )?; +//! # } +//! # Ok(()) +//! # } +//! ``` +#![warn(rust_2018_idioms, clippy::all, missing_docs)] -use native_tls::TlsConnector; -use postgres::tls::{Stream, TlsHandshake, TlsStream}; -use std::error::Error; -use std::fmt::{self, Debug}; -use std::io::{self, Read, Write}; +use native_tls::TlsConnectorBuilder; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf}; +use tokio_postgres::tls; +#[cfg(feature = "runtime")] +use tokio_postgres::tls::MakeTlsConnect; +use tokio_postgres::tls::{ChannelBinding, TlsConnect}; #[cfg(test)] mod test; -pub struct NativeTls { - connector: TlsConnector, -} +/// A `MakeTlsConnect` implementation using the `native-tls` crate. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +#[cfg(feature = "runtime")] +#[derive(Clone)] +pub struct MakeTlsConnector(native_tls::TlsConnector); -impl Debug for NativeTls { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("NativeTls").finish() +#[cfg(feature = "runtime")] +impl MakeTlsConnector { + /// Creates a new connector. + pub fn new(connector: native_tls::TlsConnector) -> MakeTlsConnector { + MakeTlsConnector(connector) } } -impl NativeTls { - pub fn new() -> Result { - let connector = TlsConnector::builder().build()?; - Ok(NativeTls::with_connector(connector)) - } +#[cfg(feature = "runtime")] +impl MakeTlsConnect for MakeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + 'static + Send, +{ + type Stream = TlsStream; + type TlsConnect = TlsConnector; + type Error = native_tls::Error; - pub fn with_connector(connector: TlsConnector) -> NativeTls { - NativeTls { connector } + fn make_tls_connect(&mut self, domain: &str) -> Result { + Ok(TlsConnector::new(self.0.clone(), domain)) } } -impl TlsHandshake for NativeTls { - fn tls_handshake( - &self, - domain: &str, - stream: Stream, - ) -> Result, Box> { - let stream = self.connector.connect(domain, stream)?; - Ok(Box::new(NativeTlsStream(stream))) +/// A `TlsConnect` implementation using the `native-tls` crate. +pub struct TlsConnector { + connector: tokio_native_tls::TlsConnector, + domain: String, +} + +impl TlsConnector { + /// Creates a new connector configured to connect to the specified domain. + pub fn new(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector { + TlsConnector { + connector: tokio_native_tls::TlsConnector::from(connector), + domain: domain.to_string(), + } } } -#[derive(Debug)] -struct NativeTlsStream(native_tls::TlsStream); +impl TlsConnect for TlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + 'static + Send, +{ + type Stream = TlsStream; + type Error = native_tls::Error; + #[allow(clippy::type_complexity)] + type Future = Pin, native_tls::Error>> + Send>>; + + fn connect(self, stream: S) -> Self::Future { + let stream = BufReader::with_capacity(8192, stream); + let future = async move { + let stream = self.connector.connect(&self.domain, stream).await?; + + Ok(TlsStream(stream)) + }; -impl Read for NativeTlsStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) + Box::pin(future) } } -impl Write for NativeTlsStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.write(buf) - } +/// The stream returned by `TlsConnector`. +pub struct TlsStream(tokio_native_tls::TlsStream>); - fn flush(&mut self) -> io::Result<()> { - self.0.flush() +impl AsyncRead for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) } } -impl TlsStream for NativeTlsStream { - fn get_ref(&self) -> &Stream { - self.0.get_ref() +impl AsyncWrite for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) } - fn get_mut(&mut self) -> &mut Stream { - self.0.get_mut() + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) } +} - fn tls_server_end_point(&self) -> Option> { - self.0.tls_server_end_point().ok().and_then(|o| o) +impl tls::TlsStream for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn channel_binding(&self) -> ChannelBinding { + match self.0.get_ref().tls_server_end_point().ok().flatten() { + Some(buf) => ChannelBinding::tls_server_end_point(buf), + None => ChannelBinding::none(), + } } } + +/// Set ALPN for `TlsConnectorBuilder` +/// +/// This is required when using `sslnegotiation=direct` +pub fn set_postgresql_alpn(builder: &mut TlsConnectorBuilder) { + builder.request_alpns(&["postgresql"]); +} diff --git a/postgres-native-tls/src/test.rs b/postgres-native-tls/src/test.rs index a84798d46..738c04bd7 100644 --- a/postgres-native-tls/src/test.rs +++ b/postgres-native-tls/src/test.rs @@ -1,38 +1,115 @@ -use native_tls::{Certificate, TlsConnector}; -use postgres::{Connection, TlsMode}; +use futures_util::FutureExt; +use native_tls::{self, Certificate}; +use tokio::net::TcpStream; +use tokio_postgres::tls::TlsConnect; -use NativeTls; +#[cfg(feature = "runtime")] +use crate::MakeTlsConnector; +use crate::{set_postgresql_alpn, TlsConnector}; -#[test] -fn connect() { - let cert = include_bytes!("../../test/server.crt"); - let cert = Certificate::from_pem(cert).unwrap(); +async fn smoke_test(s: &str, tls: T) +where + T: TlsConnect, + T::Stream: 'static + Send, +{ + let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap(); - let mut builder = TlsConnector::builder(); - builder.add_root_certificate(cert); - let connector = builder.build().unwrap(); + let builder = s.parse::().unwrap(); + let (client, connection) = builder.connect_raw(stream, tls).await.unwrap(); + + let connection = connection.map(|r| r.unwrap()); + tokio::spawn(connection); + + let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); + let rows = client.query(&stmt, &[&1i32]).await.unwrap(); - let handshake = NativeTls::with_connector(connector); - let conn = Connection::connect( - "postgres://ssl_user@localhost:5433/postgres", - TlsMode::Require(&handshake), - ).unwrap(); - conn.execute("SELECT 1::VARCHAR", &[]).unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, i32>(0), 1); } -#[test] -fn scram_user() { - let cert = include_bytes!("../../test/server.crt"); - let cert = Certificate::from_pem(cert).unwrap(); +#[tokio::test] +async fn require() { + let connector = native_tls::TlsConnector::builder() + .add_root_certificate( + Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(), + ) + .build() + .unwrap(); + smoke_test( + "user=ssl_user dbname=postgres sslmode=require", + TlsConnector::new(connector, "localhost"), + ) + .await; +} - let mut builder = TlsConnector::builder(); - builder.add_root_certificate(cert); +#[tokio::test] +async fn direct() { + let mut builder = native_tls::TlsConnector::builder(); + builder.add_root_certificate( + Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(), + ); + set_postgresql_alpn(&mut builder); let connector = builder.build().unwrap(); + smoke_test( + "user=ssl_user dbname=postgres sslmode=require sslnegotiation=direct", + TlsConnector::new(connector, "localhost"), + ) + .await; +} + +#[tokio::test] +async fn prefer() { + let connector = native_tls::TlsConnector::builder() + .add_root_certificate( + Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(), + ) + .build() + .unwrap(); + smoke_test( + "user=ssl_user dbname=postgres", + TlsConnector::new(connector, "localhost"), + ) + .await; +} + +#[tokio::test] +async fn scram_user() { + let connector = native_tls::TlsConnector::builder() + .add_root_certificate( + Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(), + ) + .build() + .unwrap(); + smoke_test( + "user=scram_user password=password dbname=postgres sslmode=require", + TlsConnector::new(connector, "localhost"), + ) + .await; +} + +#[tokio::test] +#[cfg(feature = "runtime")] +async fn runtime() { + let connector = native_tls::TlsConnector::builder() + .add_root_certificate( + Certificate::from_pem(include_bytes!("../../test/server.crt")).unwrap(), + ) + .build() + .unwrap(); + let connector = MakeTlsConnector::new(connector); + + let (client, connection) = tokio_postgres::connect( + "host=localhost port=5433 user=postgres sslmode=require", + connector, + ) + .await + .unwrap(); + let connection = connection.map(|r| r.unwrap()); + tokio::spawn(connection); + + let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); + let rows = client.query(&stmt, &[&1i32]).await.unwrap(); - let handshake = NativeTls::with_connector(connector); - let conn = Connection::connect( - "postgres://scram_user:password@localhost:5433/postgres", - TlsMode::Require(&handshake), - ).unwrap(); - conn.execute("SELECT 1::VARCHAR", &[]).unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, i32>(0), 1); } diff --git a/postgres-openssl/CHANGELOG.md b/postgres-openssl/CHANGELOG.md new file mode 100644 index 000000000..33f5a127a --- /dev/null +++ b/postgres-openssl/CHANGELOG.md @@ -0,0 +1,43 @@ +# Change Log + +## v0.5.1 - 2025-02-02 + +### Added + +* Added `set_postgresql_alpn`. + +## v0.5.0 - 2020-12-25 + +### Changed + +* Upgraded to `tokio-postgres` 0.7. + +## v0.4.0 - 2020-10-17 + +### Changed + +* Upgraded to `tokio-postgres` 0.6. + +## v0.3.0 - 2019-12-23 + +### Changed + +* Upgraded to `tokio-postgres` 0.5. + +## v0.3.0-alpha.2 - 2019-11-27 + +### Changed + +* Upgraded `tokio-postgres` v0.5.0-alpha.2. + +## v0.3.0-alpha.1 - 2019-10-14 + +### Changed + +* Updated to `tokio-postgres` v0.5.0-alpha.1. + +## v0.2.0-rc.1 - 2019-03-06 + +### Changed + +* Updated to `tokio-postgres` v0.4.0-rc. diff --git a/postgres-openssl/Cargo.toml b/postgres-openssl/Cargo.toml index 53321fd56..6ebb86bef 100644 --- a/postgres-openssl/Cargo.toml +++ b/postgres-openssl/Cargo.toml @@ -1,9 +1,27 @@ [package] name = "postgres-openssl" -version = "0.1.0" +version = "0.5.1" authors = ["Steven Fackler "] +edition = "2018" +license = "MIT OR Apache-2.0" +description = "TLS support for tokio-postgres via openssl" +repository = "https://github.com/sfackler/rust-postgres" +readme = "../README.md" + +[badges] +circle-ci = { repository = "sfackler/rust-postgres" } + +[features] +default = ["runtime"] +runtime = ["tokio-postgres/runtime"] [dependencies] -openssl = "0.10.9" +openssl = "0.10" +tokio = "1.0" +tokio-openssl = "0.6" +tokio-postgres = { version = "0.7.11", path = "../tokio-postgres", default-features = false } -postgres = { version = "0.15", path = "../postgres" } +[dev-dependencies] +futures-util = "0.3" +tokio = { version = "1.0", features = ["macros", "net", "rt"] } +postgres = { version = "0.19.8", path = "../postgres" } diff --git a/postgres-openssl/LICENSE-APACHE b/postgres-openssl/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/postgres-openssl/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/postgres-openssl/LICENSE-MIT b/postgres-openssl/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/postgres-openssl/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/postgres-openssl/src/lib.rs b/postgres-openssl/src/lib.rs index 7f09b4708..232cccd05 100644 --- a/postgres-openssl/src/lib.rs +++ b/postgres-openssl/src/lib.rs @@ -1,100 +1,259 @@ -pub extern crate openssl; -extern crate postgres; +//! TLS support for `tokio-postgres` and `postgres` via `openssl`. +//! +//! # Examples +//! +//! ```no_run +//! use openssl::ssl::{SslConnector, SslMethod}; +//! # #[cfg(feature = "runtime")] +//! use postgres_openssl::MakeTlsConnector; +//! +//! # fn main() -> Result<(), Box> { +//! # #[cfg(feature = "runtime")] { +//! let mut builder = SslConnector::builder(SslMethod::tls())?; +//! builder.set_ca_file("database_cert.pem")?; +//! let connector = MakeTlsConnector::new(builder.build()); +//! +//! let connect_future = tokio_postgres::connect( +//! "host=localhost user=postgres sslmode=require", +//! connector, +//! ); +//! # } +//! +//! // ... +//! # Ok(()) +//! # } +//! ``` +//! +//! ```no_run +//! use openssl::ssl::{SslConnector, SslMethod}; +//! # #[cfg(feature = "runtime")] +//! use postgres_openssl::MakeTlsConnector; +//! +//! # fn main() -> Result<(), Box> { +//! # #[cfg(feature = "runtime")] { +//! let mut builder = SslConnector::builder(SslMethod::tls())?; +//! builder.set_ca_file("database_cert.pem")?; +//! let connector = MakeTlsConnector::new(builder.build()); +//! +//! let client = postgres::Client::connect( +//! "host=localhost user=postgres sslmode=require", +//! connector, +//! )?; +//! # } +//! +//! // ... +//! # Ok(()) +//! # } +//! ``` +#![warn(rust_2018_idioms, clippy::all, missing_docs)] +#[cfg(feature = "runtime")] use openssl::error::ErrorStack; -use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod, SslRef, SslStream}; -use postgres::tls::{Stream, TlsHandshake, TlsStream}; +use openssl::hash::MessageDigest; +use openssl::nid::Nid; +#[cfg(feature = "runtime")] +use openssl::ssl::SslConnector; +use openssl::ssl::{self, ConnectConfiguration, SslConnectorBuilder, SslRef}; +use openssl::x509::X509VerifyResult; use std::error::Error; -use std::fmt; -use std::io::{self, Read, Write}; +use std::fmt::{self, Debug}; +use std::future::Future; +use std::io; +use std::pin::Pin; +#[cfg(feature = "runtime")] +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, BufReader, ReadBuf}; +use tokio_openssl::SslStream; +use tokio_postgres::tls; +#[cfg(feature = "runtime")] +use tokio_postgres::tls::MakeTlsConnect; +use tokio_postgres::tls::{ChannelBinding, TlsConnect}; #[cfg(test)] mod test; -pub struct OpenSsl { - connector: SslConnector, - config: Box Result<(), ErrorStack> + Sync + Send>, -} +type ConfigCallback = + dyn Fn(&mut ConnectConfiguration, &str) -> Result<(), ErrorStack> + Sync + Send; -impl fmt::Debug for OpenSsl { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("OpenSsl").finish() - } +/// A `MakeTlsConnect` implementation using the `openssl` crate. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +#[cfg(feature = "runtime")] +#[derive(Clone)] +pub struct MakeTlsConnector { + connector: SslConnector, + config: Arc, } -impl OpenSsl { - pub fn new() -> Result { - let connector = SslConnector::builder(SslMethod::tls())?.build(); - Ok(OpenSsl::with_connector(connector)) - } - - pub fn with_connector(connector: SslConnector) -> OpenSsl { - OpenSsl { +#[cfg(feature = "runtime")] +impl MakeTlsConnector { + /// Creates a new connector. + pub fn new(connector: SslConnector) -> MakeTlsConnector { + MakeTlsConnector { connector, - config: Box::new(|_| Ok(())), + config: Arc::new(|_, _| Ok(())), } } - pub fn callback(&mut self, f: F) + /// Sets a callback used to apply per-connection configuration. + /// + /// The the callback is provided the domain name along with the `ConnectConfiguration`. + pub fn set_callback(&mut self, f: F) where - F: Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + 'static + Sync + Send, + F: Fn(&mut ConnectConfiguration, &str) -> Result<(), ErrorStack> + 'static + Sync + Send, { - self.config = Box::new(f); + self.config = Arc::new(f); } } -impl TlsHandshake for OpenSsl { - fn tls_handshake( - &self, - domain: &str, - stream: Stream, - ) -> Result, Box> { +#[cfg(feature = "runtime")] +impl MakeTlsConnect for MakeTlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + Debug + 'static + Sync + Send, +{ + type Stream = TlsStream; + type TlsConnect = TlsConnector; + type Error = ErrorStack; + + fn make_tls_connect(&mut self, domain: &str) -> Result { let mut ssl = self.connector.configure()?; - (self.config)(&mut ssl)?; - let stream = ssl.connect(domain, stream)?; + (self.config)(&mut ssl, domain)?; + Ok(TlsConnector::new(ssl, domain)) + } +} + +/// A `TlsConnect` implementation using the `openssl` crate. +pub struct TlsConnector { + ssl: ConnectConfiguration, + domain: String, +} + +impl TlsConnector { + /// Creates a new connector configured to connect to the specified domain. + pub fn new(ssl: ConnectConfiguration, domain: &str) -> TlsConnector { + TlsConnector { + ssl, + domain: domain.to_string(), + } + } +} + +impl TlsConnect for TlsConnector +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Stream = TlsStream; + type Error = Box; + #[allow(clippy::type_complexity)] + type Future = Pin, Self::Error>> + Send>>; - Ok(Box::new(OpenSslStream(stream))) + fn connect(self, stream: S) -> Self::Future { + let stream = BufReader::with_capacity(8192, stream); + let future = async move { + let ssl = self.ssl.into_ssl(&self.domain)?; + let mut stream = SslStream::new(ssl, stream)?; + match Pin::new(&mut stream).connect().await { + Ok(()) => Ok(TlsStream(stream)), + Err(error) => Err(Box::new(ConnectError { + error, + verify_result: stream.ssl().verify_result(), + }) as _), + } + }; + + Box::pin(future) } } #[derive(Debug)] -struct OpenSslStream(SslStream); +struct ConnectError { + error: ssl::Error, + verify_result: X509VerifyResult, +} + +impl fmt::Display for ConnectError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.error, fmt)?; -impl Read for OpenSslStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) + if self.verify_result != X509VerifyResult::OK { + fmt.write_str(": ")?; + fmt::Display::fmt(&self.verify_result, fmt)?; + } + + Ok(()) } } -impl Write for OpenSslStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.write(buf) +impl Error for ConnectError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(&self.error) } +} + +/// The stream returned by `TlsConnector`. +pub struct TlsStream(SslStream>); - fn flush(&mut self) -> io::Result<()> { - self.0.flush() +impl AsyncRead for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) } } -impl TlsStream for OpenSslStream { - fn get_ref(&self) -> &Stream { - self.0.get_ref() +impl AsyncWrite for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) } - fn get_mut(&mut self) -> &mut Stream { - self.0.get_mut() + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) } - fn tls_unique(&self) -> Option> { - let f = if self.0.ssl().session_reused() { - SslRef::peer_finished - } else { - SslRef::finished - }; + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } +} - let len = f(self.0.ssl(), &mut []); - let mut buf = vec![0; len]; - f(self.0.ssl(), &mut buf); - Some(buf) +impl tls::TlsStream for TlsStream +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn channel_binding(&self) -> ChannelBinding { + match tls_server_end_point(self.0.ssl()) { + Some(buf) => ChannelBinding::tls_server_end_point(buf), + None => ChannelBinding::none(), + } } } + +fn tls_server_end_point(ssl: &SslRef) -> Option> { + let cert = ssl.peer_certificate()?; + let algo_nid = cert.signature_algorithm().object().nid(); + let signature_algorithms = algo_nid.signature_algorithms()?; + let md = match signature_algorithms.digest { + Nid::MD5 | Nid::SHA1 => MessageDigest::sha256(), + nid => MessageDigest::from_nid(nid)?, + }; + cert.digest(md).ok().map(|b| b.to_vec()) +} + +/// Set ALPN for `SslConnectorBuilder` +/// +/// This is required when using `sslnegotiation=direct` +pub fn set_postgresql_alpn(builder: &mut SslConnectorBuilder) -> Result<(), ErrorStack> { + builder.set_alpn_protos(b"\x0apostgresql") +} diff --git a/postgres-openssl/src/test.rs b/postgres-openssl/src/test.rs index 8314e179f..66bb22641 100644 --- a/postgres-openssl/src/test.rs +++ b/postgres-openssl/src/test.rs @@ -1,40 +1,124 @@ +use futures_util::FutureExt; use openssl::ssl::{SslConnector, SslMethod}; -use postgres::{Connection, TlsMode}; +use tokio::net::TcpStream; +use tokio_postgres::tls::TlsConnect; -use OpenSsl; +use super::*; -#[test] -fn require() { +async fn smoke_test(s: &str, tls: T) +where + T: TlsConnect, + T::Stream: 'static + Send, +{ + let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap(); + + let builder = s.parse::().unwrap(); + let (client, connection) = builder.connect_raw(stream, tls).await.unwrap(); + + let connection = connection.map(|r| r.unwrap()); + tokio::spawn(connection); + + let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); + let rows = client.query(&stmt, &[&1i32]).await.unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, i32>(0), 1); +} + +#[tokio::test] +async fn require() { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_ca_file("../test/server.crt").unwrap(); + let ctx = builder.build(); + smoke_test( + "user=ssl_user dbname=postgres sslmode=require", + TlsConnector::new(ctx.configure().unwrap(), "localhost"), + ) + .await; +} + +#[tokio::test] +async fn direct() { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_ca_file("../test/server.crt").unwrap(); + set_postgresql_alpn(&mut builder).unwrap(); + let ctx = builder.build(); + smoke_test( + "user=ssl_user dbname=postgres sslmode=require sslnegotiation=direct", + TlsConnector::new(ctx.configure().unwrap(), "localhost"), + ) + .await; +} + +#[tokio::test] +async fn prefer() { let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_ca_file("../test/server.crt").unwrap(); - let negotiator = OpenSsl::with_connector(builder.build()); - let conn = Connection::connect( - "postgres://ssl_user@localhost:5433/postgres", - TlsMode::Require(&negotiator), - ).unwrap(); - conn.execute("SELECT 1::VARCHAR", &[]).unwrap(); + let ctx = builder.build(); + smoke_test( + "user=ssl_user dbname=postgres", + TlsConnector::new(ctx.configure().unwrap(), "localhost"), + ) + .await; } -#[test] -fn prefer() { +#[tokio::test] +async fn scram_user() { let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_ca_file("../test/server.crt").unwrap(); - let negotiator = OpenSsl::with_connector(builder.build()); - let conn = Connection::connect( - "postgres://ssl_user@localhost:5433/postgres", - TlsMode::Require(&negotiator), - ).unwrap(); - conn.execute("SELECT 1::VARCHAR", &[]).unwrap(); + let ctx = builder.build(); + smoke_test( + "user=scram_user password=password dbname=postgres sslmode=require", + TlsConnector::new(ctx.configure().unwrap(), "localhost"), + ) + .await; } -#[test] -fn scram_user() { +#[tokio::test] +async fn require_channel_binding_err() { let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); builder.set_ca_file("../test/server.crt").unwrap(); - let negotiator = OpenSsl::with_connector(builder.build()); - let conn = Connection::connect( - "postgres://scram_user:password@localhost:5433/postgres", - TlsMode::Require(&negotiator), - ).unwrap(); - conn.execute("SELECT 1::VARCHAR", &[]).unwrap(); + let ctx = builder.build(); + let connector = TlsConnector::new(ctx.configure().unwrap(), "localhost"); + + let stream = TcpStream::connect("127.0.0.1:5433").await.unwrap(); + let builder = "user=pass_user password=password dbname=postgres channel_binding=require" + .parse::() + .unwrap(); + builder.connect_raw(stream, connector).await.err().unwrap(); +} + +#[tokio::test] +async fn require_channel_binding_ok() { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_ca_file("../test/server.crt").unwrap(); + let ctx = builder.build(); + smoke_test( + "user=scram_user password=password dbname=postgres channel_binding=require", + TlsConnector::new(ctx.configure().unwrap(), "localhost"), + ) + .await; +} + +#[tokio::test] +#[cfg(feature = "runtime")] +async fn runtime() { + let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); + builder.set_ca_file("../test/server.crt").unwrap(); + let connector = MakeTlsConnector::new(builder.build()); + + let (client, connection) = tokio_postgres::connect( + "host=localhost port=5433 user=postgres sslmode=require", + connector, + ) + .await + .unwrap(); + let connection = connection.map(|r| r.unwrap()); + tokio::spawn(connection); + + let stmt = client.prepare("SELECT $1::INT4").await.unwrap(); + let rows = client.query(&stmt, &[&1i32]).await.unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, i32>(0), 1); } diff --git a/postgres-protocol/CHANGELOG.md b/postgres-protocol/CHANGELOG.md new file mode 100644 index 000000000..25e717128 --- /dev/null +++ b/postgres-protocol/CHANGELOG.md @@ -0,0 +1,137 @@ +# Change Log + +## v0.6.8 - 2025-02-02 + +### Changed + +* Upgraded `getrandom`. + +## v0.6.7 - 2024-07-21 + +### Deprecated + +* Deprecated `ErrorField::value`. + +### Added + +* Added a `Clone` implementation for `DataRowBody`. +* Added `ErrorField::value_bytes`. + +### Changed + +* Upgraded `base64`. + +## v0.6.6 - 2023-08-19 + +### Added + +* Added the `js` feature for WASM support. + +## v0.6.5 - 2023-03-27 + +### Added + +* Added `message::frontend::flush`. +* Added `DataRowBody::buffer_bytes`. + +### Changed + +* Upgraded `base64`. + +## v0.6.4 - 2022-04-03 + +### Added + +* Added parsing support for `ltree`, `lquery`, and `ltxtquery`. + +## v0.6.3 - 2021-12-10 + +### Changed + +* Upgraded `hmac`, `md-5` and `sha`. + +## v0.6.2 - 2021-09-29 + +### Changed + +* Upgraded `hmac`. + +## v0.6.1 - 2021-04-03 + +### Added + +* Added the `password` module, which can be used to hash passwords before using them in queries like `ALTER USER`. +* Added type conversions for `LSN`. + +### Changed + +* Moved from `md5` to `md-5`. + +## v0.6.0 - 2020-12-25 + +### Changed + +* Upgraded `bytes`, `hmac`, and `rand`. + +### Added + +* Added `escape::{escape_literal, escape_identifier}`. + +## v0.5.3 - 2020-10-17 + +### Changed + +* Upgraded `base64` and `hmac`. + +## v0.5.2 - 2020-07-06 + +### Changed + +* Upgraded `hmac` and `sha2`. + +## v0.5.1 - 2020-03-17 + +### Changed + +* Upgraded `base64` to 0.12. + +## v0.5.0 - 2019-12-23 + +### Changed + +* `frontend::Message` is now a true non-exhaustive enum. + +## v0.5.0-alpha.2 - 2019-11-27 + +### Changed + +* Upgraded `bytes` to 0.5. + +## v0.5.0-alpha.1 - 2019-10-14 + +### Changed + +* Frontend messages and types now serialize to `BytesMut` rather than `Vec`. + +## v0.4.1 - 2019-06-29 + +### Added + +* Added `backend::Framed` to minimally parse the structure of backend messages. + +## v0.4.0 - 2019-03-05 + +### Added + +* Added channel binding support to SCRAM authentication API. + +### Changed + +* Passwords are no longer required to be UTF8 strings. +* `types::array_to_sql` now automatically computes the required flags and no longer takes a has_nulls parameter. + +## Older + +Look at the [release tags] for information about older releases. + +[release tags]: https://github.com/sfackler/rust-postgres/releases diff --git a/postgres-protocol/Cargo.toml b/postgres-protocol/Cargo.toml index 6ab4c0077..9351ea14f 100644 --- a/postgres-protocol/Cargo.toml +++ b/postgres-protocol/Cargo.toml @@ -1,21 +1,26 @@ [package] name = "postgres-protocol" -version = "0.3.2" +version = "0.6.8" authors = ["Steven Fackler "] +edition = "2018" description = "Low level Postgres protocol APIs" -license = "MIT/Apache-2.0" -repository = "https://github.com/sfackler/rust-postgres-protocol" +license = "MIT OR Apache-2.0" +repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" +[features] +default = [] +js = ["getrandom/wasm_js"] + [dependencies] -base64 = "0.9" +base64 = "0.22" byteorder = "1.0" -bytes = "0.4" -fallible-iterator = "0.1" -generic-array = "0.11" -hmac = "0.6" -md5 = "0.3" +bytes = "1.0" +fallible-iterator = "0.2" +hmac = "0.12" +md-5 = "0.10" memchr = "2.0" -rand = "0.5" -sha2 = "0.7" +rand = "0.9" +sha2 = "0.10" stringprep = "0.1" +getrandom = { version = "0.3", optional = true } diff --git a/postgres-protocol/LICENSE-APACHE b/postgres-protocol/LICENSE-APACHE deleted file mode 100644 index 16fe87b06..000000000 --- a/postgres-protocol/LICENSE-APACHE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - -1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS - -APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - -Copyright [yyyy] [name of copyright owner] - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. diff --git a/postgres-protocol/LICENSE-APACHE b/postgres-protocol/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/postgres-protocol/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/postgres-protocol/LICENSE-MIT b/postgres-protocol/LICENSE-MIT deleted file mode 100644 index 71803aea1..000000000 --- a/postgres-protocol/LICENSE-MIT +++ /dev/null @@ -1,22 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2016 Steven Fackler - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - diff --git a/postgres-protocol/LICENSE-MIT b/postgres-protocol/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/postgres-protocol/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/postgres-protocol/src/authentication/mod.rs b/postgres-protocol/src/authentication/mod.rs index edacb46e7..71afa4b9b 100644 --- a/postgres-protocol/src/authentication/mod.rs +++ b/postgres-protocol/src/authentication/mod.rs @@ -1,5 +1,5 @@ //! Authentication protocol support. -use md5::Context; +use md5::{Digest, Md5}; pub mod sasl; @@ -10,14 +10,13 @@ pub mod sasl; /// `PasswordMessage` message. #[inline] pub fn md5_hash(username: &[u8], password: &[u8], salt: [u8; 4]) -> String { - let mut context = Context::new(); - context.consume(password); - context.consume(username); - let output = context.compute(); - context = Context::new(); - context.consume(format!("{:x}", output)); - context.consume(&salt); - format!("md5{:x}", context.compute()) + let mut md5 = Md5::new(); + md5.update(password); + md5.update(username); + let output = md5.finalize_reset(); + md5.update(format!("{:x}", output)); + md5.update(salt); + format!("md5{:x}", md5.finalize()) } #[cfg(test)] diff --git a/postgres-protocol/src/authentication/sasl.rs b/postgres-protocol/src/authentication/sasl.rs index dfbb70b26..85a589c99 100644 --- a/postgres-protocol/src/authentication/sasl.rs +++ b/postgres-protocol/src/authentication/sasl.rs @@ -1,24 +1,24 @@ //! SASL-based authentication support. -use base64; -use generic_array::typenum::U32; -use generic_array::GenericArray; +use base64::display::Base64Display; +use base64::engine::general_purpose::STANDARD; +use base64::Engine; use hmac::{Hmac, Mac}; use rand::{self, Rng}; +use sha2::digest::FixedOutput; use sha2::{Digest, Sha256}; use std::fmt::Write; use std::io; use std::iter; use std::mem; use std::str; -use stringprep; const NONCE_LENGTH: usize = 24; /// The identifier of the SCRAM-SHA-256 SASL authentication mechanism. -pub const SCRAM_SHA_256: &'static str = "SCRAM-SHA-256"; +pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; /// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism. -pub const SCRAM_SHA_256_PLUS: &'static str = "SCRAM-SHA-256-PLUS"; +pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; // since postgres passwords are not required to exclude saslprep-prohibited // characters or even be valid UTF8, we run saslprep if possible and otherwise @@ -35,31 +35,31 @@ fn normalize(pass: &[u8]) -> Vec { } } -fn hi(str: &[u8], salt: &[u8], i: u32) -> GenericArray { - let mut hmac = Hmac::::new_varkey(str).expect("HMAC is able to accept all key sizes"); - hmac.input(salt); - hmac.input(&[0, 0, 0, 1]); - let mut prev = hmac.result().code(); +pub(crate) fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] { + let mut hmac = + Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); + hmac.update(salt); + hmac.update(&[0, 0, 0, 1]); + let mut prev = hmac.finalize().into_bytes(); - let mut hi = GenericArray::::clone_from_slice(&prev); + let mut hi = prev; for _ in 1..i { - let mut hmac = Hmac::::new_varkey(str).expect("already checked above"); - hmac.input(prev.as_slice()); - prev = hmac.result().code(); + let mut hmac = Hmac::::new_from_slice(str).expect("already checked above"); + hmac.update(&prev); + prev = hmac.finalize().into_bytes(); for (hi, prev) in hi.iter_mut().zip(prev) { *hi ^= prev; } } - hi + hi.into() } enum ChannelBindingInner { Unrequested, Unsupported, - TlsUnique(Vec), TlsServerEndPoint(Vec), } @@ -77,11 +77,6 @@ impl ChannelBinding { ChannelBinding(ChannelBindingInner::Unsupported) } - /// The server requested channel binding and the client will use the `tls-unique` method. - pub fn tls_unique(finished: Vec) -> ChannelBinding { - ChannelBinding(ChannelBindingInner::TlsUnique(finished)) - } - /// The server requested channel binding and the client will use the `tls-server-end-point` /// method. pub fn tls_server_end_point(signature: Vec) -> ChannelBinding { @@ -92,7 +87,6 @@ impl ChannelBinding { match self.0 { ChannelBindingInner::Unrequested => "y,,", ChannelBindingInner::Unsupported => "n,,", - ChannelBindingInner::TlsUnique(_) => "p=tls-unique,,", ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,", } } @@ -100,8 +94,7 @@ impl ChannelBinding { fn cbind_data(&self) -> &[u8] { match self.0 { ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[], - ChannelBindingInner::TlsUnique(ref buf) - | ChannelBindingInner::TlsServerEndPoint(ref buf) => buf, + ChannelBindingInner::TlsServerEndPoint(ref buf) => buf, } } } @@ -113,7 +106,7 @@ enum State { channel_binding: ChannelBinding, }, Finish { - salted_password: GenericArray, + salted_password: [u8; 32], auth_message: String, }, Done, @@ -143,15 +136,16 @@ impl ScramSha256 { /// Constructs a new instance which will use the provided password for authentication. pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 { // rand 0.5's ThreadRng is cryptographically secure - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let nonce = (0..NONCE_LENGTH) .map(|_| { - let mut v = rng.gen_range(0x21u8, 0x7e); + let mut v = rng.random_range(0x21u8..0x7e); if v == 0x2c { v = 0x7e } v as char - }).collect::(); + }) + .collect::(); ScramSha256::new_inner(password, channel_binding, nonce) } @@ -162,7 +156,7 @@ impl ScramSha256 { state: State::Update { nonce, password: normalize(password), - channel_binding: channel_binding, + channel_binding, }, } } @@ -186,7 +180,7 @@ impl ScramSha256 { password, channel_binding, } => (nonce, password, channel_binding), - _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), + _ => return Err(io::Error::other("invalid SCRAM state")), }; let message = @@ -198,47 +192,52 @@ impl ScramSha256 { return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce")); } - let salt = match base64::decode(parsed.salt) { + let salt = match STANDARD.decode(parsed.salt) { Ok(salt) => salt, Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), }; let salted_password = hi(&password, &salt, parsed.iteration_count); - let mut hmac = Hmac::::new_varkey(&salted_password) + let mut hmac = Hmac::::new_from_slice(&salted_password) .expect("HMAC is able to accept all key sizes"); - hmac.input(b"Client Key"); - let client_key = hmac.result().code(); + hmac.update(b"Client Key"); + let client_key = hmac.finalize().into_bytes(); let mut hash = Sha256::default(); - hash.input(client_key.as_slice()); - let stored_key = hash.result(); + hash.update(client_key.as_slice()); + let stored_key = hash.finalize_fixed(); let mut cbind_input = vec![]; cbind_input.extend(channel_binding.gs2_header().as_bytes()); cbind_input.extend(channel_binding.cbind_data()); - let cbind_input = base64::encode(&cbind_input); + let cbind_input = STANDARD.encode(&cbind_input); self.message.clear(); write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap(); let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message); - let mut hmac = - Hmac::::new_varkey(&stored_key).expect("HMAC is able to accept all key sizes"); - hmac.input(auth_message.as_bytes()); - let client_signature = hmac.result(); + let mut hmac = Hmac::::new_from_slice(&stored_key) + .expect("HMAC is able to accept all key sizes"); + hmac.update(auth_message.as_bytes()); + let client_signature = hmac.finalize().into_bytes(); - let mut client_proof = GenericArray::::clone_from_slice(&client_key); - for (proof, signature) in client_proof.iter_mut().zip(client_signature.code()) { + let mut client_proof = client_key; + for (proof, signature) in client_proof.iter_mut().zip(client_signature) { *proof ^= signature; } - write!(&mut self.message, ",p={}", base64::encode(&*client_proof)).unwrap(); + write!( + &mut self.message, + ",p={}", + Base64Display::new(&client_proof, &STANDARD) + ) + .unwrap(); self.state = State::Finish { - salted_password: salted_password, - auth_message: auth_message, + salted_password, + auth_message, }; Ok(()) } @@ -253,7 +252,7 @@ impl ScramSha256 { salted_password, auth_message, } => (salted_password, auth_message), - _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), + _ => return Err(io::Error::other("invalid SCRAM state")), }; let message = @@ -263,28 +262,25 @@ impl ScramSha256 { let verifier = match parsed { ServerFinalMessage::Error(e) => { - return Err(io::Error::new( - io::ErrorKind::Other, - format!("SCRAM error: {}", e), - )) + return Err(io::Error::other(format!("SCRAM error: {}", e))); } ServerFinalMessage::Verifier(verifier) => verifier, }; - let verifier = match base64::decode(verifier) { + let verifier = match STANDARD.decode(verifier) { Ok(verifier) => verifier, Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), }; - let mut hmac = Hmac::::new_varkey(&salted_password) + let mut hmac = Hmac::::new_from_slice(&salted_password) .expect("HMAC is able to accept all key sizes"); - hmac.input(b"Server Key"); - let server_key = hmac.result(); + hmac.update(b"Server Key"); + let server_key = hmac.finalize().into_bytes(); - let mut hmac = Hmac::::new_varkey(&server_key.code()) + let mut hmac = Hmac::::new_from_slice(&server_key) .expect("HMAC is able to accept all key sizes"); - hmac.input(auth_message.as_bytes()); - hmac.verify(&verifier) + hmac.update(auth_message.as_bytes()); + hmac.verify_slice(&verifier) .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error")) } } @@ -297,7 +293,7 @@ struct Parser<'a> { impl<'a> Parser<'a> { fn new(s: &'a str) -> Parser<'a> { Parser { - s: s, + s, it: s.char_indices().peekable(), } } @@ -340,10 +336,7 @@ impl<'a> Parser<'a> { } fn printable(&mut self) -> io::Result<&'a str> { - self.take_while(|c| match c { - '\x21'...'\x2b' | '\x2d'...'\x7e' => true, - _ => false, - }) + self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e')) } fn nonce(&mut self) -> io::Result<&'a str> { @@ -353,10 +346,7 @@ impl<'a> Parser<'a> { } fn base64(&mut self) -> io::Result<&'a str> { - self.take_while(|c| match c { - 'a'...'z' | 'A'...'Z' | '0'...'9' | '/' | '+' | '=' => true, - _ => false, - }) + self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '=')) } fn salt(&mut self) -> io::Result<&'a str> { @@ -366,10 +356,7 @@ impl<'a> Parser<'a> { } fn posit_number(&mut self) -> io::Result { - let n = self.take_while(|c| match c { - '0'...'9' => true, - _ => false, - })?; + let n = self.take_while(|c| c.is_ascii_digit())?; n.parse() .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) } @@ -399,17 +386,14 @@ impl<'a> Parser<'a> { self.eof()?; Ok(ServerFirstMessage { - nonce: nonce, - salt: salt, - iteration_count: iteration_count, + nonce, + salt, + iteration_count, }) } fn value(&mut self) -> io::Result<&'a str> { - self.take_while(|c| match c { - '\0' | '=' | ',' => false, - _ => true, - }) + self.take_while(|c| matches!(c, '\0' | '=' | ',')) } fn server_error(&mut self) -> io::Result> { diff --git a/postgres-protocol/src/escape/mod.rs b/postgres-protocol/src/escape/mod.rs new file mode 100644 index 000000000..0ba7efdca --- /dev/null +++ b/postgres-protocol/src/escape/mod.rs @@ -0,0 +1,93 @@ +//! Provides functions for escaping literals and identifiers for use +//! in SQL queries. +//! +//! Prefer parameterized queries where possible. Do not escape +//! parameters in a parameterized query. + +#[cfg(test)] +mod test; + +/// Escape a literal and surround result with single quotes. Not +/// recommended in most cases. +/// +/// If input contains backslashes, result will be of the form ` +/// E'...'` so it is safe to use regardless of the setting of +/// standard_conforming_strings. +pub fn escape_literal(input: &str) -> String { + escape_internal(input, false) +} + +/// Escape an identifier and surround result with double quotes. +pub fn escape_identifier(input: &str) -> String { + escape_internal(input, true) +} + +// Translation of PostgreSQL libpq's PQescapeInternal(). Does not +// require a connection because input string is known to be valid +// UTF-8. +// +// Escape arbitrary strings. If as_ident is true, we escape the +// result as an identifier; if false, as a literal. The result is +// returned in a newly allocated buffer. If we fail due to an +// encoding violation or out of memory condition, we return NULL, +// storing an error message into conn. +fn escape_internal(input: &str, as_ident: bool) -> String { + let mut num_backslashes = 0; + let mut num_quotes = 0; + let quote_char = if as_ident { '"' } else { '\'' }; + + // Scan the string for characters that must be escaped. + for ch in input.chars() { + if ch == quote_char { + num_quotes += 1; + } else if ch == '\\' { + num_backslashes += 1; + } + } + + // Allocate output String. + let mut result_size = input.len() + num_quotes + 3; // two quotes, plus a NUL + if !as_ident && num_backslashes > 0 { + result_size += num_backslashes + 2; + } + + let mut output = String::with_capacity(result_size); + + // If we are escaping a literal that contains backslashes, we use + // the escape string syntax so that the result is correct under + // either value of standard_conforming_strings. We also emit a + // leading space in this case, to guard against the possibility + // that the result might be interpolated immediately following an + // identifier. + if !as_ident && num_backslashes > 0 { + output.push(' '); + output.push('E'); + } + + // Opening quote. + output.push(quote_char); + + // Use fast path if possible. + // + // We've already verified that the input string is well-formed in + // the current encoding. If it contains no quotes and, in the + // case of literal-escaping, no backslashes, then we can just copy + // it directly to the output buffer, adding the necessary quotes. + // + // If not, we must rescan the input and process each character + // individually. + if num_quotes == 0 && (num_backslashes == 0 || as_ident) { + output.push_str(input); + } else { + for ch in input.chars() { + if ch == quote_char || (!as_ident && ch == '\\') { + output.push(ch); + } + output.push(ch); + } + } + + output.push(quote_char); + + output +} diff --git a/postgres-protocol/src/escape/test.rs b/postgres-protocol/src/escape/test.rs new file mode 100644 index 000000000..4816a103b --- /dev/null +++ b/postgres-protocol/src/escape/test.rs @@ -0,0 +1,17 @@ +use crate::escape::{escape_identifier, escape_literal}; + +#[test] +fn test_escape_idenifier() { + assert_eq!(escape_identifier("foo"), String::from("\"foo\"")); + assert_eq!(escape_identifier("f\\oo"), String::from("\"f\\oo\"")); + assert_eq!(escape_identifier("f'oo"), String::from("\"f'oo\"")); + assert_eq!(escape_identifier("f\"oo"), String::from("\"f\"\"oo\"")); +} + +#[test] +fn test_escape_literal() { + assert_eq!(escape_literal("foo"), String::from("'foo'")); + assert_eq!(escape_literal("f\\oo"), String::from(" E'f\\\\oo'")); + assert_eq!(escape_literal("f'oo"), String::from("'f''oo'")); + assert_eq!(escape_literal("f\"oo"), String::from("'f\"oo'")); +} diff --git a/postgres-protocol/src/lib.rs b/postgres-protocol/src/lib.rs index f49165ff3..e0de3b6c6 100644 --- a/postgres-protocol/src/lib.rs +++ b/postgres-protocol/src/lib.rs @@ -9,30 +9,24 @@ //! //! This library assumes that the `client_encoding` backend parameter has been //! set to `UTF8`. It will most likely not behave properly if that is not the case. -#![doc(html_root_url="https://docs.rs/postgres-protocol/0.3")] -#![warn(missing_docs)] -extern crate base64; -extern crate byteorder; -extern crate bytes; -extern crate fallible_iterator; -extern crate generic_array; -extern crate hmac; -extern crate md5; -extern crate memchr; -extern crate rand; -extern crate sha2; -extern crate stringprep; +#![warn(missing_docs, rust_2018_idioms, clippy::all)] use byteorder::{BigEndian, ByteOrder}; +use bytes::{BufMut, BytesMut}; use std::io; pub mod authentication; +pub mod escape; pub mod message; +pub mod password; pub mod types; /// A Postgres OID. pub type Oid = u32; +/// A Postgres Log Sequence Number (LSN). +pub type Lsn = u64; + /// An enum indicating if a value is `NULL` or not. pub enum IsNull { /// The value is `NULL`. @@ -41,14 +35,13 @@ pub enum IsNull { No, } -#[inline] -fn write_nullable(serializer: F, buf: &mut Vec) -> Result<(), E> +fn write_nullable(serializer: F, buf: &mut BytesMut) -> Result<(), E> where - F: FnOnce(&mut Vec) -> Result, + F: FnOnce(&mut BytesMut) -> Result, E: From, { let base = buf.len(); - buf.extend_from_slice(&[0; 4]); + buf.put_i32(0); let size = match serializer(buf)? { IsNull::No => i32::from_usize(buf.len() - base - 4)?, IsNull::Yes => -1, @@ -67,14 +60,17 @@ macro_rules! from_usize { impl FromUsize for $t { #[inline] fn from_usize(x: usize) -> io::Result<$t> { - if x > <$t>::max_value() as usize { - Err(io::Error::new(io::ErrorKind::InvalidInput, "value too large to transmit")) + if x > <$t>::MAX as usize { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "value too large to transmit", + )) } else { Ok(x as $t) } } } - } + }; } from_usize!(i16); diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index eacb5da47..013bfbb81 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -1,6 +1,6 @@ #![allow(missing_docs)] -use byteorder::{BigEndian, ReadBytesExt}; +use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; use bytes::{Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use memchr::memchr; @@ -9,9 +9,70 @@ use std::io::{self, Read}; use std::ops::Range; use std::str; -use Oid; +use crate::Oid; + +pub const PARSE_COMPLETE_TAG: u8 = b'1'; +pub const BIND_COMPLETE_TAG: u8 = b'2'; +pub const CLOSE_COMPLETE_TAG: u8 = b'3'; +pub const NOTIFICATION_RESPONSE_TAG: u8 = b'A'; +pub const COPY_DONE_TAG: u8 = b'c'; +pub const COMMAND_COMPLETE_TAG: u8 = b'C'; +pub const COPY_DATA_TAG: u8 = b'd'; +pub const DATA_ROW_TAG: u8 = b'D'; +pub const ERROR_RESPONSE_TAG: u8 = b'E'; +pub const COPY_IN_RESPONSE_TAG: u8 = b'G'; +pub const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +pub const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; +pub const BACKEND_KEY_DATA_TAG: u8 = b'K'; +pub const NO_DATA_TAG: u8 = b'n'; +pub const NOTICE_RESPONSE_TAG: u8 = b'N'; +pub const AUTHENTICATION_TAG: u8 = b'R'; +pub const PORTAL_SUSPENDED_TAG: u8 = b's'; +pub const PARAMETER_STATUS_TAG: u8 = b'S'; +pub const PARAMETER_DESCRIPTION_TAG: u8 = b't'; +pub const ROW_DESCRIPTION_TAG: u8 = b'T'; +pub const READY_FOR_QUERY_TAG: u8 = b'Z'; + +#[derive(Debug, Copy, Clone)] +pub struct Header { + tag: u8, + len: i32, +} + +#[allow(clippy::len_without_is_empty)] +impl Header { + #[inline] + pub fn parse(buf: &[u8]) -> io::Result> { + if buf.len() < 5 { + return Ok(None); + } + + let tag = buf[0]; + let len = BigEndian::read_i32(&buf[1..]); + + if len < 4 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length: header length < 4", + )); + } + + Ok(Some(Header { tag, len })) + } + + #[inline] + pub fn tag(self) -> u8 { + self.tag + } + + #[inline] + pub fn len(self) -> i32 { + self.len + } +} /// An enum representing Postgres backend messages. +#[non_exhaustive] pub enum Message { AuthenticationCleartextPassword, AuthenticationGss, @@ -44,8 +105,6 @@ pub enum Message { PortalSuspended, ReadyForQuery(ReadyForQueryBody), RowDescription(RowDescriptionBody), - #[doc(hidden)] - __ForExtensibility, } impl Message { @@ -63,7 +122,7 @@ impl Message { if len < 4 { return Err(io::Error::new( io::ErrorKind::InvalidInput, - "invalid message length", + "invalid message length: parsing u32", )); } @@ -80,82 +139,79 @@ impl Message { }; let message = match tag { - b'1' => Message::ParseComplete, - b'2' => Message::BindComplete, - b'3' => Message::CloseComplete, - b'A' => { + PARSE_COMPLETE_TAG => Message::ParseComplete, + BIND_COMPLETE_TAG => Message::BindComplete, + CLOSE_COMPLETE_TAG => Message::CloseComplete, + NOTIFICATION_RESPONSE_TAG => { let process_id = buf.read_i32::()?; let channel = buf.read_cstr()?; let message = buf.read_cstr()?; Message::NotificationResponse(NotificationResponseBody { - process_id: process_id, - channel: channel, - message: message, + process_id, + channel, + message, }) } - b'c' => Message::CopyDone, - b'C' => { + COPY_DONE_TAG => Message::CopyDone, + COMMAND_COMPLETE_TAG => { let tag = buf.read_cstr()?; - Message::CommandComplete(CommandCompleteBody { tag: tag }) + Message::CommandComplete(CommandCompleteBody { tag }) } - b'd' => { + COPY_DATA_TAG => { let storage = buf.read_all(); - Message::CopyData(CopyDataBody { storage: storage }) + Message::CopyData(CopyDataBody { storage }) } - b'D' => { + DATA_ROW_TAG => { let len = buf.read_u16::()?; let storage = buf.read_all(); - Message::DataRow(DataRowBody { - storage: storage, - len: len, - }) + Message::DataRow(DataRowBody { storage, len }) } - b'E' => { + ERROR_RESPONSE_TAG => { let storage = buf.read_all(); - Message::ErrorResponse(ErrorResponseBody { storage: storage }) + Message::ErrorResponse(ErrorResponseBody { storage }) } - b'G' => { + COPY_IN_RESPONSE_TAG => { let format = buf.read_u8()?; let len = buf.read_u16::()?; let storage = buf.read_all(); Message::CopyInResponse(CopyInResponseBody { - format: format, - len: len, - storage: storage, + format, + len, + storage, }) } - b'H' => { + COPY_OUT_RESPONSE_TAG => { let format = buf.read_u8()?; let len = buf.read_u16::()?; let storage = buf.read_all(); Message::CopyOutResponse(CopyOutResponseBody { - format: format, - len: len, - storage: storage, + format, + len, + storage, }) } - b'I' => Message::EmptyQueryResponse, - b'K' => { + EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, + BACKEND_KEY_DATA_TAG => { let process_id = buf.read_i32::()?; let secret_key = buf.read_i32::()?; Message::BackendKeyData(BackendKeyDataBody { - process_id: process_id, - secret_key: secret_key, + process_id, + secret_key, }) } - b'n' => Message::NoData, - b'N' => { + NO_DATA_TAG => Message::NoData, + NOTICE_RESPONSE_TAG => { let storage = buf.read_all(); - Message::NoticeResponse(NoticeResponseBody { storage: storage }) + Message::NoticeResponse(NoticeResponseBody { storage }) } - b'R' => match buf.read_i32::()? { + AUTHENTICATION_TAG => match buf.read_i32::()? { 0 => Message::AuthenticationOk, 2 => Message::AuthenticationKerberosV5, 3 => Message::AuthenticationCleartextPassword, 5 => { let mut salt = [0; 4]; buf.read_exact(&mut salt)?; - Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt: salt }) + Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt }) } 6 => Message::AuthenticationScmCredential, 7 => Message::AuthenticationGss, @@ -183,34 +239,25 @@ impl Message { )); } }, - b's' => Message::PortalSuspended, - b'S' => { + PORTAL_SUSPENDED_TAG => Message::PortalSuspended, + PARAMETER_STATUS_TAG => { let name = buf.read_cstr()?; let value = buf.read_cstr()?; - Message::ParameterStatus(ParameterStatusBody { - name: name, - value: value, - }) + Message::ParameterStatus(ParameterStatusBody { name, value }) } - b't' => { + PARAMETER_DESCRIPTION_TAG => { let len = buf.read_u16::()?; let storage = buf.read_all(); - Message::ParameterDescription(ParameterDescriptionBody { - storage: storage, - len: len, - }) + Message::ParameterDescription(ParameterDescriptionBody { storage, len }) } - b'T' => { + ROW_DESCRIPTION_TAG => { let len = buf.read_u16::()?; let storage = buf.read_all(); - Message::RowDescription(RowDescriptionBody { - storage: storage, - len: len, - }) + Message::RowDescription(RowDescriptionBody { storage, len }) } - b'Z' => { + READY_FOR_QUERY_TAG => { let status = buf.read_u8()?; - Message::ReadyForQuery(ReadyForQueryBody { status: status }) + Message::ReadyForQuery(ReadyForQueryBody { status }) } tag => { return Err(io::Error::new( @@ -223,7 +270,7 @@ impl Message { if !buf.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidInput, - "invalid message length", + "invalid message length: expected buffer to be empty", )); } @@ -237,20 +284,23 @@ struct Buffer { } impl Buffer { + #[inline] fn slice(&self) -> &[u8] { &self.bytes[self.idx..] } + #[inline] fn is_empty(&self) -> bool { self.slice().is_empty() } + #[inline] fn read_cstr(&mut self) -> io::Result { match memchr(0, self.slice()) { Some(pos) => { let start = self.idx; let end = start + pos; - let cstr = self.bytes.slice(start, end); + let cstr = self.bytes.slice(start..end); self.idx = end + 1; Ok(cstr) } @@ -261,14 +311,16 @@ impl Buffer { } } + #[inline] fn read_all(&mut self) -> Bytes { - let buf = self.bytes.slice_from(self.idx); + let buf = self.bytes.slice(self.idx..); self.idx = self.bytes.len(); buf } } impl Read for Buffer { + #[inline] fn read(&mut self, buf: &mut [u8]) -> io::Result { let len = { let slice = self.slice(); @@ -305,7 +357,7 @@ pub struct AuthenticationSaslBody(Bytes); impl AuthenticationSaslBody { #[inline] - pub fn mechanisms<'a>(&'a self) -> SaslMechanisms<'a> { + pub fn mechanisms(&self) -> SaslMechanisms<'_> { SaslMechanisms(&self.0) } } @@ -323,7 +375,7 @@ impl<'a> FallibleIterator for SaslMechanisms<'a> { if self.0.len() != 1 { return Err(io::Error::new( io::ErrorKind::InvalidData, - "invalid message length", + "invalid message length: expected to be at end of iterator for sasl", )); } Ok(None) @@ -398,9 +450,9 @@ impl CopyDataBody { } pub struct CopyInResponseBody { - storage: Bytes, - len: u16, format: u8, + len: u16, + storage: Bytes, } impl CopyInResponseBody { @@ -410,7 +462,7 @@ impl CopyInResponseBody { } #[inline] - pub fn column_formats<'a>(&'a self) -> ColumnFormats<'a> { + pub fn column_formats(&self) -> ColumnFormats<'_> { ColumnFormats { remaining: self.len, buf: &self.storage, @@ -423,7 +475,7 @@ pub struct ColumnFormats<'a> { remaining: u16, } -impl<'a> FallibleIterator for ColumnFormats<'a> { +impl FallibleIterator for ColumnFormats<'_> { type Item = u16; type Error = io::Error; @@ -435,7 +487,7 @@ impl<'a> FallibleIterator for ColumnFormats<'a> { } else { return Err(io::Error::new( io::ErrorKind::InvalidInput, - "invalid message length", + "invalid message length: wrong column formats", )); } } @@ -452,9 +504,9 @@ impl<'a> FallibleIterator for ColumnFormats<'a> { } pub struct CopyOutResponseBody { - storage: Bytes, - len: u16, format: u8, + len: u16, + storage: Bytes, } impl CopyOutResponseBody { @@ -464,7 +516,7 @@ impl CopyOutResponseBody { } #[inline] - pub fn column_formats<'a>(&'a self) -> ColumnFormats<'a> { + pub fn column_formats(&self) -> ColumnFormats<'_> { ColumnFormats { remaining: self.len, buf: &self.storage, @@ -472,6 +524,7 @@ impl CopyOutResponseBody { } } +#[derive(Debug, Clone)] pub struct DataRowBody { storage: Bytes, len: u16, @@ -479,7 +532,7 @@ pub struct DataRowBody { impl DataRowBody { #[inline] - pub fn ranges<'a>(&'a self) -> DataRowRanges<'a> { + pub fn ranges(&self) -> DataRowRanges<'_> { DataRowRanges { buf: &self.storage, len: self.storage.len(), @@ -491,6 +544,11 @@ impl DataRowBody { pub fn buffer(&self) -> &[u8] { &self.storage } + + #[inline] + pub fn buffer_bytes(&self) -> &Bytes { + &self.storage + } } pub struct DataRowRanges<'a> { @@ -499,7 +557,7 @@ pub struct DataRowRanges<'a> { remaining: u16, } -impl<'a> FallibleIterator for DataRowRanges<'a> { +impl FallibleIterator for DataRowRanges<'_> { type Item = Option>; type Error = io::Error; @@ -511,7 +569,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> { } else { return Err(io::Error::new( io::ErrorKind::InvalidInput, - "invalid message length", + "invalid message length: datarowrange is not empty", )); } } @@ -529,7 +587,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> { )); } let base = self.len - self.buf.len(); - self.buf = &self.buf[len as usize..]; + self.buf = &self.buf[len..]; Ok(Some(Some(base..base + len))) } } @@ -547,7 +605,7 @@ pub struct ErrorResponseBody { impl ErrorResponseBody { #[inline] - pub fn fields<'a>(&'a self) -> ErrorFields<'a> { + pub fn fields(&self) -> ErrorFields<'_> { ErrorFields { buf: &self.storage } } } @@ -569,35 +627,38 @@ impl<'a> FallibleIterator for ErrorFields<'a> { } else { return Err(io::Error::new( io::ErrorKind::InvalidInput, - "invalid message length", + "invalid message length: error fields is not drained", )); } } let value_end = find_null(self.buf, 0)?; - let value = get_str(&self.buf[..value_end])?; + let value = &self.buf[..value_end]; self.buf = &self.buf[value_end + 1..]; - Ok(Some(ErrorField { - type_: type_, - value: value, - })) + Ok(Some(ErrorField { type_, value })) } } pub struct ErrorField<'a> { type_: u8, - value: &'a str, + value: &'a [u8], } -impl<'a> ErrorField<'a> { +impl ErrorField<'_> { #[inline] pub fn type_(&self) -> u8 { self.type_ } #[inline] + #[deprecated(note = "use value_bytes instead", since = "0.6.7")] pub fn value(&self) -> &str { + str::from_utf8(self.value).expect("error field value contained non-UTF8 bytes") + } + + #[inline] + pub fn value_bytes(&self) -> &[u8] { self.value } } @@ -608,7 +669,7 @@ pub struct NoticeResponseBody { impl NoticeResponseBody { #[inline] - pub fn fields<'a>(&'a self) -> ErrorFields<'a> { + pub fn fields(&self) -> ErrorFields<'_> { ErrorFields { buf: &self.storage } } } @@ -643,7 +704,7 @@ pub struct ParameterDescriptionBody { impl ParameterDescriptionBody { #[inline] - pub fn parameters<'a>(&'a self) -> Parameters<'a> { + pub fn parameters(&self) -> Parameters<'_> { Parameters { buf: &self.storage, remaining: self.len, @@ -656,7 +717,7 @@ pub struct Parameters<'a> { remaining: u16, } -impl<'a> FallibleIterator for Parameters<'a> { +impl FallibleIterator for Parameters<'_> { type Item = Oid; type Error = io::Error; @@ -668,7 +729,7 @@ impl<'a> FallibleIterator for Parameters<'a> { } else { return Err(io::Error::new( io::ErrorKind::InvalidInput, - "invalid message length", + "invalid message length: parameters is not drained", )); } } @@ -719,7 +780,7 @@ pub struct RowDescriptionBody { impl RowDescriptionBody { #[inline] - pub fn fields<'a>(&'a self) -> Fields<'a> { + pub fn fields(&self) -> Fields<'_> { Fields { buf: &self.storage, remaining: self.len, @@ -744,7 +805,7 @@ impl<'a> FallibleIterator for Fields<'a> { } else { return Err(io::Error::new( io::ErrorKind::InvalidInput, - "invalid message length", + "invalid message length: field is not drained", )); } } @@ -761,13 +822,13 @@ impl<'a> FallibleIterator for Fields<'a> { let format = self.buf.read_i16::()?; Ok(Some(Field { - name: name, - table_oid: table_oid, - column_id: column_id, - type_oid: type_oid, - type_size: type_size, - type_modifier: type_modifier, - format: format, + name, + table_oid, + column_id, + type_oid, + type_size, + type_modifier, + format, })) } } diff --git a/postgres-protocol/src/message/frontend.rs b/postgres-protocol/src/message/frontend.rs index a340df0ce..600f7da48 100644 --- a/postgres-protocol/src/message/frontend.rs +++ b/postgres-protocol/src/message/frontend.rs @@ -1,113 +1,19 @@ //! Frontend message serialization. #![allow(missing_docs)] -use byteorder::{WriteBytesExt, BigEndian, ByteOrder}; +use byteorder::{BigEndian, ByteOrder}; +use bytes::{Buf, BufMut, BytesMut}; +use std::convert::TryFrom; use std::error::Error; use std::io; use std::marker; -use {Oid, FromUsize, IsNull, write_nullable}; - -pub enum Message<'a> { - Bind { - portal: &'a str, - statement: &'a str, - formats: &'a [i16], - values: &'a [Option>], - result_formats: &'a [i16], - }, - CancelRequest { process_id: i32, secret_key: i32 }, - Close { variant: u8, name: &'a str }, - CopyData { data: &'a [u8] }, - CopyDone, - CopyFail { message: &'a str }, - Describe { variant: u8, name: &'a str }, - Execute { portal: &'a str, max_rows: i32 }, - Parse { - name: &'a str, - query: &'a str, - param_types: &'a [Oid], - }, - PasswordMessage { password: &'a str }, - Query { query: &'a str }, - SaslInitialResponse { mechanism: &'a str, data: &'a [u8] }, - SaslResponse { data: &'a [u8] }, - SslRequest, - StartupMessage { parameters: &'a [(String, String)] }, - Sync, - Terminate, - #[doc(hidden)] - __ForExtensibility, -} - -impl<'a> Message<'a> { - #[inline] - pub fn serialize(&self, buf: &mut Vec) -> io::Result<()> { - match *self { - Message::Bind { - portal, - statement, - formats, - values, - result_formats, - } => { - let r = bind( - portal, - statement, - formats.iter().cloned(), - values, - |v, buf| match *v { - Some(ref v) => { - buf.extend_from_slice(v); - Ok(IsNull::No) - } - None => Ok(IsNull::Yes), - }, - result_formats.iter().cloned(), - buf, - ); - match r { - Ok(()) => Ok(()), - Err(BindError::Conversion(_)) => unreachable!(), - Err(BindError::Serialization(e)) => Err(e), - } - } - Message::CancelRequest { - process_id, - secret_key, - } => Ok(cancel_request(process_id, secret_key, buf)), - Message::Close { variant, name } => close(variant, name, buf), - Message::CopyData { data } => copy_data(data, buf), - Message::CopyDone => Ok(copy_done(buf)), - Message::CopyFail { message } => copy_fail(message, buf), - Message::Describe { variant, name } => describe(variant, name, buf), - Message::Execute { portal, max_rows } => execute(portal, max_rows, buf), - Message::Parse { - name, - query, - param_types, - } => parse(name, query, param_types.iter().cloned(), buf), - Message::PasswordMessage { password } => password_message(password, buf), - Message::Query { query: q } => query(q, buf), - Message::SaslInitialResponse { mechanism, data } => { - sasl_initial_response(mechanism, data, buf) - } - Message::SaslResponse { data } => sasl_response(data, buf), - Message::SslRequest => Ok(ssl_request(buf)), - Message::StartupMessage { parameters } => { - startup_message(parameters.iter().map(|&(ref k, ref v)| (&**k, &**v)), buf) - } - Message::Sync => Ok(sync(buf)), - Message::Terminate => Ok(terminate(buf)), - Message::__ForExtensibility => unreachable!(), - } - } -} +use crate::{write_nullable, FromUsize, IsNull, Oid}; #[inline] -fn write_body(buf: &mut Vec, f: F) -> Result<(), E> +fn write_body(buf: &mut BytesMut, f: F) -> Result<(), E> where - F: FnOnce(&mut Vec) -> Result<(), E>, + F: FnOnce(&mut BytesMut) -> Result<(), E>, E: From, { let base = buf.len(); @@ -121,13 +27,13 @@ where } pub enum BindError { - Conversion(Box), + Conversion(Box), Serialization(io::Error), } -impl From> for BindError { +impl From> for BindError { #[inline] - fn from(e: Box) -> BindError { + fn from(e: Box) -> BindError { BindError::Conversion(e) } } @@ -147,36 +53,50 @@ pub fn bind( values: J, mut serializer: F, result_formats: K, - buf: &mut Vec, + buf: &mut BytesMut, ) -> Result<(), BindError> where I: IntoIterator, J: IntoIterator, - F: FnMut(T, &mut Vec) -> Result>, + F: FnMut(T, &mut BytesMut) -> Result>, K: IntoIterator, { - buf.push(b'B'); + buf.put_u8(b'B'); write_body(buf, |buf| { - buf.write_cstr(portal)?; - buf.write_cstr(statement)?; - write_counted(formats, |f, buf| buf.write_i16::(f), buf)?; + write_cstr(portal.as_bytes(), buf)?; + write_cstr(statement.as_bytes(), buf)?; + write_counted( + formats, + |f, buf| { + buf.put_i16(f); + Ok::<_, io::Error>(()) + }, + buf, + )?; write_counted( values, |v, buf| write_nullable(|buf| serializer(v, buf), buf), buf, )?; - write_counted(result_formats, |f, buf| buf.write_i16::(f), buf)?; + write_counted( + result_formats, + |f, buf| { + buf.put_i16(f); + Ok::<_, io::Error>(()) + }, + buf, + )?; Ok(()) }) } #[inline] -fn write_counted(items: I, mut serializer: F, buf: &mut Vec) -> Result<(), E> +fn write_counted(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E> where I: IntoIterator, - F: FnMut(T, &mut Vec) -> Result<(), E>, + F: FnMut(T, &mut BytesMut) -> Result<(), E>, E: From, { let base = buf.len(); @@ -193,156 +113,191 @@ where } #[inline] -pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut Vec) { +pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) { write_body(buf, |buf| { - buf.write_i32::(80877102).unwrap(); - buf.write_i32::(process_id).unwrap(); - buf.write_i32::(secret_key) - }).unwrap(); + buf.put_i32(80_877_102); + buf.put_i32(process_id); + buf.put_i32(secret_key); + Ok::<_, io::Error>(()) + }) + .unwrap(); } #[inline] -pub fn close(variant: u8, name: &str, buf: &mut Vec) -> io::Result<()> { - buf.push(b'C'); +pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'C'); write_body(buf, |buf| { - buf.push(variant); - buf.write_cstr(name) + buf.put_u8(variant); + write_cstr(name.as_bytes(), buf) }) } -// FIXME ideally this'd take a Read but it's unclear what to do at EOF -#[inline] -pub fn copy_data(data: &[u8], buf: &mut Vec) -> io::Result<()> { - buf.push(b'd'); - write_body(buf, |buf| { - buf.extend_from_slice(data); - Ok(()) - }) +pub struct CopyData { + buf: T, + len: i32, +} + +impl CopyData +where + T: Buf, +{ + pub fn new(buf: T) -> io::Result> { + let len = buf + .remaining() + .checked_add(4) + .and_then(|l| i32::try_from(l).ok()) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "message length overflow") + })?; + + Ok(CopyData { buf, len }) + } + + pub fn write(self, out: &mut BytesMut) { + out.put_u8(b'd'); + out.put_i32(self.len); + out.put(self.buf); + } } #[inline] -pub fn copy_done(buf: &mut Vec) { - buf.push(b'c'); +pub fn copy_done(buf: &mut BytesMut) { + buf.put_u8(b'c'); write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); } #[inline] -pub fn copy_fail(message: &str, buf: &mut Vec) -> io::Result<()> { - buf.push(b'f'); - write_body(buf, |buf| buf.write_cstr(message)) +pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'f'); + write_body(buf, |buf| write_cstr(message.as_bytes(), buf)) } #[inline] -pub fn describe(variant: u8, name: &str, buf: &mut Vec) -> io::Result<()> { - buf.push(b'D'); +pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'D'); write_body(buf, |buf| { - buf.push(variant); - buf.write_cstr(name) + buf.put_u8(variant); + write_cstr(name.as_bytes(), buf) }) } #[inline] -pub fn execute(portal: &str, max_rows: i32, buf: &mut Vec) -> io::Result<()> { - buf.push(b'E'); +pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'E'); write_body(buf, |buf| { - buf.write_cstr(portal)?; - buf.write_i32::(max_rows).unwrap(); + write_cstr(portal.as_bytes(), buf)?; + buf.put_i32(max_rows); Ok(()) }) } #[inline] -pub fn parse(name: &str, query: &str, param_types: I, buf: &mut Vec) -> io::Result<()> +pub fn parse(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()> where I: IntoIterator, { - buf.push(b'P'); + buf.put_u8(b'P'); write_body(buf, |buf| { - buf.write_cstr(name)?; - buf.write_cstr(query)?; - write_counted(param_types, |t, buf| buf.write_u32::(t), buf)?; + write_cstr(name.as_bytes(), buf)?; + write_cstr(query.as_bytes(), buf)?; + write_counted( + param_types, + |t, buf| { + buf.put_u32(t); + Ok::<_, io::Error>(()) + }, + buf, + )?; Ok(()) }) } #[inline] -pub fn password_message(password: &str, buf: &mut Vec) -> io::Result<()> { - buf.push(b'p'); - write_body(buf, |buf| buf.write_cstr(password)) +pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); + write_body(buf, |buf| write_cstr(password, buf)) } #[inline] -pub fn query(query: &str, buf: &mut Vec) -> io::Result<()> { - buf.push(b'Q'); - write_body(buf, |buf| buf.write_cstr(query)) +pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'Q'); + write_body(buf, |buf| write_cstr(query.as_bytes(), buf)) } #[inline] -pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut Vec) -> io::Result<()> { - buf.push(b'p'); +pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); write_body(buf, |buf| { - buf.write_cstr(mechanism)?; + write_cstr(mechanism.as_bytes(), buf)?; let len = i32::from_usize(data.len())?; - buf.write_i32::(len)?; - buf.extend_from_slice(data); + buf.put_i32(len); + buf.put_slice(data); Ok(()) }) } #[inline] -pub fn sasl_response(data: &[u8], buf: &mut Vec) -> io::Result<()> { - buf.push(b'p'); - write_body(buf, |buf| Ok(buf.extend_from_slice(data))) +pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); + write_body(buf, |buf| { + buf.put_slice(data); + Ok(()) + }) } #[inline] -pub fn ssl_request(buf: &mut Vec) { - write_body(buf, |buf| buf.write_i32::(80877103)).unwrap(); +pub fn ssl_request(buf: &mut BytesMut) { + write_body(buf, |buf| { + buf.put_i32(80_877_103); + Ok::<_, io::Error>(()) + }) + .unwrap(); } #[inline] -pub fn startup_message<'a, I>(parameters: I, buf: &mut Vec) -> io::Result<()> +pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()> where I: IntoIterator, { write_body(buf, |buf| { - buf.write_i32::(196608).unwrap(); + // postgres protocol version 3.0(196608) in bigger-endian + buf.put_i32(0x00_03_00_00); for (key, value) in parameters { - buf.write_cstr(key.as_ref())?; - buf.write_cstr(value.as_ref())?; + write_cstr(key.as_bytes(), buf)?; + write_cstr(value.as_bytes(), buf)?; } - buf.push(0); + buf.put_u8(0); Ok(()) }) } #[inline] -pub fn sync(buf: &mut Vec) { - buf.push(b'S'); +pub fn flush(buf: &mut BytesMut) { + buf.put_u8(b'H'); write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); } #[inline] -pub fn terminate(buf: &mut Vec) { - buf.push(b'X'); +pub fn sync(buf: &mut BytesMut) { + buf.put_u8(b'S'); write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); } -trait WriteCStr { - fn write_cstr(&mut self, s: &str) -> Result<(), io::Error>; +#[inline] +pub fn terminate(buf: &mut BytesMut) { + buf.put_u8(b'X'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); } -impl WriteCStr for Vec { - #[inline] - fn write_cstr(&mut self, s: &str) -> Result<(), io::Error> { - if s.as_bytes().contains(&0) { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "string contains embedded null", - )); - } - self.extend_from_slice(s.as_bytes()); - self.push(0); - Ok(()) +#[inline] +fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { + if s.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "string contains embedded null", + )); } + buf.put_slice(s); + buf.put_u8(0); + Ok(()) } diff --git a/postgres-protocol/src/password/mod.rs b/postgres-protocol/src/password/mod.rs new file mode 100644 index 000000000..445fb0c0e --- /dev/null +++ b/postgres-protocol/src/password/mod.rs @@ -0,0 +1,106 @@ +//! Functions to encrypt a password in the client. +//! +//! This is intended to be used by client applications that wish to +//! send commands like `ALTER USER joe PASSWORD 'pwd'`. The password +//! need not be sent in cleartext if it is encrypted on the client +//! side. This is good because it ensures the cleartext password won't +//! end up in logs pg_stat displays, etc. + +use crate::authentication::sasl; +use base64::display::Base64Display; +use base64::engine::general_purpose::STANDARD; +use hmac::{Hmac, Mac}; +use md5::Md5; +use rand::RngCore; +use sha2::digest::FixedOutput; +use sha2::{Digest, Sha256}; + +#[cfg(test)] +mod test; + +const SCRAM_DEFAULT_ITERATIONS: u32 = 4096; +const SCRAM_DEFAULT_SALT_LEN: usize = 16; + +/// Hash password using SCRAM-SHA-256 with a randomly-generated +/// salt. +/// +/// The client may assume the returned string doesn't contain any +/// special characters that would require escaping in an SQL command. +pub fn scram_sha_256(password: &[u8]) -> String { + let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN]; + let mut rng = rand::rng(); + rng.fill_bytes(&mut salt); + scram_sha_256_salt(password, salt) +} + +// Internal implementation of scram_sha_256 with a caller-provided +// salt. This is useful for testing. +pub(crate) fn scram_sha_256_salt(password: &[u8], salt: [u8; SCRAM_DEFAULT_SALT_LEN]) -> String { + // Prepare the password, per [RFC + // 4013](https://tools.ietf.org/html/rfc4013), if possible. + // + // Postgres treats passwords as byte strings (without embedded NUL + // bytes), but SASL expects passwords to be valid UTF-8. + // + // Follow the behavior of libpq's PQencryptPasswordConn(), and + // also the backend. If the password is not valid UTF-8, or if it + // contains prohibited characters (such as non-ASCII whitespace), + // just skip the SASLprep step and use the original byte + // sequence. + let prepared: Vec = match std::str::from_utf8(password) { + Ok(password_str) => { + match stringprep::saslprep(password_str) { + Ok(p) => p.into_owned().into_bytes(), + // contains invalid characters; skip saslprep + Err(_) => Vec::from(password), + } + } + // not valid UTF-8; skip saslprep + Err(_) => Vec::from(password), + }; + + // salt password + let salted_password = sasl::hi(&prepared, &salt, SCRAM_DEFAULT_ITERATIONS); + + // client key + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(b"Client Key"); + let client_key = hmac.finalize().into_bytes(); + + // stored key + let mut hash = Sha256::default(); + hash.update(client_key.as_slice()); + let stored_key = hash.finalize_fixed(); + + // server key + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(b"Server Key"); + let server_key = hmac.finalize().into_bytes(); + + format!( + "SCRAM-SHA-256${}:{}${}:{}", + SCRAM_DEFAULT_ITERATIONS, + Base64Display::new(&salt, &STANDARD), + Base64Display::new(&stored_key, &STANDARD), + Base64Display::new(&server_key, &STANDARD) + ) +} + +/// **Not recommended, as MD5 is not considered to be secure.** +/// +/// Hash password using MD5 with the username as the salt. +/// +/// The client may assume the returned string doesn't contain any +/// special characters that would require escaping. +pub fn md5(password: &[u8], username: &str) -> String { + // salt password with username + let mut salted_password = Vec::from(password); + salted_password.extend_from_slice(username.as_bytes()); + + let mut hash = Md5::new(); + hash.update(&salted_password); + let digest = hash.finalize(); + format!("md5{:x}", digest) +} diff --git a/postgres-protocol/src/password/test.rs b/postgres-protocol/src/password/test.rs new file mode 100644 index 000000000..1432cb204 --- /dev/null +++ b/postgres-protocol/src/password/test.rs @@ -0,0 +1,19 @@ +use crate::password; + +#[test] +fn test_encrypt_scram_sha_256() { + // Specify the salt to make the test deterministic. Any bytes will do. + let salt: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + assert_eq!( + password::scram_sha_256_salt(b"secret", salt), + "SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA=" + ); +} + +#[test] +fn test_encrypt_md5() { + assert_eq!( + password::md5(b"secret", "foo"), + "md54ab2c5d00339c4b2a4e921d2dc4edec7" + ); +} diff --git a/postgres-protocol/src/types.rs b/postgres-protocol/src/types/mod.rs similarity index 62% rename from postgres-protocol/src/types.rs rename to postgres-protocol/src/types/mod.rs index 1066ee6a4..03bd90799 100644 --- a/postgres-protocol/src/types.rs +++ b/postgres-protocol/src/types/mod.rs @@ -1,11 +1,17 @@ //! Conversions to and from Postgres's binary format for various types. -use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt}; +use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; +use bytes::{BufMut, BytesMut}; use fallible_iterator::FallibleIterator; use std::boxed::Box as StdBox; use std::error::Error; +use std::io::Read; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::str; -use {write_nullable, FromUsize, IsNull, Oid}; +use crate::{write_nullable, FromUsize, IsNull, Lsn, Oid}; + +#[cfg(test)] +mod test; const RANGE_UPPER_UNBOUNDED: u8 = 0b0001_0000; const RANGE_LOWER_UNBOUNDED: u8 = 0b0000_1000; @@ -13,15 +19,18 @@ const RANGE_UPPER_INCLUSIVE: u8 = 0b0000_0100; const RANGE_LOWER_INCLUSIVE: u8 = 0b0000_0010; const RANGE_EMPTY: u8 = 0b0000_0001; +const PGSQL_AF_INET: u8 = 2; +const PGSQL_AF_INET6: u8 = 3; + /// Serializes a `BOOL` value. #[inline] -pub fn bool_to_sql(v: bool, buf: &mut Vec) { - buf.push(v as u8); +pub fn bool_to_sql(v: bool, buf: &mut BytesMut) { + buf.put_u8(v as u8); } /// Deserializes a `BOOL` value. #[inline] -pub fn bool_from_sql(buf: &[u8]) -> Result> { +pub fn bool_from_sql(buf: &[u8]) -> Result> { if buf.len() != 1 { return Err("invalid buffer size".into()); } @@ -31,8 +40,8 @@ pub fn bool_from_sql(buf: &[u8]) -> Result> { /// Serializes a `BYTEA` value. #[inline] -pub fn bytea_to_sql(v: &[u8], buf: &mut Vec) { - buf.extend_from_slice(v); +pub fn bytea_to_sql(v: &[u8], buf: &mut BytesMut) { + buf.put_slice(v); } /// Deserializes a `BYTEA value. @@ -43,25 +52,25 @@ pub fn bytea_from_sql(buf: &[u8]) -> &[u8] { /// Serializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value. #[inline] -pub fn text_to_sql(v: &str, buf: &mut Vec) { - buf.extend_from_slice(v.as_bytes()); +pub fn text_to_sql(v: &str, buf: &mut BytesMut) { + buf.put_slice(v.as_bytes()); } /// Deserializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value. #[inline] -pub fn text_from_sql(buf: &[u8]) -> Result<&str, StdBox> { +pub fn text_from_sql(buf: &[u8]) -> Result<&str, StdBox> { Ok(str::from_utf8(buf)?) } /// Serializes a `"char"` value. #[inline] -pub fn char_to_sql(v: i8, buf: &mut Vec) { - buf.write_i8(v).unwrap(); +pub fn char_to_sql(v: i8, buf: &mut BytesMut) { + buf.put_i8(v); } /// Deserializes a `"char"` value. #[inline] -pub fn char_from_sql(mut buf: &[u8]) -> Result> { +pub fn char_from_sql(mut buf: &[u8]) -> Result> { let v = buf.read_i8()?; if !buf.is_empty() { return Err("invalid buffer size".into()); @@ -71,13 +80,13 @@ pub fn char_from_sql(mut buf: &[u8]) -> Result> /// Serializes an `INT2` value. #[inline] -pub fn int2_to_sql(v: i16, buf: &mut Vec) { - buf.write_i16::(v).unwrap(); +pub fn int2_to_sql(v: i16, buf: &mut BytesMut) { + buf.put_i16(v); } /// Deserializes an `INT2` value. #[inline] -pub fn int2_from_sql(mut buf: &[u8]) -> Result> { +pub fn int2_from_sql(mut buf: &[u8]) -> Result> { let v = buf.read_i16::()?; if !buf.is_empty() { return Err("invalid buffer size".into()); @@ -87,13 +96,13 @@ pub fn int2_from_sql(mut buf: &[u8]) -> Result> /// Serializes an `INT4` value. #[inline] -pub fn int4_to_sql(v: i32, buf: &mut Vec) { - buf.write_i32::(v).unwrap(); +pub fn int4_to_sql(v: i32, buf: &mut BytesMut) { + buf.put_i32(v); } /// Deserializes an `INT4` value. #[inline] -pub fn int4_from_sql(mut buf: &[u8]) -> Result> { +pub fn int4_from_sql(mut buf: &[u8]) -> Result> { let v = buf.read_i32::()?; if !buf.is_empty() { return Err("invalid buffer size".into()); @@ -103,13 +112,13 @@ pub fn int4_from_sql(mut buf: &[u8]) -> Result> /// Serializes an `OID` value. #[inline] -pub fn oid_to_sql(v: Oid, buf: &mut Vec) { - buf.write_u32::(v).unwrap(); +pub fn oid_to_sql(v: Oid, buf: &mut BytesMut) { + buf.put_u32(v); } /// Deserializes an `OID` value. #[inline] -pub fn oid_from_sql(mut buf: &[u8]) -> Result> { +pub fn oid_from_sql(mut buf: &[u8]) -> Result> { let v = buf.read_u32::()?; if !buf.is_empty() { return Err("invalid buffer size".into()); @@ -119,13 +128,13 @@ pub fn oid_from_sql(mut buf: &[u8]) -> Result> /// Serializes an `INT8` value. #[inline] -pub fn int8_to_sql(v: i64, buf: &mut Vec) { - buf.write_i64::(v).unwrap(); +pub fn int8_to_sql(v: i64, buf: &mut BytesMut) { + buf.put_i64(v); } /// Deserializes an `INT8` value. #[inline] -pub fn int8_from_sql(mut buf: &[u8]) -> Result> { +pub fn int8_from_sql(mut buf: &[u8]) -> Result> { let v = buf.read_i64::()?; if !buf.is_empty() { return Err("invalid buffer size".into()); @@ -133,15 +142,31 @@ pub fn int8_from_sql(mut buf: &[u8]) -> Result> Ok(v) } +/// Serializes a `PG_LSN` value. +#[inline] +pub fn lsn_to_sql(v: Lsn, buf: &mut BytesMut) { + buf.put_u64(v); +} + +/// Deserializes a `PG_LSN` value. +#[inline] +pub fn lsn_from_sql(mut buf: &[u8]) -> Result> { + let v = buf.read_u64::()?; + if !buf.is_empty() { + return Err("invalid buffer size".into()); + } + Ok(v) +} + /// Serializes a `FLOAT4` value. #[inline] -pub fn float4_to_sql(v: f32, buf: &mut Vec) { - buf.write_f32::(v).unwrap(); +pub fn float4_to_sql(v: f32, buf: &mut BytesMut) { + buf.put_f32(v); } /// Deserializes a `FLOAT4` value. #[inline] -pub fn float4_from_sql(mut buf: &[u8]) -> Result> { +pub fn float4_from_sql(mut buf: &[u8]) -> Result> { let v = buf.read_f32::()?; if !buf.is_empty() { return Err("invalid buffer size".into()); @@ -151,13 +176,13 @@ pub fn float4_from_sql(mut buf: &[u8]) -> Result) { - buf.write_f64::(v).unwrap(); +pub fn float8_to_sql(v: f64, buf: &mut BytesMut) { + buf.put_f64(v); } /// Deserializes a `FLOAT8` value. #[inline] -pub fn float8_from_sql(mut buf: &[u8]) -> Result> { +pub fn float8_from_sql(mut buf: &[u8]) -> Result> { let v = buf.read_f64::()?; if !buf.is_empty() { return Err("invalid buffer size".into()); @@ -167,12 +192,15 @@ pub fn float8_from_sql(mut buf: &[u8]) -> Result(values: I, buf: &mut Vec) -> Result<(), StdBox> +pub fn hstore_to_sql<'a, I>( + values: I, + buf: &mut BytesMut, +) -> Result<(), StdBox> where I: IntoIterator)>, { let base = buf.len(); - buf.extend_from_slice(&[0; 4]); + buf.put_i32(0); let mut count = 0; for (key, value) in values { @@ -184,7 +212,7 @@ where Some(value) => { write_pascal_string(value, buf)?; } - None => buf.write_i32::(-1).unwrap(), + None => buf.put_i32(-1), } } @@ -194,18 +222,18 @@ where Ok(()) } -fn write_pascal_string(s: &str, buf: &mut Vec) -> Result<(), StdBox> { +fn write_pascal_string(s: &str, buf: &mut BytesMut) -> Result<(), StdBox> { let size = i32::from_usize(s.len())?; - buf.write_i32::(size).unwrap(); - buf.extend_from_slice(s.as_bytes()); + buf.put_i32(size); + buf.put_slice(s.as_bytes()); Ok(()) } /// Deserializes an `HSTORE` value. #[inline] -pub fn hstore_from_sql<'a>( - mut buf: &'a [u8], -) -> Result, StdBox> { +pub fn hstore_from_sql( + mut buf: &[u8], +) -> Result, StdBox> { let count = buf.read_i32::()?; if count < 0 { return Err("invalid entry count".into()); @@ -213,7 +241,7 @@ pub fn hstore_from_sql<'a>( Ok(HstoreEntries { remaining: count, - buf: buf, + buf, }) } @@ -225,10 +253,13 @@ pub struct HstoreEntries<'a> { impl<'a> FallibleIterator for HstoreEntries<'a> { type Item = (&'a str, Option<&'a str>); - type Error = StdBox; + type Error = StdBox; #[inline] - fn next(&mut self) -> Result)>, StdBox> { + #[allow(clippy::type_complexity)] + fn next( + &mut self, + ) -> Result)>, StdBox> { if self.remaining == 0 { if !self.buf.is_empty() { return Err("invalid buffer size".into()); @@ -271,16 +302,16 @@ impl<'a> FallibleIterator for HstoreEntries<'a> { pub fn varbit_to_sql( len: usize, v: I, - buf: &mut Vec, -) -> Result<(), StdBox> + buf: &mut BytesMut, +) -> Result<(), StdBox> where I: Iterator, { let len = i32::from_usize(len)?; - buf.write_i32::(len).unwrap(); + buf.put_i32(len); for byte in v { - buf.push(byte); + buf.put_u8(byte); } Ok(()) @@ -288,14 +319,14 @@ where /// Deserializes a `VARBIT` or `BIT` value. #[inline] -pub fn varbit_from_sql<'a>(mut buf: &'a [u8]) -> Result, StdBox> { +pub fn varbit_from_sql(mut buf: &[u8]) -> Result, StdBox> { let len = buf.read_i32::()?; if len < 0 { - return Err("invalid varbit length".into()); + return Err("invalid varbit length: varbit < 0".into()); } - let bytes = (len as usize + 7) / 8; + let bytes = (len as usize).div_ceil(8); if buf.len() != bytes { - return Err("invalid message length".into()); + return Err("invalid message length: varbit mismatch".into()); } Ok(Varbit { @@ -317,6 +348,12 @@ impl<'a> Varbit<'a> { self.len } + /// Determines if the value has no bits. + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + /// Returns the bits as a slice of bytes. #[inline] pub fn bytes(&self) -> &'a [u8] { @@ -328,18 +365,18 @@ impl<'a> Varbit<'a> { /// /// The value should represent the number of microseconds since midnight, January 1st, 2000. #[inline] -pub fn timestamp_to_sql(v: i64, buf: &mut Vec) { - buf.write_i64::(v).unwrap(); +pub fn timestamp_to_sql(v: i64, buf: &mut BytesMut) { + buf.put_i64(v); } /// Deserializes a `TIMESTAMP` or `TIMESTAMPTZ` value. /// /// The value represents the number of microseconds since midnight, January 1st, 2000. #[inline] -pub fn timestamp_from_sql(mut buf: &[u8]) -> Result> { +pub fn timestamp_from_sql(mut buf: &[u8]) -> Result> { let v = buf.read_i64::()?; if !buf.is_empty() { - return Err("invalid message length".into()); + return Err("invalid message length: timestamp not drained".into()); } Ok(v) } @@ -348,18 +385,18 @@ pub fn timestamp_from_sql(mut buf: &[u8]) -> Result) { - buf.write_i32::(v).unwrap(); +pub fn date_to_sql(v: i32, buf: &mut BytesMut) { + buf.put_i32(v); } /// Deserializes a `DATE` value. /// /// The value represents the number of days since January 1st, 2000. #[inline] -pub fn date_from_sql(mut buf: &[u8]) -> Result> { +pub fn date_from_sql(mut buf: &[u8]) -> Result> { let v = buf.read_i32::()?; if !buf.is_empty() { - return Err("invalid message length".into()); + return Err("invalid message length: date not drained".into()); } Ok(v) } @@ -368,33 +405,33 @@ pub fn date_from_sql(mut buf: &[u8]) -> Result> /// /// The value should represent the number of microseconds since midnight. #[inline] -pub fn time_to_sql(v: i64, buf: &mut Vec) { - buf.write_i64::(v).unwrap(); +pub fn time_to_sql(v: i64, buf: &mut BytesMut) { + buf.put_i64(v); } /// Deserializes a `TIME` or `TIMETZ` value. /// /// The value represents the number of microseconds since midnight. #[inline] -pub fn time_from_sql(mut buf: &[u8]) -> Result> { +pub fn time_from_sql(mut buf: &[u8]) -> Result> { let v = buf.read_i64::()?; if !buf.is_empty() { - return Err("invalid message length".into()); + return Err("invalid message length: time not drained".into()); } Ok(v) } /// Serializes a `MACADDR` value. #[inline] -pub fn macaddr_to_sql(v: [u8; 6], buf: &mut Vec) { - buf.extend_from_slice(&v); +pub fn macaddr_to_sql(v: [u8; 6], buf: &mut BytesMut) { + buf.put_slice(&v); } /// Deserializes a `MACADDR` value. #[inline] -pub fn macaddr_from_sql(buf: &[u8]) -> Result<[u8; 6], StdBox> { +pub fn macaddr_from_sql(buf: &[u8]) -> Result<[u8; 6], StdBox> { if buf.len() != 6 { - return Err("invalid message length".into()); + return Err("invalid message length: macaddr length mismatch".into()); } let mut out = [0; 6]; out.copy_from_slice(buf); @@ -403,15 +440,15 @@ pub fn macaddr_from_sql(buf: &[u8]) -> Result<[u8; 6], StdBox) { - buf.extend_from_slice(&v); +pub fn uuid_to_sql(v: [u8; 16], buf: &mut BytesMut) { + buf.put_slice(&v); } /// Deserializes a `UUID` value. #[inline] -pub fn uuid_from_sql(buf: &[u8]) -> Result<[u8; 16], StdBox> { +pub fn uuid_from_sql(buf: &[u8]) -> Result<[u8; 16], StdBox> { if buf.len() != 16 { - return Err("invalid message length".into()); + return Err("invalid message length: uuid size mismatch".into()); } let mut out = [0; 16]; out.copy_from_slice(buf); @@ -425,24 +462,24 @@ pub fn array_to_sql( element_type: Oid, elements: J, mut serializer: F, - buf: &mut Vec, -) -> Result<(), StdBox> + buf: &mut BytesMut, +) -> Result<(), StdBox> where I: IntoIterator, J: IntoIterator, - F: FnMut(T, &mut Vec) -> Result>, + F: FnMut(T, &mut BytesMut) -> Result>, { let dimensions_idx = buf.len(); - buf.extend_from_slice(&[0; 4]); + buf.put_i32(0); let flags_idx = buf.len(); - buf.extend_from_slice(&[0; 4]); - buf.write_u32::(element_type).unwrap(); + buf.put_i32(0); + buf.put_u32(element_type); let mut num_dimensions = 0; for dimension in dimensions { num_dimensions += 1; - buf.write_i32::(dimension.len).unwrap(); - buf.write_i32::(dimension.lower_bound).unwrap(); + buf.put_i32(dimension.len); + buf.put_i32(dimension.lower_bound); } let num_dimensions = i32::from_usize(num_dimensions)?; @@ -469,7 +506,7 @@ where /// Deserializes an array value. #[inline] -pub fn array_from_sql<'a>(mut buf: &'a [u8]) -> Result, StdBox> { +pub fn array_from_sql(mut buf: &[u8]) -> Result, StdBox> { let dimensions = buf.read_i32::()?; if dimensions < 0 { return Err("invalid dimension count".into()); @@ -496,11 +533,11 @@ pub fn array_from_sql<'a>(mut buf: &'a [u8]) -> Result, StdBox Array<'a> { /// An iterator over the dimensions of an array. pub struct ArrayDimensions<'a>(&'a [u8]); -impl<'a> FallibleIterator for ArrayDimensions<'a> { +impl FallibleIterator for ArrayDimensions<'_> { type Item = ArrayDimension; - type Error = StdBox; + type Error = StdBox; #[inline] - fn next(&mut self) -> Result, StdBox> { + fn next(&mut self) -> Result, StdBox> { if self.0.is_empty() { return Ok(None); } @@ -558,10 +595,7 @@ impl<'a> FallibleIterator for ArrayDimensions<'a> { let len = self.0.read_i32::()?; let lower_bound = self.0.read_i32::()?; - Ok(Some(ArrayDimension { - len: len, - lower_bound: lower_bound, - })) + Ok(Some(ArrayDimension { len, lower_bound })) } #[inline] @@ -589,13 +623,13 @@ pub struct ArrayValues<'a> { impl<'a> FallibleIterator for ArrayValues<'a> { type Item = Option<&'a [u8]>; - type Error = StdBox; + type Error = StdBox; #[inline] - fn next(&mut self) -> Result>, StdBox> { + fn next(&mut self) -> Result>, StdBox> { if self.remaining == 0 { if !self.buf.is_empty() { - return Err("invalid message length".into()); + return Err("invalid message length: arrayvalue not drained".into()); } return Ok(None); } @@ -625,22 +659,22 @@ impl<'a> FallibleIterator for ArrayValues<'a> { /// Serializes an empty range. #[inline] -pub fn empty_range_to_sql(buf: &mut Vec) { - buf.push(RANGE_EMPTY); +pub fn empty_range_to_sql(buf: &mut BytesMut) { + buf.put_u8(RANGE_EMPTY); } /// Serializes a range value. pub fn range_to_sql( lower: F, upper: G, - buf: &mut Vec, -) -> Result<(), StdBox> + buf: &mut BytesMut, +) -> Result<(), StdBox> where - F: FnOnce(&mut Vec) -> Result, StdBox>, - G: FnOnce(&mut Vec) -> Result, StdBox>, + F: FnOnce(&mut BytesMut) -> Result, StdBox>, + G: FnOnce(&mut BytesMut) -> Result, StdBox>, { let tag_idx = buf.len(); - buf.push(0); + buf.put_u8(0); let mut tag = 0; match write_bound(lower, buf)? { @@ -662,13 +696,13 @@ where fn write_bound( bound: F, - buf: &mut Vec, -) -> Result, StdBox> + buf: &mut BytesMut, +) -> Result, StdBox> where - F: FnOnce(&mut Vec) -> Result, StdBox>, + F: FnOnce(&mut BytesMut) -> Result, StdBox>, { let base = buf.len(); - buf.extend_from_slice(&[0; 4]); + buf.put_i32(0); let (null, ret) = match bound(buf)? { RangeBound::Inclusive(null) => (Some(null), RangeBound::Inclusive(())), @@ -702,7 +736,7 @@ pub enum RangeBound { /// Deserializes a range value. #[inline] -pub fn range_from_sql<'a>(mut buf: &'a [u8]) -> Result, StdBox> { +pub fn range_from_sql(mut buf: &[u8]) -> Result, StdBox> { let tag = buf.read_u8()?; if tag == RANGE_EMPTY { @@ -728,7 +762,7 @@ fn read_bound<'a>( tag: u8, unbounded: u8, inclusive: u8, -) -> Result>, StdBox> { +) -> Result>, StdBox> { if tag & unbounded != 0 { Ok(RangeBound::Unbounded) } else { @@ -763,20 +797,20 @@ pub enum Range<'a> { /// Serializes a point value. #[inline] -pub fn point_to_sql(x: f64, y: f64, buf: &mut Vec) { - buf.write_f64::(x).unwrap(); - buf.write_f64::(y).unwrap(); +pub fn point_to_sql(x: f64, y: f64, buf: &mut BytesMut) { + buf.put_f64(x); + buf.put_f64(y); } /// Deserializes a point value. #[inline] -pub fn point_from_sql(mut buf: &[u8]) -> Result> { +pub fn point_from_sql(mut buf: &[u8]) -> Result> { let x = buf.read_f64::()?; let y = buf.read_f64::()?; if !buf.is_empty() { return Err("invalid buffer size".into()); } - Ok(Point { x: x, y: y }) + Ok(Point { x, y }) } /// A Postgres point. @@ -802,16 +836,16 @@ impl Point { /// Serializes a box value. #[inline] -pub fn box_to_sql(x1: f64, y1: f64, x2: f64, y2: f64, buf: &mut Vec) { - buf.write_f64::(x1).unwrap(); - buf.write_f64::(y1).unwrap(); - buf.write_f64::(x2).unwrap(); - buf.write_f64::(y2).unwrap(); +pub fn box_to_sql(x1: f64, y1: f64, x2: f64, y2: f64, buf: &mut BytesMut) { + buf.put_f64(x1); + buf.put_f64(y1); + buf.put_f64(x2); + buf.put_f64(y2); } /// Deserializes a box value. #[inline] -pub fn box_from_sql(mut buf: &[u8]) -> Result> { +pub fn box_from_sql(mut buf: &[u8]) -> Result> { let x1 = buf.read_f64::()?; let y1 = buf.read_f64::()?; let x2 = buf.read_f64::()?; @@ -851,20 +885,20 @@ impl Box { pub fn path_to_sql( closed: bool, points: I, - buf: &mut Vec, -) -> Result<(), StdBox> + buf: &mut BytesMut, +) -> Result<(), StdBox> where I: IntoIterator, { - buf.push(closed as u8); + buf.put_u8(closed as u8); let points_idx = buf.len(); - buf.extend_from_slice(&[0; 4]); + buf.put_i32(0); let mut num_points = 0; for (x, y) in points { num_points += 1; - buf.write_f64::(x).unwrap(); - buf.write_f64::(y).unwrap(); + buf.put_f64(x); + buf.put_f64(y); } let num_points = i32::from_usize(num_points)?; @@ -875,14 +909,14 @@ where /// Deserializes a Postgres path. #[inline] -pub fn path_from_sql<'a>(mut buf: &'a [u8]) -> Result, StdBox> { +pub fn path_from_sql(mut buf: &[u8]) -> Result, StdBox> { let closed = buf.read_u8()? != 0; let points = buf.read_i32::()?; Ok(Path { - closed: closed, - points: points, - buf: buf, + closed, + points, + buf, }) } @@ -916,15 +950,15 @@ pub struct PathPoints<'a> { buf: &'a [u8], } -impl<'a> FallibleIterator for PathPoints<'a> { +impl FallibleIterator for PathPoints<'_> { type Item = Point; - type Error = StdBox; + type Error = StdBox; #[inline] - fn next(&mut self) -> Result, StdBox> { + fn next(&mut self) -> Result, StdBox> { if self.remaining == 0 { if !self.buf.is_empty() { - return Err("invalid message length".into()); + return Err("invalid message length: path points not drained".into()); } return Ok(None); } @@ -933,7 +967,7 @@ impl<'a> FallibleIterator for PathPoints<'a> { let x = self.buf.read_f64::()?; let y = self.buf.read_f64::()?; - Ok(Some(Point { x: x, y: y })) + Ok(Some(Point { x, y })) } #[inline] @@ -943,158 +977,142 @@ impl<'a> FallibleIterator for PathPoints<'a> { } } -#[cfg(test)] -mod test { - use fallible_iterator::FallibleIterator; - use std::collections::HashMap; - - use super::*; - use IsNull; - - #[test] - fn bool() { - let mut buf = vec![]; - bool_to_sql(true, &mut buf); - assert_eq!(bool_from_sql(&buf).unwrap(), true); - - let mut buf = vec![]; - bool_to_sql(false, &mut buf); - assert_eq!(bool_from_sql(&buf).unwrap(), false); +/// Serializes a Postgres inet. +#[inline] +pub fn inet_to_sql(addr: IpAddr, netmask: u8, buf: &mut BytesMut) { + let family = match addr { + IpAddr::V4(_) => PGSQL_AF_INET, + IpAddr::V6(_) => PGSQL_AF_INET6, + }; + buf.put_u8(family); + buf.put_u8(netmask); + buf.put_u8(0); // is_cidr + match addr { + IpAddr::V4(addr) => { + buf.put_u8(4); + buf.put_slice(&addr.octets()); + } + IpAddr::V6(addr) => { + buf.put_u8(16); + buf.put_slice(&addr.octets()); + } } +} - #[test] - fn int2() { - let mut buf = vec![]; - int2_to_sql(0x0102, &mut buf); - assert_eq!(int2_from_sql(&buf).unwrap(), 0x0102); - } +/// Deserializes a Postgres inet. +#[inline] +pub fn inet_from_sql(mut buf: &[u8]) -> Result> { + let family = buf.read_u8()?; + let netmask = buf.read_u8()?; + buf.read_u8()?; // is_cidr + let len = buf.read_u8()?; + + let addr = match family { + PGSQL_AF_INET => { + if netmask > 32 { + return Err("invalid IPv4 netmask".into()); + } + if len != 4 { + return Err("invalid IPv4 address length".into()); + } + let mut addr = [0; 4]; + buf.read_exact(&mut addr)?; + IpAddr::V4(Ipv4Addr::from(addr)) + } + PGSQL_AF_INET6 => { + if netmask > 128 { + return Err("invalid IPv6 netmask".into()); + } + if len != 16 { + return Err("invalid IPv6 address length".into()); + } + let mut addr = [0; 16]; + buf.read_exact(&mut addr)?; + IpAddr::V6(Ipv6Addr::from(addr)) + } + _ => return Err("invalid IP family".into()), + }; - #[test] - fn int4() { - let mut buf = vec![]; - int4_to_sql(0x01020304, &mut buf); - assert_eq!(int4_from_sql(&buf).unwrap(), 0x01020304); + if !buf.is_empty() { + return Err("invalid buffer size".into()); } - #[test] - fn int8() { - let mut buf = vec![]; - int8_to_sql(0x0102030405060708, &mut buf); - assert_eq!(int8_from_sql(&buf).unwrap(), 0x0102030405060708); - } + Ok(Inet { addr, netmask }) +} - #[test] - fn float4() { - let mut buf = vec![]; - float4_to_sql(10343.95, &mut buf); - assert_eq!(float4_from_sql(&buf).unwrap(), 10343.95); - } +/// A Postgres network address. +pub struct Inet { + addr: IpAddr, + netmask: u8, +} - #[test] - fn float8() { - let mut buf = vec![]; - float8_to_sql(10343.95, &mut buf); - assert_eq!(float8_from_sql(&buf).unwrap(), 10343.95); +impl Inet { + /// Returns the IP address. + #[inline] + pub fn addr(&self) -> IpAddr { + self.addr } - #[test] - fn hstore() { - let mut map = HashMap::new(); - map.insert("hello", Some("world")); - map.insert("hola", None); - - let mut buf = vec![]; - hstore_to_sql(map.iter().map(|(&k, &v)| (k, v)), &mut buf).unwrap(); - assert_eq!( - hstore_from_sql(&buf) - .unwrap() - .collect::>() - .unwrap(), - map - ); + /// Returns the netmask. + #[inline] + pub fn netmask(&self) -> u8 { + self.netmask } +} - #[test] - fn varbit() { - let len = 12; - let bits = [0b0010_1011, 0b0000_1111]; +/// Serializes a Postgres ltree string +#[inline] +pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltree string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} - let mut buf = vec![]; - varbit_to_sql(len, bits.iter().cloned(), &mut buf).unwrap(); - let out = varbit_from_sql(&buf).unwrap(); - assert_eq!(out.len(), len); - assert_eq!(out.bytes(), bits); +/// Deserialize a Postgres ltree string +#[inline] +pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltree per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltree version 1 only supported".into()), } +} - #[test] - fn array() { - let dimensions = [ - ArrayDimension { - len: 1, - lower_bound: 10, - }, - ArrayDimension { - len: 2, - lower_bound: 0, - }, - ]; - let values = [None, Some(&b"hello"[..])]; - - let mut buf = vec![]; - array_to_sql( - dimensions.iter().cloned(), - 10, - values.iter().cloned(), - |v, buf| match v { - Some(v) => { - buf.extend_from_slice(v); - Ok(IsNull::No) - } - None => Ok(IsNull::Yes), - }, - &mut buf, - ).unwrap(); - - let array = array_from_sql(&buf).unwrap(); - assert_eq!(array.has_nulls(), true); - assert_eq!(array.element_type(), 10); - assert_eq!(array.dimensions().collect::>().unwrap(), dimensions); - assert_eq!(array.values().collect::>().unwrap(), values); +/// Serializes a Postgres lquery string +#[inline] +pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an lquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres lquery string +#[inline] +pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the lquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("lquery version 1 only supported".into()), } +} - #[test] - fn non_null_array() { - let dimensions = [ - ArrayDimension { - len: 1, - lower_bound: 10, - }, - ArrayDimension { - len: 2, - lower_bound: 0, - }, - ]; - let values = [Some(&b"hola"[..]), Some(&b"hello"[..])]; - - let mut buf = vec![]; - array_to_sql( - dimensions.iter().cloned(), - 10, - values.iter().cloned(), - |v, buf| match v { - Some(v) => { - buf.extend_from_slice(v); - Ok(IsNull::No) - } - None => Ok(IsNull::Yes), - }, - &mut buf, - ).unwrap(); - - let array = array_from_sql(&buf).unwrap(); - assert_eq!(array.has_nulls(), false); - assert_eq!(array.element_type(), 10); - assert_eq!(array.dimensions().collect::>().unwrap(), dimensions); - assert_eq!(array.values().collect::>().unwrap(), values); +/// Serializes a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltxtquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltxtquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltxtquery version 1 only supported".into()), } } diff --git a/postgres-protocol/src/types/test.rs b/postgres-protocol/src/types/test.rs new file mode 100644 index 000000000..3e33b08f0 --- /dev/null +++ b/postgres-protocol/src/types/test.rs @@ -0,0 +1,242 @@ +use bytes::{Buf, BytesMut}; +use fallible_iterator::FallibleIterator; +use std::collections::HashMap; + +use super::*; +use crate::IsNull; + +#[test] +#[allow(clippy::bool_assert_comparison)] +fn bool() { + let mut buf = BytesMut::new(); + bool_to_sql(true, &mut buf); + assert_eq!(bool_from_sql(&buf).unwrap(), true); + + let mut buf = BytesMut::new(); + bool_to_sql(false, &mut buf); + assert_eq!(bool_from_sql(&buf).unwrap(), false); +} + +#[test] +fn int2() { + let mut buf = BytesMut::new(); + int2_to_sql(0x0102, &mut buf); + assert_eq!(int2_from_sql(&buf).unwrap(), 0x0102); +} + +#[test] +fn int4() { + let mut buf = BytesMut::new(); + int4_to_sql(0x0102_0304, &mut buf); + assert_eq!(int4_from_sql(&buf).unwrap(), 0x0102_0304); +} + +#[test] +fn int8() { + let mut buf = BytesMut::new(); + int8_to_sql(0x0102_0304_0506_0708, &mut buf); + assert_eq!(int8_from_sql(&buf).unwrap(), 0x0102_0304_0506_0708); +} + +#[test] +#[allow(clippy::float_cmp)] +fn float4() { + let mut buf = BytesMut::new(); + float4_to_sql(10343.95, &mut buf); + assert_eq!(float4_from_sql(&buf).unwrap(), 10343.95); +} + +#[test] +#[allow(clippy::float_cmp)] +fn float8() { + let mut buf = BytesMut::new(); + float8_to_sql(10343.95, &mut buf); + assert_eq!(float8_from_sql(&buf).unwrap(), 10343.95); +} + +#[test] +fn hstore() { + let mut map = HashMap::new(); + map.insert("hello", Some("world")); + map.insert("hola", None); + + let mut buf = BytesMut::new(); + hstore_to_sql(map.iter().map(|(&k, &v)| (k, v)), &mut buf).unwrap(); + assert_eq!( + hstore_from_sql(&buf) + .unwrap() + .collect::>() + .unwrap(), + map + ); +} + +#[test] +fn varbit() { + let len = 12; + let bits = [0b0010_1011, 0b0000_1111]; + + let mut buf = BytesMut::new(); + varbit_to_sql(len, bits.iter().cloned(), &mut buf).unwrap(); + let out = varbit_from_sql(&buf).unwrap(); + assert_eq!(out.len(), len); + assert_eq!(out.bytes(), bits); +} + +#[test] +fn array() { + let dimensions = [ + ArrayDimension { + len: 1, + lower_bound: 10, + }, + ArrayDimension { + len: 2, + lower_bound: 0, + }, + ]; + let values = [None, Some(&b"hello"[..])]; + + let mut buf = BytesMut::new(); + array_to_sql( + dimensions.iter().cloned(), + 10, + values.iter().cloned(), + |v, buf| match v { + Some(v) => { + buf.extend_from_slice(v); + Ok(IsNull::No) + } + None => Ok(IsNull::Yes), + }, + &mut buf, + ) + .unwrap(); + + let array = array_from_sql(&buf).unwrap(); + assert!(array.has_nulls()); + assert_eq!(array.element_type(), 10); + assert_eq!(array.dimensions().collect::>().unwrap(), dimensions); + assert_eq!(array.values().collect::>().unwrap(), values); +} + +#[test] +fn non_null_array() { + let dimensions = [ + ArrayDimension { + len: 1, + lower_bound: 10, + }, + ArrayDimension { + len: 2, + lower_bound: 0, + }, + ]; + let values = [Some(&b"hola"[..]), Some(&b"hello"[..])]; + + let mut buf = BytesMut::new(); + array_to_sql( + dimensions.iter().cloned(), + 10, + values.iter().cloned(), + |v, buf| match v { + Some(v) => { + buf.extend_from_slice(v); + Ok(IsNull::No) + } + None => Ok(IsNull::Yes), + }, + &mut buf, + ) + .unwrap(); + + let array = array_from_sql(&buf).unwrap(); + assert!(!array.has_nulls()); + assert_eq!(array.element_type(), 10); + assert_eq!(array.dimensions().collect::>().unwrap(), dimensions); + assert_eq!(array.values().collect::>().unwrap(), values); +} + +#[test] +fn ltree_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltree_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn ltree_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_err()) +} + +#[test] +fn lquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + lquery_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn lquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(lquery_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn lquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(lquery_from_sql(query.as_slice()).is_err()) +} + +#[test] +fn ltxtquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("a & b*", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltxtquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn ltxtquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_err()) +} diff --git a/postgres-shared/Cargo.toml b/postgres-shared/Cargo.toml deleted file mode 100644 index 853a9f46e..000000000 --- a/postgres-shared/Cargo.toml +++ /dev/null @@ -1,28 +0,0 @@ -[package] -name = "postgres-shared" -version = "0.4.1" -authors = ["Steven Fackler "] -license = "MIT" -description = "Internal crate used by postgres and postgres-tokio" -repository = "https://github.com/sfackler/rust-postgres" - -[features] -"with-bit-vec-0.5" = ["bit-vec"] -"with-chrono-0.4" = ["chrono"] -"with-eui48-0.3" = ["eui48"] -"with-geo-0.10" = ["geo"] -with-serde_json-1 = ["serde_json"] -"with-uuid-0.6" = ["uuid"] - -[dependencies] -hex = "0.3" -fallible-iterator = "0.1.3" -phf = "=0.7.22" -postgres-protocol = { version = "0.3", path = "../postgres-protocol" } - -bit-vec = { version = "0.5", optional = true } -chrono = { version = "0.4", optional = true } -eui48 = { version = "0.3", optional = true } -geo = { version = "0.10", optional = true } -serde_json = { version = "1.0", optional = true } -uuid = { version = "0.6", optional = true } diff --git a/postgres-shared/src/error/mod.rs b/postgres-shared/src/error/mod.rs deleted file mode 100644 index c437219cb..000000000 --- a/postgres-shared/src/error/mod.rs +++ /dev/null @@ -1,448 +0,0 @@ -//! Errors. - -use fallible_iterator::FallibleIterator; -use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody}; -use std::convert::From; -use std::error; -use std::fmt; -use std::io; - -pub use self::sqlstate::*; - -mod sqlstate; - -/// The severity of a Postgres error or notice. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum Severity { - /// PANIC - Panic, - /// FATAL - Fatal, - /// ERROR - Error, - /// WARNING - Warning, - /// NOTICE - Notice, - /// DEBUG - Debug, - /// INFO - Info, - /// LOG - Log, -} - -impl fmt::Display for Severity { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - let s = match *self { - Severity::Panic => "PANIC", - Severity::Fatal => "FATAL", - Severity::Error => "ERROR", - Severity::Warning => "WARNING", - Severity::Notice => "NOTICE", - Severity::Debug => "DEBUG", - Severity::Info => "INFO", - Severity::Log => "LOG", - }; - fmt.write_str(s) - } -} - -impl Severity { - fn from_str(s: &str) -> Option { - match s { - "PANIC" => Some(Severity::Panic), - "FATAL" => Some(Severity::Fatal), - "ERROR" => Some(Severity::Error), - "WARNING" => Some(Severity::Warning), - "NOTICE" => Some(Severity::Notice), - "DEBUG" => Some(Severity::Debug), - "INFO" => Some(Severity::Info), - "LOG" => Some(Severity::Log), - _ => None, - } - } -} - -/// A Postgres error or notice. -#[derive(Clone, PartialEq, Eq)] -pub struct DbError { - /// The field contents are ERROR, FATAL, or PANIC (in an error message), - /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a - /// localized translation of one of these. - pub severity: String, - - /// A parsed, nonlocalized version of `severity`. (PostgreSQL 9.6+) - pub parsed_severity: Option, - - /// The SQLSTATE code for the error. - pub code: SqlState, - - /// The primary human-readable error message. This should be accurate but - /// terse (typically one line). - pub message: String, - - /// An optional secondary error message carrying more detail about the - /// problem. Might run to multiple lines. - pub detail: Option, - - /// An optional suggestion what to do about the problem. This is intended - /// to differ from Detail in that it offers advice (potentially - /// inappropriate) rather than hard facts. Might run to multiple lines. - pub hint: Option, - - /// An optional error cursor position into either the original query string - /// or an internally generated query. - pub position: Option, - - /// An indication of the context in which the error occurred. Presently - /// this includes a call stack traceback of active procedural language - /// functions and internally-generated queries. The trace is one entry per - /// line, most recent first. - pub where_: Option, - - /// If the error was associated with a specific database object, the name - /// of the schema containing that object, if any. (PostgreSQL 9.3+) - pub schema: Option, - - /// If the error was associated with a specific table, the name of the - /// table. (Refer to the schema name field for the name of the table's - /// schema.) (PostgreSQL 9.3+) - pub table: Option, - - /// If the error was associated with a specific table column, the name of - /// the column. (Refer to the schema and table name fields to identify the - /// table.) (PostgreSQL 9.3+) - pub column: Option, - - /// If the error was associated with a specific data type, the name of the - /// data type. (Refer to the schema name field for the name of the data - /// type's schema.) (PostgreSQL 9.3+) - pub datatype: Option, - - /// If the error was associated with a specific constraint, the name of the - /// constraint. Refer to fields listed above for the associated table or - /// domain. (For this purpose, indexes are treated as constraints, even if - /// they weren't created with constraint syntax.) (PostgreSQL 9.3+) - pub constraint: Option, - - /// The file name of the source-code location where the error was reported. - pub file: Option, - - /// The line number of the source-code location where the error was - /// reported. - pub line: Option, - - /// The name of the source-code routine reporting the error. - pub routine: Option, - - _p: (), -} - -impl DbError { - #[doc(hidden)] - pub fn new(fields: &mut ErrorFields) -> io::Result { - let mut severity = None; - let mut parsed_severity = None; - let mut code = None; - let mut message = None; - let mut detail = None; - let mut hint = None; - let mut normal_position = None; - let mut internal_position = None; - let mut internal_query = None; - let mut where_ = None; - let mut schema = None; - let mut table = None; - let mut column = None; - let mut datatype = None; - let mut constraint = None; - let mut file = None; - let mut line = None; - let mut routine = None; - - while let Some(field) = fields.next()? { - match field.type_() { - b'S' => severity = Some(field.value().to_owned()), - b'C' => code = Some(SqlState::from_code(field.value())), - b'M' => message = Some(field.value().to_owned()), - b'D' => detail = Some(field.value().to_owned()), - b'H' => hint = Some(field.value().to_owned()), - b'P' => { - normal_position = Some(field.value().parse::().map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "`P` field did not contain an integer", - ) - })?); - } - b'p' => { - internal_position = Some(field.value().parse::().map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "`p` field did not contain an integer", - ) - })?); - } - b'q' => internal_query = Some(field.value().to_owned()), - b'W' => where_ = Some(field.value().to_owned()), - b's' => schema = Some(field.value().to_owned()), - b't' => table = Some(field.value().to_owned()), - b'c' => column = Some(field.value().to_owned()), - b'd' => datatype = Some(field.value().to_owned()), - b'n' => constraint = Some(field.value().to_owned()), - b'F' => file = Some(field.value().to_owned()), - b'L' => { - line = Some(field.value().parse::().map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "`L` field did not contain an integer", - ) - })?); - } - b'R' => routine = Some(field.value().to_owned()), - b'V' => { - parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "`V` field contained an invalid value", - ) - })?); - } - _ => {} - } - } - - Ok(DbError { - severity: severity - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?, - parsed_severity: parsed_severity, - code: code - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?, - message: message - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?, - detail: detail, - hint: hint, - position: match normal_position { - Some(position) => Some(ErrorPosition::Normal(position)), - None => match internal_position { - Some(position) => Some(ErrorPosition::Internal { - position: position, - query: internal_query.ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "`q` field missing but `p` field present", - ) - })?, - }), - None => None, - }, - }, - where_: where_, - schema: schema, - table: table, - column: column, - datatype: datatype, - constraint: constraint, - file: file, - line: line, - routine: routine, - _p: (), - }) - } -} - -// manual impl to leave out _p -impl fmt::Debug for DbError { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("DbError") - .field("severity", &self.severity) - .field("parsed_severity", &self.parsed_severity) - .field("code", &self.code) - .field("message", &self.message) - .field("detail", &self.detail) - .field("hint", &self.hint) - .field("position", &self.position) - .field("where_", &self.where_) - .field("schema", &self.schema) - .field("table", &self.table) - .field("column", &self.column) - .field("datatype", &self.datatype) - .field("constraint", &self.constraint) - .field("file", &self.file) - .field("line", &self.line) - .field("routine", &self.routine) - .finish() - } -} - -impl fmt::Display for DbError { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "{}: {}", self.severity, self.message) - } -} - -impl error::Error for DbError { - fn description(&self) -> &str { - &self.message - } -} - -/// Represents the position of an error in a query. -#[derive(Clone, PartialEq, Eq, Debug)] -pub enum ErrorPosition { - /// A position in the original query. - Normal(u32), - /// A position in an internally generated query. - Internal { - /// The byte position. - position: u32, - /// A query generated by the Postgres server. - query: String, - }, -} - -#[doc(hidden)] -pub fn connect(e: Box) -> Error { - Error(Box::new(ErrorKind::ConnectParams(e))) -} - -#[doc(hidden)] -pub fn tls(e: Box) -> Error { - Error(Box::new(ErrorKind::Tls(e))) -} - -#[doc(hidden)] -pub fn db(e: DbError) -> Error { - Error(Box::new(ErrorKind::Db(e))) -} - -#[doc(hidden)] -pub fn __db(e: ErrorResponseBody) -> Error { - match DbError::new(&mut e.fields()) { - Ok(e) => Error(Box::new(ErrorKind::Db(e))), - Err(e) => Error(Box::new(ErrorKind::Io(e))), - } -} - -#[doc(hidden)] -pub fn __user(e: T) -> Error -where - T: Into>, -{ - Error(Box::new(ErrorKind::Conversion(e.into()))) -} - -#[doc(hidden)] -pub fn io(e: io::Error) -> Error { - Error(Box::new(ErrorKind::Io(e))) -} - -#[doc(hidden)] -pub fn conversion(e: Box) -> Error { - Error(Box::new(ErrorKind::Conversion(e))) -} - -#[derive(Debug)] -enum ErrorKind { - ConnectParams(Box), - Tls(Box), - Db(DbError), - Io(io::Error), - Conversion(Box), -} - -/// An error communicating with the Postgres server. -#[derive(Debug)] -pub struct Error(Box); - -impl fmt::Display for Error { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.write_str(error::Error::description(self))?; - match *self.0 { - ErrorKind::ConnectParams(ref err) => write!(fmt, ": {}", err), - ErrorKind::Tls(ref err) => write!(fmt, ": {}", err), - ErrorKind::Db(ref err) => write!(fmt, ": {}", err), - ErrorKind::Io(ref err) => write!(fmt, ": {}", err), - ErrorKind::Conversion(ref err) => write!(fmt, ": {}", err), - } - } -} - -impl error::Error for Error { - fn description(&self) -> &str { - match *self.0 { - ErrorKind::ConnectParams(_) => "invalid connection parameters", - ErrorKind::Tls(_) => "TLS handshake error", - ErrorKind::Db(_) => "database error", - ErrorKind::Io(_) => "IO error", - ErrorKind::Conversion(_) => "type conversion error", - } - } - - fn cause(&self) -> Option<&error::Error> { - match *self.0 { - ErrorKind::ConnectParams(ref err) => Some(&**err), - ErrorKind::Tls(ref err) => Some(&**err), - ErrorKind::Db(ref err) => Some(err), - ErrorKind::Io(ref err) => Some(err), - ErrorKind::Conversion(ref err) => Some(&**err), - } - } -} - -impl Error { - /// Returns the SQLSTATE error code associated with this error if it is a DB - /// error. - pub fn code(&self) -> Option<&SqlState> { - self.as_db().map(|e| &e.code) - } - - /// Returns the inner error if this is a connection parameter error. - pub fn as_connection(&self) -> Option<&(error::Error + 'static + Sync + Send)> { - match *self.0 { - ErrorKind::ConnectParams(ref err) => Some(&**err), - _ => None, - } - } - - /// Returns the `DbError` associated with this error if it is a DB error. - pub fn as_db(&self) -> Option<&DbError> { - match *self.0 { - ErrorKind::Db(ref err) => Some(err), - _ => None, - } - } - - /// Returns the inner error if this is a conversion error. - pub fn as_conversion(&self) -> Option<&(error::Error + 'static + Sync + Send)> { - match *self.0 { - ErrorKind::Conversion(ref err) => Some(&**err), - _ => None, - } - } - - /// Returns the inner `io::Error` associated with this error if it is an IO - /// error. - pub fn as_io(&self) -> Option<&io::Error> { - match *self.0 { - ErrorKind::Io(ref err) => Some(err), - _ => None, - } - } -} - -impl From for Error { - fn from(err: io::Error) -> Error { - Error(Box::new(ErrorKind::Io(err))) - } -} - -impl From for io::Error { - fn from(err: Error) -> io::Error { - match *err.0 { - ErrorKind::Io(e) => e, - _ => io::Error::new(io::ErrorKind::Other, err), - } - } -} diff --git a/postgres-shared/src/error/sqlstate.rs b/postgres-shared/src/error/sqlstate.rs deleted file mode 100644 index c8e3ec2eb..000000000 --- a/postgres-shared/src/error/sqlstate.rs +++ /dev/null @@ -1,1061 +0,0 @@ -// Autogenerated file - DO NOT EDIT -use phf; -use std::borrow::Cow; - -/// A SQLSTATE error code -#[derive(PartialEq, Eq, Clone, Debug)] -pub struct SqlState(Cow<'static, str>); - -impl SqlState { - /// Creates a `SqlState` from its error code. - pub fn from_code(s: &str) -> SqlState { - match SQLSTATE_MAP.get(s) { - Some(state) => state.clone(), - None => SqlState(Cow::Owned(s.to_string())), - } - } - - /// Returns the error code corresponding to the `SqlState`. - pub fn code(&self) -> &str { - &self.0 - } - - /// 00000 - pub const SUCCESSFUL_COMPLETION: SqlState = SqlState(Cow::Borrowed("00000")); - - /// 01000 - pub const WARNING: SqlState = SqlState(Cow::Borrowed("01000")); - - /// 0100C - pub const WARNING_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Cow::Borrowed("0100C")); - - /// 01008 - pub const WARNING_IMPLICIT_ZERO_BIT_PADDING: SqlState = SqlState(Cow::Borrowed("01008")); - - /// 01003 - pub const WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: SqlState = SqlState(Cow::Borrowed("01003")); - - /// 01007 - pub const WARNING_PRIVILEGE_NOT_GRANTED: SqlState = SqlState(Cow::Borrowed("01007")); - - /// 01006 - pub const WARNING_PRIVILEGE_NOT_REVOKED: SqlState = SqlState(Cow::Borrowed("01006")); - - /// 01004 - pub const WARNING_STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Cow::Borrowed("01004")); - - /// 01P01 - pub const WARNING_DEPRECATED_FEATURE: SqlState = SqlState(Cow::Borrowed("01P01")); - - /// 02000 - pub const NO_DATA: SqlState = SqlState(Cow::Borrowed("02000")); - - /// 02001 - pub const NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Cow::Borrowed("02001")); - - /// 03000 - pub const SQL_STATEMENT_NOT_YET_COMPLETE: SqlState = SqlState(Cow::Borrowed("03000")); - - /// 08000 - pub const CONNECTION_EXCEPTION: SqlState = SqlState(Cow::Borrowed("08000")); - - /// 08003 - pub const CONNECTION_DOES_NOT_EXIST: SqlState = SqlState(Cow::Borrowed("08003")); - - /// 08006 - pub const CONNECTION_FAILURE: SqlState = SqlState(Cow::Borrowed("08006")); - - /// 08001 - pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: SqlState = SqlState(Cow::Borrowed("08001")); - - /// 08004 - pub const SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: SqlState = SqlState(Cow::Borrowed("08004")); - - /// 08007 - pub const TRANSACTION_RESOLUTION_UNKNOWN: SqlState = SqlState(Cow::Borrowed("08007")); - - /// 08P01 - pub const PROTOCOL_VIOLATION: SqlState = SqlState(Cow::Borrowed("08P01")); - - /// 09000 - pub const TRIGGERED_ACTION_EXCEPTION: SqlState = SqlState(Cow::Borrowed("09000")); - - /// 0A000 - pub const FEATURE_NOT_SUPPORTED: SqlState = SqlState(Cow::Borrowed("0A000")); - - /// 0B000 - pub const INVALID_TRANSACTION_INITIATION: SqlState = SqlState(Cow::Borrowed("0B000")); - - /// 0F000 - pub const LOCATOR_EXCEPTION: SqlState = SqlState(Cow::Borrowed("0F000")); - - /// 0F001 - pub const L_E_INVALID_SPECIFICATION: SqlState = SqlState(Cow::Borrowed("0F001")); - - /// 0L000 - pub const INVALID_GRANTOR: SqlState = SqlState(Cow::Borrowed("0L000")); - - /// 0LP01 - pub const INVALID_GRANT_OPERATION: SqlState = SqlState(Cow::Borrowed("0LP01")); - - /// 0P000 - pub const INVALID_ROLE_SPECIFICATION: SqlState = SqlState(Cow::Borrowed("0P000")); - - /// 0Z000 - pub const DIAGNOSTICS_EXCEPTION: SqlState = SqlState(Cow::Borrowed("0Z000")); - - /// 0Z002 - pub const STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: SqlState = SqlState(Cow::Borrowed("0Z002")); - - /// 20000 - pub const CASE_NOT_FOUND: SqlState = SqlState(Cow::Borrowed("20000")); - - /// 21000 - pub const CARDINALITY_VIOLATION: SqlState = SqlState(Cow::Borrowed("21000")); - - /// 22000 - pub const DATA_EXCEPTION: SqlState = SqlState(Cow::Borrowed("22000")); - - /// 2202E - pub const ARRAY_ELEMENT_ERROR: SqlState = SqlState(Cow::Borrowed("2202E")); - - /// 2202E - pub const ARRAY_SUBSCRIPT_ERROR: SqlState = SqlState(Cow::Borrowed("2202E")); - - /// 22021 - pub const CHARACTER_NOT_IN_REPERTOIRE: SqlState = SqlState(Cow::Borrowed("22021")); - - /// 22008 - pub const DATETIME_FIELD_OVERFLOW: SqlState = SqlState(Cow::Borrowed("22008")); - - /// 22008 - pub const DATETIME_VALUE_OUT_OF_RANGE: SqlState = SqlState(Cow::Borrowed("22008")); - - /// 22012 - pub const DIVISION_BY_ZERO: SqlState = SqlState(Cow::Borrowed("22012")); - - /// 22005 - pub const ERROR_IN_ASSIGNMENT: SqlState = SqlState(Cow::Borrowed("22005")); - - /// 2200B - pub const ESCAPE_CHARACTER_CONFLICT: SqlState = SqlState(Cow::Borrowed("2200B")); - - /// 22022 - pub const INDICATOR_OVERFLOW: SqlState = SqlState(Cow::Borrowed("22022")); - - /// 22015 - pub const INTERVAL_FIELD_OVERFLOW: SqlState = SqlState(Cow::Borrowed("22015")); - - /// 2201E - pub const INVALID_ARGUMENT_FOR_LOG: SqlState = SqlState(Cow::Borrowed("2201E")); - - /// 22014 - pub const INVALID_ARGUMENT_FOR_NTILE: SqlState = SqlState(Cow::Borrowed("22014")); - - /// 22016 - pub const INVALID_ARGUMENT_FOR_NTH_VALUE: SqlState = SqlState(Cow::Borrowed("22016")); - - /// 2201F - pub const INVALID_ARGUMENT_FOR_POWER_FUNCTION: SqlState = SqlState(Cow::Borrowed("2201F")); - - /// 2201G - pub const INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: SqlState = SqlState(Cow::Borrowed("2201G")); - - /// 22018 - pub const INVALID_CHARACTER_VALUE_FOR_CAST: SqlState = SqlState(Cow::Borrowed("22018")); - - /// 22007 - pub const INVALID_DATETIME_FORMAT: SqlState = SqlState(Cow::Borrowed("22007")); - - /// 22019 - pub const INVALID_ESCAPE_CHARACTER: SqlState = SqlState(Cow::Borrowed("22019")); - - /// 2200D - pub const INVALID_ESCAPE_OCTET: SqlState = SqlState(Cow::Borrowed("2200D")); - - /// 22025 - pub const INVALID_ESCAPE_SEQUENCE: SqlState = SqlState(Cow::Borrowed("22025")); - - /// 22P06 - pub const NONSTANDARD_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Cow::Borrowed("22P06")); - - /// 22010 - pub const INVALID_INDICATOR_PARAMETER_VALUE: SqlState = SqlState(Cow::Borrowed("22010")); - - /// 22023 - pub const INVALID_PARAMETER_VALUE: SqlState = SqlState(Cow::Borrowed("22023")); - - /// 2201B - pub const INVALID_REGULAR_EXPRESSION: SqlState = SqlState(Cow::Borrowed("2201B")); - - /// 2201W - pub const INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: SqlState = SqlState(Cow::Borrowed("2201W")); - - /// 2201X - pub const INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: SqlState = SqlState(Cow::Borrowed("2201X")); - - /// 2202H - pub const INVALID_TABLESAMPLE_ARGUMENT: SqlState = SqlState(Cow::Borrowed("2202H")); - - /// 2202G - pub const INVALID_TABLESAMPLE_REPEAT: SqlState = SqlState(Cow::Borrowed("2202G")); - - /// 22009 - pub const INVALID_TIME_ZONE_DISPLACEMENT_VALUE: SqlState = SqlState(Cow::Borrowed("22009")); - - /// 2200C - pub const INVALID_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Cow::Borrowed("2200C")); - - /// 2200G - pub const MOST_SPECIFIC_TYPE_MISMATCH: SqlState = SqlState(Cow::Borrowed("2200G")); - - /// 22004 - pub const NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Cow::Borrowed("22004")); - - /// 22002 - pub const NULL_VALUE_NO_INDICATOR_PARAMETER: SqlState = SqlState(Cow::Borrowed("22002")); - - /// 22003 - pub const NUMERIC_VALUE_OUT_OF_RANGE: SqlState = SqlState(Cow::Borrowed("22003")); - - /// 2200H - pub const SEQUENCE_GENERATOR_LIMIT_EXCEEDED: SqlState = SqlState(Cow::Borrowed("2200H")); - - /// 22026 - pub const STRING_DATA_LENGTH_MISMATCH: SqlState = SqlState(Cow::Borrowed("22026")); - - /// 22001 - pub const STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Cow::Borrowed("22001")); - - /// 22011 - pub const SUBSTRING_ERROR: SqlState = SqlState(Cow::Borrowed("22011")); - - /// 22027 - pub const TRIM_ERROR: SqlState = SqlState(Cow::Borrowed("22027")); - - /// 22024 - pub const UNTERMINATED_C_STRING: SqlState = SqlState(Cow::Borrowed("22024")); - - /// 2200F - pub const ZERO_LENGTH_CHARACTER_STRING: SqlState = SqlState(Cow::Borrowed("2200F")); - - /// 22P01 - pub const FLOATING_POINT_EXCEPTION: SqlState = SqlState(Cow::Borrowed("22P01")); - - /// 22P02 - pub const INVALID_TEXT_REPRESENTATION: SqlState = SqlState(Cow::Borrowed("22P02")); - - /// 22P03 - pub const INVALID_BINARY_REPRESENTATION: SqlState = SqlState(Cow::Borrowed("22P03")); - - /// 22P04 - pub const BAD_COPY_FILE_FORMAT: SqlState = SqlState(Cow::Borrowed("22P04")); - - /// 22P05 - pub const UNTRANSLATABLE_CHARACTER: SqlState = SqlState(Cow::Borrowed("22P05")); - - /// 2200L - pub const NOT_AN_XML_DOCUMENT: SqlState = SqlState(Cow::Borrowed("2200L")); - - /// 2200M - pub const INVALID_XML_DOCUMENT: SqlState = SqlState(Cow::Borrowed("2200M")); - - /// 2200N - pub const INVALID_XML_CONTENT: SqlState = SqlState(Cow::Borrowed("2200N")); - - /// 2200S - pub const INVALID_XML_COMMENT: SqlState = SqlState(Cow::Borrowed("2200S")); - - /// 2200T - pub const INVALID_XML_PROCESSING_INSTRUCTION: SqlState = SqlState(Cow::Borrowed("2200T")); - - /// 23000 - pub const INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Cow::Borrowed("23000")); - - /// 23001 - pub const RESTRICT_VIOLATION: SqlState = SqlState(Cow::Borrowed("23001")); - - /// 23502 - pub const NOT_NULL_VIOLATION: SqlState = SqlState(Cow::Borrowed("23502")); - - /// 23503 - pub const FOREIGN_KEY_VIOLATION: SqlState = SqlState(Cow::Borrowed("23503")); - - /// 23505 - pub const UNIQUE_VIOLATION: SqlState = SqlState(Cow::Borrowed("23505")); - - /// 23514 - pub const CHECK_VIOLATION: SqlState = SqlState(Cow::Borrowed("23514")); - - /// 23P01 - pub const EXCLUSION_VIOLATION: SqlState = SqlState(Cow::Borrowed("23P01")); - - /// 24000 - pub const INVALID_CURSOR_STATE: SqlState = SqlState(Cow::Borrowed("24000")); - - /// 25000 - pub const INVALID_TRANSACTION_STATE: SqlState = SqlState(Cow::Borrowed("25000")); - - /// 25001 - pub const ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25001")); - - /// 25002 - pub const BRANCH_TRANSACTION_ALREADY_ACTIVE: SqlState = SqlState(Cow::Borrowed("25002")); - - /// 25008 - pub const HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: SqlState = SqlState(Cow::Borrowed("25008")); - - /// 25003 - pub const INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25003")); - - /// 25004 - pub const INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25004")); - - /// 25005 - pub const NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25005")); - - /// 25006 - pub const READ_ONLY_SQL_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25006")); - - /// 25007 - pub const SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: SqlState = SqlState(Cow::Borrowed("25007")); - - /// 25P01 - pub const NO_ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25P01")); - - /// 25P02 - pub const IN_FAILED_SQL_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25P02")); - - /// 25P03 - pub const IDLE_IN_TRANSACTION_SESSION_TIMEOUT: SqlState = SqlState(Cow::Borrowed("25P03")); - - /// 26000 - pub const INVALID_SQL_STATEMENT_NAME: SqlState = SqlState(Cow::Borrowed("26000")); - - /// 26000 - pub const UNDEFINED_PSTATEMENT: SqlState = SqlState(Cow::Borrowed("26000")); - - /// 27000 - pub const TRIGGERED_DATA_CHANGE_VIOLATION: SqlState = SqlState(Cow::Borrowed("27000")); - - /// 28000 - pub const INVALID_AUTHORIZATION_SPECIFICATION: SqlState = SqlState(Cow::Borrowed("28000")); - - /// 28P01 - pub const INVALID_PASSWORD: SqlState = SqlState(Cow::Borrowed("28P01")); - - /// 2B000 - pub const DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: SqlState = SqlState(Cow::Borrowed("2B000")); - - /// 2BP01 - pub const DEPENDENT_OBJECTS_STILL_EXIST: SqlState = SqlState(Cow::Borrowed("2BP01")); - - /// 2D000 - pub const INVALID_TRANSACTION_TERMINATION: SqlState = SqlState(Cow::Borrowed("2D000")); - - /// 2F000 - pub const SQL_ROUTINE_EXCEPTION: SqlState = SqlState(Cow::Borrowed("2F000")); - - /// 2F005 - pub const S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT: SqlState = SqlState(Cow::Borrowed("2F005")); - - /// 2F002 - pub const S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Cow::Borrowed("2F002")); - - /// 2F003 - pub const S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Cow::Borrowed("2F003")); - - /// 2F004 - pub const S_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Cow::Borrowed("2F004")); - - /// 34000 - pub const INVALID_CURSOR_NAME: SqlState = SqlState(Cow::Borrowed("34000")); - - /// 34000 - pub const UNDEFINED_CURSOR: SqlState = SqlState(Cow::Borrowed("34000")); - - /// 38000 - pub const EXTERNAL_ROUTINE_EXCEPTION: SqlState = SqlState(Cow::Borrowed("38000")); - - /// 38001 - pub const E_R_E_CONTAINING_SQL_NOT_PERMITTED: SqlState = SqlState(Cow::Borrowed("38001")); - - /// 38002 - pub const E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Cow::Borrowed("38002")); - - /// 38003 - pub const E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Cow::Borrowed("38003")); - - /// 38004 - pub const E_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Cow::Borrowed("38004")); - - /// 39000 - pub const EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: SqlState = SqlState(Cow::Borrowed("39000")); - - /// 39001 - pub const E_R_I_E_INVALID_SQLSTATE_RETURNED: SqlState = SqlState(Cow::Borrowed("39001")); - - /// 39004 - pub const E_R_I_E_NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Cow::Borrowed("39004")); - - /// 39P01 - pub const E_R_I_E_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Cow::Borrowed("39P01")); - - /// 39P02 - pub const E_R_I_E_SRF_PROTOCOL_VIOLATED: SqlState = SqlState(Cow::Borrowed("39P02")); - - /// 39P03 - pub const E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Cow::Borrowed("39P03")); - - /// 3B000 - pub const SAVEPOINT_EXCEPTION: SqlState = SqlState(Cow::Borrowed("3B000")); - - /// 3B001 - pub const S_E_INVALID_SPECIFICATION: SqlState = SqlState(Cow::Borrowed("3B001")); - - /// 3D000 - pub const INVALID_CATALOG_NAME: SqlState = SqlState(Cow::Borrowed("3D000")); - - /// 3D000 - pub const UNDEFINED_DATABASE: SqlState = SqlState(Cow::Borrowed("3D000")); - - /// 3F000 - pub const INVALID_SCHEMA_NAME: SqlState = SqlState(Cow::Borrowed("3F000")); - - /// 3F000 - pub const UNDEFINED_SCHEMA: SqlState = SqlState(Cow::Borrowed("3F000")); - - /// 40000 - pub const TRANSACTION_ROLLBACK: SqlState = SqlState(Cow::Borrowed("40000")); - - /// 40002 - pub const T_R_INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Cow::Borrowed("40002")); - - /// 40001 - pub const T_R_SERIALIZATION_FAILURE: SqlState = SqlState(Cow::Borrowed("40001")); - - /// 40003 - pub const T_R_STATEMENT_COMPLETION_UNKNOWN: SqlState = SqlState(Cow::Borrowed("40003")); - - /// 40P01 - pub const T_R_DEADLOCK_DETECTED: SqlState = SqlState(Cow::Borrowed("40P01")); - - /// 42000 - pub const SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: SqlState = SqlState(Cow::Borrowed("42000")); - - /// 42601 - pub const SYNTAX_ERROR: SqlState = SqlState(Cow::Borrowed("42601")); - - /// 42501 - pub const INSUFFICIENT_PRIVILEGE: SqlState = SqlState(Cow::Borrowed("42501")); - - /// 42846 - pub const CANNOT_COERCE: SqlState = SqlState(Cow::Borrowed("42846")); - - /// 42803 - pub const GROUPING_ERROR: SqlState = SqlState(Cow::Borrowed("42803")); - - /// 42P20 - pub const WINDOWING_ERROR: SqlState = SqlState(Cow::Borrowed("42P20")); - - /// 42P19 - pub const INVALID_RECURSION: SqlState = SqlState(Cow::Borrowed("42P19")); - - /// 42830 - pub const INVALID_FOREIGN_KEY: SqlState = SqlState(Cow::Borrowed("42830")); - - /// 42602 - pub const INVALID_NAME: SqlState = SqlState(Cow::Borrowed("42602")); - - /// 42622 - pub const NAME_TOO_LONG: SqlState = SqlState(Cow::Borrowed("42622")); - - /// 42939 - pub const RESERVED_NAME: SqlState = SqlState(Cow::Borrowed("42939")); - - /// 42804 - pub const DATATYPE_MISMATCH: SqlState = SqlState(Cow::Borrowed("42804")); - - /// 42P18 - pub const INDETERMINATE_DATATYPE: SqlState = SqlState(Cow::Borrowed("42P18")); - - /// 42P21 - pub const COLLATION_MISMATCH: SqlState = SqlState(Cow::Borrowed("42P21")); - - /// 42P22 - pub const INDETERMINATE_COLLATION: SqlState = SqlState(Cow::Borrowed("42P22")); - - /// 42809 - pub const WRONG_OBJECT_TYPE: SqlState = SqlState(Cow::Borrowed("42809")); - - /// 428C9 - pub const GENERATED_ALWAYS: SqlState = SqlState(Cow::Borrowed("428C9")); - - /// 42703 - pub const UNDEFINED_COLUMN: SqlState = SqlState(Cow::Borrowed("42703")); - - /// 42883 - pub const UNDEFINED_FUNCTION: SqlState = SqlState(Cow::Borrowed("42883")); - - /// 42P01 - pub const UNDEFINED_TABLE: SqlState = SqlState(Cow::Borrowed("42P01")); - - /// 42P02 - pub const UNDEFINED_PARAMETER: SqlState = SqlState(Cow::Borrowed("42P02")); - - /// 42704 - pub const UNDEFINED_OBJECT: SqlState = SqlState(Cow::Borrowed("42704")); - - /// 42701 - pub const DUPLICATE_COLUMN: SqlState = SqlState(Cow::Borrowed("42701")); - - /// 42P03 - pub const DUPLICATE_CURSOR: SqlState = SqlState(Cow::Borrowed("42P03")); - - /// 42P04 - pub const DUPLICATE_DATABASE: SqlState = SqlState(Cow::Borrowed("42P04")); - - /// 42723 - pub const DUPLICATE_FUNCTION: SqlState = SqlState(Cow::Borrowed("42723")); - - /// 42P05 - pub const DUPLICATE_PSTATEMENT: SqlState = SqlState(Cow::Borrowed("42P05")); - - /// 42P06 - pub const DUPLICATE_SCHEMA: SqlState = SqlState(Cow::Borrowed("42P06")); - - /// 42P07 - pub const DUPLICATE_TABLE: SqlState = SqlState(Cow::Borrowed("42P07")); - - /// 42712 - pub const DUPLICATE_ALIAS: SqlState = SqlState(Cow::Borrowed("42712")); - - /// 42710 - pub const DUPLICATE_OBJECT: SqlState = SqlState(Cow::Borrowed("42710")); - - /// 42702 - pub const AMBIGUOUS_COLUMN: SqlState = SqlState(Cow::Borrowed("42702")); - - /// 42725 - pub const AMBIGUOUS_FUNCTION: SqlState = SqlState(Cow::Borrowed("42725")); - - /// 42P08 - pub const AMBIGUOUS_PARAMETER: SqlState = SqlState(Cow::Borrowed("42P08")); - - /// 42P09 - pub const AMBIGUOUS_ALIAS: SqlState = SqlState(Cow::Borrowed("42P09")); - - /// 42P10 - pub const INVALID_COLUMN_REFERENCE: SqlState = SqlState(Cow::Borrowed("42P10")); - - /// 42611 - pub const INVALID_COLUMN_DEFINITION: SqlState = SqlState(Cow::Borrowed("42611")); - - /// 42P11 - pub const INVALID_CURSOR_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P11")); - - /// 42P12 - pub const INVALID_DATABASE_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P12")); - - /// 42P13 - pub const INVALID_FUNCTION_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P13")); - - /// 42P14 - pub const INVALID_PSTATEMENT_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P14")); - - /// 42P15 - pub const INVALID_SCHEMA_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P15")); - - /// 42P16 - pub const INVALID_TABLE_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P16")); - - /// 42P17 - pub const INVALID_OBJECT_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P17")); - - /// 44000 - pub const WITH_CHECK_OPTION_VIOLATION: SqlState = SqlState(Cow::Borrowed("44000")); - - /// 53000 - pub const INSUFFICIENT_RESOURCES: SqlState = SqlState(Cow::Borrowed("53000")); - - /// 53100 - pub const DISK_FULL: SqlState = SqlState(Cow::Borrowed("53100")); - - /// 53200 - pub const OUT_OF_MEMORY: SqlState = SqlState(Cow::Borrowed("53200")); - - /// 53300 - pub const TOO_MANY_CONNECTIONS: SqlState = SqlState(Cow::Borrowed("53300")); - - /// 53400 - pub const CONFIGURATION_LIMIT_EXCEEDED: SqlState = SqlState(Cow::Borrowed("53400")); - - /// 54000 - pub const PROGRAM_LIMIT_EXCEEDED: SqlState = SqlState(Cow::Borrowed("54000")); - - /// 54001 - pub const STATEMENT_TOO_COMPLEX: SqlState = SqlState(Cow::Borrowed("54001")); - - /// 54011 - pub const TOO_MANY_COLUMNS: SqlState = SqlState(Cow::Borrowed("54011")); - - /// 54023 - pub const TOO_MANY_ARGUMENTS: SqlState = SqlState(Cow::Borrowed("54023")); - - /// 55000 - pub const OBJECT_NOT_IN_PREREQUISITE_STATE: SqlState = SqlState(Cow::Borrowed("55000")); - - /// 55006 - pub const OBJECT_IN_USE: SqlState = SqlState(Cow::Borrowed("55006")); - - /// 55P02 - pub const CANT_CHANGE_RUNTIME_PARAM: SqlState = SqlState(Cow::Borrowed("55P02")); - - /// 55P03 - pub const LOCK_NOT_AVAILABLE: SqlState = SqlState(Cow::Borrowed("55P03")); - - /// 55P04 - pub const UNSAFE_NEW_ENUM_VALUE_USAGE: SqlState = SqlState(Cow::Borrowed("55P04")); - - /// 57000 - pub const OPERATOR_INTERVENTION: SqlState = SqlState(Cow::Borrowed("57000")); - - /// 57014 - pub const QUERY_CANCELED: SqlState = SqlState(Cow::Borrowed("57014")); - - /// 57P01 - pub const ADMIN_SHUTDOWN: SqlState = SqlState(Cow::Borrowed("57P01")); - - /// 57P02 - pub const CRASH_SHUTDOWN: SqlState = SqlState(Cow::Borrowed("57P02")); - - /// 57P03 - pub const CANNOT_CONNECT_NOW: SqlState = SqlState(Cow::Borrowed("57P03")); - - /// 57P04 - pub const DATABASE_DROPPED: SqlState = SqlState(Cow::Borrowed("57P04")); - - /// 58000 - pub const SYSTEM_ERROR: SqlState = SqlState(Cow::Borrowed("58000")); - - /// 58030 - pub const IO_ERROR: SqlState = SqlState(Cow::Borrowed("58030")); - - /// 58P01 - pub const UNDEFINED_FILE: SqlState = SqlState(Cow::Borrowed("58P01")); - - /// 58P02 - pub const DUPLICATE_FILE: SqlState = SqlState(Cow::Borrowed("58P02")); - - /// 72000 - pub const SNAPSHOT_TOO_OLD: SqlState = SqlState(Cow::Borrowed("72000")); - - /// F0000 - pub const CONFIG_FILE_ERROR: SqlState = SqlState(Cow::Borrowed("F0000")); - - /// F0001 - pub const LOCK_FILE_EXISTS: SqlState = SqlState(Cow::Borrowed("F0001")); - - /// HV000 - pub const FDW_ERROR: SqlState = SqlState(Cow::Borrowed("HV000")); - - /// HV005 - pub const FDW_COLUMN_NAME_NOT_FOUND: SqlState = SqlState(Cow::Borrowed("HV005")); - - /// HV002 - pub const FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: SqlState = SqlState(Cow::Borrowed("HV002")); - - /// HV010 - pub const FDW_FUNCTION_SEQUENCE_ERROR: SqlState = SqlState(Cow::Borrowed("HV010")); - - /// HV021 - pub const FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: SqlState = SqlState(Cow::Borrowed("HV021")); - - /// HV024 - pub const FDW_INVALID_ATTRIBUTE_VALUE: SqlState = SqlState(Cow::Borrowed("HV024")); - - /// HV007 - pub const FDW_INVALID_COLUMN_NAME: SqlState = SqlState(Cow::Borrowed("HV007")); - - /// HV008 - pub const FDW_INVALID_COLUMN_NUMBER: SqlState = SqlState(Cow::Borrowed("HV008")); - - /// HV004 - pub const FDW_INVALID_DATA_TYPE: SqlState = SqlState(Cow::Borrowed("HV004")); - - /// HV006 - pub const FDW_INVALID_DATA_TYPE_DESCRIPTORS: SqlState = SqlState(Cow::Borrowed("HV006")); - - /// HV091 - pub const FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: SqlState = SqlState(Cow::Borrowed("HV091")); - - /// HV00B - pub const FDW_INVALID_HANDLE: SqlState = SqlState(Cow::Borrowed("HV00B")); - - /// HV00C - pub const FDW_INVALID_OPTION_INDEX: SqlState = SqlState(Cow::Borrowed("HV00C")); - - /// HV00D - pub const FDW_INVALID_OPTION_NAME: SqlState = SqlState(Cow::Borrowed("HV00D")); - - /// HV090 - pub const FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: SqlState = SqlState(Cow::Borrowed("HV090")); - - /// HV00A - pub const FDW_INVALID_STRING_FORMAT: SqlState = SqlState(Cow::Borrowed("HV00A")); - - /// HV009 - pub const FDW_INVALID_USE_OF_NULL_POINTER: SqlState = SqlState(Cow::Borrowed("HV009")); - - /// HV014 - pub const FDW_TOO_MANY_HANDLES: SqlState = SqlState(Cow::Borrowed("HV014")); - - /// HV001 - pub const FDW_OUT_OF_MEMORY: SqlState = SqlState(Cow::Borrowed("HV001")); - - /// HV00P - pub const FDW_NO_SCHEMAS: SqlState = SqlState(Cow::Borrowed("HV00P")); - - /// HV00J - pub const FDW_OPTION_NAME_NOT_FOUND: SqlState = SqlState(Cow::Borrowed("HV00J")); - - /// HV00K - pub const FDW_REPLY_HANDLE: SqlState = SqlState(Cow::Borrowed("HV00K")); - - /// HV00Q - pub const FDW_SCHEMA_NOT_FOUND: SqlState = SqlState(Cow::Borrowed("HV00Q")); - - /// HV00R - pub const FDW_TABLE_NOT_FOUND: SqlState = SqlState(Cow::Borrowed("HV00R")); - - /// HV00L - pub const FDW_UNABLE_TO_CREATE_EXECUTION: SqlState = SqlState(Cow::Borrowed("HV00L")); - - /// HV00M - pub const FDW_UNABLE_TO_CREATE_REPLY: SqlState = SqlState(Cow::Borrowed("HV00M")); - - /// HV00N - pub const FDW_UNABLE_TO_ESTABLISH_CONNECTION: SqlState = SqlState(Cow::Borrowed("HV00N")); - - /// P0000 - pub const PLPGSQL_ERROR: SqlState = SqlState(Cow::Borrowed("P0000")); - - /// P0001 - pub const RAISE_EXCEPTION: SqlState = SqlState(Cow::Borrowed("P0001")); - - /// P0002 - pub const NO_DATA_FOUND: SqlState = SqlState(Cow::Borrowed("P0002")); - - /// P0003 - pub const TOO_MANY_ROWS: SqlState = SqlState(Cow::Borrowed("P0003")); - - /// P0004 - pub const ASSERT_FAILURE: SqlState = SqlState(Cow::Borrowed("P0004")); - - /// XX000 - pub const INTERNAL_ERROR: SqlState = SqlState(Cow::Borrowed("XX000")); - - /// XX001 - pub const DATA_CORRUPTED: SqlState = SqlState(Cow::Borrowed("XX001")); - - /// XX002 - pub const INDEX_CORRUPTED: SqlState = SqlState(Cow::Borrowed("XX002")); -} -#[cfg_attr(rustfmt, rustfmt_skip)] -static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = ::phf::Map { - key: 1897749892740154578, - disps: ::phf::Slice::Static(&[ - (1, 99), - (0, 0), - (1, 5), - (0, 3), - (0, 110), - (0, 54), - (0, 3), - (0, 13), - (0, 0), - (0, 24), - (0, 214), - (0, 52), - (1, 34), - (0, 33), - (0, 44), - (0, 130), - (0, 16), - (0, 187), - (0, 3), - (13, 168), - (0, 4), - (0, 19), - (0, 13), - (0, 87), - (0, 0), - (0, 108), - (0, 123), - (7, 181), - (0, 109), - (0, 32), - (0, 0), - (1, 69), - (1, 81), - (1, 219), - (0, 157), - (2, 41), - (8, 141), - (0, 5), - (0, 0), - (1, 6), - (0, 3), - (1, 146), - (1, 227), - (9, 94), - (10, 158), - (29, 65), - (3, 2), - (0, 33), - (1, 94), - ]), - entries: ::phf::Slice::Static(&[ - ("23001", SqlState::RESTRICT_VIOLATION), - ("42830", SqlState::INVALID_FOREIGN_KEY), - ("P0000", SqlState::PLPGSQL_ERROR), - ("58000", SqlState::SYSTEM_ERROR), - ("57P01", SqlState::ADMIN_SHUTDOWN), - ("22P04", SqlState::BAD_COPY_FILE_FORMAT), - ("42P05", SqlState::DUPLICATE_PSTATEMENT), - ("28000", SqlState::INVALID_AUTHORIZATION_SPECIFICATION), - ("2202E", SqlState::ARRAY_ELEMENT_ERROR), - ("2F005", SqlState::S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT), - ("53400", SqlState::CONFIGURATION_LIMIT_EXCEEDED), - ("20000", SqlState::CASE_NOT_FOUND), - ("25004", SqlState::INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION), - ("09000", SqlState::TRIGGERED_ACTION_EXCEPTION), - ("42P10", SqlState::INVALID_COLUMN_REFERENCE), - ("39P03", SqlState::E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED), - ("08000", SqlState::CONNECTION_EXCEPTION), - ("08006", SqlState::CONNECTION_FAILURE), - ("2201W", SqlState::INVALID_ROW_COUNT_IN_LIMIT_CLAUSE), - ("03000", SqlState::SQL_STATEMENT_NOT_YET_COMPLETE), - ("22014", SqlState::INVALID_ARGUMENT_FOR_NTILE), - ("42611", SqlState::INVALID_COLUMN_DEFINITION), - ("42P11", SqlState::INVALID_CURSOR_DEFINITION), - ("2200N", SqlState::INVALID_XML_CONTENT), - ("57014", SqlState::QUERY_CANCELED), - ("01003", SqlState::WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION), - ("01000", SqlState::WARNING), - ("55P04", SqlState::UNSAFE_NEW_ENUM_VALUE_USAGE), - ("25003", SqlState::INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION), - ("2200L", SqlState::NOT_AN_XML_DOCUMENT), - ("42846", SqlState::CANNOT_COERCE), - ("55P03", SqlState::LOCK_NOT_AVAILABLE), - ("08007", SqlState::TRANSACTION_RESOLUTION_UNKNOWN), - ("XX000", SqlState::INTERNAL_ERROR), - ("22005", SqlState::ERROR_IN_ASSIGNMENT), - ("22P03", SqlState::INVALID_BINARY_REPRESENTATION), - ("2201X", SqlState::INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE), - ("54011", SqlState::TOO_MANY_COLUMNS), - ("HV008", SqlState::FDW_INVALID_COLUMN_NUMBER), - ("HV009", SqlState::FDW_INVALID_USE_OF_NULL_POINTER), - ("0LP01", SqlState::INVALID_GRANT_OPERATION), - ("42704", SqlState::UNDEFINED_OBJECT), - ("25005", SqlState::NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION), - ("25P03", SqlState::IDLE_IN_TRANSACTION_SESSION_TIMEOUT), - ("44000", SqlState::WITH_CHECK_OPTION_VIOLATION), - ("22024", SqlState::UNTERMINATED_C_STRING), - ("0L000", SqlState::INVALID_GRANTOR), - ("40000", SqlState::TRANSACTION_ROLLBACK), - ("42P08", SqlState::AMBIGUOUS_PARAMETER), - ("38000", SqlState::EXTERNAL_ROUTINE_EXCEPTION), - ("42939", SqlState::RESERVED_NAME), - ("40001", SqlState::T_R_SERIALIZATION_FAILURE), - ("HV00K", SqlState::FDW_REPLY_HANDLE), - ("2F002", SqlState::S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), - ("HV001", SqlState::FDW_OUT_OF_MEMORY), - ("42P19", SqlState::INVALID_RECURSION), - ("HV002", SqlState::FDW_DYNAMIC_PARAMETER_VALUE_NEEDED), - ("0A000", SqlState::FEATURE_NOT_SUPPORTED), - ("58P02", SqlState::DUPLICATE_FILE), - ("25006", SqlState::READ_ONLY_SQL_TRANSACTION), - ("22009", SqlState::INVALID_TIME_ZONE_DISPLACEMENT_VALUE), - ("0F001", SqlState::L_E_INVALID_SPECIFICATION), - ("P0002", SqlState::NO_DATA_FOUND), - ("2F000", SqlState::SQL_ROUTINE_EXCEPTION), - ("01006", SqlState::WARNING_PRIVILEGE_NOT_REVOKED), - ("22025", SqlState::INVALID_ESCAPE_SEQUENCE), - ("22027", SqlState::TRIM_ERROR), - ("54001", SqlState::STATEMENT_TOO_COMPLEX), - ("42602", SqlState::INVALID_NAME), - ("54023", SqlState::TOO_MANY_ARGUMENTS), - ("2200T", SqlState::INVALID_XML_PROCESSING_INSTRUCTION), - ("01007", SqlState::WARNING_PRIVILEGE_NOT_GRANTED), - ("22000", SqlState::DATA_EXCEPTION), - ("28P01", SqlState::INVALID_PASSWORD), - ("23514", SqlState::CHECK_VIOLATION), - ("39P02", SqlState::E_R_I_E_SRF_PROTOCOL_VIOLATED), - ("57P02", SqlState::CRASH_SHUTDOWN), - ("42P03", SqlState::DUPLICATE_CURSOR), - ("22021", SqlState::CHARACTER_NOT_IN_REPERTOIRE), - ("HV00P", SqlState::FDW_NO_SCHEMAS), - ("42701", SqlState::DUPLICATE_COLUMN), - ("42P15", SqlState::INVALID_SCHEMA_DEFINITION), - ("HV00B", SqlState::FDW_INVALID_HANDLE), - ("34000", SqlState::INVALID_CURSOR_NAME), - ("22P06", SqlState::NONSTANDARD_USE_OF_ESCAPE_CHARACTER), - ("P0001", SqlState::RAISE_EXCEPTION), - ("08P01", SqlState::PROTOCOL_VIOLATION), - ("42723", SqlState::DUPLICATE_FUNCTION), - ("08001", SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), - ("HV006", SqlState::FDW_INVALID_DATA_TYPE_DESCRIPTORS), - ("23000", SqlState::INTEGRITY_CONSTRAINT_VIOLATION), - ("42712", SqlState::DUPLICATE_ALIAS), - ("2201G", SqlState::INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION), - ("2200F", SqlState::ZERO_LENGTH_CHARACTER_STRING), - ("XX002", SqlState::INDEX_CORRUPTED), - ("53300", SqlState::TOO_MANY_CONNECTIONS), - ("38002", SqlState::E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), - ("22015", SqlState::INTERVAL_FIELD_OVERFLOW), - ("22P01", SqlState::FLOATING_POINT_EXCEPTION), - ("22012", SqlState::DIVISION_BY_ZERO), - ("XX001", SqlState::DATA_CORRUPTED), - ("0100C", SqlState::WARNING_DYNAMIC_RESULT_SETS_RETURNED), - ("42P01", SqlState::UNDEFINED_TABLE), - ("25002", SqlState::BRANCH_TRANSACTION_ALREADY_ACTIVE), - ("2D000", SqlState::INVALID_TRANSACTION_TERMINATION), - ("P0004", SqlState::ASSERT_FAILURE), - ("2200C", SqlState::INVALID_USE_OF_ESCAPE_CHARACTER), - ("HV00R", SqlState::FDW_TABLE_NOT_FOUND), - ("22016", SqlState::INVALID_ARGUMENT_FOR_NTH_VALUE), - ("01P01", SqlState::WARNING_DEPRECATED_FEATURE), - ("F0000", SqlState::CONFIG_FILE_ERROR), - ("0Z000", SqlState::DIAGNOSTICS_EXCEPTION), - ("42P02", SqlState::UNDEFINED_PARAMETER), - ("2200S", SqlState::INVALID_XML_COMMENT), - ("2200H", SqlState::SEQUENCE_GENERATOR_LIMIT_EXCEEDED), - ("HV00C", SqlState::FDW_INVALID_OPTION_INDEX), - ("38004", SqlState::E_R_E_READING_SQL_DATA_NOT_PERMITTED), - ("42703", SqlState::UNDEFINED_COLUMN), - ("23503", SqlState::FOREIGN_KEY_VIOLATION), - ("42000", SqlState::SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION), - ("22004", SqlState::NULL_VALUE_NOT_ALLOWED), - ("25008", SqlState::HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL), - ("22018", SqlState::INVALID_CHARACTER_VALUE_FOR_CAST), - ("22023", SqlState::INVALID_PARAMETER_VALUE), - ("22011", SqlState::SUBSTRING_ERROR), - ("40002", SqlState::T_R_INTEGRITY_CONSTRAINT_VIOLATION), - ("42803", SqlState::GROUPING_ERROR), - ("72000", SqlState::SNAPSHOT_TOO_OLD), - ("HV010", SqlState::FDW_FUNCTION_SEQUENCE_ERROR), - ("42809", SqlState::WRONG_OBJECT_TYPE), - ("42P16", SqlState::INVALID_TABLE_DEFINITION), - ("HV00D", SqlState::FDW_INVALID_OPTION_NAME), - ("39000", SqlState::EXTERNAL_ROUTINE_INVOCATION_EXCEPTION), - ("2202G", SqlState::INVALID_TABLESAMPLE_REPEAT), - ("42601", SqlState::SYNTAX_ERROR), - ("42622", SqlState::NAME_TOO_LONG), - ("HV00L", SqlState::FDW_UNABLE_TO_CREATE_EXECUTION), - ("25000", SqlState::INVALID_TRANSACTION_STATE), - ("3B000", SqlState::SAVEPOINT_EXCEPTION), - ("42P21", SqlState::COLLATION_MISMATCH), - ("23505", SqlState::UNIQUE_VIOLATION), - ("22001", SqlState::STRING_DATA_RIGHT_TRUNCATION), - ("02001", SqlState::NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED), - ("21000", SqlState::CARDINALITY_VIOLATION), - ("58P01", SqlState::UNDEFINED_FILE), - ("HV091", SqlState::FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER), - ("25P01", SqlState::NO_ACTIVE_SQL_TRANSACTION), - ("40P01", SqlState::T_R_DEADLOCK_DETECTED), - ("HV021", SqlState::FDW_INCONSISTENT_DESCRIPTOR_INFORMATION), - ("42P09", SqlState::AMBIGUOUS_ALIAS), - ("25007", SqlState::SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED), - ("23P01", SqlState::EXCLUSION_VIOLATION), - ("HV00J", SqlState::FDW_OPTION_NAME_NOT_FOUND), - ("58030", SqlState::IO_ERROR), - ("HV004", SqlState::FDW_INVALID_DATA_TYPE), - ("42710", SqlState::DUPLICATE_OBJECT), - ("HV090", SqlState::FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH), - ("42P18", SqlState::INDETERMINATE_DATATYPE), - ("HV00M", SqlState::FDW_UNABLE_TO_CREATE_REPLY), - ("42804", SqlState::DATATYPE_MISMATCH), - ("24000", SqlState::INVALID_CURSOR_STATE), - ("HV007", SqlState::FDW_INVALID_COLUMN_NAME), - ("2201E", SqlState::INVALID_ARGUMENT_FOR_LOG), - ("42P22", SqlState::INDETERMINATE_COLLATION), - ("22P05", SqlState::UNTRANSLATABLE_CHARACTER), - ("42P07", SqlState::DUPLICATE_TABLE), - ("2F004", SqlState::S_R_E_READING_SQL_DATA_NOT_PERMITTED), - ("23502", SqlState::NOT_NULL_VIOLATION), - ("57000", SqlState::OPERATOR_INTERVENTION), - ("HV000", SqlState::FDW_ERROR), - ("42883", SqlState::UNDEFINED_FUNCTION), - ("2201B", SqlState::INVALID_REGULAR_EXPRESSION), - ("2200D", SqlState::INVALID_ESCAPE_OCTET), - ("42P06", SqlState::DUPLICATE_SCHEMA), - ("38003", SqlState::E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), - ("22026", SqlState::STRING_DATA_LENGTH_MISMATCH), - ("P0003", SqlState::TOO_MANY_ROWS), - ("3D000", SqlState::INVALID_CATALOG_NAME), - ("0B000", SqlState::INVALID_TRANSACTION_INITIATION), - ("55006", SqlState::OBJECT_IN_USE), - ("53200", SqlState::OUT_OF_MEMORY), - ("3F000", SqlState::INVALID_SCHEMA_NAME), - ("53100", SqlState::DISK_FULL), - ("2F003", SqlState::S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), - ("55P02", SqlState::CANT_CHANGE_RUNTIME_PARAM), - ("01004", SqlState::WARNING_STRING_DATA_RIGHT_TRUNCATION), - ("3B001", SqlState::S_E_INVALID_SPECIFICATION), - ("2200G", SqlState::MOST_SPECIFIC_TYPE_MISMATCH), - ("428C9", SqlState::GENERATED_ALWAYS), - ("HV005", SqlState::FDW_COLUMN_NAME_NOT_FOUND), - ("2201F", SqlState::INVALID_ARGUMENT_FOR_POWER_FUNCTION), - ("22022", SqlState::INDICATOR_OVERFLOW), - ("HV00Q", SqlState::FDW_SCHEMA_NOT_FOUND), - ("0F000", SqlState::LOCATOR_EXCEPTION), - ("22002", SqlState::NULL_VALUE_NO_INDICATOR_PARAMETER), - ("02000", SqlState::NO_DATA), - ("2202H", SqlState::INVALID_TABLESAMPLE_ARGUMENT), - ("27000", SqlState::TRIGGERED_DATA_CHANGE_VIOLATION), - ("2BP01", SqlState::DEPENDENT_OBJECTS_STILL_EXIST), - ("55000", SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE), - ("39001", SqlState::E_R_I_E_INVALID_SQLSTATE_RETURNED), - ("08004", SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION), - ("42P13", SqlState::INVALID_FUNCTION_DEFINITION), - ("HV024", SqlState::FDW_INVALID_ATTRIBUTE_VALUE), - ("22019", SqlState::INVALID_ESCAPE_CHARACTER), - ("54000", SqlState::PROGRAM_LIMIT_EXCEEDED), - ("42501", SqlState::INSUFFICIENT_PRIVILEGE), - ("HV00A", SqlState::FDW_INVALID_STRING_FORMAT), - ("42702", SqlState::AMBIGUOUS_COLUMN), - ("53000", SqlState::INSUFFICIENT_RESOURCES), - ("25P02", SqlState::IN_FAILED_SQL_TRANSACTION), - ("22010", SqlState::INVALID_INDICATOR_PARAMETER_VALUE), - ("01008", SqlState::WARNING_IMPLICIT_ZERO_BIT_PADDING), - ("HV014", SqlState::FDW_TOO_MANY_HANDLES), - ("42P20", SqlState::WINDOWING_ERROR), - ("42725", SqlState::AMBIGUOUS_FUNCTION), - ("F0001", SqlState::LOCK_FILE_EXISTS), - ("08003", SqlState::CONNECTION_DOES_NOT_EXIST), - ("2200M", SqlState::INVALID_XML_DOCUMENT), - ("22003", SqlState::NUMERIC_VALUE_OUT_OF_RANGE), - ("39004", SqlState::E_R_I_E_NULL_VALUE_NOT_ALLOWED), - ("2200B", SqlState::ESCAPE_CHARACTER_CONFLICT), - ("0P000", SqlState::INVALID_ROLE_SPECIFICATION), - ("00000", SqlState::SUCCESSFUL_COMPLETION), - ("22P02", SqlState::INVALID_TEXT_REPRESENTATION), - ("25001", SqlState::ACTIVE_SQL_TRANSACTION), - ("HV00N", SqlState::FDW_UNABLE_TO_ESTABLISH_CONNECTION), - ("39P01", SqlState::E_R_I_E_TRIGGER_PROTOCOL_VIOLATED), - ("2B000", SqlState::DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST), - ("22008", SqlState::DATETIME_FIELD_OVERFLOW), - ("42P14", SqlState::INVALID_PSTATEMENT_DEFINITION), - ("57P04", SqlState::DATABASE_DROPPED), - ("26000", SqlState::INVALID_SQL_STATEMENT_NAME), - ("42P17", SqlState::INVALID_OBJECT_DEFINITION), - ("42P04", SqlState::DUPLICATE_DATABASE), - ("38001", SqlState::E_R_E_CONTAINING_SQL_NOT_PERMITTED), - ("0Z002", SqlState::STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER), - ("22007", SqlState::INVALID_DATETIME_FORMAT), - ("40003", SqlState::T_R_STATEMENT_COMPLETION_UNKNOWN), - ("42P12", SqlState::INVALID_DATABASE_DEFINITION), - ("57P03", SqlState::CANNOT_CONNECT_NOW), - ]), -}; diff --git a/postgres-shared/src/lib.rs b/postgres-shared/src/lib.rs deleted file mode 100644 index 521779b34..000000000 --- a/postgres-shared/src/lib.rs +++ /dev/null @@ -1,32 +0,0 @@ -#![allow(unknown_lints)] // for clippy - -extern crate hex; -extern crate fallible_iterator; -extern crate phf; -extern crate postgres_protocol; - -pub mod error; -pub mod params; -pub mod types; -pub mod rows; -pub mod stmt; - -/// Contains information necessary to cancel queries for a session. -#[derive(Copy, Clone, Debug)] -pub struct CancelData { - /// The process ID of the session. - pub process_id: i32, - /// The secret key for the session. - pub secret_key: i32, -} - -/// An asynchronous notification. -#[derive(Clone, Debug)] -pub struct Notification { - /// The process ID of the notifying backend process. - pub process_id: i32, - /// The name of the channel that the notify has been raised on. - pub channel: String, - /// The "payload" string passed from the notifying process. - pub payload: String, -} diff --git a/postgres-shared/src/params/mod.rs b/postgres-shared/src/params/mod.rs deleted file mode 100644 index 296483f9d..000000000 --- a/postgres-shared/src/params/mod.rs +++ /dev/null @@ -1,295 +0,0 @@ -//! Connection parameters -use std::error::Error; -use std::mem; -use std::path::PathBuf; -use std::str::FromStr; -use std::time::Duration; - -use error; -use params::url::Url; - -mod url; - -/// The host. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum Host { - /// A TCP hostname. - Tcp(String), - /// The path to a directory containing the server's Unix socket. - Unix(PathBuf), -} - -/// Authentication information. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct User { - name: String, - password: Option, -} - -impl User { - /// The username. - pub fn name(&self) -> &str { - &self.name - } - - /// An optional password. - pub fn password(&self) -> Option<&str> { - self.password.as_ref().map(|p| &**p) - } -} - -/// Information necessary to open a new connection to a Postgres server. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct ConnectParams { - host: Host, - port: u16, - user: Option, - database: Option, - options: Vec<(String, String)>, - connect_timeout: Option, - keepalive: Option, -} - -impl ConnectParams { - /// Returns a new builder. - pub fn builder() -> Builder { - Builder::new() - } - - /// The target host. - pub fn host(&self) -> &Host { - &self.host - } - - /// The target port. - /// - /// Defaults to 5432. - pub fn port(&self) -> u16 { - self.port - } - - /// The user to log in as. - /// - /// A user is required to open a new connection but not to cancel a query. - pub fn user(&self) -> Option<&User> { - self.user.as_ref() - } - - /// The database to connect to. - pub fn database(&self) -> Option<&str> { - self.database.as_ref().map(|d| &**d) - } - - /// Runtime parameters to be passed to the Postgres backend. - pub fn options(&self) -> &[(String, String)] { - &self.options - } - - /// A timeout to apply to each socket-level connection attempt. - pub fn connect_timeout(&self) -> Option { - self.connect_timeout - } - - /// The interval at which TCP keepalive messages are sent on the socket. - /// - /// This is ignored for Unix sockets. - pub fn keepalive(&self) -> Option { - self.keepalive - } -} - -impl FromStr for ConnectParams { - type Err = error::Error; - - fn from_str(s: &str) -> Result { - s.into_connect_params().map_err(error::connect) - } -} - -/// A builder for `ConnectParams`. -pub struct Builder { - port: u16, - user: Option, - database: Option, - options: Vec<(String, String)>, - connect_timeout: Option, - keepalive: Option, -} - -impl Builder { - /// Creates a new builder. - pub fn new() -> Builder { - Builder { - port: 5432, - user: None, - database: None, - options: vec![], - connect_timeout: None, - keepalive: None, - } - } - - /// Sets the port. - pub fn port(&mut self, port: u16) -> &mut Builder { - self.port = port; - self - } - - /// Sets the user. - pub fn user(&mut self, name: &str, password: Option<&str>) -> &mut Builder { - self.user = Some(User { - name: name.to_string(), - password: password.map(ToString::to_string), - }); - self - } - - /// Sets the database. - pub fn database(&mut self, database: &str) -> &mut Builder { - self.database = Some(database.to_string()); - self - } - - /// Adds a runtime parameter. - pub fn option(&mut self, name: &str, value: &str) -> &mut Builder { - self.options.push((name.to_string(), value.to_string())); - self - } - - /// Sets the connection timeout. - pub fn connect_timeout(&mut self, connect_timeout: Option) -> &mut Builder { - self.connect_timeout = connect_timeout; - self - } - - /// Sets the keepalive interval. - pub fn keepalive(&mut self, keepalive: Option) -> &mut Builder { - self.keepalive = keepalive; - self - } - - /// Constructs a `ConnectParams` from the builder. - pub fn build(&mut self, host: Host) -> ConnectParams { - ConnectParams { - host: host, - port: self.port, - user: self.user.take(), - database: self.database.take(), - options: mem::replace(&mut self.options, vec![]), - connect_timeout: self.connect_timeout, - keepalive: self.keepalive, - } - } -} - -/// A trait implemented by types that can be converted into a `ConnectParams`. -pub trait IntoConnectParams { - /// Converts the value of `self` into a `ConnectParams`. - fn into_connect_params(self) -> Result>; -} - -impl IntoConnectParams for ConnectParams { - fn into_connect_params(self) -> Result> { - Ok(self) - } -} - -impl<'a> IntoConnectParams for &'a str { - fn into_connect_params(self) -> Result> { - match Url::parse(self) { - Ok(url) => url.into_connect_params(), - Err(err) => Err(err.into()), - } - } -} - -impl IntoConnectParams for String { - fn into_connect_params(self) -> Result> { - self.as_str().into_connect_params() - } -} - -impl IntoConnectParams for Url { - fn into_connect_params(self) -> Result> { - let Url { - host, - port, - user, - path: - url::Path { - path, - query: options, - .. - }, - .. - } = self; - - let mut builder = ConnectParams::builder(); - - if let Some(port) = port { - builder.port(port); - } - - if let Some(info) = user { - builder.user(&info.user, info.pass.as_ref().map(|p| &**p)); - } - - if !path.is_empty() { - // path contains the leading / - builder.database(&path[1..]); - } - - for (name, value) in options { - match &*name { - "connect_timeout" => { - let timeout = value.parse().map_err(|_| "invalid connect_timeout")?; - let timeout = Duration::from_secs(timeout); - builder.connect_timeout(Some(timeout)); - } - "keepalive" => { - let keepalive = value.parse().map_err(|_| "invalid keepalive")?; - let keepalive = Duration::from_secs(keepalive); - builder.keepalive(Some(keepalive)); - } - _ => { - builder.option(&name, &value); - } - } - } - - let maybe_path = url::decode_component(&host)?; - let host = if maybe_path.starts_with('/') { - Host::Unix(maybe_path.into()) - } else { - Host::Tcp(maybe_path) - }; - - Ok(builder.build(host)) - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn parse_url() { - let params = "postgres://user@host:44/dbname?connect_timeout=10&application_name=foo"; - let params = params.into_connect_params().unwrap(); - assert_eq!( - params.user(), - Some(&User { - name: "user".to_string(), - password: None, - }) - ); - assert_eq!(params.host(), &Host::Tcp("host".to_string())); - assert_eq!(params.port(), 44); - assert_eq!(params.database(), Some("dbname")); - assert_eq!( - params.options(), - &[("application_name".to_string(), "foo".to_string())][..] - ); - assert_eq!(params.connect_timeout(), Some(Duration::from_secs(10))); - } -} diff --git a/postgres-shared/src/params/url.rs b/postgres-shared/src/params/url.rs deleted file mode 100644 index 549beebfb..000000000 --- a/postgres-shared/src/params/url.rs +++ /dev/null @@ -1,438 +0,0 @@ -// Copyright 2012-2014 The Rust Project Developers. See the COPYRIGHT -// file at the top-level directory of this distribution and at -// http://rust-lang.org/COPYRIGHT. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. -use std::str::FromStr; -use hex::FromHex; - -pub struct Url { - pub scheme: String, - pub user: Option, - pub host: String, - pub port: Option, - pub path: Path, -} - -pub struct Path { - pub path: String, - pub query: Query, - pub fragment: Option, -} - -pub struct UserInfo { - pub user: String, - pub pass: Option, -} - -pub type Query = Vec<(String, String)>; - -impl Url { - pub fn new( - scheme: String, - user: Option, - host: String, - port: Option, - path: String, - query: Query, - fragment: Option, - ) -> Url { - Url { - scheme: scheme, - user: user, - host: host, - port: port, - path: Path::new(path, query, fragment), - } - } - - pub fn parse(rawurl: &str) -> DecodeResult { - // scheme - let (scheme, rest) = get_scheme(rawurl)?; - - // authority - let (userinfo, host, port, rest) = get_authority(rest)?; - - // path - let has_authority = !host.is_empty(); - let (path, rest) = get_path(rest, has_authority)?; - - // query and fragment - let (query, fragment) = get_query_fragment(rest)?; - - let url = Url::new( - scheme.to_owned(), - userinfo, - host.to_owned(), - port, - path, - query, - fragment, - ); - Ok(url) - } -} - -impl Path { - pub fn new(path: String, query: Query, fragment: Option) -> Path { - Path { - path: path, - query: query, - fragment: fragment, - } - } - - pub fn parse(rawpath: &str) -> DecodeResult { - let (path, rest) = get_path(rawpath, false)?; - - // query and fragment - let (query, fragment) = get_query_fragment(&rest)?; - - Ok(Path { - path: path, - query: query, - fragment: fragment, - }) - } -} - -impl UserInfo { - #[inline] - pub fn new(user: String, pass: Option) -> UserInfo { - UserInfo { - user: user, - pass: pass, - } - } -} - -pub type DecodeResult = Result; - -pub fn decode_component(container: &str) -> DecodeResult { - decode_inner(container, false) -} - -fn decode_inner(c: &str, full_url: bool) -> DecodeResult { - let mut out = String::new(); - let mut iter = c.as_bytes().iter().cloned(); - - loop { - match iter.next() { - Some(b) => { - match b as char { - '%' => { - let bytes = match (iter.next(), iter.next()) { - (Some(one), Some(two)) => [one, two], - _ => { - return Err( - "Malformed input: found '%' without two \ - trailing bytes" - .to_owned(), - ) - } - }; - - let bytes_from_hex = match Vec::::from_hex(&bytes) { - Ok(b) => b, - _ => { - return Err( - "Malformed input: found '%' followed by \ - invalid hex values. Character '%' must \ - escaped." - .to_owned(), - ) - } - }; - - // Only decode some characters if full_url: - match bytes_from_hex[0] as char { - // gen-delims: - ':' | '/' | '?' | '#' | '[' | ']' | '@' | '!' | '$' | '&' | '"' | - '(' | ')' | '*' | '+' | ',' | ';' | '=' if full_url => { - out.push('%'); - out.push(bytes[0] as char); - out.push(bytes[1] as char); - } - - ch => out.push(ch), - } - } - ch => out.push(ch), - } - } - None => return Ok(out), - } - } -} - -fn split_char_first(s: &str, c: char) -> (&str, &str) { - let mut iter = s.splitn(2, c); - - match (iter.next(), iter.next()) { - (Some(a), Some(b)) => (a, b), - (Some(a), None) => (a, ""), - (None, _) => unreachable!(), - } -} - -fn query_from_str(rawquery: &str) -> DecodeResult { - let mut query: Query = vec![]; - if !rawquery.is_empty() { - for p in rawquery.split('&') { - let (k, v) = split_char_first(p, '='); - query.push((decode_component(k)?, decode_component(v)?)); - } - } - - Ok(query) -} - -pub fn get_scheme(rawurl: &str) -> DecodeResult<(&str, &str)> { - for (i, c) in rawurl.chars().enumerate() { - let result = match c { - 'A'...'Z' | 'a'...'z' => continue, - '0'...'9' | '+' | '-' | '.' => { - if i != 0 { - continue; - } - - Err("url: Scheme must begin with a letter.".to_owned()) - } - ':' => { - if i == 0 { - Err("url: Scheme cannot be empty.".to_owned()) - } else { - Ok((&rawurl[0..i], &rawurl[i + 1..rawurl.len()])) - } - } - _ => Err("url: Invalid character in scheme.".to_owned()), - }; - - return result; - } - - Err("url: Scheme must be terminated with a colon.".to_owned()) -} - -// returns userinfo, host, port, and unparsed part, or an error -fn get_authority(rawurl: &str) -> DecodeResult<(Option, &str, Option, &str)> { - enum State { - Start, // starting state - PassHostPort, // could be in user or port - Ip6Port, // either in ipv6 host or port - Ip6Host, // are in an ipv6 host - InHost, // are in a host - may be ipv6, but don't know yet - InPort, // are in port - } - - #[derive(Clone, PartialEq)] - enum Input { - Digit, // all digits - Hex, // digits and letters a-f - Unreserved, // all other legal characters - } - - if !rawurl.starts_with("//") { - // there is no authority. - return Ok((None, "", None, rawurl)); - } - - let len = rawurl.len(); - let mut st = State::Start; - let mut input = Input::Digit; // most restricted, start here. - - let mut userinfo = None; - let mut host = ""; - let mut port = None; - - let mut colon_count = 0usize; - let mut pos = 0; - let mut begin = 2; - let mut end = len; - - for (i, c) in rawurl.chars().enumerate().skip(2) { - // deal with input class first - match c { - '0'...'9' => (), - 'A'...'F' | 'a'...'f' => { - if input == Input::Digit { - input = Input::Hex; - } - } - 'G'...'Z' | 'g'...'z' | '-' | '.' | '_' | '~' | '%' | '&' | '\'' | '(' | ')' | - '+' | '!' | '*' | ',' | ';' | '=' => input = Input::Unreserved, - ':' | '@' | '?' | '#' | '/' => { - // separators, don't change anything - } - _ => return Err("Illegal character in authority".to_owned()), - } - - // now process states - match c { - ':' => { - colon_count += 1; - match st { - State::Start => { - pos = i; - st = State::PassHostPort; - } - State::PassHostPort => { - // multiple colons means ipv6 address. - if input == Input::Unreserved { - return Err("Illegal characters in IPv6 address.".to_owned()); - } - st = State::Ip6Host; - } - State::InHost => { - pos = i; - if input == Input::Unreserved { - // must be port - host = &rawurl[begin..i]; - st = State::InPort; - } else { - // can't be sure whether this is an ipv6 address or a port - st = State::Ip6Port; - } - } - State::Ip6Port => { - if input == Input::Unreserved { - return Err("Illegal characters in authority.".to_owned()); - } - st = State::Ip6Host; - } - State::Ip6Host => { - if colon_count > 7 { - host = &rawurl[begin..i]; - pos = i; - st = State::InPort; - } - } - _ => return Err("Invalid ':' in authority.".to_owned()), - } - input = Input::Digit; // reset input class - } - - '@' => { - input = Input::Digit; // reset input class - colon_count = 0; // reset count - match st { - State::Start => { - let user = decode_component(&rawurl[begin..i])?; - userinfo = Some(UserInfo::new(user, None)); - st = State::InHost; - } - State::PassHostPort => { - let user = decode_component(&rawurl[begin..pos])?; - let pass = decode_component(&rawurl[pos + 1..i])?; - userinfo = Some(UserInfo::new(user, Some(pass))); - st = State::InHost; - } - _ => return Err("Invalid '@' in authority.".to_owned()), - } - begin = i + 1; - } - - '?' | '#' | '/' => { - end = i; - break; - } - _ => (), - } - } - - // finish up - match st { - State::PassHostPort | State::Ip6Port => { - if input != Input::Digit { - return Err("Non-digit characters in port.".to_owned()); - } - host = &rawurl[begin..pos]; - port = Some(&rawurl[pos + 1..end]); - } - State::Ip6Host | State::InHost | State::Start => host = &rawurl[begin..end], - State::InPort => { - if input != Input::Digit { - return Err("Non-digit characters in port.".to_owned()); - } - port = Some(&rawurl[pos + 1..end]); - } - } - - let rest = &rawurl[end..len]; - // If we have a port string, ensure it parses to u16. - let port = match port { - None => None, - opt => { - match opt.and_then(|p| FromStr::from_str(p).ok()) { - None => return Err(format!("Failed to parse port: {:?}", port)), - opt => opt, - } - } - }; - - Ok((userinfo, host, port, rest)) -} - - -// returns the path and unparsed part of url, or an error -fn get_path(rawurl: &str, is_authority: bool) -> DecodeResult<(String, &str)> { - let len = rawurl.len(); - let mut end = len; - for (i, c) in rawurl.chars().enumerate() { - match c { - 'A'...'Z' | 'a'...'z' | '0'...'9' | '&' | '\'' | '(' | ')' | '.' | '@' | ':' | - '%' | '/' | '+' | '!' | '*' | ',' | ';' | '=' | '_' | '-' | '~' => continue, - '?' | '#' => { - end = i; - break; - } - _ => return Err("Invalid character in path.".to_owned()), - } - } - - if is_authority && end != 0 && !rawurl.starts_with('/') { - Err( - "Non-empty path must begin with '/' in presence of authority.".to_owned(), - ) - } else { - Ok((decode_component(&rawurl[0..end])?, &rawurl[end..len])) - } -} - -// returns the parsed query and the fragment, if present -fn get_query_fragment(rawurl: &str) -> DecodeResult<(Query, Option)> { - let (before_fragment, raw_fragment) = split_char_first(rawurl, '#'); - - // Parse the fragment if available - let fragment = match raw_fragment { - "" => None, - raw => Some(decode_component(raw)?), - }; - - match before_fragment.chars().next() { - Some('?') => Ok((query_from_str(&before_fragment[1..])?, fragment)), - None => Ok((vec![], fragment)), - _ => Err(format!( - "Query didn't start with '?': '{}..'", - before_fragment - )), - } -} - -impl FromStr for Url { - type Err = String; - fn from_str(s: &str) -> Result { - Url::parse(s) - } -} - -impl FromStr for Path { - type Err = String; - fn from_str(s: &str) -> Result { - Path::parse(s) - } -} diff --git a/postgres-shared/src/rows.rs b/postgres-shared/src/rows.rs deleted file mode 100644 index 48de53601..000000000 --- a/postgres-shared/src/rows.rs +++ /dev/null @@ -1,93 +0,0 @@ -use fallible_iterator::FallibleIterator; -use postgres_protocol::message::backend::DataRowBody; -use std::io; -use std::ops::Range; - -use rows::sealed::Sealed; -use stmt::Column; - -mod sealed { - use stmt::Column; - - pub trait Sealed { - fn __idx(&self, stmt: &[Column]) -> Option; - } -} - -/// A trait implemented by types that can index into columns of a row. -/// -/// This cannot be implemented outside of this crate. -pub trait RowIndex: Sealed {} - -impl Sealed for usize { - #[inline] - fn __idx(&self, stmt: &[Column]) -> Option { - if *self >= stmt.len() { - None - } else { - Some(*self) - } - } -} - -impl RowIndex for usize {} - -impl Sealed for str { - #[inline] - fn __idx(&self, stmt: &[Column]) -> Option { - if let Some(idx) = stmt.iter().position(|d| d.name() == self) { - return Some(idx); - }; - - // FIXME ASCII-only case insensitivity isn't really the right thing to - // do. Postgres itself uses a dubious wrapper around tolower and JDBC - // uses the US locale. - stmt.iter() - .position(|d| d.name().eq_ignore_ascii_case(self)) - } -} - -impl RowIndex for str {} - -impl<'a, T> Sealed for &'a T -where - T: ?Sized + Sealed, -{ - #[inline] - fn __idx(&self, columns: &[Column]) -> Option { - T::__idx(*self, columns) - } -} - -impl<'a, T> RowIndex for &'a T -where - T: ?Sized + Sealed, -{ -} - -#[doc(hidden)] -pub struct RowData { - body: DataRowBody, - ranges: Vec>>, -} - -impl RowData { - pub fn new(body: DataRowBody) -> io::Result { - let ranges = body.ranges().collect()?; - Ok(RowData { - body: body, - ranges: ranges, - }) - } - - pub fn len(&self) -> usize { - self.ranges.len() - } - - pub fn get(&self, index: usize) -> Option<&[u8]> { - match &self.ranges[index] { - &Some(ref range) => Some(&self.body.buffer()[range.clone()]), - &None => None, - } - } -} diff --git a/postgres-shared/src/stmt.rs b/postgres-shared/src/stmt.rs deleted file mode 100644 index 85a993daf..000000000 --- a/postgres-shared/src/stmt.rs +++ /dev/null @@ -1,28 +0,0 @@ -use types::Type; - -/// Information about a column of a Postgres query. -#[derive(Debug)] -pub struct Column { - name: String, - type_: Type, -} - -impl Column { - #[doc(hidden)] - pub fn new(name: String, type_: Type) -> Column { - Column { - name: name, - type_: type_, - } - } - - /// Returns the name of the column. - pub fn name(&self) -> &str { - &self.name - } - - /// Returns the type of the column. - pub fn type_(&self) -> &Type { - &self.type_ - } -} diff --git a/postgres-shared/src/types/geo.rs b/postgres-shared/src/types/geo.rs deleted file mode 100644 index c09b0c566..000000000 --- a/postgres-shared/src/types/geo.rs +++ /dev/null @@ -1,70 +0,0 @@ -extern crate geo; - -use self::geo::{Coordinate, LineString, Point, Rect}; -use fallible_iterator::FallibleIterator; -use postgres_protocol::types; -use std::error::Error; - -use types::{FromSql, IsNull, ToSql, Type}; - -impl<'a> FromSql<'a> for Point { - fn from_sql(_: &Type, raw: &[u8]) -> Result> { - let point = types::point_from_sql(raw)?; - Ok(Point::new(point.x(), point.y())) - } - - accepts!(POINT); -} - -impl ToSql for Point { - fn to_sql(&self, _: &Type, out: &mut Vec) -> Result> { - types::point_to_sql(self.x(), self.y(), out); - Ok(IsNull::No) - } - - accepts!(POINT); - to_sql_checked!(); -} - -impl<'a> FromSql<'a> for Rect { - fn from_sql(_: &Type, raw: &[u8]) -> Result> { - let rect = types::box_from_sql(raw)?; - Ok(Rect { - min: Coordinate { x: rect.lower_left().x(), y: rect.lower_left().y(), }, - max: Coordinate { x: rect.upper_right().x(), y: rect.upper_right().y(), }, - }) - } - - accepts!(BOX); -} - -impl ToSql for Rect { - fn to_sql(&self, _: &Type, out: &mut Vec) -> Result> { - types::box_to_sql(self.min.x, self.min.y, self.max.x, self.max.y, out); - Ok(IsNull::No) - } - - accepts!(BOX); - to_sql_checked!(); -} - -impl<'a> FromSql<'a> for LineString { - fn from_sql(_: &Type, raw: &[u8]) -> Result> { - let path = types::path_from_sql(raw)?; - let points = path.points().map(|p| Coordinate { x: p.x(), y: p.y() }).collect()?; - Ok(LineString(points)) - } - - accepts!(PATH); -} - -impl ToSql for LineString { - fn to_sql(&self, _: &Type, out: &mut Vec) -> Result> { - let closed = false; // always encode an open path from LineString - types::path_to_sql(closed, self.0.iter().map(|p| (p.x, p.y)), out)?; - Ok(IsNull::No) - } - - accepts!(PATH); - to_sql_checked!(); -} diff --git a/postgres-shared/src/types/mod.rs b/postgres-shared/src/types/mod.rs deleted file mode 100644 index 457248164..000000000 --- a/postgres-shared/src/types/mod.rs +++ /dev/null @@ -1,786 +0,0 @@ -//! Types. - -use fallible_iterator::FallibleIterator; -use postgres_protocol; -use postgres_protocol::types::{self, ArrayDimension}; -use std::borrow::Cow; -use std::collections::HashMap; -use std::error::Error; -use std::fmt; -use std::sync::Arc; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; - -use types::type_gen::{Inner, Other}; - -#[doc(inline)] -pub use postgres_protocol::Oid; - -pub use types::special::{Date, Timestamp}; - -// Number of seconds from 1970-01-01 to 2000-01-01 -const TIME_SEC_CONVERSION: u64 = 946684800; -const USEC_PER_SEC: u64 = 1_000_000; -const NSEC_PER_USEC: u64 = 1_000; - -/// Generates a simple implementation of `ToSql::accepts` which accepts the -/// types passed to it. -#[macro_export] -macro_rules! accepts { - ($($expected:ident),+) => ( - fn accepts(ty: &$crate::types::Type) -> bool { - match *ty { - $($crate::types::Type::$expected)|+ => true, - _ => false - } - } - ) -} - -/// Generates an implementation of `ToSql::to_sql_checked`. -/// -/// All `ToSql` implementations should use this macro. -#[macro_export] -macro_rules! to_sql_checked { - () => { - fn to_sql_checked(&self, - ty: &$crate::types::Type, - out: &mut ::std::vec::Vec) - -> ::std::result::Result<$crate::types::IsNull, - Box<::std::error::Error + - ::std::marker::Sync + - ::std::marker::Send>> { - $crate::types::__to_sql_checked(self, ty, out) - } - } -} - -// WARNING: this function is not considered part of this crate's public API. -// It is subject to change at any time. -#[doc(hidden)] -pub fn __to_sql_checked( - v: &T, - ty: &Type, - out: &mut Vec, -) -> Result> -where - T: ToSql, -{ - if !T::accepts(ty) { - return Err(Box::new(WrongType(ty.clone()))); - } - v.to_sql(ty, out) -} - -#[cfg(feature = "with-bit-vec-0.5")] -mod bit_vec; -#[cfg(feature = "with-chrono-0.4")] -mod chrono; -#[cfg(feature = "with-eui48-0.3")] -mod eui48; -#[cfg(feature = "with-geo-0.10")] -mod geo; -#[cfg(feature = "with-serde_json-1")] -mod serde_json; -#[cfg(feature = "with-uuid-0.6")] -mod uuid; - -mod special; -mod type_gen; - -/// A Postgres type. -#[derive(PartialEq, Eq, Clone, Debug)] -pub struct Type(Inner); - -impl fmt::Display for Type { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match self.schema() { - "public" | "pg_catalog" => {} - schema => write!(fmt, "{}.", schema)?, - } - fmt.write_str(self.name()) - } -} - -impl Type { - // WARNING: this is not considered public API - #[doc(hidden)] - pub fn _new(name: String, oid: Oid, kind: Kind, schema: String) -> Type { - Type(Inner::Other(Arc::new(Other { - name: name, - oid: oid, - kind: kind, - schema: schema, - }))) - } - - /// Returns the `Type` corresponding to the provided `Oid` if it - /// corresponds to a built-in type. - pub fn from_oid(oid: Oid) -> Option { - Inner::from_oid(oid).map(Type) - } - - /// Returns the OID of the `Type`. - pub fn oid(&self) -> Oid { - self.0.oid() - } - - /// Returns the kind of this type. - pub fn kind(&self) -> &Kind { - self.0.kind() - } - - /// Returns the schema of this type. - pub fn schema(&self) -> &str { - match self.0 { - Inner::Other(ref u) => &u.schema, - _ => "pg_catalog", - } - } - - /// Returns the name of this type. - pub fn name(&self) -> &str { - self.0.name() - } -} - -/// Represents the kind of a Postgres type. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum Kind { - /// A simple type like `VARCHAR` or `INTEGER`. - Simple, - /// An enumerated type along with its variants. - Enum(Vec), - /// A pseudo-type. - Pseudo, - /// An array type along with the type of its elements. - Array(Type), - /// A range type along with the type of its elements. - Range(Type), - /// A domain type along with its underlying type. - Domain(Type), - /// A composite type along with information about its fields. - Composite(Vec), - #[doc(hidden)] - __PseudoPrivateForExtensibility, -} - -/// Information about a field of a composite type. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Field { - name: String, - type_: Type, -} - -impl Field { - /// Returns the name of the field. - pub fn name(&self) -> &str { - &self.name - } - - /// Returns the type of the field. - pub fn type_(&self) -> &Type { - &self.type_ - } -} - -impl Field { - #[doc(hidden)] - pub fn new(name: String, type_: Type) -> Field { - Field { - name: name, - type_: type_, - } - } -} - -/// An error indicating that a `NULL` Postgres value was passed to a `FromSql` -/// implementation that does not support `NULL` values. -#[derive(Debug, Clone, Copy)] -pub struct WasNull; - -impl fmt::Display for WasNull { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.write_str(self.description()) - } -} - -impl Error for WasNull { - fn description(&self) -> &str { - "a Postgres value was `NULL`" - } -} - -/// An error indicating that a conversion was attempted between incompatible -/// Rust and Postgres types. -#[derive(Debug)] -pub struct WrongType(Type); - -impl fmt::Display for WrongType { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!( - fmt, - "cannot convert to or from a Postgres value of type `{}`", - self.0 - ) - } -} - -impl Error for WrongType { - fn description(&self) -> &str { - "cannot convert to or from a Postgres value" - } -} - -impl WrongType { - #[doc(hidden)] - pub fn new(ty: Type) -> WrongType { - WrongType(ty) - } -} - -/// A trait for types that can be created from a Postgres value. -/// -/// # Types -/// -/// The following implementations are provided by this crate, along with the -/// corresponding Postgres types: -/// -/// | Rust type | Postgres type(s) | -/// |-----------------------------------|-----------------------------------------------| -/// | `bool` | BOOL | -/// | `i8` | "char" | -/// | `i16` | SMALLINT, SMALLSERIAL | -/// | `i32` | INT, SERIAL | -/// | `u32` | OID | -/// | `i64` | BIGINT, BIGSERIAL | -/// | `f32` | REAL | -/// | `f64` | DOUBLE PRECISION | -/// | `&str`/`String` | VARCHAR, CHAR(n), TEXT, CITEXT, NAME, UNKNOWN | -/// | `&[u8]`/`Vec` | BYTEA | -/// | `HashMap>` | HSTORE | -/// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | -/// -/// In addition, some implementations are provided for types in third party -/// crates. These are disabled by default; to opt into one of these -/// implementations, activate the Cargo feature corresponding to the crate's -/// name prefixed by `with-`. For example, the `with-serde_json` feature enables -/// the implementation for the `serde_json::Value` type. -/// -/// | Rust type | Postgres type(s) | -/// |---------------------------------|-------------------------------------| -/// | `serialize::json::Json` | JSON, JSONB | -/// | `serde_json::Value` | JSON, JSONB | -/// | `time::Timespec` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | -/// | `chrono::NaiveDateTime` | TIMESTAMP | -/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | -/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | -/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | -/// | `chrono::NaiveDate` | DATE | -/// | `chrono::NaiveTime` | TIME | -/// | `eui48::MacAddress` | MACADDR | -/// | `uuid::Uuid` | UUID | -/// | `bit_vec::BitVec` | BIT, VARBIT | -/// | `eui48::MacAddress` | MACADDR | -/// -/// # Nullability -/// -/// In addition to the types listed above, `FromSql` is implemented for -/// `Option` where `T` implements `FromSql`. An `Option` represents a -/// nullable Postgres value. -/// -/// # Arrays -/// -/// `FromSql` is implemented for `Vec` where `T` implements `FromSql`, and -/// corresponds to one-dimensional Postgres arrays. -pub trait FromSql<'a>: Sized { - /// Creates a new value of this type from a buffer of data of the specified - /// Postgres `Type` in its binary format. - /// - /// The caller of this method is responsible for ensuring that this type - /// is compatible with the Postgres `Type`. - fn from_sql(ty: &Type, raw: &'a [u8]) -> Result>; - - /// Creates a new value of this type from a `NULL` SQL value. - /// - /// The caller of this method is responsible for ensuring that this type - /// is compatible with the Postgres `Type`. - /// - /// The default implementation returns - /// `Err(Box::new(WasNull))`. - #[allow(unused_variables)] - fn from_sql_null(ty: &Type) -> Result> { - Err(Box::new(WasNull)) - } - - /// A convenience function that delegates to `from_sql` and `from_sql_null` depending on the - /// value of `raw`. - fn from_sql_nullable( - ty: &Type, - raw: Option<&'a [u8]>, - ) -> Result> { - match raw { - Some(raw) => Self::from_sql(ty, raw), - None => Self::from_sql_null(ty), - } - } - - /// Determines if a value of this type can be created from the specified - /// Postgres `Type`. - fn accepts(ty: &Type) -> bool; -} - -/// A trait for types which can be created from a Postgres value without borrowing any data. -/// -/// This is primarily useful for trait bounds on functions. -pub trait FromSqlOwned: for<'a> FromSql<'a> {} - -impl FromSqlOwned for T -where - T: for<'a> FromSql<'a>, -{ -} - -impl<'a, T: FromSql<'a>> FromSql<'a> for Option { - fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { - ::from_sql(ty, raw).map(Some) - } - - fn from_sql_null(_: &Type) -> Result, Box> { - Ok(None) - } - - fn accepts(ty: &Type) -> bool { - ::accepts(ty) - } -} - -impl<'a, T: FromSql<'a>> FromSql<'a> for Vec { - fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { - let member_type = match *ty.kind() { - Kind::Array(ref member) => member, - _ => panic!("expected array type"), - }; - - let array = types::array_from_sql(raw)?; - if array.dimensions().count()? > 1 { - return Err("array contains too many dimensions".into()); - } - - array - .values() - .and_then(|v| T::from_sql_nullable(member_type, v)) - .collect() - } - - fn accepts(ty: &Type) -> bool { - match *ty.kind() { - Kind::Array(ref inner) => T::accepts(inner), - _ => false, - } - } -} - -impl<'a> FromSql<'a> for Vec { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result, Box> { - Ok(types::bytea_from_sql(raw).to_owned()) - } - - accepts!(BYTEA); -} - -impl<'a> FromSql<'a> for &'a [u8] { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result<&'a [u8], Box> { - Ok(types::bytea_from_sql(raw)) - } - - accepts!(BYTEA); -} - -impl<'a> FromSql<'a> for String { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result> { - types::text_from_sql(raw).map(|b| b.to_owned()) - } - - fn accepts(ty: &Type) -> bool { - <&str as FromSql>::accepts(ty) - } -} - -impl<'a> FromSql<'a> for &'a str { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result<&'a str, Box> { - types::text_from_sql(raw) - } - - fn accepts(ty: &Type) -> bool { - match *ty { - Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty if ty.name() == "citext" => true, - _ => false, - } - } -} - -macro_rules! simple_from { - ($t:ty, $f:ident, $($expected:ident),+) => { - impl<'a> FromSql<'a> for $t { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result<$t, Box> { - types::$f(raw) - } - - accepts!($($expected),+); - } - } -} - -simple_from!(bool, bool_from_sql, BOOL); -simple_from!(i8, char_from_sql, CHAR); -simple_from!(i16, int2_from_sql, INT2); -simple_from!(i32, int4_from_sql, INT4); -simple_from!(u32, oid_from_sql, OID); -simple_from!(i64, int8_from_sql, INT8); -simple_from!(f32, float4_from_sql, FLOAT4); -simple_from!(f64, float8_from_sql, FLOAT8); - -impl<'a> FromSql<'a> for HashMap> { - fn from_sql( - _: &Type, - raw: &'a [u8], - ) -> Result>, Box> { - types::hstore_from_sql(raw)? - .map(|(k, v)| (k.to_owned(), v.map(str::to_owned))) - .collect() - } - - fn accepts(ty: &Type) -> bool { - ty.name() == "hstore" - } -} - -impl<'a> FromSql<'a> for SystemTime { - fn from_sql(_: &Type, raw: &'a [u8]) -> Result> { - let time = types::timestamp_from_sql(raw)?; - let epoch = UNIX_EPOCH + Duration::from_secs(TIME_SEC_CONVERSION); - - let negative = time < 0; - let time = time.abs() as u64; - - let secs = time / USEC_PER_SEC; - let nsec = (time % USEC_PER_SEC) * NSEC_PER_USEC; - let offset = Duration::new(secs, nsec as u32); - - let time = if negative { - epoch - offset - } else { - epoch + offset - }; - - Ok(time) - } - - accepts!(TIMESTAMP, TIMESTAMPTZ); -} - -/// An enum representing the nullability of a Postgres value. -pub enum IsNull { - /// The value is NULL. - Yes, - /// The value is not NULL. - No, -} - -/// A trait for types that can be converted into Postgres values. -/// -/// # Types -/// -/// The following implementations are provided by this crate, along with the -/// corresponding Postgres types: -/// -/// | Rust type | Postgres type(s) | -/// |-----------------------------------|--------------------------------------| -/// | `bool` | BOOL | -/// | `i8` | "char" | -/// | `i16` | SMALLINT, SMALLSERIAL | -/// | `i32` | INT, SERIAL | -/// | `u32` | OID | -/// | `i64` | BIGINT, BIGSERIAL | -/// | `f32` | REAL | -/// | `f64` | DOUBLE PRECISION | -/// | `&str`/`String` | VARCHAR, CHAR(n), TEXT, CITEXT, NAME | -/// | `&[u8]`/Vec` | BYTEA | -/// | `HashMap>` | HSTORE | -/// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | -/// -/// In addition, some implementations are provided for types in third party -/// crates. These are disabled by default; to opt into one of these -/// implementations, activate the Cargo feature corresponding to the crate's -/// name prefixed by `with-`. For example, the `with-serde_json` feature enables -/// the implementation for the `serde_json::Value` type. -/// -/// | Rust type | Postgres type(s) | -/// |---------------------------------|-------------------------------------| -/// | `serialize::json::Json` | JSON, JSONB | -/// | `serde_json::Value` | JSON, JSONB | -/// | `time::Timespec` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | -/// | `chrono::NaiveDateTime` | TIMESTAMP | -/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | -/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | -/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | -/// | `chrono::NaiveDate` | DATE | -/// | `chrono::NaiveTime` | TIME | -/// | `uuid::Uuid` | UUID | -/// | `bit_vec::BitVec` | BIT, VARBIT | -/// | `eui48::MacAddress` | MACADDR | -/// -/// # Nullability -/// -/// In addition to the types listed above, `ToSql` is implemented for -/// `Option` where `T` implements `ToSql`. An `Option` represents a -/// nullable Postgres value. -/// -/// # Arrays -/// -/// `ToSql` is implemented for `Vec` and `&[T]` where `T` implements `ToSql`, -/// and corresponds to one-dimentional Postgres arrays with an index offset of 1. -pub trait ToSql: fmt::Debug { - /// Converts the value of `self` into the binary format of the specified - /// Postgres `Type`, appending it to `out`. - /// - /// The caller of this method is responsible for ensuring that this type - /// is compatible with the Postgres `Type`. - /// - /// The return value indicates if this value should be represented as - /// `NULL`. If this is the case, implementations **must not** write - /// anything to `out`. - fn to_sql(&self, ty: &Type, out: &mut Vec) -> Result> - where - Self: Sized; - - /// Determines if a value of this type can be converted to the specified - /// Postgres `Type`. - fn accepts(ty: &Type) -> bool - where - Self: Sized; - - /// An adaptor method used internally by Rust-Postgres. - /// - /// *All* implementations of this method should be generated by the - /// `to_sql_checked!()` macro. - fn to_sql_checked( - &self, - ty: &Type, - out: &mut Vec, - ) -> Result>; -} - -impl<'a, T> ToSql for &'a T -where - T: ToSql, -{ - fn to_sql(&self, ty: &Type, out: &mut Vec) -> Result> { - (*self).to_sql(ty, out) - } - - fn accepts(ty: &Type) -> bool { - T::accepts(ty) - } - - to_sql_checked!(); -} - -impl ToSql for Option { - fn to_sql(&self, ty: &Type, out: &mut Vec) -> Result> { - match *self { - Some(ref val) => val.to_sql(ty, out), - None => Ok(IsNull::Yes), - } - } - - fn accepts(ty: &Type) -> bool { - ::accepts(ty) - } - - to_sql_checked!(); -} - -impl<'a, T: ToSql> ToSql for &'a [T] { - fn to_sql(&self, ty: &Type, w: &mut Vec) -> Result> { - let member_type = match *ty.kind() { - Kind::Array(ref member) => member, - _ => panic!("expected array type"), - }; - - let dimension = ArrayDimension { - len: downcast(self.len())?, - lower_bound: 1, - }; - - types::array_to_sql( - Some(dimension), - member_type.oid(), - self.iter(), - |e, w| match e.to_sql(member_type, w)? { - IsNull::No => Ok(postgres_protocol::IsNull::No), - IsNull::Yes => Ok(postgres_protocol::IsNull::Yes), - }, - w, - )?; - Ok(IsNull::No) - } - - fn accepts(ty: &Type) -> bool { - match *ty.kind() { - Kind::Array(ref member) => T::accepts(member), - _ => false, - } - } - - to_sql_checked!(); -} - -impl<'a> ToSql for &'a [u8] { - fn to_sql(&self, _: &Type, w: &mut Vec) -> Result> { - types::bytea_to_sql(*self, w); - Ok(IsNull::No) - } - - accepts!(BYTEA); - - to_sql_checked!(); -} - -impl ToSql for Vec { - fn to_sql(&self, ty: &Type, w: &mut Vec) -> Result> { - <&[T] as ToSql>::to_sql(&&**self, ty, w) - } - - fn accepts(ty: &Type) -> bool { - <&[T] as ToSql>::accepts(ty) - } - - to_sql_checked!(); -} - -impl ToSql for Vec { - fn to_sql(&self, ty: &Type, w: &mut Vec) -> Result> { - <&[u8] as ToSql>::to_sql(&&**self, ty, w) - } - - fn accepts(ty: &Type) -> bool { - <&[u8] as ToSql>::accepts(ty) - } - - to_sql_checked!(); -} - -impl<'a> ToSql for &'a str { - fn to_sql(&self, _: &Type, w: &mut Vec) -> Result> { - types::text_to_sql(*self, w); - Ok(IsNull::No) - } - - fn accepts(ty: &Type) -> bool { - match *ty { - Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty if ty.name() == "citext" => true, - _ => false, - } - } - - to_sql_checked!(); -} - -impl<'a> ToSql for Cow<'a, str> { - fn to_sql(&self, ty: &Type, w: &mut Vec) -> Result> { - <&str as ToSql>::to_sql(&&self.as_ref(), ty, w) - } - - fn accepts(ty: &Type) -> bool { - <&str as ToSql>::accepts(ty) - } - - to_sql_checked!(); -} - -impl ToSql for String { - fn to_sql(&self, ty: &Type, w: &mut Vec) -> Result> { - <&str as ToSql>::to_sql(&&**self, ty, w) - } - - fn accepts(ty: &Type) -> bool { - <&str as ToSql>::accepts(ty) - } - - to_sql_checked!(); -} - -macro_rules! simple_to { - ($t:ty, $f:ident, $($expected:ident),+) => { - impl ToSql for $t { - fn to_sql(&self, - _: &Type, - w: &mut Vec) - -> Result> { - types::$f(*self, w); - Ok(IsNull::No) - } - - accepts!($($expected),+); - - to_sql_checked!(); - } - } -} - -simple_to!(bool, bool_to_sql, BOOL); -simple_to!(i8, char_to_sql, CHAR); -simple_to!(i16, int2_to_sql, INT2); -simple_to!(i32, int4_to_sql, INT4); -simple_to!(u32, oid_to_sql, OID); -simple_to!(i64, int8_to_sql, INT8); -simple_to!(f32, float4_to_sql, FLOAT4); -simple_to!(f64, float8_to_sql, FLOAT8); - -impl ToSql for HashMap> { - fn to_sql(&self, _: &Type, w: &mut Vec) -> Result> { - types::hstore_to_sql( - self.iter().map(|(k, v)| (&**k, v.as_ref().map(|v| &**v))), - w, - )?; - Ok(IsNull::No) - } - - fn accepts(ty: &Type) -> bool { - ty.name() == "hstore" - } - - to_sql_checked!(); -} - -impl ToSql for SystemTime { - fn to_sql(&self, _: &Type, w: &mut Vec) -> Result> { - let epoch = UNIX_EPOCH + Duration::from_secs(TIME_SEC_CONVERSION); - - let to_usec = - |d: Duration| d.as_secs() * USEC_PER_SEC + (d.subsec_nanos() as u64) / NSEC_PER_USEC; - - let time = match self.duration_since(epoch) { - Ok(duration) => to_usec(duration) as i64, - Err(e) => -(to_usec(e.duration()) as i64), - }; - - types::timestamp_to_sql(time, w); - Ok(IsNull::No) - } - - accepts!(TIMESTAMP, TIMESTAMPTZ); - - to_sql_checked!(); -} - -fn downcast(len: usize) -> Result> { - if len > i32::max_value() as usize { - Err("value too large to transmit".into()) - } else { - Ok(len as i32) - } -} diff --git a/postgres-shared/src/types/serde_json.rs b/postgres-shared/src/types/serde_json.rs deleted file mode 100644 index 53a63d587..000000000 --- a/postgres-shared/src/types/serde_json.rs +++ /dev/null @@ -1,36 +0,0 @@ -extern crate serde_json; - -use self::serde_json::Value; -use std::error::Error; -use std::io::{Read, Write}; - -use types::{FromSql, IsNull, ToSql, Type}; - -impl<'a> FromSql<'a> for Value { - fn from_sql(ty: &Type, mut raw: &[u8]) -> Result> { - if *ty == Type::JSONB { - let mut b = [0; 1]; - raw.read_exact(&mut b)?; - // We only support version 1 of the jsonb binary format - if b[0] != 1 { - return Err("unsupported JSONB encoding version".into()); - } - } - serde_json::de::from_reader(raw).map_err(Into::into) - } - - accepts!(JSON, JSONB); -} - -impl ToSql for Value { - fn to_sql(&self, ty: &Type, out: &mut Vec) -> Result> { - if *ty == Type::JSONB { - out.push(1); - } - write!(out, "{}", self)?; - Ok(IsNull::No) - } - - accepts!(JSON, JSONB); - to_sql_checked!(); -} diff --git a/postgres-types/CHANGELOG.md b/postgres-types/CHANGELOG.md new file mode 100644 index 000000000..7fa6d6506 --- /dev/null +++ b/postgres-types/CHANGELOG.md @@ -0,0 +1,139 @@ +# Change Log + +## Unreleased + +## v0.2.9 - 2025-02-02 + +### Added + +* Added support for `cidr` 0.3 via the `with-cidr-0_3` feature. + +### Fixed + +* Fixed deserialization of out of bounds inputs to `time` 0.3 types to return an error rather than panic. + +## v0.2.8 - 2024-09-15 + +### Added + +* Added support for `jiff` 0.1 via the `with-jiff-01` feature. + +## v0.2.7 - 2024-07-21 + +### Added + +* Added `Default` implementation for `Json`. +* Added a `js` feature for WASM compatibility. + +### Changed + +* `FromStr` implementation for `PgLsn` no longer allocates a `Vec` when splitting an lsn string on it's `/`. +* The `eui48-1` feature no longer enables default features of the `eui48` library. + +## v0.2.6 - 2023-08-19 + +### Fixed + +* Fixed serialization to `OIDVECTOR` and `INT2VECTOR`. + +### Added + +* Removed the `'static` requirement for the `impl BorrowToSql for Box`. +* Added a `ToSql` implementation for `Cow<[u8]>`. + +## v0.2.5 - 2023-03-27 + +### Added + +* Added support for multi-range types. + +## v0.2.4 - 2022-08-20 + +### Added + +* Added `ToSql` and `FromSql` implementations for `Box<[T]>`. +* Added `ToSql` and `FromSql` implementations for `[u8; N]` via the `array-impls` feature. +* Added support for `smol_str` 0.1 via the `with-smol_str-01` feature. +* Added `ToSql::encode_format` to support text encodings of parameters. + +## v0.2.3 - 2022-04-30 + +### Added + +* Added `ToSql` and `FromSql` implementations for `Box`. +* Added `BorrowToSql` implementations for `Box` and `Box`. +* Added support for `cidr` 0.2 via the `with-cidr-02` feature. +* Added conversions between the `LTREE`, `LQUERY` and `LTXTQUERY` types and Rust strings. +* Added support for `uuid` 1.0 via the `with-uuid-1` feature. + +## v0.2.2 - 2021-09-29 + +### Added + +* Added support for `eui48` 1.0 via the `with-eui48-1` feature. +* Added `ToSql` and `FromSql` implementations for array types via the `array-impls` feature. +* Added support for `time` 0.3 via the `with-time-0_3` feature. + +## v0.2.1 - 2021-04-03 + +### Added + +* Added support for `geo-types` 0.7 via `with-geo-types-0_7` feature. +* Added the `PgLsn` type, corresponding to `PG_LSN`. + +## v0.2.0 - 2020-12-25 + +### Changed + +* Upgraded `bytes` to 1.0. + +### Removed + +* Removed support for `geo-types` 0.4. + +## v0.1.3 - 2020-10-17 + +### Added + +* Implemented `Clone`, `PartialEq`, and `Eq` for `Json`. + +### Fixed + +* Checked for overflow in `NaiveDate` and `NaiveDateTime` conversions. + +## v0.1.2 - 2020-07-03 + +### Added + +* Added support for `geo-types` 0.6. + +## v0.1.1 - 2020-03-05 + +### Added + +* Added support for `time` 0.2. + +## v0.1.0 - 2019-12-23 + +### Changed + +* `Kind` is now a true non-exhaustive enum. + +### Removed + +* Removed `uuid` 0.7 support. + +### Added + +* Added a `Hash` implementation for `Type`. + +## v0.1.0-alpha.2 - 2019-11-27 + +### Changed + +* Upgraded `bytes` to 0.5. +* Upgraded `uuid` to 0.8. + +## v0.1.0-alpha.1 - 2019-10-14 + +Initial release diff --git a/postgres-types/Cargo.toml b/postgres-types/Cargo.toml new file mode 100644 index 000000000..d6527f3b9 --- /dev/null +++ b/postgres-types/Cargo.toml @@ -0,0 +1,61 @@ +[package] +name = "postgres-types" +version = "0.2.9" +authors = ["Steven Fackler "] +edition = "2018" +license = "MIT OR Apache-2.0" +description = "Conversions between Rust and Postgres values" +repository = "https://github.com/sfackler/rust-postgres" +readme = "../README.md" +keywords = ["database", "postgres", "postgresql", "sql"] +categories = ["database"] + +[features] +derive = ["postgres-derive"] +array-impls = ["array-init"] +js = ["postgres-protocol/js"] +with-bit-vec-0_6 = ["bit-vec-06"] +with-cidr-0_2 = ["cidr-02"] +with-cidr-0_3 = ["cidr-03"] +with-chrono-0_4 = ["chrono-04"] +with-eui48-0_4 = ["eui48-04"] +with-eui48-1 = ["eui48-1"] +with-geo-types-0_6 = ["geo-types-06"] +with-geo-types-0_7 = ["geo-types-0_7"] +with-jiff-0_1 = ["jiff-01"] +with-jiff-0_2 = ["jiff-02"] +with-serde_json-1 = ["serde-1", "serde_json-1"] +with-smol_str-01 = ["smol_str-01"] +with-uuid-0_8 = ["uuid-08"] +with-uuid-1 = ["uuid-1"] +with-time-0_2 = ["time-02"] +with-time-0_3 = ["time-03"] + +[dependencies] +bytes = "1.0" +fallible-iterator = "0.2" +postgres-protocol = { version = "0.6.8", path = "../postgres-protocol" } +postgres-derive = { version = "0.4.6", optional = true, path = "../postgres-derive" } + +array-init = { version = "2", optional = true } +bit-vec-06 = { version = "0.6", package = "bit-vec", optional = true } +chrono-04 = { version = "0.4.16", package = "chrono", default-features = false, features = [ + "clock", +], optional = true } +cidr-02 = { version = "0.2", package = "cidr", optional = true } +cidr-03 = { version = "0.3", package = "cidr", optional = true } +# eui48-04 will stop compiling and support will be removed +# See https://github.com/sfackler/rust-postgres/issues/1073 +eui48-04 = { version = "0.4", package = "eui48", optional = true } +eui48-1 = { version = "1.0", package = "eui48", optional = true, default-features = false } +geo-types-06 = { version = "0.6", package = "geo-types", optional = true } +geo-types-0_7 = { version = "0.7", package = "geo-types", optional = true } +jiff-01 = { version = "0.1", package = "jiff", optional = true } +jiff-02 = { version = "0.2", package = "jiff", optional = true } +serde-1 = { version = "1.0", package = "serde", optional = true } +serde_json-1 = { version = "1.0", package = "serde_json", optional = true } +uuid-08 = { version = "0.8", package = "uuid", optional = true } +uuid-1 = { version = "1.0", package = "uuid", optional = true } +time-02 = { version = "0.2", package = "time", optional = true } +time-03 = { version = "0.3", package = "time", default-features = false, optional = true } +smol_str-01 = { version = "0.1.23", package = "smol_str", default-features = false, optional = true } diff --git a/postgres-types/LICENSE-APACHE b/postgres-types/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/postgres-types/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/postgres-types/LICENSE-MIT b/postgres-types/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/postgres-types/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/postgres-shared/src/types/bit_vec.rs b/postgres-types/src/bit_vec_06.rs similarity index 65% rename from postgres-shared/src/types/bit_vec.rs rename to postgres-types/src/bit_vec_06.rs index cd3f4743c..322472c6f 100644 --- a/postgres-shared/src/types/bit_vec.rs +++ b/postgres-types/src/bit_vec_06.rs @@ -1,13 +1,12 @@ -extern crate bit_vec; - -use self::bit_vec::BitVec; +use bit_vec_06::BitVec; +use bytes::BytesMut; use postgres_protocol::types; use std::error::Error; -use types::{FromSql, IsNull, ToSql, Type}; +use crate::{FromSql, IsNull, ToSql, Type}; impl<'a> FromSql<'a> for BitVec { - fn from_sql(_: &Type, raw: &[u8]) -> Result> { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { let varbit = types::varbit_from_sql(raw)?; let mut bitvec = BitVec::from_bytes(varbit.bytes()); while bitvec.len() > varbit.len() { @@ -21,7 +20,7 @@ impl<'a> FromSql<'a> for BitVec { } impl ToSql for BitVec { - fn to_sql(&self, _: &Type, out: &mut Vec) -> Result> { + fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result> { types::varbit_to_sql(self.len(), self.to_bytes().into_iter(), out)?; Ok(IsNull::No) } diff --git a/postgres-shared/src/types/chrono.rs b/postgres-types/src/chrono_04.rs similarity index 57% rename from postgres-shared/src/types/chrono.rs rename to postgres-types/src/chrono_04.rs index 0f305ea1e..d599bde02 100644 --- a/postgres-shared/src/types/chrono.rs +++ b/postgres-types/src/chrono_04.rs @@ -1,27 +1,32 @@ -extern crate chrono; - -use self::chrono::{DateTime, Duration, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, - Utc}; +use bytes::BytesMut; +use chrono_04::{ + DateTime, Duration, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc, +}; use postgres_protocol::types; use std::error::Error; -use types::{FromSql, IsNull, ToSql, Type}; +use crate::{FromSql, IsNull, ToSql, Type}; fn base() -> NaiveDateTime { - NaiveDate::from_ymd(2000, 1, 1).and_hms(0, 0, 0) + NaiveDate::from_ymd_opt(2000, 1, 1) + .unwrap() + .and_hms_opt(0, 0, 0) + .unwrap() } impl<'a> FromSql<'a> for NaiveDateTime { - fn from_sql(_: &Type, raw: &[u8]) -> Result> { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { let t = types::timestamp_from_sql(raw)?; - Ok(base() + Duration::microseconds(t)) + base() + .checked_add_signed(Duration::microseconds(t)) + .ok_or_else(|| "value too large to decode".into()) } accepts!(TIMESTAMP); } impl ToSql for NaiveDateTime { - fn to_sql(&self, _: &Type, w: &mut Vec) -> Result> { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { let time = match self.signed_duration_since(base()).num_microseconds() { Some(time) => time, None => return Err("value too large to transmit".into()), @@ -35,16 +40,20 @@ impl ToSql for NaiveDateTime { } impl<'a> FromSql<'a> for DateTime { - fn from_sql(type_: &Type, raw: &[u8]) -> Result, Box> { + fn from_sql(type_: &Type, raw: &[u8]) -> Result, Box> { let naive = NaiveDateTime::from_sql(type_, raw)?; - Ok(DateTime::from_utc(naive, Utc)) + Ok(Utc.from_utc_datetime(&naive)) } accepts!(TIMESTAMPTZ); } impl ToSql for DateTime { - fn to_sql(&self, type_: &Type, w: &mut Vec) -> Result> { + fn to_sql( + &self, + type_: &Type, + w: &mut BytesMut, + ) -> Result> { self.naive_utc().to_sql(type_, w) } @@ -53,7 +62,7 @@ impl ToSql for DateTime { } impl<'a> FromSql<'a> for DateTime { - fn from_sql(type_: &Type, raw: &[u8]) -> Result, Box> { + fn from_sql(type_: &Type, raw: &[u8]) -> Result, Box> { let utc = DateTime::::from_sql(type_, raw)?; Ok(utc.with_timezone(&Local)) } @@ -62,7 +71,11 @@ impl<'a> FromSql<'a> for DateTime { } impl ToSql for DateTime { - fn to_sql(&self, type_: &Type, w: &mut Vec) -> Result> { + fn to_sql( + &self, + type_: &Type, + w: &mut BytesMut, + ) -> Result> { self.with_timezone(&Utc).to_sql(type_, w) } @@ -74,16 +87,20 @@ impl<'a> FromSql<'a> for DateTime { fn from_sql( type_: &Type, raw: &[u8], - ) -> Result, Box> { + ) -> Result, Box> { let utc = DateTime::::from_sql(type_, raw)?; - Ok(utc.with_timezone(&FixedOffset::east(0))) + Ok(utc.with_timezone(&FixedOffset::east_opt(0).unwrap())) } accepts!(TIMESTAMPTZ); } impl ToSql for DateTime { - fn to_sql(&self, type_: &Type, w: &mut Vec) -> Result> { + fn to_sql( + &self, + type_: &Type, + w: &mut BytesMut, + ) -> Result> { self.with_timezone(&Utc).to_sql(type_, w) } @@ -92,18 +109,21 @@ impl ToSql for DateTime { } impl<'a> FromSql<'a> for NaiveDate { - fn from_sql(_: &Type, raw: &[u8]) -> Result> { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { let jd = types::date_from_sql(raw)?; - Ok(base().date() + Duration::days(jd as i64)) + base() + .date() + .checked_add_signed(Duration::days(i64::from(jd))) + .ok_or_else(|| "value too large to decode".into()) } accepts!(DATE); } impl ToSql for NaiveDate { - fn to_sql(&self, _: &Type, w: &mut Vec) -> Result> { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { let jd = self.signed_duration_since(base().date()).num_days(); - if jd > i32::max_value() as i64 || jd < i32::min_value() as i64 { + if jd > i64::from(i32::max_value()) || jd < i64::from(i32::min_value()) { return Err("value too large to transmit".into()); } @@ -116,17 +136,17 @@ impl ToSql for NaiveDate { } impl<'a> FromSql<'a> for NaiveTime { - fn from_sql(_: &Type, raw: &[u8]) -> Result> { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { let usec = types::time_from_sql(raw)?; - Ok(NaiveTime::from_hms(0, 0, 0) + Duration::microseconds(usec)) + Ok(NaiveTime::from_hms_opt(0, 0, 0).unwrap() + Duration::microseconds(usec)) } accepts!(TIME); } impl ToSql for NaiveTime { - fn to_sql(&self, _: &Type, w: &mut Vec) -> Result> { - let delta = self.signed_duration_since(NaiveTime::from_hms(0, 0, 0)); + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let delta = self.signed_duration_since(NaiveTime::from_hms_opt(0, 0, 0).unwrap()); let time = match delta.num_microseconds() { Some(time) => time, None => return Err("value too large to transmit".into()), diff --git a/postgres-types/src/cidr_02.rs b/postgres-types/src/cidr_02.rs new file mode 100644 index 000000000..2de952c3c --- /dev/null +++ b/postgres-types/src/cidr_02.rs @@ -0,0 +1,44 @@ +use bytes::BytesMut; +use cidr_02::{IpCidr, IpInet}; +use postgres_protocol::types; +use std::error::Error; + +use crate::{FromSql, IsNull, ToSql, Type}; + +impl<'a> FromSql<'a> for IpCidr { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let inet = types::inet_from_sql(raw)?; + Ok(IpCidr::new(inet.addr(), inet.netmask())?) + } + + accepts!(CIDR); +} + +impl ToSql for IpCidr { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + types::inet_to_sql(self.first_address(), self.network_length(), w); + Ok(IsNull::No) + } + + accepts!(CIDR); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for IpInet { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let inet = types::inet_from_sql(raw)?; + Ok(IpInet::new(inet.addr(), inet.netmask())?) + } + + accepts!(INET); +} + +impl ToSql for IpInet { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + types::inet_to_sql(self.address(), self.network_length(), w); + Ok(IsNull::No) + } + + accepts!(INET); + to_sql_checked!(); +} diff --git a/postgres-types/src/cidr_03.rs b/postgres-types/src/cidr_03.rs new file mode 100644 index 000000000..6a0178711 --- /dev/null +++ b/postgres-types/src/cidr_03.rs @@ -0,0 +1,44 @@ +use bytes::BytesMut; +use cidr_03::{IpCidr, IpInet}; +use postgres_protocol::types; +use std::error::Error; + +use crate::{FromSql, IsNull, ToSql, Type}; + +impl<'a> FromSql<'a> for IpCidr { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let inet = types::inet_from_sql(raw)?; + Ok(IpCidr::new(inet.addr(), inet.netmask())?) + } + + accepts!(CIDR); +} + +impl ToSql for IpCidr { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + types::inet_to_sql(self.first_address(), self.network_length(), w); + Ok(IsNull::No) + } + + accepts!(CIDR); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for IpInet { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let inet = types::inet_from_sql(raw)?; + Ok(IpInet::new(inet.addr(), inet.netmask())?) + } + + accepts!(INET); +} + +impl ToSql for IpInet { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + types::inet_to_sql(self.address(), self.network_length(), w); + Ok(IsNull::No) + } + + accepts!(INET); + to_sql_checked!(); +} diff --git a/postgres-types/src/eui48_04.rs b/postgres-types/src/eui48_04.rs new file mode 100644 index 000000000..45df89a84 --- /dev/null +++ b/postgres-types/src/eui48_04.rs @@ -0,0 +1,27 @@ +use bytes::BytesMut; +use eui48_04::MacAddress; +use postgres_protocol::types; +use std::error::Error; + +use crate::{FromSql, IsNull, ToSql, Type}; + +impl<'a> FromSql<'a> for MacAddress { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let bytes = types::macaddr_from_sql(raw)?; + Ok(MacAddress::new(bytes)) + } + + accepts!(MACADDR); +} + +impl ToSql for MacAddress { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let mut bytes = [0; 6]; + bytes.copy_from_slice(self.as_bytes()); + types::macaddr_to_sql(bytes, w); + Ok(IsNull::No) + } + + accepts!(MACADDR); + to_sql_checked!(); +} diff --git a/postgres-shared/src/types/eui48.rs b/postgres-types/src/eui48_1.rs similarity index 69% rename from postgres-shared/src/types/eui48.rs rename to postgres-types/src/eui48_1.rs index a4e1bb6b9..4c35e63ce 100644 --- a/postgres-shared/src/types/eui48.rs +++ b/postgres-types/src/eui48_1.rs @@ -1,13 +1,12 @@ -extern crate eui48; - -use self::eui48::MacAddress; +use bytes::BytesMut; +use eui48_1::MacAddress; use postgres_protocol::types; use std::error::Error; -use types::{FromSql, IsNull, ToSql, Type}; +use crate::{FromSql, IsNull, ToSql, Type}; impl<'a> FromSql<'a> for MacAddress { - fn from_sql(_: &Type, raw: &[u8]) -> Result> { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { let bytes = types::macaddr_from_sql(raw)?; Ok(MacAddress::new(bytes)) } @@ -16,7 +15,7 @@ impl<'a> FromSql<'a> for MacAddress { } impl ToSql for MacAddress { - fn to_sql(&self, _: &Type, w: &mut Vec) -> Result> { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { let mut bytes = [0; 6]; bytes.copy_from_slice(self.as_bytes()); types::macaddr_to_sql(bytes, w); diff --git a/postgres-types/src/geo_types_06.rs b/postgres-types/src/geo_types_06.rs new file mode 100644 index 000000000..0f0b14fd9 --- /dev/null +++ b/postgres-types/src/geo_types_06.rs @@ -0,0 +1,72 @@ +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use geo_types_06::{Coordinate, LineString, Point, Rect}; +use postgres_protocol::types; +use std::error::Error; + +use crate::{FromSql, IsNull, ToSql, Type}; + +impl<'a> FromSql<'a> for Point { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let point = types::point_from_sql(raw)?; + Ok(Point::new(point.x(), point.y())) + } + + accepts!(POINT); +} + +impl ToSql for Point { + fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result> { + types::point_to_sql(self.x(), self.y(), out); + Ok(IsNull::No) + } + + accepts!(POINT); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Rect { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let rect = types::box_from_sql(raw)?; + Ok(Rect::new( + (rect.lower_left().x(), rect.lower_left().y()), + (rect.upper_right().x(), rect.upper_right().y()), + )) + } + + accepts!(BOX); +} + +impl ToSql for Rect { + fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result> { + types::box_to_sql(self.min().x, self.min().y, self.max().x, self.max().y, out); + Ok(IsNull::No) + } + + accepts!(BOX); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for LineString { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let path = types::path_from_sql(raw)?; + let points = path + .points() + .map(|p| Ok(Coordinate { x: p.x(), y: p.y() })) + .collect()?; + Ok(LineString(points)) + } + + accepts!(PATH); +} + +impl ToSql for LineString { + fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result> { + let closed = false; // always encode an open path from LineString + types::path_to_sql(closed, self.0.iter().map(|p| (p.x, p.y)), out)?; + Ok(IsNull::No) + } + + accepts!(PATH); + to_sql_checked!(); +} diff --git a/postgres-types/src/geo_types_07.rs b/postgres-types/src/geo_types_07.rs new file mode 100644 index 000000000..bf7fa5601 --- /dev/null +++ b/postgres-types/src/geo_types_07.rs @@ -0,0 +1,72 @@ +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use geo_types_0_7::{Coord, LineString, Point, Rect}; +use postgres_protocol::types; +use std::error::Error; + +use crate::{FromSql, IsNull, ToSql, Type}; + +impl<'a> FromSql<'a> for Point { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let point = types::point_from_sql(raw)?; + Ok(Point::new(point.x(), point.y())) + } + + accepts!(POINT); +} + +impl ToSql for Point { + fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result> { + types::point_to_sql(self.x(), self.y(), out); + Ok(IsNull::No) + } + + accepts!(POINT); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Rect { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let rect = types::box_from_sql(raw)?; + Ok(Rect::new( + (rect.lower_left().x(), rect.lower_left().y()), + (rect.upper_right().x(), rect.upper_right().y()), + )) + } + + accepts!(BOX); +} + +impl ToSql for Rect { + fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result> { + types::box_to_sql(self.min().x, self.min().y, self.max().x, self.max().y, out); + Ok(IsNull::No) + } + + accepts!(BOX); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for LineString { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let path = types::path_from_sql(raw)?; + let points = path + .points() + .map(|p| Ok(Coord { x: p.x(), y: p.y() })) + .collect()?; + Ok(LineString(points)) + } + + accepts!(PATH); +} + +impl ToSql for LineString { + fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result> { + let closed = false; // always encode an open path from LineString + types::path_to_sql(closed, self.0.iter().map(|p| (p.x, p.y)), out)?; + Ok(IsNull::No) + } + + accepts!(PATH); + to_sql_checked!(); +} diff --git a/postgres-types/src/jiff_01.rs b/postgres-types/src/jiff_01.rs new file mode 100644 index 000000000..d3215c0e6 --- /dev/null +++ b/postgres-types/src/jiff_01.rs @@ -0,0 +1,141 @@ +use bytes::BytesMut; +use jiff_01::{ + civil::{Date, DateTime, Time}, + Span, SpanRound, Timestamp, Unit, +}; +use postgres_protocol::types; +use std::error::Error; + +use crate::{FromSql, IsNull, ToSql, Type}; + +const fn base() -> DateTime { + DateTime::constant(2000, 1, 1, 0, 0, 0, 0) +} + +/// The number of seconds from the Unix epoch to 2000-01-01 00:00:00 UTC. +const PG_EPOCH: i64 = 946684800; + +fn base_ts() -> Timestamp { + Timestamp::new(PG_EPOCH, 0).unwrap() +} + +fn round_us<'a>() -> SpanRound<'a> { + SpanRound::new().largest(Unit::Microsecond) +} + +fn decode_err(_e: E) -> Box +where + E: Error, +{ + "value too large to decode".into() +} + +fn transmit_err(_e: E) -> Box +where + E: Error, +{ + "value too large to transmit".into() +} + +impl<'a> FromSql<'a> for DateTime { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let v = types::timestamp_from_sql(raw)?; + Span::new() + .try_microseconds(v) + .and_then(|s| base().checked_add(s)) + .map_err(decode_err) + } + + accepts!(TIMESTAMP); +} + +impl ToSql for DateTime { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let v = self + .since(base()) + .and_then(|s| s.round(round_us())) + .map_err(transmit_err)? + .get_microseconds(); + types::timestamp_to_sql(v, w); + Ok(IsNull::No) + } + + accepts!(TIMESTAMP); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Timestamp { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let v = types::timestamp_from_sql(raw)?; + Span::new() + .try_microseconds(v) + .and_then(|s| base_ts().checked_add(s)) + .map_err(decode_err) + } + + accepts!(TIMESTAMPTZ); +} + +impl ToSql for Timestamp { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let v = self + .since(base_ts()) + .and_then(|s| s.round(round_us())) + .map_err(transmit_err)? + .get_microseconds(); + types::timestamp_to_sql(v, w); + Ok(IsNull::No) + } + + accepts!(TIMESTAMPTZ); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Date { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let v = types::date_from_sql(raw)?; + Span::new() + .try_days(v) + .and_then(|s| base().date().checked_add(s)) + .map_err(decode_err) + } + accepts!(DATE); +} + +impl ToSql for Date { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let v = self.since(base().date()).map_err(transmit_err)?.get_days(); + types::date_to_sql(v, w); + Ok(IsNull::No) + } + + accepts!(DATE); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Time { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let v = types::time_from_sql(raw)?; + Span::new() + .try_microseconds(v) + .and_then(|s| Time::midnight().checked_add(s)) + .map_err(decode_err) + } + + accepts!(TIME); +} + +impl ToSql for Time { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let v = self + .since(Time::midnight()) + .and_then(|s| s.round(round_us())) + .map_err(transmit_err)? + .get_microseconds(); + types::time_to_sql(v, w); + Ok(IsNull::No) + } + + accepts!(TIME); + to_sql_checked!(); +} diff --git a/postgres-types/src/jiff_02.rs b/postgres-types/src/jiff_02.rs new file mode 100644 index 000000000..a736dd3eb --- /dev/null +++ b/postgres-types/src/jiff_02.rs @@ -0,0 +1,141 @@ +use bytes::BytesMut; +use jiff_02::{ + civil::{Date, DateTime, Time}, + Span, SpanRound, Timestamp, Unit, +}; +use postgres_protocol::types; +use std::error::Error; + +use crate::{FromSql, IsNull, ToSql, Type}; + +const fn base() -> DateTime { + DateTime::constant(2000, 1, 1, 0, 0, 0, 0) +} + +/// The number of seconds from the Unix epoch to 2000-01-01 00:00:00 UTC. +const PG_EPOCH: i64 = 946684800; + +fn base_ts() -> Timestamp { + Timestamp::new(PG_EPOCH, 0).unwrap() +} + +fn round_us<'a>() -> SpanRound<'a> { + SpanRound::new().largest(Unit::Microsecond) +} + +fn decode_err(_e: E) -> Box +where + E: Error, +{ + "value too large to decode".into() +} + +fn transmit_err(_e: E) -> Box +where + E: Error, +{ + "value too large to transmit".into() +} + +impl<'a> FromSql<'a> for DateTime { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let v = types::timestamp_from_sql(raw)?; + Span::new() + .try_microseconds(v) + .and_then(|s| base().checked_add(s)) + .map_err(decode_err) + } + + accepts!(TIMESTAMP); +} + +impl ToSql for DateTime { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let v = self + .since(base()) + .and_then(|s| s.round(round_us().relative(base()))) + .map_err(transmit_err)? + .get_microseconds(); + types::timestamp_to_sql(v, w); + Ok(IsNull::No) + } + + accepts!(TIMESTAMP); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Timestamp { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let v = types::timestamp_from_sql(raw)?; + Span::new() + .try_microseconds(v) + .and_then(|s| base_ts().checked_add(s)) + .map_err(decode_err) + } + + accepts!(TIMESTAMPTZ); +} + +impl ToSql for Timestamp { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let v = self + .since(base_ts()) + .and_then(|s| s.round(round_us())) + .map_err(transmit_err)? + .get_microseconds(); + types::timestamp_to_sql(v, w); + Ok(IsNull::No) + } + + accepts!(TIMESTAMPTZ); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Date { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let v = types::date_from_sql(raw)?; + Span::new() + .try_days(v) + .and_then(|s| base().date().checked_add(s)) + .map_err(decode_err) + } + accepts!(DATE); +} + +impl ToSql for Date { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let v = self.since(base().date()).map_err(transmit_err)?.get_days(); + types::date_to_sql(v, w); + Ok(IsNull::No) + } + + accepts!(DATE); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Time { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let v = types::time_from_sql(raw)?; + Span::new() + .try_microseconds(v) + .and_then(|s| Time::midnight().checked_add(s)) + .map_err(decode_err) + } + + accepts!(TIME); +} + +impl ToSql for Time { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let v = self + .since(Time::midnight()) + .and_then(|s| s.round(round_us())) + .map_err(transmit_err)? + .get_microseconds(); + types::time_to_sql(v, w); + Ok(IsNull::No) + } + + accepts!(TIME); + to_sql_checked!(); +} diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs new file mode 100644 index 000000000..51137b6b4 --- /dev/null +++ b/postgres-types/src/lib.rs @@ -0,0 +1,1307 @@ +//! Conversions to and from Postgres types. +//! +//! This crate is used by the `tokio-postgres` and `postgres` crates. You normally don't need to depend directly on it +//! unless you want to define your own `ToSql` or `FromSql` definitions. +//! +//! # Derive +//! +//! If the `derive` cargo feature is enabled, you can derive `ToSql` and `FromSql` implementations for custom Postgres +//! types. Explicitly, modify your `Cargo.toml` file to include the following: +//! +//! ```toml +//! [dependencies] +//! postgres-types = { version = "0.X.X", features = ["derive"] } +//! ``` +//! +//! ## Enums +//! +//! Postgres enums correspond to C-like enums in Rust: +//! +//! ```sql +//! CREATE TYPE "Mood" AS ENUM ( +//! 'Sad', +//! 'Ok', +//! 'Happy' +//! ); +//! ``` +//! +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! enum Mood { +//! Sad, +//! Ok, +//! Happy, +//! } +//! ``` +//! +//! ## Domains +//! +//! Postgres domains correspond to tuple structs with one member in Rust: +//! +//! ```sql +//! CREATE DOMAIN "SessionId" AS BYTEA CHECK(octet_length(VALUE) = 16); +//! ``` +//! +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! struct SessionId(Vec); +//! ``` +//! +//! ## Newtypes +//! +//! The `#[postgres(transparent)]` attribute can be used on a single-field tuple struct to create a +//! Rust-only wrapper type that will use the [`ToSql`] & [`FromSql`] implementation of the inner +//! value : +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! #[postgres(transparent)] +//! struct UserId(i32); +//! ``` +//! +//! ## Composites +//! +//! Postgres composite types correspond to structs in Rust: +//! +//! ```sql +//! CREATE TYPE "InventoryItem" AS ( +//! name TEXT, +//! supplier_id INT, +//! price DOUBLE PRECISION +//! ); +//! ``` +//! +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! struct InventoryItem { +//! name: String, +//! supplier_id: i32, +//! price: Option, +//! } +//! ``` +//! +//! ## Naming +//! +//! The derived implementations will enforce exact matches of type, field, and variant names between the Rust and +//! Postgres types. The `#[postgres(name = "...")]` attribute can be used to adjust the name on a type, variant, or +//! field: +//! +//! ```sql +//! CREATE TYPE mood AS ENUM ( +//! 'sad', +//! 'ok', +//! 'happy' +//! ); +//! ``` +//! +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! #[postgres(name = "mood")] +//! enum Mood { +//! #[postgres(name = "sad")] +//! Sad, +//! #[postgres(name = "ok")] +//! Ok, +//! #[postgres(name = "happy")] +//! Happy, +//! } +//! ``` +//! +//! Alternatively, the `#[postgres(rename_all = "...")]` attribute can be used to rename all fields or variants +//! with the chosen casing convention. This will not affect the struct or enum's type name. Note that +//! `#[postgres(name = "...")]` takes precendence when used in conjunction with `#[postgres(rename_all = "...")]`: +//! +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! #[postgres(name = "mood", rename_all = "snake_case")] +//! enum Mood { +//! #[postgres(name = "ok")] +//! Ok, // ok +//! VeryHappy, // very_happy +//! } +//! ``` +//! +//! The following case conventions are supported: +//! - `"lowercase"` +//! - `"UPPERCASE"` +//! - `"PascalCase"` +//! - `"camelCase"` +//! - `"snake_case"` +//! - `"SCREAMING_SNAKE_CASE"` +//! - `"kebab-case"` +//! - `"SCREAMING-KEBAB-CASE"` +//! - `"Train-Case"` +//! +//! ## Allowing Enum Mismatches +//! +//! By default the generated implementation of [`ToSql`] & [`FromSql`] for enums will require an exact match of the enum +//! variants between the Rust and Postgres types. +//! To allow mismatches, the `#[postgres(allow_mismatch)]` attribute can be used on the enum definition: +//! +//! ```sql +//! CREATE TYPE mood AS ENUM ( +//! 'Sad', +//! 'Ok', +//! 'Happy' +//! ); +//! ``` +//! +//! ```rust +//! # #[cfg(feature = "derive")] +//! use postgres_types::{ToSql, FromSql}; +//! +//! # #[cfg(feature = "derive")] +//! #[derive(Debug, ToSql, FromSql)] +//! #[postgres(allow_mismatch)] +//! enum Mood { +//! Happy, +//! Meh, +//! } +//! ``` +#![warn(clippy::all, rust_2018_idioms, missing_docs)] +use fallible_iterator::FallibleIterator; +use postgres_protocol::types::{self, ArrayDimension}; +use std::any::type_name; +use std::borrow::Cow; +use std::collections::HashMap; +use std::error::Error; +use std::fmt; +use std::hash::BuildHasher; +use std::net::IpAddr; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +#[cfg(feature = "derive")] +pub use postgres_derive::{FromSql, ToSql}; + +#[cfg(feature = "with-serde_json-1")] +pub use crate::serde_json_1::Json; +use crate::type_gen::{Inner, Other}; + +#[doc(inline)] +pub use postgres_protocol::Oid; + +#[doc(inline)] +pub use pg_lsn::PgLsn; + +pub use crate::special::{Date, Timestamp}; +use bytes::BytesMut; + +// Number of seconds from 1970-01-01 to 2000-01-01 +const TIME_SEC_CONVERSION: u64 = 946_684_800; +const USEC_PER_SEC: u64 = 1_000_000; +const NSEC_PER_USEC: u64 = 1_000; + +/// Generates a simple implementation of `ToSql::accepts` which accepts the +/// types passed to it. +#[macro_export] +macro_rules! accepts { + ($($expected:ident),+) => ( + fn accepts(ty: &$crate::Type) -> bool { + matches!(*ty, $($crate::Type::$expected)|+) + } + ) +} + +/// Generates an implementation of `ToSql::to_sql_checked`. +/// +/// All `ToSql` implementations should use this macro. +#[macro_export] +macro_rules! to_sql_checked { + () => { + fn to_sql_checked( + &self, + ty: &$crate::Type, + out: &mut $crate::private::BytesMut, + ) -> ::std::result::Result< + $crate::IsNull, + Box, + > { + $crate::__to_sql_checked(self, ty, out) + } + }; +} + +// WARNING: this function is not considered part of this crate's public API. +// It is subject to change at any time. +#[doc(hidden)] +pub fn __to_sql_checked( + v: &T, + ty: &Type, + out: &mut BytesMut, +) -> Result> +where + T: ToSql, +{ + if !T::accepts(ty) { + return Err(Box::new(WrongType::new::(ty.clone()))); + } + v.to_sql(ty, out) +} + +#[cfg(feature = "with-bit-vec-0_6")] +mod bit_vec_06; +#[cfg(feature = "with-chrono-0_4")] +mod chrono_04; +#[cfg(feature = "with-cidr-0_2")] +mod cidr_02; +#[cfg(feature = "with-cidr-0_3")] +mod cidr_03; +#[cfg(feature = "with-eui48-0_4")] +mod eui48_04; +#[cfg(feature = "with-eui48-1")] +mod eui48_1; +#[cfg(feature = "with-geo-types-0_6")] +mod geo_types_06; +#[cfg(feature = "with-geo-types-0_7")] +mod geo_types_07; +#[cfg(feature = "with-jiff-0_1")] +mod jiff_01; +#[cfg(feature = "with-jiff-0_2")] +mod jiff_02; +#[cfg(feature = "with-serde_json-1")] +mod serde_json_1; +#[cfg(feature = "with-smol_str-01")] +mod smol_str_01; +#[cfg(feature = "with-time-0_2")] +mod time_02; +#[cfg(feature = "with-time-0_3")] +mod time_03; +#[cfg(feature = "with-uuid-0_8")] +mod uuid_08; +#[cfg(feature = "with-uuid-1")] +mod uuid_1; + +// The time::{date, time} macros produce compile errors if the crate package is renamed. +#[cfg(feature = "with-time-0_2")] +extern crate time_02 as time; + +mod pg_lsn; +#[doc(hidden)] +pub mod private; +mod special; +mod type_gen; + +/// A Postgres type. +#[derive(PartialEq, Eq, Clone, Hash)] +pub struct Type(Inner); + +impl fmt::Debug for Type { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.0, fmt) + } +} + +impl fmt::Display for Type { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.schema() { + "public" | "pg_catalog" => {} + schema => write!(fmt, "{}.", schema)?, + } + fmt.write_str(self.name()) + } +} + +impl Type { + /// Creates a new `Type`. + pub fn new(name: String, oid: Oid, kind: Kind, schema: String) -> Type { + Type(Inner::Other(Arc::new(Other { + name, + oid, + kind, + schema, + }))) + } + + /// Returns the `Type` corresponding to the provided `Oid` if it + /// corresponds to a built-in type. + pub fn from_oid(oid: Oid) -> Option { + Inner::from_oid(oid).map(Type) + } + + /// Returns the OID of the `Type`. + pub fn oid(&self) -> Oid { + self.0.oid() + } + + /// Returns the kind of this type. + pub fn kind(&self) -> &Kind { + self.0.kind() + } + + /// Returns the schema of this type. + pub fn schema(&self) -> &str { + match self.0 { + Inner::Other(ref u) => &u.schema, + _ => "pg_catalog", + } + } + + /// Returns the name of this type. + pub fn name(&self) -> &str { + self.0.name() + } +} + +/// Represents the kind of a Postgres type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum Kind { + /// A simple type like `VARCHAR` or `INTEGER`. + Simple, + /// An enumerated type along with its variants. + Enum(Vec), + /// A pseudo-type. + Pseudo, + /// An array type along with the type of its elements. + Array(Type), + /// A range type along with the type of its elements. + Range(Type), + /// A multirange type along with the type of its elements. + Multirange(Type), + /// A domain type along with its underlying type. + Domain(Type), + /// A composite type along with information about its fields. + Composite(Vec), +} + +/// Information about a field of a composite type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Field { + name: String, + type_: Type, +} + +impl Field { + /// Creates a new `Field`. + pub fn new(name: String, type_: Type) -> Field { + Field { name, type_ } + } + + /// Returns the name of the field. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the type of the field. + pub fn type_(&self) -> &Type { + &self.type_ + } +} + +/// An error indicating that a `NULL` Postgres value was passed to a `FromSql` +/// implementation that does not support `NULL` values. +#[derive(Debug, Clone, Copy)] +pub struct WasNull; + +impl fmt::Display for WasNull { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("a Postgres value was `NULL`") + } +} + +impl Error for WasNull {} + +/// An error indicating that a conversion was attempted between incompatible +/// Rust and Postgres types. +#[derive(Debug)] +pub struct WrongType { + postgres: Type, + rust: &'static str, +} + +impl fmt::Display for WrongType { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot convert between the Rust type `{}` and the Postgres type `{}`", + self.rust, self.postgres, + ) + } +} + +impl Error for WrongType {} + +impl WrongType { + /// Creates a new `WrongType` error. + pub fn new(ty: Type) -> WrongType { + WrongType { + postgres: ty, + rust: type_name::(), + } + } +} + +/// A trait for types that can be created from a Postgres value. +/// +/// # Types +/// +/// The following implementations are provided by this crate, along with the +/// corresponding Postgres types: +/// +/// | Rust type | Postgres type(s) | +/// |-----------------------------------|-----------------------------------------------| +/// | `bool` | BOOL | +/// | `i8` | "char" | +/// | `i16` | SMALLINT, SMALLSERIAL | +/// | `i32` | INT, SERIAL | +/// | `u32` | OID | +/// | `i64` | BIGINT, BIGSERIAL | +/// | `f32` | REAL | +/// | `f64` | DOUBLE PRECISION | +/// | `&str`/`String` | VARCHAR, CHAR(n), TEXT, CITEXT, NAME, UNKNOWN | +/// | | LTREE, LQUERY, LTXTQUERY | +/// | `&[u8]`/`Vec` | BYTEA | +/// | `HashMap>` | HSTORE | +/// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | +/// | `IpAddr` | INET | +/// +/// In addition, some implementations are provided for types in third party +/// crates. These are disabled by default; to opt into one of these +/// implementations, activate the Cargo feature corresponding to the crate's +/// name prefixed by `with-`. For example, the `with-serde_json-1` feature enables +/// the implementation for the `serde_json::Value` type. +/// +/// | Rust type | Postgres type(s) | +/// |---------------------------------|-------------------------------------| +/// | `chrono::NaiveDateTime` | TIMESTAMP | +/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | +/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | +/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | +/// | `chrono::NaiveDate` | DATE | +/// | `chrono::NaiveTime` | TIME | +/// | `cidr::IpCidr` | CIDR | +/// | `cidr::IpInet` | INET | +/// | `time::PrimitiveDateTime` | TIMESTAMP | +/// | `time::OffsetDateTime` | TIMESTAMP WITH TIME ZONE | +/// | `time::Date` | DATE | +/// | `time::Time` | TIME | +/// | `jiff::civil::Date` | DATE | +/// | `jiff::civil::DateTime` | TIMESTAMP | +/// | `jiff::civil::Time` | TIME | +/// | `jiff::Timestamp` | TIMESTAMP WITH TIME ZONE | +/// | `eui48::MacAddress` | MACADDR | +/// | `geo_types::Point` | POINT | +/// | `geo_types::Rect` | BOX | +/// | `geo_types::LineString` | PATH | +/// | `serde_json::Value` | JSON, JSONB | +/// | `uuid::Uuid` | UUID | +/// | `bit_vec::BitVec` | BIT, VARBIT | +/// | `eui48::MacAddress` | MACADDR | +/// | `cidr::InetCidr` | CIDR | +/// | `cidr::InetAddr` | INET | +/// | `smol_str::SmolStr` | VARCHAR, CHAR(n), TEXT, CITEXT, | +/// | | NAME, UNKNOWN, LTREE, LQUERY, | +/// | | LTXTQUERY | +/// +/// # Nullability +/// +/// In addition to the types listed above, `FromSql` is implemented for +/// `Option` where `T` implements `FromSql`. An `Option` represents a +/// nullable Postgres value. +/// +/// # Arrays +/// +/// `FromSql` is implemented for `Vec`, `Box<[T]>` and `[T; N]` where `T` +/// implements `FromSql`, and corresponds to one-dimensional Postgres arrays. +/// +/// **Note:** the impl for arrays only exist when the Cargo feature `array-impls` +/// is enabled. +pub trait FromSql<'a>: Sized { + /// Creates a new value of this type from a buffer of data of the specified + /// Postgres `Type` in its binary format. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result>; + + /// Creates a new value of this type from a `NULL` SQL value. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + /// + /// The default implementation returns `Err(Box::new(WasNull))`. + #[allow(unused_variables)] + fn from_sql_null(ty: &Type) -> Result> { + Err(Box::new(WasNull)) + } + + /// A convenience function that delegates to `from_sql` and `from_sql_null` depending on the + /// value of `raw`. + fn from_sql_nullable( + ty: &Type, + raw: Option<&'a [u8]>, + ) -> Result> { + match raw { + Some(raw) => Self::from_sql(ty, raw), + None => Self::from_sql_null(ty), + } + } + + /// Determines if a value of this type can be created from the specified + /// Postgres `Type`. + fn accepts(ty: &Type) -> bool; +} + +/// A trait for types which can be created from a Postgres value without borrowing any data. +/// +/// This is primarily useful for trait bounds on functions. +pub trait FromSqlOwned: for<'a> FromSql<'a> {} + +impl FromSqlOwned for T where T: for<'a> FromSql<'a> {} + +impl<'a, T: FromSql<'a>> FromSql<'a> for Option { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { + ::from_sql(ty, raw).map(Some) + } + + fn from_sql_null(_: &Type) -> Result, Box> { + Ok(None) + } + + fn accepts(ty: &Type) -> bool { + ::accepts(ty) + } +} + +impl<'a, T: FromSql<'a>> FromSql<'a> for Vec { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { + let member_type = match *ty.kind() { + Kind::Array(ref member) => member, + _ => panic!("expected array type"), + }; + + let array = types::array_from_sql(raw)?; + if array.dimensions().count()? > 1 { + return Err("array contains too many dimensions".into()); + } + + array + .values() + .map(|v| T::from_sql_nullable(member_type, v)) + .collect() + } + + fn accepts(ty: &Type) -> bool { + match *ty.kind() { + Kind::Array(ref inner) => T::accepts(inner), + _ => false, + } + } +} + +#[cfg(feature = "array-impls")] +impl<'a, T: FromSql<'a>, const N: usize> FromSql<'a> for [T; N] { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + let member_type = match *ty.kind() { + Kind::Array(ref member) => member, + _ => panic!("expected array type"), + }; + + let array = types::array_from_sql(raw)?; + if array.dimensions().count()? > 1 { + return Err("array contains too many dimensions".into()); + } + + let mut values = array.values(); + let out = array_init::try_array_init(|i| { + let v = values + .next()? + .ok_or_else(|| -> Box { + format!("too few elements in array (expected {}, got {})", N, i).into() + })?; + T::from_sql_nullable(member_type, v) + })?; + if values.next()?.is_some() { + return Err(format!( + "excess elements in array (expected {}, got more than that)", + N, + ) + .into()); + } + + Ok(out) + } + + fn accepts(ty: &Type) -> bool { + match *ty.kind() { + Kind::Array(ref inner) => T::accepts(inner), + _ => false, + } + } +} + +impl<'a, T: FromSql<'a>> FromSql<'a> for Box<[T]> { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + Vec::::from_sql(ty, raw).map(Vec::into_boxed_slice) + } + + fn accepts(ty: &Type) -> bool { + Vec::::accepts(ty) + } +} + +impl<'a> FromSql<'a> for Vec { + fn from_sql(_: &Type, raw: &'a [u8]) -> Result, Box> { + Ok(types::bytea_from_sql(raw).to_owned()) + } + + accepts!(BYTEA); +} + +impl<'a> FromSql<'a> for &'a [u8] { + fn from_sql(_: &Type, raw: &'a [u8]) -> Result<&'a [u8], Box> { + Ok(types::bytea_from_sql(raw)) + } + + accepts!(BYTEA); +} + +impl<'a> FromSql<'a> for String { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + <&str as FromSql>::from_sql(ty, raw).map(ToString::to_string) + } + + fn accepts(ty: &Type) -> bool { + <&str as FromSql>::accepts(ty) + } +} + +impl<'a> FromSql<'a> for Box { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { + <&str as FromSql>::from_sql(ty, raw) + .map(ToString::to_string) + .map(String::into_boxed_str) + } + + fn accepts(ty: &Type) -> bool { + <&str as FromSql>::accepts(ty) + } +} + +impl<'a> FromSql<'a> for &'a str { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<&'a str, Box> { + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_from_sql(raw), + ref ty if ty.name() == "lquery" => types::lquery_from_sql(raw), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_from_sql(raw), + _ => types::text_from_sql(raw), + } + } + + fn accepts(ty: &Type) -> bool { + match *ty { + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } + _ => false, + } + } +} + +macro_rules! simple_from { + ($t:ty, $f:ident, $($expected:ident),+) => { + impl<'a> FromSql<'a> for $t { + fn from_sql(_: &Type, raw: &'a [u8]) -> Result<$t, Box> { + types::$f(raw) + } + + accepts!($($expected),+); + } + } +} + +simple_from!(bool, bool_from_sql, BOOL); +simple_from!(i8, char_from_sql, CHAR); +simple_from!(i16, int2_from_sql, INT2); +simple_from!(i32, int4_from_sql, INT4); +simple_from!(u32, oid_from_sql, OID); +simple_from!(i64, int8_from_sql, INT8); +simple_from!(f32, float4_from_sql, FLOAT4); +simple_from!(f64, float8_from_sql, FLOAT8); + +impl<'a, S> FromSql<'a> for HashMap, S> +where + S: Default + BuildHasher, +{ + fn from_sql( + _: &Type, + raw: &'a [u8], + ) -> Result, S>, Box> { + types::hstore_from_sql(raw)? + .map(|(k, v)| Ok((k.to_owned(), v.map(str::to_owned)))) + .collect() + } + + fn accepts(ty: &Type) -> bool { + ty.name() == "hstore" + } +} + +impl<'a> FromSql<'a> for SystemTime { + fn from_sql(_: &Type, raw: &'a [u8]) -> Result> { + let time = types::timestamp_from_sql(raw)?; + let epoch = UNIX_EPOCH + Duration::from_secs(TIME_SEC_CONVERSION); + + let negative = time < 0; + let time = time.unsigned_abs(); + + let secs = time / USEC_PER_SEC; + let nsec = (time % USEC_PER_SEC) * NSEC_PER_USEC; + let offset = Duration::new(secs, nsec as u32); + + let time = if negative { + epoch - offset + } else { + epoch + offset + }; + + Ok(time) + } + + accepts!(TIMESTAMP, TIMESTAMPTZ); +} + +impl<'a> FromSql<'a> for IpAddr { + fn from_sql(_: &Type, raw: &'a [u8]) -> Result> { + let inet = types::inet_from_sql(raw)?; + Ok(inet.addr()) + } + + accepts!(INET); +} + +/// An enum representing the nullability of a Postgres value. +pub enum IsNull { + /// The value is NULL. + Yes, + /// The value is not NULL. + No, +} + +/// A trait for types that can be converted into Postgres values. +/// +/// # Types +/// +/// The following implementations are provided by this crate, along with the +/// corresponding Postgres types: +/// +/// | Rust type | Postgres type(s) | +/// |-----------------------------------|--------------------------------------| +/// | `bool` | BOOL | +/// | `i8` | "char" | +/// | `i16` | SMALLINT, SMALLSERIAL | +/// | `i32` | INT, SERIAL | +/// | `u32` | OID | +/// | `i64` | BIGINT, BIGSERIAL | +/// | `f32` | REAL | +/// | `f64` | DOUBLE PRECISION | +/// | `&str`/`String` | VARCHAR, CHAR(n), TEXT, CITEXT, NAME | +/// | | LTREE, LQUERY, LTXTQUERY | +/// | `&[u8]`/`Vec`/`[u8; N]` | BYTEA | +/// | `HashMap>` | HSTORE | +/// | `SystemTime` | TIMESTAMP, TIMESTAMP WITH TIME ZONE | +/// | `IpAddr` | INET | +/// +/// In addition, some implementations are provided for types in third party +/// crates. These are disabled by default; to opt into one of these +/// implementations, activate the Cargo feature corresponding to the crate's +/// name prefixed by `with-`. For example, the `with-serde_json-1` feature enables +/// the implementation for the `serde_json::Value` type. +/// +/// | Rust type | Postgres type(s) | +/// |---------------------------------|-------------------------------------| +/// | `chrono::NaiveDateTime` | TIMESTAMP | +/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | +/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | +/// | `chrono::DateTime` | TIMESTAMP WITH TIME ZONE | +/// | `chrono::NaiveDate` | DATE | +/// | `chrono::NaiveTime` | TIME | +/// | `cidr::IpCidr` | CIDR | +/// | `cidr::IpInet` | INET | +/// | `time::PrimitiveDateTime` | TIMESTAMP | +/// | `time::OffsetDateTime` | TIMESTAMP WITH TIME ZONE | +/// | `time::Date` | DATE | +/// | `time::Time` | TIME | +/// | `eui48::MacAddress` | MACADDR | +/// | `geo_types::Point` | POINT | +/// | `geo_types::Rect` | BOX | +/// | `geo_types::LineString` | PATH | +/// | `serde_json::Value` | JSON, JSONB | +/// | `uuid::Uuid` | UUID | +/// | `bit_vec::BitVec` | BIT, VARBIT | +/// | `eui48::MacAddress` | MACADDR | +/// +/// # Nullability +/// +/// In addition to the types listed above, `ToSql` is implemented for +/// `Option` where `T` implements `ToSql`. An `Option` represents a +/// nullable Postgres value. +/// +/// # Arrays +/// +/// `ToSql` is implemented for `[u8; N]`, `Vec`, `&[T]`, `Box<[T]>` and `[T; N]` +/// where `T` implements `ToSql` and `N` is const usize, and corresponds to one-dimensional +/// Postgres arrays with an index offset of 1. +/// +/// **Note:** the impl for arrays only exist when the Cargo feature `array-impls` +/// is enabled. +pub trait ToSql: fmt::Debug { + /// Converts the value of `self` into the binary format of the specified + /// Postgres `Type`, appending it to `out`. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + /// + /// The return value indicates if this value should be represented as + /// `NULL`. If this is the case, implementations **must not** write + /// anything to `out`. + fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result> + where + Self: Sized; + + /// Determines if a value of this type can be converted to the specified + /// Postgres `Type`. + fn accepts(ty: &Type) -> bool + where + Self: Sized; + + /// An adaptor method used internally by Rust-Postgres. + /// + /// *All* implementations of this method should be generated by the + /// `to_sql_checked!()` macro. + fn to_sql_checked( + &self, + ty: &Type, + out: &mut BytesMut, + ) -> Result>; + + /// Specify the encode format + fn encode_format(&self, _ty: &Type) -> Format { + Format::Binary + } +} + +/// Supported Postgres message format types +/// +/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` +#[derive(Clone, Copy, Debug)] +pub enum Format { + /// Text format (UTF-8) + Text, + /// Compact, typed binary format + Binary, +} + +impl ToSql for &T +where + T: ToSql, +{ + fn to_sql( + &self, + ty: &Type, + out: &mut BytesMut, + ) -> Result> { + (*self).to_sql(ty, out) + } + + fn accepts(ty: &Type) -> bool { + T::accepts(ty) + } + + fn encode_format(&self, ty: &Type) -> Format { + (*self).encode_format(ty) + } + + to_sql_checked!(); +} + +impl ToSql for Option { + fn to_sql( + &self, + ty: &Type, + out: &mut BytesMut, + ) -> Result> { + match *self { + Some(ref val) => val.to_sql(ty, out), + None => Ok(IsNull::Yes), + } + } + + fn accepts(ty: &Type) -> bool { + ::accepts(ty) + } + + fn encode_format(&self, ty: &Type) -> Format { + match self { + Some(val) => val.encode_format(ty), + None => Format::Binary, + } + } + + to_sql_checked!(); +} + +impl ToSql for &[T] { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + let member_type = match *ty.kind() { + Kind::Array(ref member) => member, + _ => panic!("expected array type"), + }; + + // Arrays are normally one indexed by default but oidvector and int2vector *require* zero indexing + let lower_bound = match *ty { + Type::OID_VECTOR | Type::INT2_VECTOR => 0, + _ => 1, + }; + + let dimension = ArrayDimension { + len: downcast(self.len())?, + lower_bound, + }; + + types::array_to_sql( + Some(dimension), + member_type.oid(), + self.iter(), + |e, w| match e.to_sql(member_type, w)? { + IsNull::No => Ok(postgres_protocol::IsNull::No), + IsNull::Yes => Ok(postgres_protocol::IsNull::Yes), + }, + w, + )?; + Ok(IsNull::No) + } + + fn accepts(ty: &Type) -> bool { + match *ty.kind() { + Kind::Array(ref member) => T::accepts(member), + _ => false, + } + } + + to_sql_checked!(); +} + +impl ToSql for &[u8] { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + types::bytea_to_sql(self, w); + Ok(IsNull::No) + } + + accepts!(BYTEA); + + to_sql_checked!(); +} + +#[cfg(feature = "array-impls")] +impl ToSql for [u8; N] { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + types::bytea_to_sql(&self[..], w); + Ok(IsNull::No) + } + + accepts!(BYTEA); + + to_sql_checked!(); +} + +#[cfg(feature = "array-impls")] +impl ToSql for [T; N] { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&[T] as ToSql>::to_sql(&&self[..], ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&[T] as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + +impl ToSql for Vec { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&[T] as ToSql>::to_sql(&&**self, ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&[T] as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + +impl ToSql for Box<[T]> { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&[T] as ToSql>::to_sql(&&**self, ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&[T] as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + +impl ToSql for Cow<'_, [u8]> { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&[u8] as ToSql>::to_sql(&self.as_ref(), ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&[u8] as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + +impl ToSql for Vec { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&[u8] as ToSql>::to_sql(&&**self, ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&[u8] as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + +impl ToSql for &str { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + match ty.name() { + "ltree" => types::ltree_to_sql(self, w), + "lquery" => types::lquery_to_sql(self, w), + "ltxtquery" => types::ltxtquery_to_sql(self, w), + _ => types::text_to_sql(self, w), + } + Ok(IsNull::No) + } + + fn accepts(ty: &Type) -> bool { + matches!( + *ty, + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN + ) || matches!(ty.name(), "citext" | "ltree" | "lquery" | "ltxtquery") + } + + to_sql_checked!(); +} + +impl ToSql for Cow<'_, str> { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&str as ToSql>::to_sql(&self.as_ref(), ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&str as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + +impl ToSql for String { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&str as ToSql>::to_sql(&&**self, ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&str as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + +impl ToSql for Box { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&str as ToSql>::to_sql(&&**self, ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&str as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} + +macro_rules! simple_to { + ($t:ty, $f:ident, $($expected:ident),+) => { + impl ToSql for $t { + fn to_sql(&self, + _: &Type, + w: &mut BytesMut) + -> Result> { + types::$f(*self, w); + Ok(IsNull::No) + } + + accepts!($($expected),+); + + to_sql_checked!(); + } + } +} + +simple_to!(bool, bool_to_sql, BOOL); +simple_to!(i8, char_to_sql, CHAR); +simple_to!(i16, int2_to_sql, INT2); +simple_to!(i32, int4_to_sql, INT4); +simple_to!(u32, oid_to_sql, OID); +simple_to!(i64, int8_to_sql, INT8); +simple_to!(f32, float4_to_sql, FLOAT4); +simple_to!(f64, float8_to_sql, FLOAT8); + +impl ToSql for HashMap, H> +where + H: BuildHasher, +{ + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + types::hstore_to_sql( + self.iter().map(|(k, v)| (&**k, v.as_ref().map(|v| &**v))), + w, + )?; + Ok(IsNull::No) + } + + fn accepts(ty: &Type) -> bool { + ty.name() == "hstore" + } + + to_sql_checked!(); +} + +impl ToSql for SystemTime { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let epoch = UNIX_EPOCH + Duration::from_secs(TIME_SEC_CONVERSION); + + let to_usec = + |d: Duration| d.as_secs() * USEC_PER_SEC + u64::from(d.subsec_nanos()) / NSEC_PER_USEC; + + let time = match self.duration_since(epoch) { + Ok(duration) => to_usec(duration) as i64, + Err(e) => -(to_usec(e.duration()) as i64), + }; + + types::timestamp_to_sql(time, w); + Ok(IsNull::No) + } + + accepts!(TIMESTAMP, TIMESTAMPTZ); + + to_sql_checked!(); +} + +impl ToSql for IpAddr { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let netmask = match self { + IpAddr::V4(_) => 32, + IpAddr::V6(_) => 128, + }; + types::inet_to_sql(*self, netmask, w); + Ok(IsNull::No) + } + + accepts!(INET); + + to_sql_checked!(); +} + +fn downcast(len: usize) -> Result> { + if len > i32::MAX as usize { + Err("value too large to transmit".into()) + } else { + Ok(len as i32) + } +} + +mod sealed { + pub trait Sealed {} +} + +/// A trait used by clients to abstract over `&dyn ToSql` and `T: ToSql`. +/// +/// This cannot be implemented outside of this crate. +pub trait BorrowToSql: sealed::Sealed { + /// Returns a reference to `self` as a `ToSql` trait object. + fn borrow_to_sql(&self) -> &dyn ToSql; +} + +impl sealed::Sealed for &dyn ToSql {} + +impl BorrowToSql for &dyn ToSql { + #[inline] + fn borrow_to_sql(&self) -> &dyn ToSql { + *self + } +} + +impl sealed::Sealed for Box {} + +impl BorrowToSql for Box { + #[inline] + fn borrow_to_sql(&self) -> &dyn ToSql { + self.as_ref() + } +} + +impl sealed::Sealed for Box {} +impl BorrowToSql for Box { + #[inline] + fn borrow_to_sql(&self) -> &dyn ToSql { + self.as_ref() + } +} + +impl sealed::Sealed for &(dyn ToSql + Sync) {} + +/// In async contexts it is sometimes necessary to have the additional +/// Sync requirement on parameters for queries since this enables the +/// resulting Futures to be Send, hence usable in, e.g., tokio::spawn. +/// This instance is provided for those cases. +impl BorrowToSql for &(dyn ToSql + Sync) { + #[inline] + fn borrow_to_sql(&self) -> &dyn ToSql { + *self + } +} + +impl sealed::Sealed for T where T: ToSql {} + +impl BorrowToSql for T +where + T: ToSql, +{ + #[inline] + fn borrow_to_sql(&self) -> &dyn ToSql { + self + } +} diff --git a/postgres-types/src/pg_lsn.rs b/postgres-types/src/pg_lsn.rs new file mode 100644 index 000000000..f339f9689 --- /dev/null +++ b/postgres-types/src/pg_lsn.rs @@ -0,0 +1,77 @@ +//! Log Sequence Number (LSN) type for PostgreSQL Write-Ahead Log +//! (WAL), also known as the transaction log. + +use bytes::BytesMut; +use postgres_protocol::types; +use std::error::Error; +use std::fmt; +use std::str::FromStr; + +use crate::{FromSql, IsNull, ToSql, Type}; + +/// Postgres `PG_LSN` type. +#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd)] +pub struct PgLsn(u64); + +/// Error parsing LSN. +#[derive(Debug)] +pub struct ParseLsnError(()); + +impl From for PgLsn { + fn from(lsn_u64: u64) -> Self { + PgLsn(lsn_u64) + } +} + +impl From for u64 { + fn from(lsn: PgLsn) -> u64 { + lsn.0 + } +} + +impl FromStr for PgLsn { + type Err = ParseLsnError; + + fn from_str(lsn_str: &str) -> Result { + let Some((split_hi, split_lo)) = lsn_str.split_once('/') else { + return Err(ParseLsnError(())); + }; + let (hi, lo) = ( + u64::from_str_radix(split_hi, 16).map_err(|_| ParseLsnError(()))?, + u64::from_str_radix(split_lo, 16).map_err(|_| ParseLsnError(()))?, + ); + Ok(PgLsn((hi << 32) | lo)) + } +} + +impl fmt::Display for PgLsn { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:X}/{:X}", self.0 >> 32, self.0 & 0x00000000ffffffff) + } +} + +impl fmt::Debug for PgLsn { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_fmt(format_args!("{}", self)) + } +} + +impl<'a> FromSql<'a> for PgLsn { + fn from_sql(_: &Type, raw: &'a [u8]) -> Result> { + let v = types::lsn_from_sql(raw)?; + Ok(v.into()) + } + + accepts!(PG_LSN); +} + +impl ToSql for PgLsn { + fn to_sql(&self, _: &Type, out: &mut BytesMut) -> Result> { + types::lsn_to_sql((*self).into(), out); + Ok(IsNull::No) + } + + accepts!(PG_LSN); + + to_sql_checked!(); +} diff --git a/postgres-types/src/private.rs b/postgres-types/src/private.rs new file mode 100644 index 000000000..774f9a301 --- /dev/null +++ b/postgres-types/src/private.rs @@ -0,0 +1,34 @@ +use crate::{FromSql, Type}; +pub use bytes::BytesMut; +use std::error::Error; + +pub fn read_be_i32(buf: &mut &[u8]) -> Result> { + if buf.len() < 4 { + return Err("invalid buffer size".into()); + } + let mut bytes = [0; 4]; + bytes.copy_from_slice(&buf[..4]); + *buf = &buf[4..]; + Ok(i32::from_be_bytes(bytes)) +} + +pub fn read_value<'a, T>( + type_: &Type, + buf: &mut &'a [u8], +) -> Result> +where + T: FromSql<'a>, +{ + let len = read_be_i32(buf)?; + let value = if len < 0 { + None + } else { + if len as usize > buf.len() { + return Err("invalid buffer size".into()); + } + let (head, tail) = buf.split_at(len as usize); + *buf = tail; + Some(head) + }; + T::from_sql_nullable(type_, value) +} diff --git a/postgres-types/src/serde_json_1.rs b/postgres-types/src/serde_json_1.rs new file mode 100644 index 000000000..715c33f98 --- /dev/null +++ b/postgres-types/src/serde_json_1.rs @@ -0,0 +1,73 @@ +use crate::{FromSql, IsNull, ToSql, Type}; +use bytes::{BufMut, BytesMut}; +use serde_1::{Deserialize, Serialize}; +use serde_json_1::Value; +use std::error::Error; +use std::fmt::Debug; +use std::io::Read; + +/// A wrapper type to allow arbitrary `Serialize`/`Deserialize` types to convert to Postgres JSON values. +#[derive(Clone, Default, Debug, PartialEq, Eq)] +pub struct Json(pub T); + +impl<'a, T> FromSql<'a> for Json +where + T: Deserialize<'a>, +{ + fn from_sql(ty: &Type, mut raw: &'a [u8]) -> Result, Box> { + if *ty == Type::JSONB { + let mut b = [0; 1]; + raw.read_exact(&mut b)?; + // We only support version 1 of the jsonb binary format + if b[0] != 1 { + return Err("unsupported JSONB encoding version".into()); + } + } + serde_json_1::de::from_slice(raw) + .map(Json) + .map_err(Into::into) + } + + accepts!(JSON, JSONB); +} + +impl ToSql for Json +where + T: Serialize + Debug, +{ + fn to_sql( + &self, + ty: &Type, + out: &mut BytesMut, + ) -> Result> { + if *ty == Type::JSONB { + out.put_u8(1); + } + serde_json_1::ser::to_writer(out.writer(), &self.0)?; + Ok(IsNull::No) + } + + accepts!(JSON, JSONB); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Value { + fn from_sql(ty: &Type, raw: &[u8]) -> Result> { + Json::::from_sql(ty, raw).map(|json| json.0) + } + + accepts!(JSON, JSONB); +} + +impl ToSql for Value { + fn to_sql( + &self, + ty: &Type, + out: &mut BytesMut, + ) -> Result> { + Json(self).to_sql(ty, out) + } + + accepts!(JSON, JSONB); + to_sql_checked!(); +} diff --git a/postgres-types/src/smol_str_01.rs b/postgres-types/src/smol_str_01.rs new file mode 100644 index 000000000..a0d024ce2 --- /dev/null +++ b/postgres-types/src/smol_str_01.rs @@ -0,0 +1,27 @@ +use bytes::BytesMut; +use smol_str_01::SmolStr; +use std::error::Error; + +use crate::{FromSql, IsNull, ToSql, Type}; + +impl<'a> FromSql<'a> for SmolStr { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + <&str as FromSql>::from_sql(ty, raw).map(SmolStr::from) + } + + fn accepts(ty: &Type) -> bool { + <&str as FromSql>::accepts(ty) + } +} + +impl ToSql for SmolStr { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + <&str as ToSql>::to_sql(&&**self, ty, w) + } + + fn accepts(ty: &Type) -> bool { + <&str as ToSql>::accepts(ty) + } + + to_sql_checked!(); +} diff --git a/postgres-shared/src/types/special.rs b/postgres-types/src/special.rs similarity index 74% rename from postgres-shared/src/types/special.rs rename to postgres-types/src/special.rs index a0566319e..d8541bf0e 100644 --- a/postgres-shared/src/types/special.rs +++ b/postgres-types/src/special.rs @@ -1,11 +1,11 @@ +use bytes::BytesMut; use postgres_protocol::types; use std::error::Error; -use std::{i32, i64}; -use types::{FromSql, IsNull, ToSql, Type}; +use crate::{FromSql, IsNull, ToSql, Type}; /// A wrapper that can be used to represent infinity with `Type::Date` types. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Date { /// Represents `infinity`, a date that is later than all other dates. PosInfinity, @@ -16,7 +16,7 @@ pub enum Date { } impl<'a, T: FromSql<'a>> FromSql<'a> for Date { - fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { match types::date_from_sql(raw)? { i32::MAX => Ok(Date::PosInfinity), i32::MIN => Ok(Date::NegInfinity), @@ -30,7 +30,11 @@ impl<'a, T: FromSql<'a>> FromSql<'a> for Date { } impl ToSql for Date { - fn to_sql(&self, ty: &Type, out: &mut Vec) -> Result> { + fn to_sql( + &self, + ty: &Type, + out: &mut BytesMut, + ) -> Result> { let value = match *self { Date::PosInfinity => i32::MAX, Date::NegInfinity => i32::MIN, @@ -50,7 +54,7 @@ impl ToSql for Date { /// A wrapper that can be used to represent infinity with `Type::Timestamp` and `Type::Timestamptz` /// types. -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Timestamp { /// Represents `infinity`, a timestamp that is later than all other timestamps. PosInfinity, @@ -61,7 +65,7 @@ pub enum Timestamp { } impl<'a, T: FromSql<'a>> FromSql<'a> for Timestamp { - fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { match types::timestamp_from_sql(raw)? { i64::MAX => Ok(Timestamp::PosInfinity), i64::MIN => Ok(Timestamp::NegInfinity), @@ -70,15 +74,16 @@ impl<'a, T: FromSql<'a>> FromSql<'a> for Timestamp { } fn accepts(ty: &Type) -> bool { - match *ty { - Type::TIMESTAMP | Type::TIMESTAMPTZ if T::accepts(ty) => true, - _ => false, - } + matches!(*ty, Type::TIMESTAMP | Type::TIMESTAMPTZ if T::accepts(ty)) } } impl ToSql for Timestamp { - fn to_sql(&self, ty: &Type, out: &mut Vec) -> Result> { + fn to_sql( + &self, + ty: &Type, + out: &mut BytesMut, + ) -> Result> { let value = match *self { Timestamp::PosInfinity => i64::MAX, Timestamp::NegInfinity => i64::MIN, @@ -90,10 +95,7 @@ impl ToSql for Timestamp { } fn accepts(ty: &Type) -> bool { - match *ty { - Type::TIMESTAMP | Type::TIMESTAMPTZ if T::accepts(ty) => true, - _ => false, - } + matches!(*ty, Type::TIMESTAMP | Type::TIMESTAMPTZ if T::accepts(ty)) } to_sql_checked!(); diff --git a/postgres-types/src/time_02.rs b/postgres-types/src/time_02.rs new file mode 100644 index 000000000..19a8909e7 --- /dev/null +++ b/postgres-types/src/time_02.rs @@ -0,0 +1,109 @@ +use bytes::BytesMut; +use postgres_protocol::types; +use std::convert::TryFrom; +use std::error::Error; +use time_02::{date, time, Date, Duration, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset}; + +use crate::{FromSql, IsNull, ToSql, Type}; + +#[rustfmt::skip] +const fn base() -> PrimitiveDateTime { + PrimitiveDateTime::new(date!(2000-01-01), time!(00:00:00)) +} + +impl<'a> FromSql<'a> for PrimitiveDateTime { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let t = types::timestamp_from_sql(raw)?; + Ok(base() + Duration::microseconds(t)) + } + + accepts!(TIMESTAMP); +} + +impl ToSql for PrimitiveDateTime { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let time = match i64::try_from((*self - base()).whole_microseconds()) { + Ok(time) => time, + Err(_) => return Err("value too large to transmit".into()), + }; + types::timestamp_to_sql(time, w); + Ok(IsNull::No) + } + + accepts!(TIMESTAMP); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for OffsetDateTime { + fn from_sql(type_: &Type, raw: &[u8]) -> Result> { + let primitive = PrimitiveDateTime::from_sql(type_, raw)?; + Ok(primitive.assume_utc()) + } + + accepts!(TIMESTAMPTZ); +} + +impl ToSql for OffsetDateTime { + fn to_sql( + &self, + type_: &Type, + w: &mut BytesMut, + ) -> Result> { + let utc_datetime = self.to_offset(UtcOffset::UTC); + let date = utc_datetime.date(); + let time = utc_datetime.time(); + let primitive = PrimitiveDateTime::new(date, time); + primitive.to_sql(type_, w) + } + + accepts!(TIMESTAMPTZ); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Date { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let jd = types::date_from_sql(raw)?; + Ok(base().date() + Duration::days(i64::from(jd))) + } + + accepts!(DATE); +} + +impl ToSql for Date { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let jd = (*self - base().date()).whole_days(); + if jd > i64::from(i32::max_value()) || jd < i64::from(i32::min_value()) { + return Err("value too large to transmit".into()); + } + + types::date_to_sql(jd as i32, w); + Ok(IsNull::No) + } + + accepts!(DATE); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Time { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let usec = types::time_from_sql(raw)?; + Ok(time!(00:00:00) + Duration::microseconds(usec)) + } + + accepts!(TIME); +} + +impl ToSql for Time { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let delta = *self - time!(00:00:00); + let time = match i64::try_from(delta.whole_microseconds()) { + Ok(time) => time, + Err(_) => return Err("value too large to transmit".into()), + }; + types::time_to_sql(time, w); + Ok(IsNull::No) + } + + accepts!(TIME); + to_sql_checked!(); +} diff --git a/postgres-types/src/time_03.rs b/postgres-types/src/time_03.rs new file mode 100644 index 000000000..4deea663f --- /dev/null +++ b/postgres-types/src/time_03.rs @@ -0,0 +1,113 @@ +use bytes::BytesMut; +use postgres_protocol::types; +use std::convert::TryFrom; +use std::error::Error; +use time_03::{Date, Duration, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset}; + +use crate::{FromSql, IsNull, ToSql, Type}; + +fn base() -> PrimitiveDateTime { + PrimitiveDateTime::new(Date::from_ordinal_date(2000, 1).unwrap(), Time::MIDNIGHT) +} + +impl<'a> FromSql<'a> for PrimitiveDateTime { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let t = types::timestamp_from_sql(raw)?; + Ok(base() + .checked_add(Duration::microseconds(t)) + .ok_or("value too large to decode")?) + } + + accepts!(TIMESTAMP); +} + +impl ToSql for PrimitiveDateTime { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let time = match i64::try_from((*self - base()).whole_microseconds()) { + Ok(time) => time, + Err(_) => return Err("value too large to transmit".into()), + }; + types::timestamp_to_sql(time, w); + Ok(IsNull::No) + } + + accepts!(TIMESTAMP); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for OffsetDateTime { + fn from_sql(type_: &Type, raw: &[u8]) -> Result> { + let primitive = PrimitiveDateTime::from_sql(type_, raw)?; + Ok(primitive.assume_utc()) + } + + accepts!(TIMESTAMPTZ); +} + +impl ToSql for OffsetDateTime { + fn to_sql( + &self, + type_: &Type, + w: &mut BytesMut, + ) -> Result> { + let utc_datetime = self.to_offset(UtcOffset::UTC); + let date = utc_datetime.date(); + let time = utc_datetime.time(); + let primitive = PrimitiveDateTime::new(date, time); + primitive.to_sql(type_, w) + } + + accepts!(TIMESTAMPTZ); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Date { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let jd = types::date_from_sql(raw)?; + Ok(base() + .date() + .checked_add(Duration::days(i64::from(jd))) + .ok_or("value too large to decode")?) + } + + accepts!(DATE); +} + +impl ToSql for Date { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let jd = (*self - base().date()).whole_days(); + if jd > i64::from(i32::max_value()) || jd < i64::from(i32::min_value()) { + return Err("value too large to transmit".into()); + } + + types::date_to_sql(jd as i32, w); + Ok(IsNull::No) + } + + accepts!(DATE); + to_sql_checked!(); +} + +impl<'a> FromSql<'a> for Time { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let usec = types::time_from_sql(raw)?; + Ok(Time::MIDNIGHT + Duration::microseconds(usec)) + } + + accepts!(TIME); +} + +impl ToSql for Time { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + let delta = *self - Time::MIDNIGHT; + let time = match i64::try_from(delta.whole_microseconds()) { + Ok(time) => time, + Err(_) => return Err("value too large to transmit".into()), + }; + types::time_to_sql(time, w); + Ok(IsNull::No) + } + + accepts!(TIME); + to_sql_checked!(); +} diff --git a/postgres-shared/src/types/type_gen.rs b/postgres-types/src/type_gen.rs similarity index 60% rename from postgres-shared/src/types/type_gen.rs rename to postgres-types/src/type_gen.rs index 3992112ea..a1bc3f85c 100644 --- a/postgres-shared/src/types/type_gen.rs +++ b/postgres-types/src/type_gen.rs @@ -1,9 +1,9 @@ // Autogenerated file - DO NOT EDIT use std::sync::Arc; -use types::{Type, Oid, Kind}; +use crate::{Kind, Oid, Type}; -#[derive(PartialEq, Eq, Debug)] +#[derive(PartialEq, Eq, Debug, Hash)] pub struct Other { pub name: String, pub oid: Oid, @@ -11,7 +11,7 @@ pub struct Other { pub schema: String, } -#[derive(PartialEq, Eq, Clone, Debug)] +#[derive(PartialEq, Eq, Clone, Debug, Hash)] pub enum Inner { Bool, Bytea, @@ -34,7 +34,8 @@ pub enum Inner { XmlArray, PgNodeTree, JsonArray, - Smgr, + TableAmHandler, + Xid8Array, IndexAmHandler, Point, Lseg, @@ -47,9 +48,6 @@ pub enum Inner { CidrArray, Float4, Float8, - Abstime, - Reltime, - Tinterval, Unknown, Circle, CircleArray, @@ -81,9 +79,6 @@ pub enum Inner { BoxArray, Float4Array, Float8Array, - AbstimeArray, - ReltimeArray, - TintervalArray, PolygonArray, OidArray, Aclitem, @@ -131,7 +126,6 @@ pub enum Inner { Trigger, LanguageHandler, Internal, - Opaque, Anyelement, RecordArray, Anynonarray, @@ -172,10 +166,38 @@ pub enum Inner { DateRangeArray, Int8Range, Int8RangeArray, + Jsonpath, + JsonpathArray, Regnamespace, RegnamespaceArray, Regrole, RegroleArray, + Regcollation, + RegcollationArray, + Int4multiRange, + NummultiRange, + TsmultiRange, + TstzmultiRange, + DatemultiRange, + Int8multiRange, + AnymultiRange, + AnycompatiblemultiRange, + PgBrinBloomSummary, + PgBrinMinmaxMultiSummary, + PgMcvList, + PgSnapshot, + PgSnapshotArray, + Xid8, + Anycompatible, + Anycompatiblearray, + Anycompatiblenonarray, + AnycompatibleRange, + Int4multiRangeArray, + NummultiRangeArray, + TsmultiRangeArray, + TstzmultiRangeArray, + DatemultiRangeArray, + Int8multiRangeArray, Other(Arc), } @@ -203,7 +225,8 @@ impl Inner { 143 => Some(Inner::XmlArray), 194 => Some(Inner::PgNodeTree), 199 => Some(Inner::JsonArray), - 210 => Some(Inner::Smgr), + 269 => Some(Inner::TableAmHandler), + 271 => Some(Inner::Xid8Array), 325 => Some(Inner::IndexAmHandler), 600 => Some(Inner::Point), 601 => Some(Inner::Lseg), @@ -216,9 +239,6 @@ impl Inner { 651 => Some(Inner::CidrArray), 700 => Some(Inner::Float4), 701 => Some(Inner::Float8), - 702 => Some(Inner::Abstime), - 703 => Some(Inner::Reltime), - 704 => Some(Inner::Tinterval), 705 => Some(Inner::Unknown), 718 => Some(Inner::Circle), 719 => Some(Inner::CircleArray), @@ -250,9 +270,6 @@ impl Inner { 1020 => Some(Inner::BoxArray), 1021 => Some(Inner::Float4Array), 1022 => Some(Inner::Float8Array), - 1023 => Some(Inner::AbstimeArray), - 1024 => Some(Inner::ReltimeArray), - 1025 => Some(Inner::TintervalArray), 1027 => Some(Inner::PolygonArray), 1028 => Some(Inner::OidArray), 1033 => Some(Inner::Aclitem), @@ -300,7 +317,6 @@ impl Inner { 2279 => Some(Inner::Trigger), 2280 => Some(Inner::LanguageHandler), 2281 => Some(Inner::Internal), - 2282 => Some(Inner::Opaque), 2283 => Some(Inner::Anyelement), 2287 => Some(Inner::RecordArray), 2776 => Some(Inner::Anynonarray), @@ -341,10 +357,38 @@ impl Inner { 3913 => Some(Inner::DateRangeArray), 3926 => Some(Inner::Int8Range), 3927 => Some(Inner::Int8RangeArray), + 4072 => Some(Inner::Jsonpath), + 4073 => Some(Inner::JsonpathArray), 4089 => Some(Inner::Regnamespace), 4090 => Some(Inner::RegnamespaceArray), 4096 => Some(Inner::Regrole), 4097 => Some(Inner::RegroleArray), + 4191 => Some(Inner::Regcollation), + 4192 => Some(Inner::RegcollationArray), + 4451 => Some(Inner::Int4multiRange), + 4532 => Some(Inner::NummultiRange), + 4533 => Some(Inner::TsmultiRange), + 4534 => Some(Inner::TstzmultiRange), + 4535 => Some(Inner::DatemultiRange), + 4536 => Some(Inner::Int8multiRange), + 4537 => Some(Inner::AnymultiRange), + 4538 => Some(Inner::AnycompatiblemultiRange), + 4600 => Some(Inner::PgBrinBloomSummary), + 4601 => Some(Inner::PgBrinMinmaxMultiSummary), + 5017 => Some(Inner::PgMcvList), + 5038 => Some(Inner::PgSnapshot), + 5039 => Some(Inner::PgSnapshotArray), + 5069 => Some(Inner::Xid8), + 5077 => Some(Inner::Anycompatible), + 5078 => Some(Inner::Anycompatiblearray), + 5079 => Some(Inner::Anycompatiblenonarray), + 5080 => Some(Inner::AnycompatibleRange), + 6150 => Some(Inner::Int4multiRangeArray), + 6151 => Some(Inner::NummultiRangeArray), + 6152 => Some(Inner::TsmultiRangeArray), + 6153 => Some(Inner::TstzmultiRangeArray), + 6155 => Some(Inner::DatemultiRangeArray), + 6157 => Some(Inner::Int8multiRangeArray), _ => None, } } @@ -372,7 +416,8 @@ impl Inner { Inner::XmlArray => 143, Inner::PgNodeTree => 194, Inner::JsonArray => 199, - Inner::Smgr => 210, + Inner::TableAmHandler => 269, + Inner::Xid8Array => 271, Inner::IndexAmHandler => 325, Inner::Point => 600, Inner::Lseg => 601, @@ -385,9 +430,6 @@ impl Inner { Inner::CidrArray => 651, Inner::Float4 => 700, Inner::Float8 => 701, - Inner::Abstime => 702, - Inner::Reltime => 703, - Inner::Tinterval => 704, Inner::Unknown => 705, Inner::Circle => 718, Inner::CircleArray => 719, @@ -419,9 +461,6 @@ impl Inner { Inner::BoxArray => 1020, Inner::Float4Array => 1021, Inner::Float8Array => 1022, - Inner::AbstimeArray => 1023, - Inner::ReltimeArray => 1024, - Inner::TintervalArray => 1025, Inner::PolygonArray => 1027, Inner::OidArray => 1028, Inner::Aclitem => 1033, @@ -469,7 +508,6 @@ impl Inner { Inner::Trigger => 2279, Inner::LanguageHandler => 2280, Inner::Internal => 2281, - Inner::Opaque => 2282, Inner::Anyelement => 2283, Inner::RecordArray => 2287, Inner::Anynonarray => 2776, @@ -510,668 +548,229 @@ impl Inner { Inner::DateRangeArray => 3913, Inner::Int8Range => 3926, Inner::Int8RangeArray => 3927, + Inner::Jsonpath => 4072, + Inner::JsonpathArray => 4073, Inner::Regnamespace => 4089, Inner::RegnamespaceArray => 4090, Inner::Regrole => 4096, Inner::RegroleArray => 4097, + Inner::Regcollation => 4191, + Inner::RegcollationArray => 4192, + Inner::Int4multiRange => 4451, + Inner::NummultiRange => 4532, + Inner::TsmultiRange => 4533, + Inner::TstzmultiRange => 4534, + Inner::DatemultiRange => 4535, + Inner::Int8multiRange => 4536, + Inner::AnymultiRange => 4537, + Inner::AnycompatiblemultiRange => 4538, + Inner::PgBrinBloomSummary => 4600, + Inner::PgBrinMinmaxMultiSummary => 4601, + Inner::PgMcvList => 5017, + Inner::PgSnapshot => 5038, + Inner::PgSnapshotArray => 5039, + Inner::Xid8 => 5069, + Inner::Anycompatible => 5077, + Inner::Anycompatiblearray => 5078, + Inner::Anycompatiblenonarray => 5079, + Inner::AnycompatibleRange => 5080, + Inner::Int4multiRangeArray => 6150, + Inner::NummultiRangeArray => 6151, + Inner::TsmultiRangeArray => 6152, + Inner::TstzmultiRangeArray => 6153, + Inner::DatemultiRangeArray => 6155, + Inner::Int8multiRangeArray => 6157, Inner::Other(ref u) => u.oid, } } pub fn kind(&self) -> &Kind { match *self { - Inner::Bool => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Bytea => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Char => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Name => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Int8 => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Int2 => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Int2Vector => { - const V: &'static Kind = &Kind::Array(Type(Inner::Int2)); - V - } - Inner::Int4 => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Regproc => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Text => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Oid => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Tid => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Xid => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Cid => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::OidVector => { - const V: &'static Kind = &Kind::Array(Type(Inner::Oid)); - V - } - Inner::PgDdlCommand => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Json => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Xml => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::XmlArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Xml)); - V - } - Inner::PgNodeTree => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::JsonArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Json)); - V - } - Inner::Smgr => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::IndexAmHandler => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Point => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Lseg => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Path => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Box => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Polygon => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Line => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::LineArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Line)); - V - } - Inner::Cidr => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::CidrArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Cidr)); - V - } - Inner::Float4 => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Float8 => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Abstime => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Reltime => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Tinterval => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Unknown => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Circle => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::CircleArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Circle)); - V - } - Inner::Macaddr8 => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Macaddr8Array => { - const V: &'static Kind = &Kind::Array(Type(Inner::Macaddr8)); - V - } - Inner::Money => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::MoneyArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Money)); - V - } - Inner::Macaddr => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Inet => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::BoolArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Bool)); - V - } - Inner::ByteaArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Bytea)); - V - } - Inner::CharArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Char)); - V - } - Inner::NameArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Name)); - V - } - Inner::Int2Array => { - const V: &'static Kind = &Kind::Array(Type(Inner::Int2)); - V - } - Inner::Int2VectorArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Int2Vector)); - V - } - Inner::Int4Array => { - const V: &'static Kind = &Kind::Array(Type(Inner::Int4)); - V - } - Inner::RegprocArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Regproc)); - V - } - Inner::TextArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Text)); - V - } - Inner::TidArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Tid)); - V - } - Inner::XidArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Xid)); - V - } - Inner::CidArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Cid)); - V - } - Inner::OidVectorArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::OidVector)); - V - } - Inner::BpcharArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Bpchar)); - V - } - Inner::VarcharArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Varchar)); - V - } - Inner::Int8Array => { - const V: &'static Kind = &Kind::Array(Type(Inner::Int8)); - V - } - Inner::PointArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Point)); - V - } - Inner::LsegArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Lseg)); - V - } - Inner::PathArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Path)); - V - } - Inner::BoxArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Box)); - V - } - Inner::Float4Array => { - const V: &'static Kind = &Kind::Array(Type(Inner::Float4)); - V - } - Inner::Float8Array => { - const V: &'static Kind = &Kind::Array(Type(Inner::Float8)); - V - } - Inner::AbstimeArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Abstime)); - V - } - Inner::ReltimeArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Reltime)); - V - } - Inner::TintervalArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Tinterval)); - V - } - Inner::PolygonArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Polygon)); - V - } - Inner::OidArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Oid)); - V - } - Inner::Aclitem => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::AclitemArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Aclitem)); - V - } - Inner::MacaddrArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Macaddr)); - V - } - Inner::InetArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Inet)); - V - } - Inner::Bpchar => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Varchar => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Date => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Time => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Timestamp => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::TimestampArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Timestamp)); - V - } - Inner::DateArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Date)); - V - } - Inner::TimeArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Time)); - V - } - Inner::Timestamptz => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::TimestamptzArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Timestamptz)); - V - } - Inner::Interval => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::IntervalArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Interval)); - V - } - Inner::NumericArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Numeric)); - V - } - Inner::CstringArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Cstring)); - V - } - Inner::Timetz => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::TimetzArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Timetz)); - V - } - Inner::Bit => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::BitArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Bit)); - V - } - Inner::Varbit => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::VarbitArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Varbit)); - V - } - Inner::Numeric => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Refcursor => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::RefcursorArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Refcursor)); - V - } - Inner::Regprocedure => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Regoper => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Regoperator => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Regclass => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Regtype => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::RegprocedureArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Regprocedure)); - V - } - Inner::RegoperArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Regoper)); - V - } - Inner::RegoperatorArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Regoperator)); - V - } - Inner::RegclassArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Regclass)); - V - } - Inner::RegtypeArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Regtype)); - V - } - Inner::Record => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Cstring => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Any => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Anyarray => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Void => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Trigger => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::LanguageHandler => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Internal => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Opaque => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Anyelement => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::RecordArray => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Anynonarray => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::TxidSnapshotArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::TxidSnapshot)); - V - } - Inner::Uuid => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::UuidArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Uuid)); - V - } - Inner::TxidSnapshot => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::FdwHandler => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::PgLsn => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::PgLsnArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::PgLsn)); - V - } - Inner::TsmHandler => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::PgNdistinct => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::PgDependencies => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Anyenum => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::TsVector => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::Tsquery => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::GtsVector => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::TsVectorArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::TsVector)); - V - } - Inner::GtsVectorArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::GtsVector)); - V - } - Inner::TsqueryArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Tsquery)); - V - } - Inner::Regconfig => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::RegconfigArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Regconfig)); - V - } - Inner::Regdictionary => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::RegdictionaryArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Regdictionary)); - V - } - Inner::Jsonb => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::JsonbArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Jsonb)); - V - } - Inner::AnyRange => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::EventTrigger => { - const V: &'static Kind = &Kind::Pseudo; - V - } - Inner::Int4Range => { - const V: &'static Kind = &Kind::Range(Type(Inner::Int4)); - V - } - Inner::Int4RangeArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Int4Range)); - V - } - Inner::NumRange => { - const V: &'static Kind = &Kind::Range(Type(Inner::Numeric)); - V - } - Inner::NumRangeArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::NumRange)); - V - } - Inner::TsRange => { - const V: &'static Kind = &Kind::Range(Type(Inner::Timestamp)); - V - } - Inner::TsRangeArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::TsRange)); - V - } - Inner::TstzRange => { - const V: &'static Kind = &Kind::Range(Type(Inner::Timestamptz)); - V - } - Inner::TstzRangeArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::TstzRange)); - V - } - Inner::DateRange => { - const V: &'static Kind = &Kind::Range(Type(Inner::Date)); - V - } - Inner::DateRangeArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::DateRange)); - V - } - Inner::Int8Range => { - const V: &'static Kind = &Kind::Range(Type(Inner::Int8)); - V - } - Inner::Int8RangeArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Int8Range)); - V - } - Inner::Regnamespace => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::RegnamespaceArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Regnamespace)); - V - } - Inner::Regrole => { - const V: &'static Kind = &Kind::Simple; - V - } - Inner::RegroleArray => { - const V: &'static Kind = &Kind::Array(Type(Inner::Regrole)); - V - } + Inner::Bool => &Kind::Simple, + Inner::Bytea => &Kind::Simple, + Inner::Char => &Kind::Simple, + Inner::Name => &Kind::Simple, + Inner::Int8 => &Kind::Simple, + Inner::Int2 => &Kind::Simple, + Inner::Int2Vector => &Kind::Array(Type(Inner::Int2)), + Inner::Int4 => &Kind::Simple, + Inner::Regproc => &Kind::Simple, + Inner::Text => &Kind::Simple, + Inner::Oid => &Kind::Simple, + Inner::Tid => &Kind::Simple, + Inner::Xid => &Kind::Simple, + Inner::Cid => &Kind::Simple, + Inner::OidVector => &Kind::Array(Type(Inner::Oid)), + Inner::PgDdlCommand => &Kind::Pseudo, + Inner::Json => &Kind::Simple, + Inner::Xml => &Kind::Simple, + Inner::XmlArray => &Kind::Array(Type(Inner::Xml)), + Inner::PgNodeTree => &Kind::Simple, + Inner::JsonArray => &Kind::Array(Type(Inner::Json)), + Inner::TableAmHandler => &Kind::Pseudo, + Inner::Xid8Array => &Kind::Array(Type(Inner::Xid8)), + Inner::IndexAmHandler => &Kind::Pseudo, + Inner::Point => &Kind::Simple, + Inner::Lseg => &Kind::Simple, + Inner::Path => &Kind::Simple, + Inner::Box => &Kind::Simple, + Inner::Polygon => &Kind::Simple, + Inner::Line => &Kind::Simple, + Inner::LineArray => &Kind::Array(Type(Inner::Line)), + Inner::Cidr => &Kind::Simple, + Inner::CidrArray => &Kind::Array(Type(Inner::Cidr)), + Inner::Float4 => &Kind::Simple, + Inner::Float8 => &Kind::Simple, + Inner::Unknown => &Kind::Simple, + Inner::Circle => &Kind::Simple, + Inner::CircleArray => &Kind::Array(Type(Inner::Circle)), + Inner::Macaddr8 => &Kind::Simple, + Inner::Macaddr8Array => &Kind::Array(Type(Inner::Macaddr8)), + Inner::Money => &Kind::Simple, + Inner::MoneyArray => &Kind::Array(Type(Inner::Money)), + Inner::Macaddr => &Kind::Simple, + Inner::Inet => &Kind::Simple, + Inner::BoolArray => &Kind::Array(Type(Inner::Bool)), + Inner::ByteaArray => &Kind::Array(Type(Inner::Bytea)), + Inner::CharArray => &Kind::Array(Type(Inner::Char)), + Inner::NameArray => &Kind::Array(Type(Inner::Name)), + Inner::Int2Array => &Kind::Array(Type(Inner::Int2)), + Inner::Int2VectorArray => &Kind::Array(Type(Inner::Int2Vector)), + Inner::Int4Array => &Kind::Array(Type(Inner::Int4)), + Inner::RegprocArray => &Kind::Array(Type(Inner::Regproc)), + Inner::TextArray => &Kind::Array(Type(Inner::Text)), + Inner::TidArray => &Kind::Array(Type(Inner::Tid)), + Inner::XidArray => &Kind::Array(Type(Inner::Xid)), + Inner::CidArray => &Kind::Array(Type(Inner::Cid)), + Inner::OidVectorArray => &Kind::Array(Type(Inner::OidVector)), + Inner::BpcharArray => &Kind::Array(Type(Inner::Bpchar)), + Inner::VarcharArray => &Kind::Array(Type(Inner::Varchar)), + Inner::Int8Array => &Kind::Array(Type(Inner::Int8)), + Inner::PointArray => &Kind::Array(Type(Inner::Point)), + Inner::LsegArray => &Kind::Array(Type(Inner::Lseg)), + Inner::PathArray => &Kind::Array(Type(Inner::Path)), + Inner::BoxArray => &Kind::Array(Type(Inner::Box)), + Inner::Float4Array => &Kind::Array(Type(Inner::Float4)), + Inner::Float8Array => &Kind::Array(Type(Inner::Float8)), + Inner::PolygonArray => &Kind::Array(Type(Inner::Polygon)), + Inner::OidArray => &Kind::Array(Type(Inner::Oid)), + Inner::Aclitem => &Kind::Simple, + Inner::AclitemArray => &Kind::Array(Type(Inner::Aclitem)), + Inner::MacaddrArray => &Kind::Array(Type(Inner::Macaddr)), + Inner::InetArray => &Kind::Array(Type(Inner::Inet)), + Inner::Bpchar => &Kind::Simple, + Inner::Varchar => &Kind::Simple, + Inner::Date => &Kind::Simple, + Inner::Time => &Kind::Simple, + Inner::Timestamp => &Kind::Simple, + Inner::TimestampArray => &Kind::Array(Type(Inner::Timestamp)), + Inner::DateArray => &Kind::Array(Type(Inner::Date)), + Inner::TimeArray => &Kind::Array(Type(Inner::Time)), + Inner::Timestamptz => &Kind::Simple, + Inner::TimestamptzArray => &Kind::Array(Type(Inner::Timestamptz)), + Inner::Interval => &Kind::Simple, + Inner::IntervalArray => &Kind::Array(Type(Inner::Interval)), + Inner::NumericArray => &Kind::Array(Type(Inner::Numeric)), + Inner::CstringArray => &Kind::Array(Type(Inner::Cstring)), + Inner::Timetz => &Kind::Simple, + Inner::TimetzArray => &Kind::Array(Type(Inner::Timetz)), + Inner::Bit => &Kind::Simple, + Inner::BitArray => &Kind::Array(Type(Inner::Bit)), + Inner::Varbit => &Kind::Simple, + Inner::VarbitArray => &Kind::Array(Type(Inner::Varbit)), + Inner::Numeric => &Kind::Simple, + Inner::Refcursor => &Kind::Simple, + Inner::RefcursorArray => &Kind::Array(Type(Inner::Refcursor)), + Inner::Regprocedure => &Kind::Simple, + Inner::Regoper => &Kind::Simple, + Inner::Regoperator => &Kind::Simple, + Inner::Regclass => &Kind::Simple, + Inner::Regtype => &Kind::Simple, + Inner::RegprocedureArray => &Kind::Array(Type(Inner::Regprocedure)), + Inner::RegoperArray => &Kind::Array(Type(Inner::Regoper)), + Inner::RegoperatorArray => &Kind::Array(Type(Inner::Regoperator)), + Inner::RegclassArray => &Kind::Array(Type(Inner::Regclass)), + Inner::RegtypeArray => &Kind::Array(Type(Inner::Regtype)), + Inner::Record => &Kind::Pseudo, + Inner::Cstring => &Kind::Pseudo, + Inner::Any => &Kind::Pseudo, + Inner::Anyarray => &Kind::Pseudo, + Inner::Void => &Kind::Pseudo, + Inner::Trigger => &Kind::Pseudo, + Inner::LanguageHandler => &Kind::Pseudo, + Inner::Internal => &Kind::Pseudo, + Inner::Anyelement => &Kind::Pseudo, + Inner::RecordArray => &Kind::Pseudo, + Inner::Anynonarray => &Kind::Pseudo, + Inner::TxidSnapshotArray => &Kind::Array(Type(Inner::TxidSnapshot)), + Inner::Uuid => &Kind::Simple, + Inner::UuidArray => &Kind::Array(Type(Inner::Uuid)), + Inner::TxidSnapshot => &Kind::Simple, + Inner::FdwHandler => &Kind::Pseudo, + Inner::PgLsn => &Kind::Simple, + Inner::PgLsnArray => &Kind::Array(Type(Inner::PgLsn)), + Inner::TsmHandler => &Kind::Pseudo, + Inner::PgNdistinct => &Kind::Simple, + Inner::PgDependencies => &Kind::Simple, + Inner::Anyenum => &Kind::Pseudo, + Inner::TsVector => &Kind::Simple, + Inner::Tsquery => &Kind::Simple, + Inner::GtsVector => &Kind::Simple, + Inner::TsVectorArray => &Kind::Array(Type(Inner::TsVector)), + Inner::GtsVectorArray => &Kind::Array(Type(Inner::GtsVector)), + Inner::TsqueryArray => &Kind::Array(Type(Inner::Tsquery)), + Inner::Regconfig => &Kind::Simple, + Inner::RegconfigArray => &Kind::Array(Type(Inner::Regconfig)), + Inner::Regdictionary => &Kind::Simple, + Inner::RegdictionaryArray => &Kind::Array(Type(Inner::Regdictionary)), + Inner::Jsonb => &Kind::Simple, + Inner::JsonbArray => &Kind::Array(Type(Inner::Jsonb)), + Inner::AnyRange => &Kind::Pseudo, + Inner::EventTrigger => &Kind::Pseudo, + Inner::Int4Range => &Kind::Range(Type(Inner::Int4)), + Inner::Int4RangeArray => &Kind::Array(Type(Inner::Int4Range)), + Inner::NumRange => &Kind::Range(Type(Inner::Numeric)), + Inner::NumRangeArray => &Kind::Array(Type(Inner::NumRange)), + Inner::TsRange => &Kind::Range(Type(Inner::Timestamp)), + Inner::TsRangeArray => &Kind::Array(Type(Inner::TsRange)), + Inner::TstzRange => &Kind::Range(Type(Inner::Timestamptz)), + Inner::TstzRangeArray => &Kind::Array(Type(Inner::TstzRange)), + Inner::DateRange => &Kind::Range(Type(Inner::Date)), + Inner::DateRangeArray => &Kind::Array(Type(Inner::DateRange)), + Inner::Int8Range => &Kind::Range(Type(Inner::Int8)), + Inner::Int8RangeArray => &Kind::Array(Type(Inner::Int8Range)), + Inner::Jsonpath => &Kind::Simple, + Inner::JsonpathArray => &Kind::Array(Type(Inner::Jsonpath)), + Inner::Regnamespace => &Kind::Simple, + Inner::RegnamespaceArray => &Kind::Array(Type(Inner::Regnamespace)), + Inner::Regrole => &Kind::Simple, + Inner::RegroleArray => &Kind::Array(Type(Inner::Regrole)), + Inner::Regcollation => &Kind::Simple, + Inner::RegcollationArray => &Kind::Array(Type(Inner::Regcollation)), + Inner::Int4multiRange => &Kind::Multirange(Type(Inner::Int4)), + Inner::NummultiRange => &Kind::Multirange(Type(Inner::Numeric)), + Inner::TsmultiRange => &Kind::Multirange(Type(Inner::Timestamp)), + Inner::TstzmultiRange => &Kind::Multirange(Type(Inner::Timestamptz)), + Inner::DatemultiRange => &Kind::Multirange(Type(Inner::Date)), + Inner::Int8multiRange => &Kind::Multirange(Type(Inner::Int8)), + Inner::AnymultiRange => &Kind::Pseudo, + Inner::AnycompatiblemultiRange => &Kind::Pseudo, + Inner::PgBrinBloomSummary => &Kind::Simple, + Inner::PgBrinMinmaxMultiSummary => &Kind::Simple, + Inner::PgMcvList => &Kind::Simple, + Inner::PgSnapshot => &Kind::Simple, + Inner::PgSnapshotArray => &Kind::Array(Type(Inner::PgSnapshot)), + Inner::Xid8 => &Kind::Simple, + Inner::Anycompatible => &Kind::Pseudo, + Inner::Anycompatiblearray => &Kind::Pseudo, + Inner::Anycompatiblenonarray => &Kind::Pseudo, + Inner::AnycompatibleRange => &Kind::Pseudo, + Inner::Int4multiRangeArray => &Kind::Array(Type(Inner::Int4multiRange)), + Inner::NummultiRangeArray => &Kind::Array(Type(Inner::NummultiRange)), + Inner::TsmultiRangeArray => &Kind::Array(Type(Inner::TsmultiRange)), + Inner::TstzmultiRangeArray => &Kind::Array(Type(Inner::TstzmultiRange)), + Inner::DatemultiRangeArray => &Kind::Array(Type(Inner::DatemultiRange)), + Inner::Int8multiRangeArray => &Kind::Array(Type(Inner::Int8multiRange)), Inner::Other(ref u) => &u.kind, } } @@ -1199,7 +798,8 @@ impl Inner { Inner::XmlArray => "_xml", Inner::PgNodeTree => "pg_node_tree", Inner::JsonArray => "_json", - Inner::Smgr => "smgr", + Inner::TableAmHandler => "table_am_handler", + Inner::Xid8Array => "_xid8", Inner::IndexAmHandler => "index_am_handler", Inner::Point => "point", Inner::Lseg => "lseg", @@ -1212,9 +812,6 @@ impl Inner { Inner::CidrArray => "_cidr", Inner::Float4 => "float4", Inner::Float8 => "float8", - Inner::Abstime => "abstime", - Inner::Reltime => "reltime", - Inner::Tinterval => "tinterval", Inner::Unknown => "unknown", Inner::Circle => "circle", Inner::CircleArray => "_circle", @@ -1246,9 +843,6 @@ impl Inner { Inner::BoxArray => "_box", Inner::Float4Array => "_float4", Inner::Float8Array => "_float8", - Inner::AbstimeArray => "_abstime", - Inner::ReltimeArray => "_reltime", - Inner::TintervalArray => "_tinterval", Inner::PolygonArray => "_polygon", Inner::OidArray => "_oid", Inner::Aclitem => "aclitem", @@ -1296,7 +890,6 @@ impl Inner { Inner::Trigger => "trigger", Inner::LanguageHandler => "language_handler", Inner::Internal => "internal", - Inner::Opaque => "opaque", Inner::Anyelement => "anyelement", Inner::RecordArray => "_record", Inner::Anynonarray => "anynonarray", @@ -1337,10 +930,38 @@ impl Inner { Inner::DateRangeArray => "_daterange", Inner::Int8Range => "int8range", Inner::Int8RangeArray => "_int8range", + Inner::Jsonpath => "jsonpath", + Inner::JsonpathArray => "_jsonpath", Inner::Regnamespace => "regnamespace", Inner::RegnamespaceArray => "_regnamespace", Inner::Regrole => "regrole", Inner::RegroleArray => "_regrole", + Inner::Regcollation => "regcollation", + Inner::RegcollationArray => "_regcollation", + Inner::Int4multiRange => "int4multirange", + Inner::NummultiRange => "nummultirange", + Inner::TsmultiRange => "tsmultirange", + Inner::TstzmultiRange => "tstzmultirange", + Inner::DatemultiRange => "datemultirange", + Inner::Int8multiRange => "int8multirange", + Inner::AnymultiRange => "anymultirange", + Inner::AnycompatiblemultiRange => "anycompatiblemultirange", + Inner::PgBrinBloomSummary => "pg_brin_bloom_summary", + Inner::PgBrinMinmaxMultiSummary => "pg_brin_minmax_multi_summary", + Inner::PgMcvList => "pg_mcv_list", + Inner::PgSnapshot => "pg_snapshot", + Inner::PgSnapshotArray => "_pg_snapshot", + Inner::Xid8 => "xid8", + Inner::Anycompatible => "anycompatible", + Inner::Anycompatiblearray => "anycompatiblearray", + Inner::Anycompatiblenonarray => "anycompatiblenonarray", + Inner::AnycompatibleRange => "anycompatiblerange", + Inner::Int4multiRangeArray => "_int4multirange", + Inner::NummultiRangeArray => "_nummultirange", + Inner::TsmultiRangeArray => "_tsmultirange", + Inner::TstzmultiRangeArray => "_tstzmultirange", + Inner::DatemultiRangeArray => "_datemultirange", + Inner::Int8multiRangeArray => "_int8multirange", Inner::Other(ref u) => &u.name, } } @@ -1394,7 +1015,7 @@ impl Type { /// PG_DDL_COMMAND - internal type for passing CollectedCommand pub const PG_DDL_COMMAND: Type = Type(Inner::PgDdlCommand); - /// JSON + /// JSON - JSON stored as text pub const JSON: Type = Type(Inner::Json); /// XML - XML content @@ -1409,10 +1030,13 @@ impl Type { /// JSON[] pub const JSON_ARRAY: Type = Type(Inner::JsonArray); - /// SMGR - storage manager - pub const SMGR: Type = Type(Inner::Smgr); + /// TABLE_AM_HANDLER + pub const TABLE_AM_HANDLER: Type = Type(Inner::TableAmHandler); - /// INDEX_AM_HANDLER + /// XID8[] + pub const XID8_ARRAY: Type = Type(Inner::Xid8Array); + + /// INDEX_AM_HANDLER - pseudo-type for the result of an index AM handler function pub const INDEX_AM_HANDLER: Type = Type(Inner::IndexAmHandler); /// POINT - geometric point '(x, y)' @@ -1448,16 +1072,7 @@ impl Type { /// FLOAT8 - double-precision floating point number, 8-byte storage pub const FLOAT8: Type = Type(Inner::Float8); - /// ABSTIME - absolute, limited-range date and time (Unix system time) - pub const ABSTIME: Type = Type(Inner::Abstime); - - /// RELTIME - relative, limited-range time interval (Unix delta time) - pub const RELTIME: Type = Type(Inner::Reltime); - - /// TINTERVAL - (abstime,abstime), time interval - pub const TINTERVAL: Type = Type(Inner::Tinterval); - - /// UNKNOWN + /// UNKNOWN - pseudo-type representing an undetermined type pub const UNKNOWN: Type = Type(Inner::Unknown); /// CIRCLE - geometric circle '(center,radius)' @@ -1550,15 +1165,6 @@ impl Type { /// FLOAT8[] pub const FLOAT8_ARRAY: Type = Type(Inner::Float8Array); - /// ABSTIME[] - pub const ABSTIME_ARRAY: Type = Type(Inner::AbstimeArray); - - /// RELTIME[] - pub const RELTIME_ARRAY: Type = Type(Inner::ReltimeArray); - - /// TINTERVAL[] - pub const TINTERVAL_ARRAY: Type = Type(Inner::TintervalArray); - /// POLYGON[] pub const POLYGON_ARRAY: Type = Type(Inner::PolygonArray); @@ -1676,40 +1282,37 @@ impl Type { /// REGTYPE[] pub const REGTYPE_ARRAY: Type = Type(Inner::RegtypeArray); - /// RECORD + /// RECORD - pseudo-type representing any composite type pub const RECORD: Type = Type(Inner::Record); - /// CSTRING + /// CSTRING - C-style string pub const CSTRING: Type = Type(Inner::Cstring); - /// ANY + /// ANY - pseudo-type representing any type pub const ANY: Type = Type(Inner::Any); - /// ANYARRAY + /// ANYARRAY - pseudo-type representing a polymorphic array type pub const ANYARRAY: Type = Type(Inner::Anyarray); - /// VOID + /// VOID - pseudo-type for the result of a function with no real result pub const VOID: Type = Type(Inner::Void); - /// TRIGGER + /// TRIGGER - pseudo-type for the result of a trigger function pub const TRIGGER: Type = Type(Inner::Trigger); - /// LANGUAGE_HANDLER + /// LANGUAGE_HANDLER - pseudo-type for the result of a language handler function pub const LANGUAGE_HANDLER: Type = Type(Inner::LanguageHandler); - /// INTERNAL + /// INTERNAL - pseudo-type representing an internal data structure pub const INTERNAL: Type = Type(Inner::Internal); - /// OPAQUE - pub const OPAQUE: Type = Type(Inner::Opaque); - - /// ANYELEMENT + /// ANYELEMENT - pseudo-type representing a polymorphic base type pub const ANYELEMENT: Type = Type(Inner::Anyelement); /// RECORD[] pub const RECORD_ARRAY: Type = Type(Inner::RecordArray); - /// ANYNONARRAY + /// ANYNONARRAY - pseudo-type representing a polymorphic base type that is not an array pub const ANYNONARRAY: Type = Type(Inner::Anynonarray); /// TXID_SNAPSHOT[] @@ -1724,7 +1327,7 @@ impl Type { /// TXID_SNAPSHOT - txid snapshot pub const TXID_SNAPSHOT: Type = Type(Inner::TxidSnapshot); - /// FDW_HANDLER + /// FDW_HANDLER - pseudo-type for the result of an FDW handler function pub const FDW_HANDLER: Type = Type(Inner::FdwHandler); /// PG_LSN - PostgreSQL LSN datatype @@ -1733,7 +1336,7 @@ impl Type { /// PG_LSN[] pub const PG_LSN_ARRAY: Type = Type(Inner::PgLsnArray); - /// TSM_HANDLER + /// TSM_HANDLER - pseudo-type for the result of a tablesample method function pub const TSM_HANDLER: Type = Type(Inner::TsmHandler); /// PG_NDISTINCT - multivariate ndistinct coefficients @@ -1742,7 +1345,7 @@ impl Type { /// PG_DEPENDENCIES - multivariate dependencies pub const PG_DEPENDENCIES: Type = Type(Inner::PgDependencies); - /// ANYENUM + /// ANYENUM - pseudo-type representing a polymorphic base type that is an enum pub const ANYENUM: Type = Type(Inner::Anyenum); /// TSVECTOR - text representation for text search @@ -1781,10 +1384,10 @@ impl Type { /// JSONB[] pub const JSONB_ARRAY: Type = Type(Inner::JsonbArray); - /// ANYRANGE + /// ANYRANGE - pseudo-type representing a range over a polymorphic base type pub const ANY_RANGE: Type = Type(Inner::AnyRange); - /// EVENT_TRIGGER + /// EVENT_TRIGGER - pseudo-type for the result of an event trigger function pub const EVENT_TRIGGER: Type = Type(Inner::EventTrigger); /// INT4RANGE - range of integers @@ -1823,6 +1426,12 @@ impl Type { /// INT8RANGE[] pub const INT8_RANGE_ARRAY: Type = Type(Inner::Int8RangeArray); + /// JSONPATH - JSON path + pub const JSONPATH: Type = Type(Inner::Jsonpath); + + /// JSONPATH[] + pub const JSONPATH_ARRAY: Type = Type(Inner::JsonpathArray); + /// REGNAMESPACE - registered namespace pub const REGNAMESPACE: Type = Type(Inner::Regnamespace); @@ -1834,4 +1443,82 @@ impl Type { /// REGROLE[] pub const REGROLE_ARRAY: Type = Type(Inner::RegroleArray); -} \ No newline at end of file + + /// REGCOLLATION - registered collation + pub const REGCOLLATION: Type = Type(Inner::Regcollation); + + /// REGCOLLATION[] + pub const REGCOLLATION_ARRAY: Type = Type(Inner::RegcollationArray); + + /// INT4MULTIRANGE - multirange of integers + pub const INT4MULTI_RANGE: Type = Type(Inner::Int4multiRange); + + /// NUMMULTIRANGE - multirange of numerics + pub const NUMMULTI_RANGE: Type = Type(Inner::NummultiRange); + + /// TSMULTIRANGE - multirange of timestamps without time zone + pub const TSMULTI_RANGE: Type = Type(Inner::TsmultiRange); + + /// TSTZMULTIRANGE - multirange of timestamps with time zone + pub const TSTZMULTI_RANGE: Type = Type(Inner::TstzmultiRange); + + /// DATEMULTIRANGE - multirange of dates + pub const DATEMULTI_RANGE: Type = Type(Inner::DatemultiRange); + + /// INT8MULTIRANGE - multirange of bigints + pub const INT8MULTI_RANGE: Type = Type(Inner::Int8multiRange); + + /// ANYMULTIRANGE - pseudo-type representing a polymorphic base type that is a multirange + pub const ANYMULTI_RANGE: Type = Type(Inner::AnymultiRange); + + /// ANYCOMPATIBLEMULTIRANGE - pseudo-type representing a multirange over a polymorphic common type + pub const ANYCOMPATIBLEMULTI_RANGE: Type = Type(Inner::AnycompatiblemultiRange); + + /// PG_BRIN_BLOOM_SUMMARY - BRIN bloom summary + pub const PG_BRIN_BLOOM_SUMMARY: Type = Type(Inner::PgBrinBloomSummary); + + /// PG_BRIN_MINMAX_MULTI_SUMMARY - BRIN minmax-multi summary + pub const PG_BRIN_MINMAX_MULTI_SUMMARY: Type = Type(Inner::PgBrinMinmaxMultiSummary); + + /// PG_MCV_LIST - multivariate MCV list + pub const PG_MCV_LIST: Type = Type(Inner::PgMcvList); + + /// PG_SNAPSHOT - snapshot + pub const PG_SNAPSHOT: Type = Type(Inner::PgSnapshot); + + /// PG_SNAPSHOT[] + pub const PG_SNAPSHOT_ARRAY: Type = Type(Inner::PgSnapshotArray); + + /// XID8 - full transaction id + pub const XID8: Type = Type(Inner::Xid8); + + /// ANYCOMPATIBLE - pseudo-type representing a polymorphic common type + pub const ANYCOMPATIBLE: Type = Type(Inner::Anycompatible); + + /// ANYCOMPATIBLEARRAY - pseudo-type representing an array of polymorphic common type elements + pub const ANYCOMPATIBLEARRAY: Type = Type(Inner::Anycompatiblearray); + + /// ANYCOMPATIBLENONARRAY - pseudo-type representing a polymorphic common type that is not an array + pub const ANYCOMPATIBLENONARRAY: Type = Type(Inner::Anycompatiblenonarray); + + /// ANYCOMPATIBLERANGE - pseudo-type representing a range over a polymorphic common type + pub const ANYCOMPATIBLE_RANGE: Type = Type(Inner::AnycompatibleRange); + + /// INT4MULTIRANGE[] + pub const INT4MULTI_RANGE_ARRAY: Type = Type(Inner::Int4multiRangeArray); + + /// NUMMULTIRANGE[] + pub const NUMMULTI_RANGE_ARRAY: Type = Type(Inner::NummultiRangeArray); + + /// TSMULTIRANGE[] + pub const TSMULTI_RANGE_ARRAY: Type = Type(Inner::TsmultiRangeArray); + + /// TSTZMULTIRANGE[] + pub const TSTZMULTI_RANGE_ARRAY: Type = Type(Inner::TstzmultiRangeArray); + + /// DATEMULTIRANGE[] + pub const DATEMULTI_RANGE_ARRAY: Type = Type(Inner::DatemultiRangeArray); + + /// INT8MULTIRANGE[] + pub const INT8MULTI_RANGE_ARRAY: Type = Type(Inner::Int8multiRangeArray); +} diff --git a/postgres-shared/src/types/uuid.rs b/postgres-types/src/uuid_08.rs similarity index 50% rename from postgres-shared/src/types/uuid.rs rename to postgres-types/src/uuid_08.rs index d7a190d01..72d5e82fc 100644 --- a/postgres-shared/src/types/uuid.rs +++ b/postgres-types/src/uuid_08.rs @@ -1,22 +1,21 @@ -extern crate uuid; - -use self::uuid::Uuid; +use bytes::BytesMut; use postgres_protocol::types; use std::error::Error; +use uuid_08::Uuid; -use types::{FromSql, IsNull, ToSql, Type}; +use crate::{FromSql, IsNull, ToSql, Type}; impl<'a> FromSql<'a> for Uuid { - fn from_sql(_: &Type, raw: &[u8]) -> Result> { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { let bytes = types::uuid_from_sql(raw)?; - Ok(Uuid::from_bytes(&bytes).unwrap()) + Ok(Uuid::from_bytes(bytes)) } accepts!(UUID); } impl ToSql for Uuid { - fn to_sql(&self, _: &Type, w: &mut Vec) -> Result> { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { types::uuid_to_sql(*self.as_bytes(), w); Ok(IsNull::No) } diff --git a/postgres-types/src/uuid_1.rs b/postgres-types/src/uuid_1.rs new file mode 100644 index 000000000..d9969f60c --- /dev/null +++ b/postgres-types/src/uuid_1.rs @@ -0,0 +1,25 @@ +use bytes::BytesMut; +use postgres_protocol::types; +use std::error::Error; +use uuid_1::Uuid; + +use crate::{FromSql, IsNull, ToSql, Type}; + +impl<'a> FromSql<'a> for Uuid { + fn from_sql(_: &Type, raw: &[u8]) -> Result> { + let bytes = types::uuid_from_sql(raw)?; + Ok(Uuid::from_bytes(bytes)) + } + + accepts!(UUID); +} + +impl ToSql for Uuid { + fn to_sql(&self, _: &Type, w: &mut BytesMut) -> Result> { + types::uuid_to_sql(*self.as_bytes(), w); + Ok(IsNull::No) + } + + accepts!(UUID); + to_sql_checked!(); +} diff --git a/postgres/CHANGELOG.md b/postgres/CHANGELOG.md new file mode 100644 index 000000000..771e2e779 --- /dev/null +++ b/postgres/CHANGELOG.md @@ -0,0 +1,286 @@ +# Change Log + +## Unreleased + +## v0.19.10 - 2025-02-02 + +### Added + +* Added support for direct TLS negotiation. +* Added support for `cidr` 0.3 via the `with-cidr-0_3` feature. + +## v0.19.9 - 2024-09-15 + +### Added + +* Added support for `jiff` 0.1 via the `with-jiff-01` feature. + +## v0.19.8 - 2024-07-21 + +### Added + +* Added `{Client, Transaction, GenericClient}::query_typed`. + +## v0.19.7 - 2023-08-25 + +## Fixed + +* Defered default username lookup to avoid regressing `Config` behavior. + +## v0.19.6 - 2023-08-19 + +### Added + +* Added support for the `hostaddr` config option to bypass DNS lookups. +* Added support for the `load_balance_hosts` config option to randomize connection ordering. +* The `user` config option now defaults to the executing process's user. + +## v0.19.5 - 2023-03-27 + +### Added + +* Added `keepalives_interval` and `keepalives_retries` config options. +* Added the `tcp_user_timeout` config option. +* Added `RowIter::rows_affected`. + +### Changed + +* Passing an incorrect number of parameters to a query method now returns an error instead of panicking. + +## v0.19.4 - 2022-08-21 + +### Added + +* Added `ToSql` and `FromSql` implementations for `[u8; N]` via the `array-impls` feature. +* Added support for `smol_str` 0.1 via the `with-smol_str-01` feature. +* Added `ToSql::encode_format` to support text encodings of parameters. + +## v0.19.3 - 2022-04-30 + +### Added + +* Added support for `uuid` 1.0 via the `with-uuid-1` feature. + +## v0.19.2 - 2021-09-29 + +### Added + +* Added `SimpleQueryRow::columns`. +* Added support for `eui48` 1.0 via the `with-eui48-1` feature. +* Added `FromSql` and `ToSql` implementations for arrays via the `array-impls` feature. +* Added support for `time` 0.3 via the `with-time-0_3` feature. + +## v0.19.1 - 2021-04-03 + +### Added + +* Added support for `geo-types` 0.7 via `with-geo-types-0_7` feature. +* Added `Client::clear_type_cache`. + +## v0.19.0 - 2020-12-25 + +### Changed + +* Upgraded to `tokio-postgres` 0.7. +* Methods taking iterators of `ToSql` values can now take both `&dyn ToSql` and `T: ToSql` values. + +### Added + +* Added `Client::is_valid` which can be used to check that the connection is still alive with a + timeout. + +## v0.18.1 - 2020-10-19 + +### Fixed + +* Restored the `Send` implementation for `Client`. + +## v0.18.0 - 2020-10-17 + +### Changed + +* Upgraded to `tokio-postgres` 0.6. + +### Added + +* Added `Config::notice_callback`, which can be used to provide a custom callback for notices. + +### Fixed + +* Fixed client shutdown to explicitly terminate the database session. + +## v0.17.5 - 2020-07-19 + +### Fixed + +* Fixed transactions to roll back immediately on drop. + +## v0.17.4 - 2020-07-03 + +### Added + +* Added support for `geo-types` 0.6. + +## v0.17.3 - 2020-05-01 + +### Fixed + +* Errors sent by the server will now be returned from `Client` methods rather than just being logged. + +### Added + +* Added `Transaction::savepoint`, which can be used to create a savepoint with a custom name. +* Added `Client::notifications`, which returns an interface to the notifications sent by the server. + +## v0.17.2 - 2020-03-05 + +### Added + +* Added `Debug` implementations for `Client`, `Row`, and `Column`. +* Added `time` 0.2 support. + +## v0.17.1 - 2020-01-31 + +### Added + +* Added `Client::build_transaction` to allow configuration of various transaction options. +* Added `Client::cancel_token`, which returns a separate owned object that can be used to cancel queries. +* Added accessors for `Config` fields. +* Added a `GenericClient` trait implemented for `Client` and `Transaction` and covering shared functionality. + +## v0.17.0 - 2019-12-23 + +### Changed + +* Each `Client` now has its own non-threaded tokio `Runtime` rather than sharing a global threaded `Runtime`. This + significantly improves performance by minimizing context switches and cross-thread synchronization. +* `Client::copy_in` now returns a writer rather than taking in a reader. +* `Client::query_raw` now returns a named type. +* `Client::copy_in` and `Client::copy_out` no longer take query parameters as PostgreSQL doesn't support them in COPY + queries. + +### Removed + +* Removed support for `uuid` 0.7. + +### Added + +* Added `Client::query_opt` for queries that are expected to return zero or one rows. +* Added binary copy support in the `binary_copy` module. +* The `fallible-iterator` crate is now publicly reexported. + +## v0.17.0-alpha.2 - 2019-11-27 + +### Changed + +* Changed `Config::executor` to `Config::spawner`. + +### Added + +* Added support for `uuid` 0.8. +* Added `Transaction::query_one`. + +## v0.17.0-alpha.1 - 2019-10-14 + +### Changed + +* Updated `tokio-postgres` to 0.5.0-alpha.1. + +## v0.16.0-rc.2 - 2019-06-29 + +### Fixed + +* Documentation fixes + +## v0.16.0-rc.1 - 2019-04-06 + +### Changed + +* `Connection` has been renamed to `Client`. +* The `Client` type is now a thin wrapper around the tokio-postgres nonblocking client. By default, this is handled + transparently by spawning connections onto an internal tokio `Runtime`, but this can also be controlled explicitly. +* The `ConnectParams` type and `IntoConnectParams` trait have been replaced by a builder-style `Config` type. + + Before: + ```rust + let params = ConnectParams::builder() + .user("postgres", None) + .build(Host::Tcp("localhost".to_string())) + .build(); + let conn = Connection::connect(params, &TlsMode::None)?; + ``` + After: + ```rust + let client = Client::configure() + .user("postgres") + .host("localhost") + .connect(NoTls)?; + ``` +* The TLS connection mode (e.g. `prefer`) is now part of the connection configuration instead of being passed in + separately. + + Before: + ```rust + let conn = Connection::connect("postgres://postgres@localhost", &TlsMode::Prefer(connector))?; + ``` + After: + ```rust + let client = Client::connect("postgres://postgres@localhost?sslmode=prefer", connector)?; + ``` +* `Client` and `Transaction` methods take `&mut self` rather than `&self`, and correct use of the active transaction is + verified at compile time rather than runtime. +* `Row` no longer borrows any data. +* `Statement` is now a "token" which is passed into methods on `Client` and `Transaction` and does not borrow the + client: + + Before: + ```rust + let statement = conn.prepare("SELECT * FROM foo WHERE bar = $1")?; + let rows = statement.query(&[&1i32])?; + ``` + After: + ```rust + let statement = client.prepare("SELECT * FROM foo WHERE bar = $1")?; + let rows = client.query(&statement, &[1i32])?; + ``` +* `Statement::lazy_query` has been replaced with `Transaction::bind`, which returns a `Portal` type that can be used + with `Transaction::query_portal`. +* `Statement::copy_in` and `Statement::copy_out` have been moved to `Client` and `Transaction`. +* `Client::copy_out` and `Transaction::copy_out` now return a `Read`er rather than consuming in a `Write`r. +* `Connection::batch_execute` and `Transaction::batch_execute` have been replaced with `Client::simple_query` and + `Transaction::simple_query`. +* The Cargo features enabling `ToSql` and `FromSql` implementations for external crates are now versioned. For example, + `with-uuid` is now `with-uuid-0_7`. This enables us to add support for new major versions of the crates in parallel + without breaking backwards compatibility. + +### Added + +* Connection string configuration now more fully mirrors libpq's syntax, and supports both URL-style and key-value style + strings. +* `FromSql` implementations can now borrow from the data buffer. In particular, this means that you can deserialize + values as `&str`. The `FromSqlOwned` trait can be used as a bound to restrict code to deserializing owned values. +* Added support for channel binding with SCRAM authentication. +* Added multi-host support in connection configuration. +* Added support for simple query requests returning row data. +* Added variants of query methods which return fallible iterators of values and avoid fully buffering the response in + memory. + +### Removed + +* The `with-openssl` and `with-native-tls` Cargo features have been removed. Use the `tokio-postgres-openssl` and + `tokio-postgres-native-tls` crates instead. +* The `with-rustc_serialize` and `with-time` Cargo features have been removed. Use `serde` and `SystemTime` or `chrono` + instead. +* The `Transaction::set_commit` and `Transaction::set_rollback` methods have been removed. The only way to commit a + transaction is to explicitly consume it via `Transaction::commit`. +* The `Rows` type has been removed; methods now return `Vec` instead. +* `Connection::prepare_cache` has been removed, as `Statement` is now `'static` and can be more easily cached + externally. +* Some other slightly more obscure features have been removed in the initial release. If you depended on them, please + file an issue and we can find the right design to add them back! + +## Older + +Look at the [release tags] for information about older releases. + +[release tags]: https://github.com/sfackler/rust-postgres/releases diff --git a/postgres/Cargo.toml b/postgres/Cargo.toml index 2a21d33a1..456bfb808 100644 --- a/postgres/Cargo.toml +++ b/postgres/Cargo.toml @@ -1,66 +1,51 @@ [package] name = "postgres" -version = "0.15.2" +version = "0.19.10" authors = ["Steven Fackler "] -license = "MIT" -description = "A native PostgreSQL driver" +edition = "2018" +license = "MIT OR Apache-2.0" +description = "A native, synchronous PostgreSQL client" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" keywords = ["database", "postgres", "postgresql", "sql"] -include = ["src/*", "Cargo.toml", "LICENSE", "README.md", "THIRD_PARTY"] categories = ["database"] +[[bench]] +name = "bench" +harness = false + [package.metadata.docs.rs] -features = [ - "with-bit-vec-0.5", - "with-chrono-0.4", - "with-eui48-0.3", - "with-geo-0.10", - "with-serde_json-1", - "with-uuid-0.6", - "with-openssl", - "with-native-tls", -] +all-features = true [badges] circle-ci = { repository = "sfackler/rust-postgres" } -[lib] -name = "postgres" -path = "src/lib.rs" -test = false -bench = false - -[[test]] -name = "test" -path = "tests/test.rs" - [features] -"with-bit-vec-0.5" = ["postgres-shared/with-bit-vec-0.5"] -"with-chrono-0.4" = ["postgres-shared/with-chrono-0.4"] -"with-eui48-0.3" = ["postgres-shared/with-eui48-0.3"] -"with-geo-0.10" = ["postgres-shared/with-geo-0.10"] -"with-serde_json-1" = ["postgres-shared/with-serde_json-1"] -"with-uuid-0.6" = ["postgres-shared/with-uuid-0.6"] - -no-logging = [] +array-impls = ["tokio-postgres/array-impls"] +with-bit-vec-0_6 = ["tokio-postgres/with-bit-vec-0_6"] +with-chrono-0_4 = ["tokio-postgres/with-chrono-0_4"] +with-cidr-0_2 = ["tokio-postgres/with-cidr-0_2"] +with-cidr-0_3 = ["tokio-postgres/with-cidr-0_3"] +with-eui48-0_4 = ["tokio-postgres/with-eui48-0_4"] +with-eui48-1 = ["tokio-postgres/with-eui48-1"] +with-geo-types-0_6 = ["tokio-postgres/with-geo-types-0_6"] +with-geo-types-0_7 = ["tokio-postgres/with-geo-types-0_7"] +with-jiff-0_1 = ["tokio-postgres/with-jiff-0_1"] +with-jiff-0_2 = ["tokio-postgres/with-jiff-0_2"] +with-serde_json-1 = ["tokio-postgres/with-serde_json-1"] +with-smol_str-01 = ["tokio-postgres/with-smol_str-01"] +with-uuid-0_8 = ["tokio-postgres/with-uuid-0_8"] +with-uuid-1 = ["tokio-postgres/with-uuid-1"] +with-time-0_2 = ["tokio-postgres/with-time-0_2"] +with-time-0_3 = ["tokio-postgres/with-time-0_3"] [dependencies] -bytes = "0.4" -fallible-iterator = "0.1.3" +bytes = "1.0" +fallible-iterator = "0.2" +futures-util = { version = "0.3.14", features = ["sink"] } log = "0.4" -socket2 = { version = "0.3.5", features = ["unix"] } - -postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" } -postgres-shared = { version = "0.4.1", path = "../postgres-shared" } +tokio-postgres = { version = "0.7.13", path = "../tokio-postgres" } +tokio = { version = "1.0", features = ["rt", "time"] } [dev-dependencies] -hex = "0.3" -url = "1.0" - -bit-vec = "0.5" -chrono = "0.4" -eui48 = "0.3" -geo = "0.10" -serde_json = "1.0" -uuid = "0.6" +criterion = "0.6" diff --git a/postgres/LICENSE-APACHE b/postgres/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/postgres/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/postgres/LICENSE-MIT b/postgres/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/postgres/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/postgres/benches/bench.rs b/postgres/benches/bench.rs index 64e214069..474d83591 100644 --- a/postgres/benches/bench.rs +++ b/postgres/benches/bench.rs @@ -1,28 +1,17 @@ -#![feature(test)] -extern crate test; -extern crate postgres; +use criterion::{criterion_group, criterion_main, Criterion}; +use postgres::{Client, NoTls}; -use postgres::{Connection, TlsMode}; +// spawned: 249us 252us 255us +// local: 214us 216us 219us +fn query_prepared(c: &mut Criterion) { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); -#[bench] -fn bench_naiive_execute(b: &mut test::Bencher) { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[]) - .unwrap(); + let stmt = client.prepare("SELECT $1::INT8").unwrap(); - b.iter(|| { - let stmt = conn.prepare("UPDATE foo SET id = 1").unwrap(); - let out = stmt.execute(&[]).unwrap(); - stmt.finish().unwrap(); - out + c.bench_function("query_prepared", move |b| { + b.iter(|| client.query(&stmt, &[&1i64]).unwrap()) }); } -#[bench] -fn bench_execute(b: &mut test::Bencher) { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[]) - .unwrap(); - - b.iter(|| conn.execute("UPDATE foo SET id = 1", &[]).unwrap()); -} +criterion_group!(group, query_prepared); +criterion_main!(group); diff --git a/postgres/src/binary_copy.rs b/postgres/src/binary_copy.rs new file mode 100644 index 000000000..1c4eb7d3b --- /dev/null +++ b/postgres/src/binary_copy.rs @@ -0,0 +1,97 @@ +//! Utilities for working with the PostgreSQL binary copy format. + +use crate::connection::ConnectionRef; +use crate::types::{BorrowToSql, ToSql, Type}; +use crate::{CopyInWriter, CopyOutReader, Error}; +use fallible_iterator::FallibleIterator; +use futures_util::StreamExt; +use std::pin::Pin; +#[doc(inline)] +pub use tokio_postgres::binary_copy::BinaryCopyOutRow; +use tokio_postgres::binary_copy::{self, BinaryCopyOutStream}; + +/// A type which serializes rows into the PostgreSQL binary copy format. +/// +/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. +pub struct BinaryCopyInWriter<'a> { + connection: ConnectionRef<'a>, + sink: Pin>, +} + +impl<'a> BinaryCopyInWriter<'a> { + /// Creates a new writer which will write rows of the provided types. + pub fn new(writer: CopyInWriter<'a>, types: &[Type]) -> BinaryCopyInWriter<'a> { + let stream = writer + .sink + .into_unpinned() + .expect("writer has already been written to"); + + BinaryCopyInWriter { + connection: writer.connection, + sink: Box::pin(binary_copy::BinaryCopyInWriter::new(stream, types)), + } + } + + /// Writes a single row. + /// + /// # Panics + /// + /// Panics if the number of values provided does not match the number expected. + pub fn write(&mut self, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> { + self.connection.block_on(self.sink.as_mut().write(values)) + } + + /// A maximally-flexible version of `write`. + /// + /// # Panics + /// + /// Panics if the number of values provided does not match the number expected. + pub fn write_raw(&mut self, values: I) -> Result<(), Error> + where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + self.connection + .block_on(self.sink.as_mut().write_raw(values)) + } + + /// Completes the copy, returning the number of rows added. + /// + /// This method *must* be used to complete the copy process. If it is not, the copy will be aborted. + pub fn finish(mut self) -> Result { + self.connection.block_on(self.sink.as_mut().finish()) + } +} + +/// An iterator of rows deserialized from the PostgreSQL binary copy format. +pub struct BinaryCopyOutIter<'a> { + connection: ConnectionRef<'a>, + stream: Pin>, +} + +impl<'a> BinaryCopyOutIter<'a> { + /// Creates a new iterator from a raw copy out reader and the types of the columns being returned. + pub fn new(reader: CopyOutReader<'a>, types: &[Type]) -> BinaryCopyOutIter<'a> { + let stream = reader + .stream + .into_unpinned() + .expect("reader has already been read from"); + + BinaryCopyOutIter { + connection: reader.connection, + stream: Box::pin(BinaryCopyOutStream::new(stream, types)), + } + } +} + +impl FallibleIterator for BinaryCopyOutIter<'_> { + type Item = BinaryCopyOutRow; + type Error = Error; + + fn next(&mut self) -> Result, Error> { + let stream = &mut self.stream; + self.connection + .block_on(async { stream.next().await.transpose() }) + } +} diff --git a/postgres/src/cancel_token.rs b/postgres/src/cancel_token.rs new file mode 100644 index 000000000..be24edcc8 --- /dev/null +++ b/postgres/src/cancel_token.rs @@ -0,0 +1,35 @@ +use tokio::runtime; +use tokio_postgres::tls::MakeTlsConnect; +use tokio_postgres::{Error, Socket}; + +/// The capability to request cancellation of in-progress queries on a +/// connection. +#[derive(Clone)] +pub struct CancelToken(tokio_postgres::CancelToken); + +impl CancelToken { + pub(crate) fn new(inner: tokio_postgres::CancelToken) -> CancelToken { + CancelToken(inner) + } + + /// Attempts to cancel the in-progress query on the connection associated + /// with this `CancelToken`. + /// + /// The server provides no information about whether a cancellation attempt was successful or not. An error will + /// only be returned if the client was unable to connect to the database. + /// + /// Cancellation is inherently racy. There is no guarantee that the + /// cancellation request will reach the server before the query terminates + /// normally, or that the connection associated with this token is still + /// active. + pub fn cancel_query(&self, tls: T) -> Result<(), Error> + where + T: MakeTlsConnect, + { + runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() // FIXME don't unwrap + .block_on(self.0.cancel_query(tls)) + } +} diff --git a/postgres/src/client.rs b/postgres/src/client.rs new file mode 100644 index 000000000..42ce6dec9 --- /dev/null +++ b/postgres/src/client.rs @@ -0,0 +1,624 @@ +use crate::connection::Connection; +use crate::{ + CancelToken, Config, CopyInWriter, CopyOutReader, Notifications, RowIter, Statement, + ToStatement, Transaction, TransactionBuilder, +}; +use std::task::Poll; +use std::time::Duration; +use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; +use tokio_postgres::types::{BorrowToSql, ToSql, Type}; +use tokio_postgres::{Error, Row, SimpleQueryMessage, Socket}; + +/// A synchronous PostgreSQL client. +pub struct Client { + connection: Connection, + client: tokio_postgres::Client, +} + +impl Drop for Client { + fn drop(&mut self) { + let _ = self.close_inner(); + } +} + +impl Client { + pub(crate) fn new(connection: Connection, client: tokio_postgres::Client) -> Client { + Client { connection, client } + } + + /// A convenience function which parses a configuration string into a `Config` and then connects to the database. + /// + /// See the documentation for [`Config`] for information about the connection syntax. + /// + /// [`Config`]: config/struct.Config.html + pub fn connect(params: &str, tls_mode: T) -> Result + where + T: MakeTlsConnect + 'static + Send, + T::TlsConnect: Send, + T::Stream: Send, + >::Future: Send, + { + params.parse::()?.connect(tls_mode) + } + + /// Returns a new `Config` object which can be used to configure and connect to a database. + pub fn configure() -> Config { + Config::new() + } + + /// Executes a statement, returning the number of rows modified. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned. + /// + /// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Example + /// + /// ```no_run + /// use postgres::{Client, NoTls}; + /// + /// # fn main() -> Result<(), postgres::Error> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let bar = 1i32; + /// let baz = true; + /// let rows_updated = client.execute( + /// "UPDATE foo SET bar = $1 WHERE baz = $2", + /// &[&bar, &baz], + /// )?; + /// + /// println!("{} rows updated", rows_updated); + /// # Ok(()) + /// # } + /// ``` + pub fn execute(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement, + { + self.connection.block_on(self.client.execute(query, params)) + } + + /// Executes a statement, returning the resulting rows. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Examples + /// + /// ```no_run + /// use postgres::{Client, NoTls}; + /// + /// # fn main() -> Result<(), postgres::Error> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let baz = true; + /// for row in client.query("SELECT foo FROM bar WHERE baz = $1", &[&baz])? { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn query(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.connection.block_on(self.client.query(query, params)) + } + + /// Executes a statement which returns a single row, returning it. + /// + /// Returns an error if the query does not return exactly one row. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Examples + /// + /// ```no_run + /// use postgres::{Client, NoTls}; + /// + /// # fn main() -> Result<(), postgres::Error> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let baz = true; + /// let row = client.query_one("SELECT foo FROM bar WHERE baz = $1", &[&baz])?; + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// # Ok(()) + /// # } + /// ``` + pub fn query_one(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement, + { + self.connection + .block_on(self.client.query_one(query, params)) + } + + /// Executes a statement which returns zero or one rows, returning it. + /// + /// Returns an error if the query returns more than one row. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Examples + /// + /// ```no_run + /// use postgres::{Client, NoTls}; + /// + /// # fn main() -> Result<(), postgres::Error> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let baz = true; + /// let row = client.query_opt("SELECT foo FROM bar WHERE baz = $1", &[&baz])?; + /// match row { + /// Some(row) => { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// None => println!("no matching foo"), + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn query_opt( + &mut self, + query: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.connection + .block_on(self.client.query_opt(query, params)) + } + + /// A maximally-flexible version of `query`. + /// + /// It takes an iterator of parameters rather than a slice, and returns an iterator of rows rather than collecting + /// them into an array. + /// + /// # Examples + /// + /// ```no_run + /// use postgres::{Client, NoTls}; + /// use fallible_iterator::FallibleIterator; + /// use std::iter; + /// + /// # fn main() -> Result<(), postgres::Error> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let baz = true; + /// let mut it = client.query_raw("SELECT foo FROM bar WHERE baz = $1", iter::once(baz))?; + /// + /// while let Some(row) = it.next()? { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// # Ok(()) + /// # } + /// ``` + /// + /// If you have a type like `Vec` where `T: ToSql` Rust will not know how to use it as params. To get around + /// this the type must explicitly be converted to `&dyn ToSql`. + /// + /// ```no_run + /// # use postgres::{Client, NoTls}; + /// use postgres::types::ToSql; + /// use fallible_iterator::FallibleIterator; + /// # fn main() -> Result<(), postgres::Error> { + /// # let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let params: Vec = vec![ + /// "first param".into(), + /// "second param".into(), + /// ]; + /// let mut it = client.query_raw( + /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", + /// params, + /// )?; + /// + /// while let Some(row) = it.next()? { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn query_raw(&mut self, query: &T, params: I) -> Result, Error> + where + T: ?Sized + ToStatement, + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let stream = self + .connection + .block_on(self.client.query_raw(query, params))?; + Ok(RowIter::new(self.connection.as_ref(), stream)) + } + + /// Like `query`, but requires the types of query parameters to be explicitly specified. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. + pub fn query_typed( + &mut self, + query: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.connection + .block_on(self.client.query_typed(query, params)) + } + + /// The maximally flexible version of [`query_typed`]. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. + /// + /// [`query_typed`]: #method.query_typed + /// + /// # Examples + /// ```no_run + /// # use postgres::{Client, NoTls}; + /// use postgres::types::{ToSql, Type}; + /// use fallible_iterator::FallibleIterator; + /// # fn main() -> Result<(), postgres::Error> { + /// # let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let params: Vec<(String, Type)> = vec![ + /// ("first param".into(), Type::TEXT), + /// ("second param".into(), Type::TEXT), + /// ]; + /// let mut it = client.query_typed_raw( + /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", + /// params, + /// )?; + /// + /// while let Some(row) = it.next()? { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn query_typed_raw(&mut self, query: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator, + { + let stream = self + .connection + .block_on(self.client.query_typed_raw(query, params))?; + Ok(RowIter::new(self.connection.as_ref(), stream)) + } + + /// Creates a new prepared statement. + /// + /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc), + /// which are set when executed. Prepared statements can only be used with the connection that created them. + /// + /// # Examples + /// + /// ```no_run + /// use postgres::{Client, NoTls}; + /// + /// # fn main() -> Result<(), postgres::Error> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let statement = client.prepare("SELECT name FROM people WHERE id = $1")?; + /// + /// for id in 0..10 { + /// let rows = client.query(&statement, &[&id])?; + /// let name: &str = rows[0].get(0); + /// println!("name: {}", name); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn prepare(&mut self, query: &str) -> Result { + self.connection.block_on(self.client.prepare(query)) + } + + /// Like `prepare`, but allows the types of query parameters to be explicitly specified. + /// + /// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be + /// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`. + /// + /// # Examples + /// + /// ```no_run + /// use postgres::{Client, NoTls}; + /// use postgres::types::Type; + /// + /// # fn main() -> Result<(), postgres::Error> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let statement = client.prepare_typed( + /// "SELECT name FROM people WHERE id = $1", + /// &[Type::INT8], + /// )?; + /// + /// for id in 0..10 { + /// let rows = client.query(&statement, &[&id])?; + /// let name: &str = rows[0].get(0); + /// println!("name: {}", name); + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result { + self.connection + .block_on(self.client.prepare_typed(query, types)) + } + + /// Executes a `COPY FROM STDIN` statement, returning the number of rows created. + /// + /// The `query` argument can either be a `Statement`, or a raw query string. The data in the provided reader is + /// passed along to the server verbatim; it is the caller's responsibility to ensure it uses the proper format. + /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. + /// + /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. + /// + /// # Examples + /// + /// ```no_run + /// use postgres::{Client, NoTls}; + /// use std::io::Write; + /// + /// # fn main() -> Result<(), Box> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let mut writer = client.copy_in("COPY people FROM stdin")?; + /// writer.write_all(b"1\tjohn\n2\tjane\n")?; + /// writer.finish()?; + /// # Ok(()) + /// # } + /// ``` + pub fn copy_in(&mut self, query: &T) -> Result, Error> + where + T: ?Sized + ToStatement, + { + let sink = self.connection.block_on(self.client.copy_in(query))?; + Ok(CopyInWriter::new(self.connection.as_ref(), sink)) + } + + /// Executes a `COPY TO STDOUT` statement, returning a reader of the resulting data. + /// + /// The `query` argument can either be a `Statement`, or a raw query string. PostgreSQL does not support parameters + /// in `COPY` statements, so this method does not take any. + /// + /// # Examples + /// + /// ```no_run + /// use postgres::{Client, NoTls}; + /// use std::io::Read; + /// + /// # fn main() -> Result<(), Box> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let mut reader = client.copy_out("COPY people TO stdout")?; + /// let mut buf = vec![]; + /// reader.read_to_end(&mut buf)?; + /// # Ok(()) + /// # } + /// ``` + pub fn copy_out(&mut self, query: &T) -> Result, Error> + where + T: ?Sized + ToStatement, + { + let stream = self.connection.block_on(self.client.copy_out(query))?; + Ok(CopyOutReader::new(self.connection.as_ref(), stream)) + } + + /// Executes a sequence of SQL statements using the simple query protocol. + /// + /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that + /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings, + /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning the rows, this + /// method returns a sequence of an enum which indicates either the completion of one of the commands, or a row of + /// data. This preserves the framing between the separate statements in the request. + /// + /// This is a simple convenience method over `simple_query_iter`. + /// + /// # Warning + /// + /// Prepared statements should be used for any query which contains user-specified data, as they provided the + /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass + /// them to this method! + pub fn simple_query(&mut self, query: &str) -> Result, Error> { + self.connection.block_on(self.client.simple_query(query)) + } + + /// Validates the connection by performing a simple no-op query. + /// + /// If the specified timeout is reached before the backend responds, an error will be returned. + pub fn is_valid(&mut self, timeout: Duration) -> Result<(), Error> { + let inner_client = &self.client; + self.connection.block_on(async { + let trivial_query = inner_client.simple_query(""); + tokio::time::timeout(timeout, trivial_query) + .await + .map_err(|_| Error::__private_api_timeout())? + .map(|_| ()) + }) + } + + /// Executes a sequence of SQL statements using the simple query protocol. + /// + /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that + /// point. This is intended for use when, for example, initializing a database schema. + /// + /// # Warning + /// + /// Prepared statements should be use for any query which contains user-specified data, as they provided the + /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass + /// them to this method! + pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { + self.connection.block_on(self.client.batch_execute(query)) + } + + /// Begins a new database transaction. + /// + /// The transaction will roll back by default - use the `commit` method to commit it. + /// + /// # Examples + /// + /// ```no_run + /// use postgres::{Client, NoTls}; + /// + /// # fn main() -> Result<(), postgres::Error> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let mut transaction = client.transaction()?; + /// transaction.execute("UPDATE foo SET bar = 10", &[])?; + /// // ... + /// + /// transaction.commit()?; + /// # Ok(()) + /// # } + /// ``` + pub fn transaction(&mut self) -> Result, Error> { + let transaction = self.connection.block_on(self.client.transaction())?; + Ok(Transaction::new(self.connection.as_ref(), transaction)) + } + + /// Returns a builder for a transaction with custom settings. + /// + /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other + /// attributes. + /// + /// # Examples + /// + /// ```no_run + /// use postgres::{Client, IsolationLevel, NoTls}; + /// + /// # fn main() -> Result<(), postgres::Error> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let mut transaction = client.build_transaction() + /// .isolation_level(IsolationLevel::RepeatableRead) + /// .start()?; + /// transaction.execute("UPDATE foo SET bar = 10", &[])?; + /// // ... + /// + /// transaction.commit()?; + /// # Ok(()) + /// # } + /// ``` + pub fn build_transaction(&mut self) -> TransactionBuilder<'_> { + TransactionBuilder::new(self.connection.as_ref(), self.client.build_transaction()) + } + + /// Returns a structure providing access to asynchronous notifications. + /// + /// Use the `LISTEN` command to register this connection for notifications. + pub fn notifications(&mut self) -> Notifications<'_> { + Notifications::new(self.connection.as_ref()) + } + + /// Constructs a cancellation token that can later be used to request cancellation of a query running on this + /// connection. + /// + /// # Examples + /// + /// ```no_run + /// use postgres::{Client, NoTls}; + /// use postgres::error::SqlState; + /// use std::thread; + /// use std::time::Duration; + /// + /// # fn main() -> Result<(), Box> { + /// let mut client = Client::connect("host=localhost user=postgres", NoTls)?; + /// + /// let cancel_token = client.cancel_token(); + /// + /// thread::spawn(move || { + /// // Abort the query after 5s. + /// thread::sleep(Duration::from_secs(5)); + /// let _ = cancel_token.cancel_query(NoTls); + /// }); + /// + /// match client.simple_query("SELECT long_running_query()") { + /// Err(e) if e.code() == Some(&SqlState::QUERY_CANCELED) => { + /// // Handle canceled query. + /// } + /// Err(err) => return Err(err.into()), + /// Ok(rows) => { + /// // ... + /// } + /// } + /// // ... + /// + /// # Ok(()) + /// # } + /// ``` + pub fn cancel_token(&self) -> CancelToken { + CancelToken::new(self.client.cancel_token()) + } + + /// Clears the client's type information cache. + /// + /// When user-defined types are used in a query, the client loads their definitions from the database and caches + /// them for the lifetime of the client. If those definitions are changed in the database, this method can be used + /// to flush the local cache and allow the new, updated definitions to be loaded. + pub fn clear_type_cache(&self) { + self.client.clear_type_cache(); + } + + /// Determines if the client's connection has already closed. + /// + /// If this returns `true`, the client is no longer usable. + pub fn is_closed(&self) -> bool { + self.client.is_closed() + } + + /// Closes the client's connection to the server. + /// + /// This is equivalent to `Client`'s `Drop` implementation, except that it returns any error encountered to the + /// caller. + pub fn close(mut self) -> Result<(), Error> { + self.close_inner() + } + + fn close_inner(&mut self) -> Result<(), Error> { + self.client.__private_api_close(); + + self.connection.poll_block_on(|_, _, done| { + if done { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + }) + } +} diff --git a/postgres/src/config.rs b/postgres/src/config.rs new file mode 100644 index 000000000..c7f932ba7 --- /dev/null +++ b/postgres/src/config.rs @@ -0,0 +1,489 @@ +//! Connection configuration. + +#![allow(clippy::doc_overindented_list_items)] + +use crate::connection::Connection; +use crate::Client; +use log::info; +use std::fmt; +use std::net::IpAddr; +use std::path::Path; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; +use tokio::runtime; +#[doc(inline)] +pub use tokio_postgres::config::{ + ChannelBinding, Host, LoadBalanceHosts, SslMode, SslNegotiation, TargetSessionAttrs, +}; +use tokio_postgres::error::DbError; +use tokio_postgres::tls::{MakeTlsConnect, TlsConnect}; +use tokio_postgres::{Error, Socket}; + +/// Connection configuration. +/// +/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: +/// +/// # Key-Value +/// +/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain +/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped. +/// +/// ## Keys +/// +/// * `user` - The username to authenticate with. Defaults to the user executing this process. +/// * `password` - The password to authenticate with. +/// * `dbname` - The name of the database to connect to. Defaults to the username. +/// * `options` - Command line options used to configure the server. +/// * `application_name` - Sets the `application_name` parameter on the server. +/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used +/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`. +/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the +/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts +/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting +/// with the `connect` method. +/// * `sslnegotiation` - TLS negotiation method. If set to `direct`, the client will perform direct TLS handshake, this only works for PostgreSQL 17 and newer. +/// Note that you will need to setup ALPN of TLS client configuration to `postgresql` when using direct TLS. +/// If set to `postgres`, the default value, it follows original postgres wire protocol to perform the negotiation. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for TLS certificate verification. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. +/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be +/// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if +/// omitted or the empty string. +/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames +/// can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout. +/// * `tcp_user_timeout` - The time limit that transmitted data may remain unacknowledged before a connection is forcibly closed. +/// This is ignored for Unix domain socket connections. It is only supported on systems where TCP_USER_TIMEOUT is available +/// and will default to the system default if omitted or set to 0; on other systems, it has no effect. +/// * `keepalives` - Controls the use of TCP keepalive. A value of 0 disables keepalive and nonzero integers enable it. +/// This option is ignored when connecting with Unix sockets. Defaults to on. +/// * `keepalives_idle` - The number of seconds of inactivity after which a keepalive message is sent to the server. +/// This option is ignored when connecting with Unix sockets. Defaults to 2 hours. +/// * `keepalives_interval` - The time interval between TCP keepalive probes. +/// This option is ignored when connecting with Unix sockets. +/// * `keepalives_retries` - The maximum number of TCP keepalive probes that will be sent before dropping a connection. +/// This option is ignored when connecting with Unix sockets. +/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that +/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server +/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`. +/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel +/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. +/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and +/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter +/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to +/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried +/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults +/// to `disable`. +/// +/// ## Examples +/// +/// ```not_rust +/// host=localhost user=postgres connect_timeout=10 keepalives=0 +/// ``` +/// +/// ```not_rust +/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces' +/// ``` +/// +/// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write +/// ``` +/// +/// # Url +/// +/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional, +/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple +/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded, +/// as the path component of the URL specifies the database name. +/// +/// ## Examples +/// +/// ```not_rust +/// postgresql://user@localhost +/// ``` +/// +/// ```not_rust +/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10 +/// ``` +/// +/// ```not_rust +/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write +/// ``` +/// +/// ```not_rust +/// postgresql:///mydb?user=user&host=/var/lib/postgresql +/// ``` +#[derive(Clone)] +pub struct Config { + config: tokio_postgres::Config, + notice_callback: Arc, +} + +impl fmt::Debug for Config { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Config") + .field("config", &self.config) + .finish() + } +} + +impl Default for Config { + fn default() -> Config { + Config::new() + } +} + +impl Config { + /// Creates a new configuration. + pub fn new() -> Config { + tokio_postgres::Config::new().into() + } + + /// Sets the user to authenticate with. + /// + /// If the user is not set, then this defaults to the user executing this process. + pub fn user(&mut self, user: &str) -> &mut Config { + self.config.user(user); + self + } + + /// Gets the user to authenticate with, if one has been configured with + /// the `user` method. + pub fn get_user(&self) -> Option<&str> { + self.config.get_user() + } + + /// Sets the password to authenticate with. + pub fn password(&mut self, password: T) -> &mut Config + where + T: AsRef<[u8]>, + { + self.config.password(password); + self + } + + /// Gets the password to authenticate with, if one has been configured with + /// the `password` method. + pub fn get_password(&self) -> Option<&[u8]> { + self.config.get_password() + } + + /// Sets the name of the database to connect to. + /// + /// Defaults to the user. + pub fn dbname(&mut self, dbname: &str) -> &mut Config { + self.config.dbname(dbname); + self + } + + /// Gets the name of the database to connect to, if one has been configured + /// with the `dbname` method. + pub fn get_dbname(&self) -> Option<&str> { + self.config.get_dbname() + } + + /// Sets command line options used to configure the server. + pub fn options(&mut self, options: &str) -> &mut Config { + self.config.options(options); + self + } + + /// Gets the command line options used to configure the server, if the + /// options have been set with the `options` method. + pub fn get_options(&self) -> Option<&str> { + self.config.get_options() + } + + /// Sets the value of the `application_name` runtime parameter. + pub fn application_name(&mut self, application_name: &str) -> &mut Config { + self.config.application_name(application_name); + self + } + + /// Gets the value of the `application_name` runtime parameter, if it has + /// been set with the `application_name` method. + pub fn get_application_name(&self) -> Option<&str> { + self.config.get_application_name() + } + + /// Sets the SSL configuration. + /// + /// Defaults to `prefer`. + pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config { + self.config.ssl_mode(ssl_mode); + self + } + + /// Gets the SSL configuration. + pub fn get_ssl_mode(&self) -> SslMode { + self.config.get_ssl_mode() + } + + /// Sets the SSL negotiation method + pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config { + self.config.ssl_negotiation(ssl_negotiation); + self + } + + /// Gets the SSL negotiation method + pub fn get_ssl_negotiation(&self) -> SslNegotiation { + self.config.get_ssl_negotiation() + } + + /// Adds a host to the configuration. + /// + /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix + /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. + pub fn host(&mut self, host: &str) -> &mut Config { + self.config.host(host); + self + } + + /// Gets the hosts that have been added to the configuration with `host`. + pub fn get_hosts(&self) -> &[Host] { + self.config.get_hosts() + } + + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[IpAddr] { + self.config.get_hostaddrs() + } + + /// Adds a Unix socket host to the configuration. + /// + /// Unlike `host`, this method allows non-UTF8 paths. + #[cfg(unix)] + pub fn host_path(&mut self, host: T) -> &mut Config + where + T: AsRef, + { + self.config.host_path(host); + self + } + + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.config.hostaddr(hostaddr); + self + } + + /// Adds a port to the configuration. + /// + /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which + /// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports + /// as hosts. + pub fn port(&mut self, port: u16) -> &mut Config { + self.config.port(port); + self + } + + /// Gets the ports that have been added to the configuration with `port`. + pub fn get_ports(&self) -> &[u16] { + self.config.get_ports() + } + + /// Sets the timeout applied to socket-level connection attempts. + /// + /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each + /// host separately. Defaults to no limit. + pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config { + self.config.connect_timeout(connect_timeout); + self + } + + /// Gets the connection timeout, if one has been set with the + /// `connect_timeout` method. + pub fn get_connect_timeout(&self) -> Option<&Duration> { + self.config.get_connect_timeout() + } + + /// Sets the TCP user timeout. + /// + /// This is ignored for Unix domain socket connections. It is only supported on systems where + /// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0; + /// on other systems, it has no effect. + pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config { + self.config.tcp_user_timeout(tcp_user_timeout); + self + } + + /// Gets the TCP user timeout, if one has been set with the + /// `user_timeout` method. + pub fn get_tcp_user_timeout(&self) -> Option<&Duration> { + self.config.get_tcp_user_timeout() + } + + /// Controls the use of TCP keepalive. + /// + /// This is ignored for Unix domain socket connections. Defaults to `true`. + pub fn keepalives(&mut self, keepalives: bool) -> &mut Config { + self.config.keepalives(keepalives); + self + } + + /// Reports whether TCP keepalives will be used. + pub fn get_keepalives(&self) -> bool { + self.config.get_keepalives() + } + + /// Sets the amount of idle time before a keepalive packet is sent on the connection. + /// + /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours. + pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config { + self.config.keepalives_idle(keepalives_idle); + self + } + + /// Gets the configured amount of idle time before a keepalive packet will + /// be sent on the connection. + pub fn get_keepalives_idle(&self) -> Duration { + self.config.get_keepalives_idle() + } + + /// Sets the time interval between TCP keepalive probes. + /// On Windows, this sets the value of the tcp_keepalive struct’s keepaliveinterval field. + /// + /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config { + self.config.keepalives_interval(keepalives_interval); + self + } + + /// Gets the time interval between TCP keepalive probes. + pub fn get_keepalives_interval(&self) -> Option { + self.config.get_keepalives_interval() + } + + /// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection. + /// + /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config { + self.config.keepalives_retries(keepalives_retries); + self + } + + /// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection. + pub fn get_keepalives_retries(&self) -> Option { + self.config.get_keepalives_retries() + } + + /// Sets the requirements of the session. + /// + /// This can be used to connect to the primary server in a clustered database rather than one of the read-only + /// secondary servers. Defaults to `Any`. + pub fn target_session_attrs( + &mut self, + target_session_attrs: TargetSessionAttrs, + ) -> &mut Config { + self.config.target_session_attrs(target_session_attrs); + self + } + + /// Gets the requirements of the session. + pub fn get_target_session_attrs(&self) -> TargetSessionAttrs { + self.config.get_target_session_attrs() + } + + /// Sets the channel binding behavior. + /// + /// Defaults to `prefer`. + pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config { + self.config.channel_binding(channel_binding); + self + } + + /// Gets the channel binding behavior. + pub fn get_channel_binding(&self) -> ChannelBinding { + self.config.get_channel_binding() + } + + /// Sets the host load balancing behavior. + /// + /// Defaults to `disable`. + pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config { + self.config.load_balance_hosts(load_balance_hosts); + self + } + + /// Gets the host load balancing behavior. + pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts { + self.config.get_load_balance_hosts() + } + + /// Sets the notice callback. + /// + /// This callback will be invoked with the contents of every + /// [`AsyncMessage::Notice`] that is received by the connection. Notices use + /// the same structure as errors, but they are not "errors" per-se. + /// + /// Notices are distinct from notifications, which are instead accessible + /// via the [`Notifications`] API. + /// + /// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice + /// [`Notifications`]: crate::Notifications + pub fn notice_callback(&mut self, f: F) -> &mut Config + where + F: Fn(DbError) + Send + Sync + 'static, + { + self.notice_callback = Arc::new(f); + self + } + + /// Opens a connection to a PostgreSQL database. + pub fn connect(&self, tls: T) -> Result + where + T: MakeTlsConnect + 'static + Send, + T::TlsConnect: Send, + T::Stream: Send, + >::Future: Send, + { + let runtime = runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); // FIXME don't unwrap + + let (client, connection) = runtime.block_on(self.config.connect(tls))?; + + let connection = Connection::new(runtime, connection, self.notice_callback.clone()); + Ok(Client::new(connection, client)) + } +} + +impl FromStr for Config { + type Err = Error; + + fn from_str(s: &str) -> Result { + s.parse::().map(Config::from) + } +} + +impl From for Config { + fn from(config: tokio_postgres::Config) -> Config { + Config { + config, + notice_callback: Arc::new(|notice| { + info!("{}: {}", notice.severity(), notice.message()) + }), + } + } +} diff --git a/postgres/src/connection.rs b/postgres/src/connection.rs new file mode 100644 index 000000000..b91c16555 --- /dev/null +++ b/postgres/src/connection.rs @@ -0,0 +1,137 @@ +use crate::{Error, Notification}; +use futures_util::{future, pin_mut, Stream}; +use std::collections::VecDeque; +use std::future::Future; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::runtime::Runtime; +use tokio_postgres::error::DbError; +use tokio_postgres::AsyncMessage; + +pub struct Connection { + runtime: Runtime, + connection: Pin> + Send>>, + notifications: VecDeque, + notice_callback: Arc, +} + +impl Connection { + pub fn new( + runtime: Runtime, + connection: tokio_postgres::Connection, + notice_callback: Arc, + ) -> Connection + where + S: AsyncRead + AsyncWrite + Unpin + 'static + Send, + T: AsyncRead + AsyncWrite + Unpin + 'static + Send, + { + Connection { + runtime, + connection: Box::pin(ConnectionStream { connection }), + notifications: VecDeque::new(), + notice_callback, + } + } + + pub fn as_ref(&mut self) -> ConnectionRef<'_> { + ConnectionRef { connection: self } + } + + pub fn enter(&self, f: F) -> T + where + F: FnOnce() -> T, + { + let _guard = self.runtime.enter(); + f() + } + + pub fn block_on(&mut self, future: F) -> Result + where + F: Future>, + { + pin_mut!(future); + self.poll_block_on(|cx, _, _| future.as_mut().poll(cx)) + } + + pub fn poll_block_on(&mut self, mut f: F) -> Result + where + F: FnMut(&mut Context<'_>, &mut VecDeque, bool) -> Poll>, + { + let connection = &mut self.connection; + let notifications = &mut self.notifications; + let notice_callback = &mut self.notice_callback; + self.runtime.block_on({ + future::poll_fn(|cx| { + let done = loop { + match connection.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(AsyncMessage::Notification(notification)))) => { + notifications.push_back(notification); + } + Poll::Ready(Some(Ok(AsyncMessage::Notice(notice)))) => { + notice_callback(notice) + } + Poll::Ready(Some(Ok(_))) => {} + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)), + Poll::Ready(None) => break true, + Poll::Pending => break false, + } + }; + + f(cx, notifications, done) + }) + }) + } + + pub fn notifications(&self) -> &VecDeque { + &self.notifications + } + + pub fn notifications_mut(&mut self) -> &mut VecDeque { + &mut self.notifications + } +} + +pub struct ConnectionRef<'a> { + connection: &'a mut Connection, +} + +// no-op impl to extend the borrow until drop +impl Drop for ConnectionRef<'_> { + #[inline] + fn drop(&mut self) {} +} + +impl Deref for ConnectionRef<'_> { + type Target = Connection; + + #[inline] + fn deref(&self) -> &Connection { + self.connection + } +} + +impl DerefMut for ConnectionRef<'_> { + #[inline] + fn deref_mut(&mut self) -> &mut Connection { + self.connection + } +} + +struct ConnectionStream { + connection: tokio_postgres::Connection, +} + +impl Stream for ConnectionStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.connection.poll_message(cx) + } +} diff --git a/postgres/src/copy_in_writer.rs b/postgres/src/copy_in_writer.rs new file mode 100644 index 000000000..83c642c73 --- /dev/null +++ b/postgres/src/copy_in_writer.rs @@ -0,0 +1,58 @@ +use crate::connection::ConnectionRef; +use crate::lazy_pin::LazyPin; +use bytes::{Bytes, BytesMut}; +use futures_util::SinkExt; +use std::io; +use std::io::Write; +use tokio_postgres::{CopyInSink, Error}; + +/// The writer returned by the `copy_in` method. +/// +/// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. +pub struct CopyInWriter<'a> { + pub(crate) connection: ConnectionRef<'a>, + pub(crate) sink: LazyPin>, + buf: BytesMut, +} + +impl<'a> CopyInWriter<'a> { + pub(crate) fn new(connection: ConnectionRef<'a>, sink: CopyInSink) -> CopyInWriter<'a> { + CopyInWriter { + connection, + sink: LazyPin::new(sink), + buf: BytesMut::new(), + } + } + + /// Completes the copy, returning the number of rows written. + /// + /// If this is not called, the copy will be aborted. + pub fn finish(mut self) -> Result { + self.flush_inner()?; + self.connection.block_on(self.sink.pinned().finish()) + } + + fn flush_inner(&mut self) -> Result<(), Error> { + if self.buf.is_empty() { + return Ok(()); + } + + self.connection + .block_on(self.sink.pinned().send(self.buf.split().freeze())) + } +} + +impl Write for CopyInWriter<'_> { + fn write(&mut self, buf: &[u8]) -> io::Result { + if self.buf.len() > 4096 { + self.flush()?; + } + + self.buf.extend_from_slice(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + self.flush_inner().map_err(io::Error::other) + } +} diff --git a/postgres/src/copy_out_reader.rs b/postgres/src/copy_out_reader.rs new file mode 100644 index 000000000..b683ddeec --- /dev/null +++ b/postgres/src/copy_out_reader.rs @@ -0,0 +1,55 @@ +use crate::connection::ConnectionRef; +use crate::lazy_pin::LazyPin; +use bytes::{Buf, Bytes}; +use futures_util::StreamExt; +use std::io::{self, BufRead, Read}; +use tokio_postgres::CopyOutStream; + +/// The reader returned by the `copy_out` method. +pub struct CopyOutReader<'a> { + pub(crate) connection: ConnectionRef<'a>, + pub(crate) stream: LazyPin, + cur: Bytes, +} + +impl<'a> CopyOutReader<'a> { + pub(crate) fn new(connection: ConnectionRef<'a>, stream: CopyOutStream) -> CopyOutReader<'a> { + CopyOutReader { + connection, + stream: LazyPin::new(stream), + cur: Bytes::new(), + } + } +} + +impl Read for CopyOutReader<'_> { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let b = self.fill_buf()?; + let len = usize::min(buf.len(), b.len()); + buf[..len].copy_from_slice(&b[..len]); + self.consume(len); + Ok(len) + } +} + +impl BufRead for CopyOutReader<'_> { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + while !self.cur.has_remaining() { + let mut stream = self.stream.pinned(); + match self + .connection + .block_on(async { stream.next().await.transpose() }) + { + Ok(Some(cur)) => self.cur = cur, + Err(e) => return Err(io::Error::other(e)), + Ok(None) => break, + }; + } + + Ok(&self.cur) + } + + fn consume(&mut self, amt: usize) { + self.cur.advance(amt); + } +} diff --git a/postgres/src/generic_client.rs b/postgres/src/generic_client.rs new file mode 100644 index 000000000..7b534867c --- /dev/null +++ b/postgres/src/generic_client.rs @@ -0,0 +1,276 @@ +use crate::types::{BorrowToSql, ToSql, Type}; +use crate::{ + Client, CopyInWriter, CopyOutReader, Error, Row, RowIter, SimpleQueryMessage, Statement, + ToStatement, Transaction, +}; + +mod private { + pub trait Sealed {} +} + +/// A trait allowing abstraction over connections and transactions. +/// +/// This trait is "sealed", and cannot be implemented outside of this crate. +pub trait GenericClient: private::Sealed { + /// Like `Client::execute`. + fn execute(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement; + + /// Like `Client::query`. + fn query(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> + where + T: ?Sized + ToStatement; + + /// Like `Client::query_one`. + fn query_one(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement; + + /// Like `Client::query_opt`. + fn query_opt( + &mut self, + query: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement; + + /// Like `Client::query_raw`. + fn query_raw(&mut self, query: &T, params: I) -> Result, Error> + where + T: ?Sized + ToStatement, + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator; + + /// Like [`Client::query_typed`] + fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error>; + + /// Like [`Client::query_typed_raw`] + fn query_typed_raw(&mut self, statement: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator + Sync + Send; + + /// Like `Client::prepare`. + fn prepare(&mut self, query: &str) -> Result; + + /// Like `Client::prepare_typed`. + fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result; + + /// Like `Client::copy_in`. + fn copy_in(&mut self, query: &T) -> Result, Error> + where + T: ?Sized + ToStatement; + + /// Like `Client::copy_out`. + fn copy_out(&mut self, query: &T) -> Result, Error> + where + T: ?Sized + ToStatement; + + /// Like `Client::simple_query`. + fn simple_query(&mut self, query: &str) -> Result, Error>; + + /// Like `Client::batch_execute`. + fn batch_execute(&mut self, query: &str) -> Result<(), Error>; + + /// Like `Client::transaction`. + fn transaction(&mut self) -> Result, Error>; +} + +impl private::Sealed for Client {} + +impl GenericClient for Client { + fn execute(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement, + { + self.execute(query, params) + } + + fn query(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.query(query, params) + } + + fn query_one(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement, + { + self.query_one(query, params) + } + + fn query_opt( + &mut self, + query: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.query_opt(query, params) + } + + fn query_raw(&mut self, query: &T, params: I) -> Result, Error> + where + T: ?Sized + ToStatement, + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + self.query_raw(query, params) + } + + fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed(statement, params) + } + + fn query_typed_raw(&mut self, statement: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params) + } + + fn prepare(&mut self, query: &str) -> Result { + self.prepare(query) + } + + fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result { + self.prepare_typed(query, types) + } + + fn copy_in(&mut self, query: &T) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.copy_in(query) + } + + fn copy_out(&mut self, query: &T) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.copy_out(query) + } + + fn simple_query(&mut self, query: &str) -> Result, Error> { + self.simple_query(query) + } + + fn batch_execute(&mut self, query: &str) -> Result<(), Error> { + self.batch_execute(query) + } + + fn transaction(&mut self) -> Result, Error> { + self.transaction() + } +} + +impl private::Sealed for Transaction<'_> {} + +impl GenericClient for Transaction<'_> { + fn execute(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement, + { + self.execute(query, params) + } + + fn query(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.query(query, params) + } + + fn query_one(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement, + { + self.query_one(query, params) + } + + fn query_opt( + &mut self, + query: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.query_opt(query, params) + } + + fn query_raw(&mut self, query: &T, params: I) -> Result, Error> + where + T: ?Sized + ToStatement, + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + self.query_raw(query, params) + } + + fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed(statement, params) + } + + fn query_typed_raw(&mut self, statement: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params) + } + + fn prepare(&mut self, query: &str) -> Result { + self.prepare(query) + } + + fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result { + self.prepare_typed(query, types) + } + + fn copy_in(&mut self, query: &T) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.copy_in(query) + } + + fn copy_out(&mut self, query: &T) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.copy_out(query) + } + + fn simple_query(&mut self, query: &str) -> Result, Error> { + self.simple_query(query) + } + + fn batch_execute(&mut self, query: &str) -> Result<(), Error> { + self.batch_execute(query) + } + + fn transaction(&mut self) -> Result, Error> { + self.transaction() + } +} diff --git a/postgres/src/lazy_pin.rs b/postgres/src/lazy_pin.rs new file mode 100644 index 000000000..a18b58b84 --- /dev/null +++ b/postgres/src/lazy_pin.rs @@ -0,0 +1,28 @@ +use std::pin::Pin; + +pub(crate) struct LazyPin { + value: Box, + pinned: bool, +} + +impl LazyPin { + pub fn new(value: T) -> LazyPin { + LazyPin { + value: Box::new(value), + pinned: false, + } + } + + pub fn pinned(&mut self) -> Pin<&mut T> { + self.pinned = true; + unsafe { Pin::new_unchecked(&mut *self.value) } + } + + pub fn into_unpinned(self) -> Option { + if self.pinned { + None + } else { + Some(*self.value) + } + } +} diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index d6046015a..ddf1609ad 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -1,1566 +1,108 @@ -//! A pure-Rust frontend for the popular PostgreSQL database. +//! A synchronous client for the PostgreSQL database. //! -//! ```rust,no_run -//! extern crate postgres; +//! # Example //! -//! use postgres::{Connection, TlsMode}; +//! ```no_run +//! use postgres::{Client, NoTls}; //! -//! struct Person { -//! id: i32, -//! name: String, -//! data: Option> -//! } +//! # fn main() -> Result<(), postgres::Error> { +//! let mut client = Client::connect("host=localhost user=postgres", NoTls)?; +//! +//! client.batch_execute(" +//! CREATE TABLE person ( +//! id SERIAL PRIMARY KEY, +//! name TEXT NOT NULL, +//! data BYTEA +//! ) +//! ")?; //! -//! fn main() { -//! let conn = Connection::connect("postgresql://postgres@localhost:5433", TlsMode::None) -//! .unwrap(); +//! let name = "Ferris"; +//! let data = None::<&[u8]>; +//! client.execute( +//! "INSERT INTO person (name, data) VALUES ($1, $2)", +//! &[&name, &data], +//! )?; //! -//! conn.execute("CREATE TABLE person ( -//! id SERIAL PRIMARY KEY, -//! name VARCHAR NOT NULL, -//! data BYTEA -//! )", &[]).unwrap(); -//! let me = Person { -//! id: 0, -//! name: "Steven".to_owned(), -//! data: None -//! }; -//! conn.execute("INSERT INTO person (name, data) VALUES ($1, $2)", -//! &[&me.name, &me.data]).unwrap(); +//! for row in client.query("SELECT id, name, data FROM person", &[])? { +//! let id: i32 = row.get(0); +//! let name: &str = row.get(1); +//! let data: Option<&[u8]> = row.get(2); //! -//! for row in &conn.query("SELECT id, name, data FROM person", &[]).unwrap() { -//! let person = Person { -//! id: row.get(0), -//! name: row.get(1), -//! data: row.get(2) -//! }; -//! println!("Found person {}", person.name); -//! } +//! println!("found person: {} {} {:?}", id, name, data); //! } +//! # Ok(()) +//! # } //! ``` //! -//! # SSL/TLS +//! # Implementation //! -//! This crate supports TLS secured connections. The `TlsMode` enum is passed to connection methods -//! and indicates if the connection will not, may, or must be secured by TLS. The TLS implementation -//! is pluggable through the `TlsHandshake` trait. Implementations for OpenSSL, Secure Transport, -//! SChannel, and the `native-tls` crate are provided behind the `with-openssl`, -//! `with-security-framework`, `with-schannel`, and `with-native-tls` feature flags respectively. +//! This crate is a lightweight wrapper over tokio-postgres. The `postgres::Client` is simply a wrapper around a +//! `tokio_postgres::Client` along side a tokio `Runtime`. The client simply blocks on the futures provided by the async +//! client. //! -//! ## Examples +//! # SSL/TLS support //! -//! Connecting using `native-tls`: +//! TLS support is implemented via external libraries. `Client::connect` and `Config::connect` take a TLS implementation +//! as an argument. The `NoTls` type in this crate can be used when TLS is not required. Otherwise, the +//! `postgres-openssl` and `postgres-native-tls` crates provide implementations backed by the `openssl` and `native-tls` +//! crates, respectively. //! -//! ```no_run -//! extern crate postgres; +//! # Features //! -//! use postgres::{Connection, TlsMode}; -//! # #[cfg(feature = "with-native-tls")] -//! use postgres::tls::native_tls::NativeTls; +//! The following features can be enabled from `Cargo.toml`: //! -//! # #[cfg(not(feature = "with-native-tls"))] fn main() {} -//! # #[cfg(feature = "with-native-tls")] -//! fn main() { -//! let negotiator = NativeTls::new().unwrap(); -//! let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::Require(&negotiator)) -//! .unwrap(); -//! } -//! ``` -#![doc(html_root_url = "https://docs.rs/postgres/0.15.1")] -#![warn(missing_docs)] -#![allow(unknown_lints, needless_lifetimes, doc_markdown)] // for clippy - -extern crate bytes; -extern crate fallible_iterator; -#[cfg(not(feature = "no-logging"))] -#[macro_use] -extern crate log; -extern crate postgres_protocol; -extern crate postgres_shared; -extern crate socket2; - -use fallible_iterator::FallibleIterator; -use postgres_protocol::authentication; -use postgres_protocol::authentication::sasl::{self, ChannelBinding, ScramSha256}; -use postgres_protocol::message::backend::{self, ErrorFields}; -use postgres_protocol::message::frontend; -use postgres_shared::rows::RowData; -use std::cell::{Cell, RefCell}; -use std::collections::{HashMap, VecDeque}; -use std::fmt; -use std::io; -use std::mem; -use std::result; -use std::sync::Arc; -use std::time::Duration; - -use error::{DbError, SqlState}; -use notification::{Notification, Notifications}; -use params::{IntoConnectParams, User}; -use priv_io::MessageStream; -use rows::Rows; -use stmt::{Column, Statement}; -use text_rows::TextRows; -use tls::TlsHandshake; -use transaction::{IsolationLevel, Transaction}; -use types::{Field, FromSql, IsNull, Kind, Oid, ToSql, Type}; - +//! | Feature | Description | Extra dependencies | Default | +//! | ------- | ----------- | ------------------ | ------- | +//! | `with-bit-vec-0_6` | Enable support for the `bit-vec` crate. | [bit-vec](https://crates.io/crates/bit-vec) 0.6 | no | +//! | `with-chrono-0_4` | Enable support for the `chrono` crate. | [chrono](https://crates.io/crates/chrono) 0.4 | no | +//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. This is deprecated and will be removed. | [eui48](https://crates.io/crates/eui48) 0.4 | no | +//! | `with-eui48-1` | Enable support for the 1.0 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 1.0 | no | +//! | `with-geo-types-0_6` | Enable support for the 0.6 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.6.0) 0.6 | no | +//! | `with-geo-types-0_7` | Enable support for the 0.7 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.7.0) 0.7 | no | +//! | `with-serde_json-1` | Enable support for the `serde_json` crate. | [serde_json](https://crates.io/crates/serde_json) 1.0 | no | +//! | `with-uuid-0_8` | Enable support for the `uuid` crate. | [uuid](https://crates.io/crates/uuid) 0.8 | no | +//! | `with-uuid-1` | Enable support for the `uuid` crate. | [uuid](https://crates.io/crates/uuid) 1.0 | no | +//! | `with-time-0_2` | Enable support for the 0.2 version of the `time` crate. | [time](https://crates.io/crates/time/0.2.0) 0.2 | no | +//! | `with-time-0_3` | Enable support for the 0.3 version of the `time` crate. | [time](https://crates.io/crates/time/0.3.0) 0.3 | no | +#![warn(clippy::all, rust_2018_idioms, missing_docs)] + +pub use fallible_iterator; +pub use tokio_postgres::{ + error, row, tls, types, Column, IsolationLevel, Notification, Portal, SimpleQueryMessage, + Socket, Statement, ToStatement, +}; + +pub use crate::cancel_token::CancelToken; +pub use crate::client::*; +pub use crate::config::Config; +pub use crate::copy_in_writer::CopyInWriter; +pub use crate::copy_out_reader::CopyOutReader; +#[doc(no_inline)] +pub use crate::error::Error; +pub use crate::generic_client::GenericClient; #[doc(inline)] -pub use error::Error; -#[doc(inline)] -pub use postgres_shared::CancelData; -#[doc(inline)] -pub use postgres_shared::{error, types}; - -#[macro_use] -mod macros; - -pub mod notification; -pub mod params; -mod priv_io; -pub mod rows; -pub mod stmt; -pub mod text_rows; -pub mod tls; -pub mod transaction; - -const TYPEINFO_QUERY: &'static str = "__typeinfo"; -const TYPEINFO_ENUM_QUERY: &'static str = "__typeinfo_enum"; -const TYPEINFO_COMPOSITE_QUERY: &'static str = "__typeinfo_composite"; - -/// A type alias of the result returned by many methods. -pub type Result = result::Result; - -/// A trait implemented by types that can handle Postgres notice messages. -/// -/// It is implemented for all `Send + FnMut(DbError)` closures. -pub trait HandleNotice: Send { - /// Handle a Postgres notice message - fn handle_notice(&mut self, notice: DbError); -} - -impl HandleNotice for F { - fn handle_notice(&mut self, notice: DbError) { - self(notice) - } -} - -/// A notice handler which logs at the `info` level. -/// -/// This is the default handler used by a `Connection`. -#[derive(Copy, Clone, Debug)] -pub struct LoggingNoticeHandler; - -impl HandleNotice for LoggingNoticeHandler { - fn handle_notice(&mut self, _notice: DbError) { - info!("{}: {}", _notice.severity, _notice.message); - } -} - -/// Attempts to cancel an in-progress query. -/// -/// The backend provides no information about whether a cancellation attempt -/// was successful or not. An error will only be returned if the driver was -/// unable to connect to the database. -/// -/// A `CancelData` object can be created via `Connection::cancel_data`. The -/// object can cancel any query made on that connection. -/// -/// Only the host and port of the connection info are used. See -/// `Connection::connect` for details of the `params` argument. -/// -/// # Example -/// -/// ```rust,no_run -/// # use postgres::{Connection, TlsMode}; -/// # use std::thread; -/// # let url = ""; -/// let conn = Connection::connect(url, TlsMode::None).unwrap(); -/// let cancel_data = conn.cancel_data(); -/// thread::spawn(move || { -/// conn.execute("SOME EXPENSIVE QUERY", &[]).unwrap(); -/// }); -/// postgres::cancel_query(url, TlsMode::None, &cancel_data).unwrap(); -/// ``` -pub fn cancel_query(params: T, tls: TlsMode, data: &CancelData) -> Result<()> -where - T: IntoConnectParams, -{ - let params = params.into_connect_params().map_err(error::connect)?; - let mut socket = priv_io::initialize_stream(¶ms, tls)?; - - let mut buf = vec![]; - frontend::cancel_request(data.process_id, data.secret_key, &mut buf); - socket.write_all(&buf)?; - socket.flush()?; - - Ok(()) -} - -fn bad_response() -> io::Error { - io::Error::new( - io::ErrorKind::InvalidInput, - "the server returned an unexpected response", - ) -} - -fn desynchronized() -> io::Error { - io::Error::new( - io::ErrorKind::Other, - "communication with the server has desynchronized due to an earlier IO error", - ) -} - -/// Specifies the TLS support requested for a new connection. -#[derive(Debug)] -pub enum TlsMode<'a> { - /// The connection will not use TLS. - None, - /// The connection will use TLS if the backend supports it. - Prefer(&'a TlsHandshake), - /// The connection must use TLS. - Require(&'a TlsHandshake), -} - -#[derive(Debug)] -struct StatementInfo { - name: String, - param_types: Vec, - columns: Vec, -} - -struct InnerConnection { - stream: MessageStream, - notice_handler: Box, - notifications: VecDeque, - cancel_data: CancelData, - unknown_types: HashMap, - cached_statements: HashMap>, - parameters: HashMap, - next_stmt_id: u32, - trans_depth: u32, - desynchronized: bool, - finished: bool, - has_typeinfo_query: bool, - has_typeinfo_enum_query: bool, - has_typeinfo_composite_query: bool, -} - -impl Drop for InnerConnection { - fn drop(&mut self) { - if !self.finished { - let _ = self.finish_inner(); - } - } -} - -impl InnerConnection { - fn connect(params: T, tls: TlsMode) -> Result - where - T: IntoConnectParams, - { - let params = params.into_connect_params().map_err(error::connect)?; - let stream = priv_io::initialize_stream(¶ms, tls)?; - - let user = match params.user() { - Some(user) => user, - None => { - return Err(error::connect( - "user missing from connection parameters".into(), - )); - } - }; - - let mut conn = InnerConnection { - stream: MessageStream::new(stream), - next_stmt_id: 0, - notice_handler: Box::new(LoggingNoticeHandler), - notifications: VecDeque::new(), - cancel_data: CancelData { - process_id: 0, - secret_key: 0, - }, - unknown_types: HashMap::new(), - cached_statements: HashMap::new(), - parameters: HashMap::new(), - desynchronized: false, - finished: false, - trans_depth: 0, - has_typeinfo_query: false, - has_typeinfo_enum_query: false, - has_typeinfo_composite_query: false, - }; - - let mut options = params.options().to_owned(); - options.push(("client_encoding".to_owned(), "UTF8".to_owned())); - // Postgres uses the value of TimeZone as the time zone for TIMESTAMP - // WITH TIME ZONE values. Timespec converts to GMT internally. - options.push(("timezone".to_owned(), "GMT".to_owned())); - // We have to clone here since we need the user again for auth - options.push(("user".to_owned(), user.name().to_owned())); - if let Some(database) = params.database() { - options.push(("database".to_owned(), database.to_owned())); - } - - let options = options.iter().map(|&(ref a, ref b)| (&**a, &**b)); - conn.stream - .write_message(|buf| frontend::startup_message(options, buf))?; - conn.stream.flush()?; - - conn.handle_auth(user)?; - - loop { - match conn.read_message()? { - backend::Message::BackendKeyData(body) => { - conn.cancel_data.process_id = body.process_id(); - conn.cancel_data.secret_key = body.secret_key(); - } - backend::Message::ReadyForQuery(_) => break, - backend::Message::ErrorResponse(body) => { - return Err(err(&mut body.fields())); - } - _ => return Err(bad_response().into()), - } - } - - Ok(conn) - } - - fn read_message_with_notification(&mut self) -> io::Result { - debug_assert!(!self.desynchronized); - loop { - match try_desync!(self, self.stream.read_message()) { - backend::Message::NoticeResponse(body) => { - if let Ok(err) = DbError::new(&mut body.fields()) { - self.notice_handler.handle_notice(err); - } - } - backend::Message::ParameterStatus(body) => { - self.parameters - .insert(body.name()?.to_owned(), body.value()?.to_owned()); - } - val => return Ok(val), - } - } - } - - fn read_message_with_notification_timeout( - &mut self, - timeout: Duration, - ) -> io::Result> { - debug_assert!(!self.desynchronized); - loop { - match try_desync!(self, self.stream.read_message_timeout(timeout)) { - Some(backend::Message::NoticeResponse(body)) => { - if let Ok(err) = Err(err(&mut body.fields())) { - self.notice_handler.handle_notice(err); - } - } - Some(backend::Message::ParameterStatus(body)) => { - self.parameters - .insert(body.name()?.to_owned(), body.value()?.to_owned()); - } - val => return Ok(val), - } - } - } - - fn read_message_with_notification_nonblocking( - &mut self, - ) -> io::Result> { - debug_assert!(!self.desynchronized); - loop { - match try_desync!(self, self.stream.read_message_nonblocking()) { - Some(backend::Message::NoticeResponse(body)) => { - if let Ok(err) = Err(err(&mut body.fields())) { - self.notice_handler.handle_notice(err); - } - } - Some(backend::Message::ParameterStatus(body)) => { - self.parameters - .insert(body.name()?.to_owned(), body.value()?.to_owned()); - } - val => return Ok(val), - } - } - } - - fn read_message(&mut self) -> io::Result { - loop { - match self.read_message_with_notification()? { - backend::Message::NotificationResponse(body) => { - self.notifications.push_back(Notification { - process_id: body.process_id(), - channel: body.channel()?.to_owned(), - payload: body.message()?.to_owned(), - }) - } - val => return Ok(val), - } - } - } - - fn handle_auth(&mut self, user: &User) -> Result<()> { - match self.read_message()? { - backend::Message::AuthenticationOk => return Ok(()), - backend::Message::AuthenticationCleartextPassword => { - let pass = user.password().ok_or_else(|| { - error::connect("a password was requested but not provided".into()) - })?; - self.stream - .write_message(|buf| frontend::password_message(pass, buf))?; - self.stream.flush()?; - } - backend::Message::AuthenticationMd5Password(body) => { - let pass = user.password().ok_or_else(|| { - error::connect("a password was requested but not provided".into()) - })?; - let output = - authentication::md5_hash(user.name().as_bytes(), pass.as_bytes(), body.salt()); - self.stream - .write_message(|buf| frontend::password_message(&output, buf))?; - self.stream.flush()?; - } - backend::Message::AuthenticationSasl(body) => { - let mut has_scram = false; - let mut has_scram_plus = false; - let mut mechanisms = body.mechanisms(); - while let Some(mechanism) = mechanisms.next()? { - match mechanism { - sasl::SCRAM_SHA_256 => has_scram = true, - sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true, - _ => {} - } - } - let channel_binding = self - .stream - .get_ref() - .tls_unique() - .map(ChannelBinding::tls_unique) - .or_else(|| { - self.stream - .get_ref() - .tls_server_end_point() - .map(ChannelBinding::tls_server_end_point) - }); - - let (channel_binding, mechanism) = if has_scram_plus { - match channel_binding { - Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS), - None => (ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), - } - } else if has_scram { - match channel_binding { - Some(_) => (ChannelBinding::unrequested(), sasl::SCRAM_SHA_256), - None => (ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), - } - } else { - return Err( - io::Error::new(io::ErrorKind::Other, "unsupported authentication").into(), - ); - }; - - let pass = user.password().ok_or_else(|| { - error::connect("a password was requested but not provided".into()) - })?; - - let mut scram = ScramSha256::new(pass.as_bytes(), channel_binding); - - self.stream.write_message(|buf| { - frontend::sasl_initial_response(mechanism, scram.message(), buf) - })?; - self.stream.flush()?; - - let body = match self.read_message()? { - backend::Message::AuthenticationSaslContinue(body) => body, - backend::Message::ErrorResponse(body) => return Err(err(&mut body.fields())), - _ => return Err(bad_response().into()), - }; - - scram.update(body.data())?; - - self.stream - .write_message(|buf| frontend::sasl_response(scram.message(), buf))?; - self.stream.flush()?; - - let body = match self.read_message()? { - backend::Message::AuthenticationSaslFinal(body) => body, - backend::Message::ErrorResponse(body) => return Err(err(&mut body.fields())), - _ => return Err(bad_response().into()), - }; - - scram.finish(body.data())?; - } - backend::Message::AuthenticationKerberosV5 - | backend::Message::AuthenticationScmCredential - | backend::Message::AuthenticationGss - | backend::Message::AuthenticationSspi => { - return Err( - io::Error::new(io::ErrorKind::Other, "unsupported authentication").into(), - ) - } - backend::Message::ErrorResponse(body) => return Err(err(&mut body.fields())), - _ => return Err(bad_response().into()), - } - - match self.read_message()? { - backend::Message::AuthenticationOk => Ok(()), - backend::Message::ErrorResponse(body) => Err(err(&mut body.fields())), - _ => Err(bad_response().into()), - } - } - - fn set_notice_handler(&mut self, handler: Box) -> Box { - mem::replace(&mut self.notice_handler, handler) - } - - fn raw_prepare( - &mut self, - stmt_name: &str, - query: &str, - types: &[Option], - ) -> Result<(Vec, Vec)> { - debug!("preparing query with name `{}`: {}", stmt_name, query); - - self.stream.write_message(|buf| { - frontend::parse( - stmt_name, - query, - types.iter().map(|t| t.as_ref().map_or(0, |t| t.oid())), - buf, - ) - })?; - self.stream - .write_message(|buf| frontend::describe(b'S', stmt_name, buf))?; - self.stream - .write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; - self.stream.flush()?; - - match self.read_message()? { - backend::Message::ParseComplete => {} - backend::Message::ErrorResponse(body) => { - self.wait_for_ready()?; - return Err(err(&mut body.fields())); - } - _ => bad_response!(self), - } - - let raw_param_types = match self.read_message()? { - backend::Message::ParameterDescription(body) => body, - _ => bad_response!(self), - }; - - let raw_columns = match self.read_message()? { - backend::Message::RowDescription(body) => Some(body), - backend::Message::NoData => None, - _ => bad_response!(self), - }; - - self.wait_for_ready()?; - - let param_types = raw_param_types - .parameters() - .map_err(Into::into) - .and_then(|oid| self.get_type(oid)) - .collect()?; - - let columns = self.parse_cols(raw_columns)?; - Ok((param_types, columns)) - } - - fn read_rows(&mut self, mut consumer: F) -> Result - where - F: FnMut(RowData), - { - let more_rows; - loop { - match self.read_message()? { - backend::Message::EmptyQueryResponse | backend::Message::CommandComplete(_) => { - more_rows = false; - break; - } - backend::Message::PortalSuspended => { - more_rows = true; - break; - } - backend::Message::DataRow(body) => consumer(RowData::new(body)?), - backend::Message::ErrorResponse(body) => { - self.wait_for_ready()?; - return Err(err(&mut body.fields())); - } - backend::Message::CopyInResponse(_) => { - self.stream.write_message(|buf| { - frontend::copy_fail("COPY queries cannot be directly executed", buf) - })?; - self.stream - .write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; - self.stream.flush()?; - } - backend::Message::CopyOutResponse(_) => { - loop { - if let backend::Message::ReadyForQuery(_) = self.read_message()? { - break; - } - } - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "COPY queries cannot be directly \ - executed", - ).into()); - } - _ => { - self.desynchronized = true; - return Err(bad_response().into()); - } - } - } - self.wait_for_ready()?; - Ok(more_rows) - } - - fn raw_execute( - &mut self, - stmt_name: &str, - portal_name: &str, - row_limit: i32, - param_types: &[Type], - params: &[&ToSql], - ) -> Result<()> { - assert!( - param_types.len() == params.len(), - "expected {} parameters but got {}", - param_types.len(), - params.len() - ); - debug!( - "executing statement {} with parameters: {:?}", - stmt_name, params - ); - - { - let r = self.stream.write_message(|buf| { - frontend::bind( - portal_name, - stmt_name, - Some(1), - params.iter().zip(param_types), - |(param, ty), buf| match param.to_sql_checked(ty, buf) { - Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), - Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), - Err(e) => Err(e), - }, - Some(1), - buf, - ) - }); - match r { - Ok(()) => {} - Err(frontend::BindError::Conversion(e)) => { - return Err(error::conversion(e)); - } - Err(frontend::BindError::Serialization(e)) => return Err(e.into()), - } - } - - self.stream - .write_message(|buf| frontend::execute(portal_name, row_limit, buf))?; - self.stream - .write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; - self.stream.flush()?; - - match self.read_message()? { - backend::Message::BindComplete => Ok(()), - backend::Message::ErrorResponse(body) => { - self.wait_for_ready()?; - Err(err(&mut body.fields())) - } - _ => { - self.desynchronized = true; - Err(bad_response().into()) - } - } - } - - fn make_stmt_name(&mut self) -> String { - let stmt_name = format!("s{}", self.next_stmt_id); - self.next_stmt_id += 1; - stmt_name - } - - fn prepare_typed<'a>( - &mut self, - query: &str, - types: &[Option], - conn: &'a Connection, - ) -> Result> { - let stmt_name = self.make_stmt_name(); - let (param_types, columns) = self.raw_prepare(&stmt_name, query, types)?; - let info = Arc::new(StatementInfo { - name: stmt_name, - param_types: param_types, - columns: columns, - }); - Ok(Statement::new(conn, info, Cell::new(0), false)) - } - - fn prepare_cached<'a>(&mut self, query: &str, conn: &'a Connection) -> Result> { - let info = self.cached_statements.get(query).cloned(); - - let info = match info { - Some(info) => info, - None => { - let stmt_name = self.make_stmt_name(); - let (param_types, columns) = self.raw_prepare(&stmt_name, query, &[])?; - let info = Arc::new(StatementInfo { - name: stmt_name, - param_types: param_types, - columns: columns, - }); - self.cached_statements - .insert(query.to_owned(), info.clone()); - info - } - }; - - Ok(Statement::new(conn, info, Cell::new(0), true)) - } - - fn close_statement(&mut self, name: &str, type_: u8) -> Result<()> { - self.stream - .write_message(|buf| frontend::close(type_, name, buf))?; - self.stream - .write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; - self.stream.flush()?; - let resp = match self.read_message()? { - backend::Message::CloseComplete => Ok(()), - backend::Message::ErrorResponse(body) => Err(err(&mut body.fields())), - _ => bad_response!(self), - }; - self.wait_for_ready()?; - resp - } - - fn get_type(&mut self, oid: Oid) -> Result { - if let Some(ty) = Type::from_oid(oid) { - return Ok(ty); - } - - if let Some(ty) = self.unknown_types.get(&oid) { - return Ok(ty.clone()); - } - - let ty = self.read_type(oid)?; - self.unknown_types.insert(oid, ty.clone()); - Ok(ty) - } - - fn parse_cols(&mut self, raw: Option) -> Result> { - match raw { - Some(body) => body - .fields() - .and_then(|field| { - Ok(Column::new( - field.name().to_owned(), - self.get_type(field.type_oid())?, - )) - }).collect() - .map_err(From::from), - None => Ok(vec![]), - } - } - - fn setup_typeinfo_query(&mut self) -> Result<()> { - if self.has_typeinfo_query { - return Ok(()); - } - - match self.raw_prepare( - TYPEINFO_QUERY, - "SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, \ - t.typbasetype, n.nspname, t.typrelid \ - FROM pg_catalog.pg_type t \ - LEFT OUTER JOIN pg_catalog.pg_range r ON \ - r.rngtypid = t.oid \ - INNER JOIN pg_catalog.pg_namespace n ON \ - t.typnamespace = n.oid \ - WHERE t.oid = $1", - &[], - ) { - Ok(..) => {} - // Range types weren't added until Postgres 9.2, so pg_range may not exist - Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => { - self.raw_prepare( - TYPEINFO_QUERY, - "SELECT t.typname, t.typtype, t.typelem, NULL::OID, \ - t.typbasetype, n.nspname, t.typrelid \ - FROM pg_catalog.pg_type t \ - INNER JOIN pg_catalog.pg_namespace n \ - ON t.typnamespace = n.oid \ - WHERE t.oid = $1", - &[], - )?; - } - Err(e) => return Err(e), - } - - self.has_typeinfo_query = true; - Ok(()) - } - - #[allow(if_not_else)] - fn read_type(&mut self, oid: Oid) -> Result { - self.setup_typeinfo_query()?; - self.raw_execute(TYPEINFO_QUERY, "", 0, &[Type::OID], &[&oid])?; - let mut row = None; - self.read_rows(|r| row = Some(r))?; - - let get_raw = |i: usize| row.as_ref().and_then(|r| r.get(i)); - - let (name, type_, elem_oid, rngsubtype, basetype, schema, relid) = { - let name = - String::from_sql_nullable(&Type::NAME, get_raw(0)).map_err(error::conversion)?; - let type_ = - i8::from_sql_nullable(&Type::CHAR, get_raw(1)).map_err(error::conversion)?; - let elem_oid = - Oid::from_sql_nullable(&Type::OID, get_raw(2)).map_err(error::conversion)?; - let rngsubtype = Option::::from_sql_nullable(&Type::OID, get_raw(3)) - .map_err(error::conversion)?; - let basetype = - Oid::from_sql_nullable(&Type::OID, get_raw(4)).map_err(error::conversion)?; - let schema = - String::from_sql_nullable(&Type::NAME, get_raw(5)).map_err(error::conversion)?; - let relid = - Oid::from_sql_nullable(&Type::OID, get_raw(6)).map_err(error::conversion)?; - (name, type_, elem_oid, rngsubtype, basetype, schema, relid) - }; - - let kind = if type_ == b'e' as i8 { - Kind::Enum(self.read_enum_variants(oid)?) - } else if type_ == b'p' as i8 { - Kind::Pseudo - } else if basetype != 0 { - Kind::Domain(self.get_type(basetype)?) - } else if elem_oid != 0 { - Kind::Array(self.get_type(elem_oid)?) - } else if relid != 0 { - Kind::Composite(self.read_composite_fields(relid)?) - } else { - match rngsubtype { - Some(oid) => Kind::Range(self.get_type(oid)?), - None => Kind::Simple, - } - }; - - Ok(Type::_new(name, oid, kind, schema)) - } - - fn setup_typeinfo_enum_query(&mut self) -> Result<()> { - if self.has_typeinfo_enum_query { - return Ok(()); - } - - match self.raw_prepare( - TYPEINFO_ENUM_QUERY, - "SELECT enumlabel \ - FROM pg_catalog.pg_enum \ - WHERE enumtypid = $1 \ - ORDER BY enumsortorder", - &[], - ) { - Ok(..) => {} - // Postgres 9.0 doesn't have enumsortorder - Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => { - self.raw_prepare( - TYPEINFO_ENUM_QUERY, - "SELECT enumlabel \ - FROM pg_catalog.pg_enum \ - WHERE enumtypid = $1 \ - ORDER BY oid", - &[], - )?; - } - Err(e) => return Err(e), - } - - self.has_typeinfo_enum_query = true; - Ok(()) - } - - fn read_enum_variants(&mut self, oid: Oid) -> Result> { - self.setup_typeinfo_enum_query()?; - self.raw_execute(TYPEINFO_ENUM_QUERY, "", 0, &[Type::OID], &[&oid])?; - let mut rows = vec![]; - self.read_rows(|row| rows.push(row))?; - - let mut variants = vec![]; - for row in rows { - variants.push( - String::from_sql_nullable(&Type::NAME, row.get(0)).map_err(error::conversion)?, - ); - } - - Ok(variants) - } - - fn setup_typeinfo_composite_query(&mut self) -> Result<()> { - if self.has_typeinfo_composite_query { - return Ok(()); - } - - self.raw_prepare( - TYPEINFO_COMPOSITE_QUERY, - "SELECT attname, atttypid \ - FROM pg_catalog.pg_attribute \ - WHERE attrelid = $1 \ - AND NOT attisdropped \ - AND attnum > 0 \ - ORDER BY attnum", - &[], - )?; - - self.has_typeinfo_composite_query = true; - Ok(()) - } - - fn read_composite_fields(&mut self, relid: Oid) -> Result> { - self.setup_typeinfo_composite_query()?; - self.raw_execute(TYPEINFO_COMPOSITE_QUERY, "", 0, &[Type::OID], &[&relid])?; - let mut rows = vec![]; - self.read_rows(|row| rows.push(row))?; - - let mut fields = vec![]; - for row in rows { - let (name, type_) = { - let name = String::from_sql_nullable(&Type::NAME, row.get(0)) - .map_err(error::conversion)?; - let type_ = - Oid::from_sql_nullable(&Type::OID, row.get(1)).map_err(error::conversion)?; - (name, type_) - }; - let type_ = self.get_type(type_)?; - fields.push(Field::new(name, type_)); - } - - Ok(fields) - } - - fn is_desynchronized(&self) -> bool { - self.desynchronized - } - - #[allow(needless_return)] - fn wait_for_ready(&mut self) -> Result<()> { - match self.read_message()? { - backend::Message::ReadyForQuery(_) => Ok(()), - _ => bad_response!(self), - } - } - - fn simple_query_(&mut self, query: &str) -> Result> { - check_desync!(self); - debug!("executing query: {}", query); - self.stream - .write_message(|buf| frontend::query(query, buf))?; - self.stream.flush()?; - - let mut result = vec![]; - let mut rows = vec![]; - let mut columns = None; - - loop { - match self.read_message()? { - backend::Message::ReadyForQuery(_) => break, - backend::Message::DataRow(body) => { - rows.push(RowData::new(body)?); - } - backend::Message::CopyInResponse(_) => { - self.stream.write_message(|buf| { - frontend::copy_fail("COPY queries cannot be directly executed", buf) - })?; - self.stream - .write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; - self.stream.flush()?; - } - backend::Message::ErrorResponse(body) => { - self.wait_for_ready()?; - return Err(err(&mut body.fields())); - } - backend::Message::RowDescription(body) => { - columns = Some(self.parse_cols(Some(body))?); - } - backend::Message::CommandComplete(_) => { - if let Some(cols) = columns.take() { - result.push(TextRows::new(cols, mem::replace(&mut rows, Vec::new()))); - } - } - _ => bad_response!(self), - } - } - Ok(result) - } - - fn quick_query(&mut self, query: &str) -> Result>>> { - check_desync!(self); - debug!("executing query: {}", query); - self.stream - .write_message(|buf| frontend::query(query, buf))?; - self.stream.flush()?; - - let mut result = vec![]; - loop { - match self.read_message()? { - backend::Message::ReadyForQuery(_) => break, - backend::Message::DataRow(body) => { - let row = body - .ranges() - .map(|r| r.map(|r| String::from_utf8_lossy(&body.buffer()[r]).into_owned())) - .collect()?; - result.push(row); - } - backend::Message::CopyInResponse(_) => { - self.stream.write_message(|buf| { - frontend::copy_fail("COPY queries cannot be directly executed", buf) - })?; - self.stream - .write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; - self.stream.flush()?; - } - backend::Message::ErrorResponse(body) => { - self.wait_for_ready()?; - return Err(err(&mut body.fields())); - } - _ => {} - } - } - Ok(result) - } - - fn finish_inner(&mut self) -> Result<()> { - check_desync!(self); - self.stream - .write_message(|buf| Ok::<(), io::Error>(frontend::terminate(buf)))?; - self.stream.flush()?; - Ok(()) - } -} - -fn _ensure_send() { - fn _is_send() {} - _is_send::(); -} - -/// A connection to a Postgres database. -pub struct Connection(RefCell); - -impl fmt::Debug for Connection { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - let conn = self.0.borrow(); - fmt.debug_struct("Connection") - .field("stream", &conn.stream.get_ref()) - .field("cancel_data", &conn.cancel_data) - .field("notifications", &conn.notifications.len()) - .field("transaction_depth", &conn.trans_depth) - .field("desynchronized", &conn.desynchronized) - .field("cached_statements", &conn.cached_statements.len()) - .finish() - } -} - -impl Connection { - /// Creates a new connection to a Postgres database. - /// - /// Most applications can use a URL string in the normal format: - /// - /// ```notrust - /// postgresql://user[:password]@host[:port][/database][?param1=val1[[¶m2=val2]...]] - /// ``` - /// - /// The password may be omitted if not required. The default Postgres port - /// (5432) is used if none is specified. The database name defaults to the - /// username if not specified. - /// - /// To connect to the server via Unix sockets, `host` should be set to the - /// absolute path of the directory containing the socket file. Since `/` is - /// a reserved character in URLs, the path should be URL encoded. If the - /// path contains non-UTF 8 characters, a `ConnectParams` struct should be - /// created manually and passed in. Note that Postgres does not support TLS - /// over Unix sockets. - /// - /// # Examples - /// - /// To connect over TCP: - /// - /// ```rust,no_run - /// use postgres::{Connection, TlsMode}; - /// - /// let url = "postgresql://postgres:hunter2@localhost:5433:2994/foodb"; - /// let conn = Connection::connect(url, TlsMode::None).unwrap(); - /// ``` - /// - /// To connect over a Unix socket located in `/run/postgres`: - /// - /// ```rust,no_run - /// use postgres::{Connection, TlsMode}; - /// - /// let url = "postgresql://postgres@%2Frun%2Fpostgres"; - /// let conn = Connection::connect(url, TlsMode::None).unwrap(); - /// ``` - /// - /// To connect with a manually constructed `ConnectParams`: - /// - /// ```rust,no_run - /// use postgres::{Connection, TlsMode}; - /// use postgres::params::{ConnectParams, Host}; - /// # use std::path::PathBuf; - /// - /// # #[cfg(unix)] - /// # fn f() { - /// # let some_crazy_path = PathBuf::new(); - /// let params = ConnectParams::builder() - /// .user("postgres", None) - /// .build(Host::Unix(some_crazy_path)); - /// let conn = Connection::connect(params, TlsMode::None).unwrap(); - /// # } - /// ``` - pub fn connect(params: T, tls: TlsMode) -> Result - where - T: IntoConnectParams, - { - InnerConnection::connect(params, tls).map(|conn| Connection(RefCell::new(conn))) - } - - /// Executes a statement, returning the number of rows modified. - /// - /// A statement may contain parameters, specified by `$n` where `n` is the - /// index of the parameter in the list provided, 1-indexed. - /// - /// If the statement does not modify any rows (e.g. SELECT), 0 is returned. - /// - /// If the same statement will be repeatedly executed (perhaps with - /// different query parameters), consider using the `prepare` and - /// `prepare_cached` methods. - /// - /// # Panics - /// - /// Panics if the number of parameters provided does not match the number - /// expected. - /// - /// # Example - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// # let bar = 1i32; - /// # let baz = true; - /// let rows_updated = conn.execute("UPDATE foo SET bar = $1 WHERE baz = $2", &[&bar, &baz]) - /// .unwrap(); - /// println!("{} rows updated", rows_updated); - /// ``` - pub fn execute(&self, query: &str, params: &[&ToSql]) -> Result { - let (param_types, columns) = self.0.borrow_mut().raw_prepare("", query, &[])?; - let info = Arc::new(StatementInfo { - name: String::new(), - param_types: param_types, - columns: columns, - }); - let stmt = Statement::new(self, info, Cell::new(0), true); - stmt.execute(params) - } - - /// Executes a statement, returning the resulting rows. - /// - /// A statement may contain parameters, specified by `$n` where `n` is the - /// index of the parameter in the list provided, 1-indexed. - /// - /// If the same statement will be repeatedly executed (perhaps with - /// different query parameters), consider using the `prepare` and - /// `prepare_cached` methods. - /// - /// # Panics - /// - /// Panics if the number of parameters provided does not match the number - /// expected. - /// - /// # Example - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// # let baz = true; - /// for row in &conn.query("SELECT foo FROM bar WHERE baz = $1", &[&baz]).unwrap() { - /// let foo: i32 = row.get("foo"); - /// println!("foo: {}", foo); - /// } - /// ``` - pub fn query(&self, query: &str, params: &[&ToSql]) -> Result { - let (param_types, columns) = self.0.borrow_mut().raw_prepare("", query, &[])?; - let info = Arc::new(StatementInfo { - name: String::new(), - param_types: param_types, - columns: columns, - }); - let stmt = Statement::new(self, info, Cell::new(0), true); - stmt.into_query(params) - } - - /// Begins a new transaction. - /// - /// Returns a `Transaction` object which should be used instead of - /// the connection for the duration of the transaction. The transaction - /// is active until the `Transaction` object falls out of scope. - /// - /// # Note - /// A transaction will roll back by default. The `set_commit`, - /// `set_rollback`, and `commit` methods alter this behavior. - /// - /// # Panics - /// - /// Panics if a transaction is already active. - /// - /// # Example - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// let trans = conn.transaction().unwrap(); - /// trans.execute("UPDATE foo SET bar = 10", &[]).unwrap(); - /// // ... - /// - /// trans.commit().unwrap(); - /// ``` - pub fn transaction<'a>(&'a self) -> Result> { - self.transaction_with(&transaction::Config::new()) - } - - /// Begins a new transaction with the specified configuration. - pub fn transaction_with<'a>(&'a self, config: &transaction::Config) -> Result> { - let mut conn = self.0.borrow_mut(); - check_desync!(conn); - assert!( - conn.trans_depth == 0, - "`transaction` must be called on the active transaction" - ); - let mut query = "BEGIN".to_owned(); - config.build_command(&mut query); - conn.quick_query(&query)?; - conn.trans_depth += 1; - Ok(Transaction::new(self, 1)) - } - - /// Creates a new prepared statement. - /// - /// If the same statement will be executed repeatedly, explicitly preparing - /// it can improve performance. - /// - /// The statement is associated with the connection that created it and may - /// not outlive that connection. - /// - /// # Example - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # let x = 10i32; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// # let (a, b) = (0i32, 1i32); - /// # let updates = vec![(&a, &b)]; - /// let stmt = conn.prepare("UPDATE foo SET bar = $1 WHERE baz = $2").unwrap(); - /// for (bar, baz) in updates { - /// stmt.execute(&[bar, baz]).unwrap(); - /// } - /// ``` - pub fn prepare<'a>(&'a self, query: &str) -> Result> { - self.prepare_typed(query, &[]) - } - - /// Like `prepare`, but allows for the types of query parameters to be explicitly specified. - /// - /// Postgres will normally infer the types of paramters, but this function offers more control - /// of that behavior. `None` will cause Postgres to infer the type. The list of types can be - /// shorter than the number of parameters in the query; it will act as if padded out with `None` - /// values. - /// - /// # Example - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # use postgres::types::Type; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// // $1 would normally be assigned the type INT4, but we can override that to INT8 - /// let stmt = conn.prepare_typed("SELECT $1::INT4", &[Some(Type::INT8)]).unwrap(); - /// assert_eq!(stmt.param_types()[0], Type::INT8); - /// ``` - pub fn prepare_typed<'a>( - &'a self, - query: &str, - types: &[Option], - ) -> Result> { - self.0.borrow_mut().prepare_typed(query, types, self) - } - - /// Creates a cached prepared statement. - /// - /// Like `prepare`, except that the statement is only prepared once over - /// the lifetime of the connection and then cached. If the same statement - /// is going to be prepared frequently, caching it can improve performance - /// by reducing the number of round trips to the Postgres backend. - /// - /// # Example - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # let x = 10i32; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// # let (a, b) = (0i32, 1i32); - /// # let updates = vec![(&a, &b)]; - /// let stmt = conn.prepare_cached("UPDATE foo SET bar = $1 WHERE baz = $2").unwrap(); - /// for (bar, baz) in updates { - /// stmt.execute(&[bar, baz]).unwrap(); - /// } - /// ``` - pub fn prepare_cached<'a>(&'a self, query: &str) -> Result> { - self.0.borrow_mut().prepare_cached(query, self) - } - - /// Returns the isolation level which will be used for future transactions. - /// - /// This is a simple wrapper around `SHOW TRANSACTION ISOLATION LEVEL`. - pub fn transaction_isolation(&self) -> Result { - let mut conn = self.0.borrow_mut(); - check_desync!(conn); - let result = conn.quick_query("SHOW TRANSACTION ISOLATION LEVEL")?; - IsolationLevel::new(result[0][0].as_ref().unwrap()) - } - - /// Sets the configuration that will be used for future transactions. - pub fn set_transaction_config(&self, config: &transaction::Config) -> Result<()> { - let mut command = "SET SESSION CHARACTERISTICS AS TRANSACTION".to_owned(); - config.build_command(&mut command); - self.simple_query(&command).map(|_| ()) - } - - /// Execute a sequence of SQL statements. - /// - /// Statements should be separated by `;` characters. If an error occurs, - /// execution of the sequence will stop at that point. This is intended for - /// execution of batches of non-dynamic statements - for example, creation - /// of a schema for a fresh database. - /// - /// # Warning - /// - /// Prepared statements should be used for any SQL statement which contains - /// user-specified data, as it provides functionality to safely embed that - /// data in the statement. Do not form statements via string concatenation - /// and feed them into this method. - /// - /// # Example - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode, Result}; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// conn.batch_execute(" - /// CREATE TABLE person ( - /// id SERIAL PRIMARY KEY, - /// name NOT NULL - /// ); - /// - /// CREATE TABLE purchase ( - /// id SERIAL PRIMARY KEY, - /// person INT NOT NULL REFERENCES person (id), - /// time TIMESTAMPTZ NOT NULL, - /// ); - /// - /// CREATE INDEX ON purchase (time); - /// ").unwrap(); - /// ``` - #[deprecated(since = "0.15.3", note = "please use `simple_query` instead")] - pub fn batch_execute(&self, query: &str) -> Result<()> { - self.0.borrow_mut().quick_query(query).map(|_| ()) - } - - /// Send a simple, non-prepared query - /// - /// Executes a query without making a prepared statement. All result columns - /// are returned in a UTF-8 text format rather than compact binary - /// representations. This can be useful when communicating with services - /// like _pgbouncer_ which speak "basic" postgres but don't support prepared - /// statements. - /// - /// Because rust-postgres' query parameter substitution relies on prepared - /// statements, it's not possible to pass a separate parameters list with - /// this API. - /// - /// In general, the `query` API should be prefered whenever possible. - /// - /// # Example - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// for response in &conn.simple_query("SELECT foo FROM bar WHERE baz = 'quux'").unwrap() { - /// for row in response { - /// let foo: &str = row.get("foo"); - /// println!("foo: {}", foo); - /// } - /// } - /// ``` - pub fn simple_query(&self, query: &str) -> Result> { - self.0.borrow_mut().simple_query_(query) - } - - /// Returns a structure providing access to asynchronous notifications. - /// - /// Use the `LISTEN` command to register this connection for notifications. - pub fn notifications<'a>(&'a self) -> Notifications<'a> { - Notifications::new(self) - } - - /// Returns information used to cancel pending queries. - /// - /// Used with the `cancel_query` function. The object returned can be used - /// to cancel any query executed by the connection it was created from. - pub fn cancel_data(&self) -> CancelData { - self.0.borrow().cancel_data - } - - /// Returns the value of the specified Postgres backend parameter, such as - /// `timezone` or `server_version`. - pub fn parameter(&self, param: &str) -> Option { - self.0.borrow().parameters.get(param).cloned() - } - - /// Sets the notice handler for the connection, returning the old handler. - pub fn set_notice_handler(&self, handler: Box) -> Box { - self.0.borrow_mut().set_notice_handler(handler) - } - - /// Returns whether or not the stream has been desynchronized due to an - /// error in the communication channel with the server. - /// - /// If this has occurred, all further queries will immediately return an - /// error. - pub fn is_desynchronized(&self) -> bool { - self.0.borrow().is_desynchronized() - } - - /// Determines if the `Connection` is currently "active", that is, if there - /// are no active transactions. - /// - /// The `transaction` method can only be called on the active `Connection` - /// or `Transaction`. - pub fn is_active(&self) -> bool { - self.0.borrow().trans_depth == 0 - } - - /// Consumes the connection, closing it. - /// - /// Functionally equivalent to the `Drop` implementation for `Connection` - /// except that it returns any error encountered to the caller. - pub fn finish(self) -> Result<()> { - let mut conn = self.0.borrow_mut(); - conn.finished = true; - conn.finish_inner() - } -} - -/// A trait allowing abstraction over connections and transactions -pub trait GenericConnection { - /// Like `Connection::execute`. - fn execute(&self, query: &str, params: &[&ToSql]) -> Result; - - /// Like `Connection::query`. - fn query<'a>(&'a self, query: &str, params: &[&ToSql]) -> Result; - - /// Like `Connection::prepare`. - fn prepare<'a>(&'a self, query: &str) -> Result>; - - /// Like `Connection::prepare_cached`. - fn prepare_cached<'a>(&'a self, query: &str) -> Result>; - - /// Like `Connection::transaction`. - fn transaction<'a>(&'a self) -> Result>; - - /// Like `Connection::batch_execute`. - #[deprecated(since = "0.15.3", note = "please use `simple_query` instead")] - fn batch_execute(&self, query: &str) -> Result<()>; - - /// Like `Connection::is_active`. - fn is_active(&self) -> bool; - - /// Like `Connection::simple_query`. - fn simple_query(&self, query: &str) -> Result>; -} - -impl GenericConnection for Connection { - fn execute(&self, query: &str, params: &[&ToSql]) -> Result { - self.execute(query, params) - } - - fn query<'a>(&'a self, query: &str, params: &[&ToSql]) -> Result { - self.query(query, params) - } - - fn prepare<'a>(&'a self, query: &str) -> Result> { - self.prepare(query) - } - - fn prepare_cached<'a>(&'a self, query: &str) -> Result> { - self.prepare_cached(query) - } - - fn transaction<'a>(&'a self) -> Result> { - self.transaction() - } - - fn batch_execute(&self, query: &str) -> Result<()> { - self.simple_query(query).map(|_| ()) - } - - fn is_active(&self) -> bool { - self.is_active() - } - - fn simple_query(&self, query: &str) -> Result> { - self.simple_query(query) - } -} - -impl<'a> GenericConnection for Transaction<'a> { - fn execute(&self, query: &str, params: &[&ToSql]) -> Result { - self.execute(query, params) - } - - fn query<'b>(&'b self, query: &str, params: &[&ToSql]) -> Result { - self.query(query, params) - } - - fn prepare<'b>(&'b self, query: &str) -> Result> { - self.prepare(query) - } - - fn prepare_cached<'b>(&'b self, query: &str) -> Result> { - self.prepare_cached(query) - } - - fn transaction<'b>(&'b self) -> Result> { - self.transaction() - } - - fn batch_execute(&self, query: &str) -> Result<()> { - self.simple_query(query).map(|_| ()) - } - - fn simple_query(&self, query: &str) -> Result> { - self.simple_query(query) - } - - fn is_active(&self) -> bool { - self.is_active() - } -} - -fn err(fields: &mut ErrorFields) -> Error { - match DbError::new(fields) { - Ok(err) => error::db(err), - Err(err) => err.into(), - } -} +pub use crate::notifications::Notifications; +#[doc(no_inline)] +pub use crate::row::{Row, SimpleQueryRow}; +pub use crate::row_iter::RowIter; +#[doc(no_inline)] +pub use crate::tls::NoTls; +pub use crate::transaction::*; +pub use crate::transaction_builder::TransactionBuilder; + +pub mod binary_copy; +mod cancel_token; +mod client; +pub mod config; +mod connection; +mod copy_in_writer; +mod copy_out_reader; +mod generic_client; +mod lazy_pin; +pub mod notifications; +mod row_iter; +mod transaction; +mod transaction_builder; + +#[cfg(test)] +mod test; diff --git a/postgres/src/macros.rs b/postgres/src/macros.rs deleted file mode 100644 index ba6d11ff4..000000000 --- a/postgres/src/macros.rs +++ /dev/null @@ -1,69 +0,0 @@ -macro_rules! try_desync { - ($s:expr, $e:expr) => ( - match $e { - Ok(ok) => ok, - Err(err) => { - $s.desynchronized = true; - return Err(::std::convert::From::from(err)); - } - } - ) -} - -macro_rules! check_desync { - ($e:expr) => ({ - if $e.is_desynchronized() { - return Err(::desynchronized().into()); - } - }) -} - -macro_rules! bad_response { - ($s:expr) => ({ - debug!("Bad response at {}:{}", file!(), line!()); - $s.desynchronized = true; - return Err(::bad_response().into()); - }) -} - -#[cfg(feature = "no-logging")] -macro_rules! debug { - ($($t:tt)*) => {} -} - -#[cfg(feature = "no-logging")] -macro_rules! info { - ($($t:tt)*) => {} -} - -/// Generates a simple implementation of `ToSql::accepts` which accepts the -/// types passed to it. -#[macro_export] -macro_rules! accepts { - ($($expected:pat),+) => ( - fn accepts(ty: &$crate::types::Type) -> bool { - match *ty { - $($expected)|+ => true, - _ => false - } - } - ) -} - -/// Generates an implementation of `ToSql::to_sql_checked`. -/// -/// All `ToSql` implementations should use this macro. -#[macro_export] -macro_rules! to_sql_checked { - () => { - fn to_sql_checked(&self, - ty: &$crate::types::Type, - out: &mut ::std::vec::Vec) - -> ::std::result::Result<$crate::types::IsNull, - Box<::std::error::Error + - ::std::marker::Sync + - ::std::marker::Send>> { - $crate::types::__to_sql_checked(self, ty, out) - } - } -} diff --git a/postgres/src/notification.rs b/postgres/src/notification.rs deleted file mode 100644 index addbd748e..000000000 --- a/postgres/src/notification.rs +++ /dev/null @@ -1,209 +0,0 @@ -//! Asynchronous notifications. - -use fallible_iterator::{FallibleIterator, IntoFallibleIterator}; -use std::fmt; -use std::time::Duration; -use postgres_protocol::message::backend::{self, ErrorFields}; -use error::DbError; - -#[doc(inline)] -use postgres_shared; -pub use postgres_shared::Notification; - -use {desynchronized, Result, Connection}; -use error::Error; - -/// Notifications from the Postgres backend. -pub struct Notifications<'conn> { - conn: &'conn Connection, -} - -impl<'a> fmt::Debug for Notifications<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("Notifications") - .field("pending", &self.len()) - .finish() - } -} - -impl<'conn> Notifications<'conn> { - pub(crate) fn new(conn: &'conn Connection) -> Notifications<'conn> { - Notifications { conn: conn } - } - - /// Returns the number of pending notifications. - pub fn len(&self) -> usize { - self.conn.0.borrow().notifications.len() - } - - /// Determines if there are any pending notifications. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns a fallible iterator over pending notifications. - /// - /// # Note - /// - /// This iterator may start returning `Some` after previously returning - /// `None` if more notifications are received. - pub fn iter<'a>(&'a self) -> Iter<'a> { - Iter { conn: self.conn } - } - - /// Returns a fallible iterator over notifications that blocks until one is - /// received if none are pending. - /// - /// The iterator will never return `None`. - pub fn blocking_iter<'a>(&'a self) -> BlockingIter<'a> { - BlockingIter { conn: self.conn } - } - - /// Returns a fallible iterator over notifications that blocks for a limited - /// time waiting to receive one if none are pending. - /// - /// # Note - /// - /// This iterator may start returning `Some` after previously returning - /// `None` if more notifications are received. - pub fn timeout_iter<'a>(&'a self, timeout: Duration) -> TimeoutIter<'a> { - TimeoutIter { - conn: self.conn, - timeout: timeout, - } - } -} - -impl<'a, 'conn> IntoFallibleIterator for &'a Notifications<'conn> { - type Item = Notification; - type Error = Error; - type IntoIter = Iter<'a>; - - fn into_fallible_iterator(self) -> Iter<'a> { - self.iter() - } -} - -/// A fallible iterator over pending notifications. -pub struct Iter<'a> { - conn: &'a Connection, -} - -impl<'a> FallibleIterator for Iter<'a> { - type Item = Notification; - type Error = Error; - - fn next(&mut self) -> Result> { - let mut conn = self.conn.0.borrow_mut(); - - if let Some(notification) = conn.notifications.pop_front() { - return Ok(Some(notification)); - } - - if conn.is_desynchronized() { - return Err(desynchronized().into()); - } - - match conn.read_message_with_notification_nonblocking() { - Ok(Some(backend::Message::NotificationResponse(body))) => { - Ok(Some(Notification { - process_id: body.process_id(), - channel: body.channel()?.to_owned(), - payload: body.message()?.to_owned(), - })) - } - Ok(Some(backend::Message::ErrorResponse(body))) => Err(err(&mut body.fields())), - Ok(None) => Ok(None), - Err(err) => Err(err.into()), - _ => unreachable!(), - } - } - - fn size_hint(&self) -> (usize, Option) { - (self.conn.0.borrow().notifications.len(), None) - } -} - -/// An iterator over notifications which will block if none are pending. -pub struct BlockingIter<'a> { - conn: &'a Connection, -} - -impl<'a> FallibleIterator for BlockingIter<'a> { - type Item = Notification; - type Error = Error; - - fn next(&mut self) -> Result> { - let mut conn = self.conn.0.borrow_mut(); - - if let Some(notification) = conn.notifications.pop_front() { - return Ok(Some(notification)); - } - - if conn.is_desynchronized() { - return Err(desynchronized().into()); - } - - match conn.read_message_with_notification() { - Ok(backend::Message::NotificationResponse(body)) => { - Ok(Some(Notification { - process_id: body.process_id(), - channel: body.channel()?.to_owned(), - payload: body.message()?.to_owned(), - })) - } - Ok(backend::Message::ErrorResponse(body)) => Err(err(&mut body.fields())), - Err(err) => Err(err.into()), - _ => unreachable!(), - } - } -} - -/// An iterator over notifications which will block for a period of time if -/// none are pending. -pub struct TimeoutIter<'a> { - conn: &'a Connection, - timeout: Duration, -} - -impl<'a> FallibleIterator for TimeoutIter<'a> { - type Item = Notification; - type Error = Error; - - fn next(&mut self) -> Result> { - let mut conn = self.conn.0.borrow_mut(); - - if let Some(notification) = conn.notifications.pop_front() { - return Ok(Some(notification)); - } - - if conn.is_desynchronized() { - return Err(desynchronized().into()); - } - - match conn.read_message_with_notification_timeout(self.timeout) { - Ok(Some(backend::Message::NotificationResponse(body))) => { - Ok(Some(Notification { - process_id: body.process_id(), - channel: body.channel()?.to_owned(), - payload: body.message()?.to_owned(), - })) - } - Ok(Some(backend::Message::ErrorResponse(body))) => Err(err(&mut body.fields())), - Ok(None) => Ok(None), - Err(err) => Err(err.into()), - _ => unreachable!(), - } - } - - fn size_hint(&self) -> (usize, Option) { - (self.conn.0.borrow().notifications.len(), None) - } -} - -fn err(fields: &mut ErrorFields) -> Error { - match DbError::new(fields) { - Ok(err) => postgres_shared::error::db(err), - Err(err) => err.into(), - } -} diff --git a/postgres/src/notifications.rs b/postgres/src/notifications.rs new file mode 100644 index 000000000..0c040dedf --- /dev/null +++ b/postgres/src/notifications.rs @@ -0,0 +1,162 @@ +//! Asynchronous notifications. + +use crate::connection::ConnectionRef; +use crate::{Error, Notification}; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, FutureExt}; +use std::pin::Pin; +use std::task::Poll; +use std::time::Duration; +use tokio::time::{self, Instant, Sleep}; + +/// Notifications from a PostgreSQL backend. +pub struct Notifications<'a> { + connection: ConnectionRef<'a>, +} + +impl<'a> Notifications<'a> { + pub(crate) fn new(connection: ConnectionRef<'a>) -> Notifications<'a> { + Notifications { connection } + } + + /// Returns the number of already buffered pending notifications. + pub fn len(&self) -> usize { + self.connection.notifications().len() + } + + /// Determines if there are any already buffered pending notifications. + pub fn is_empty(&self) -> bool { + self.connection.notifications().is_empty() + } + + /// Returns a nonblocking iterator over notifications. + /// + /// If there are no already buffered pending notifications, this iterator will poll the connection but will not + /// block waiting on notifications over the network. A return value of `None` either indicates that there are no + /// pending notifications or that the server has disconnected. + /// + /// # Note + /// + /// This iterator may start returning `Some` after previously returning `None` if more notifications are received. + pub fn iter(&mut self) -> Iter<'_> { + Iter { + connection: self.connection.as_ref(), + } + } + + /// Returns a blocking iterator over notifications. + /// + /// If there are no already buffered pending notifications, this iterator will block indefinitely waiting on the + /// PostgreSQL backend server to send one. It will only return `None` if the server has disconnected. + pub fn blocking_iter(&mut self) -> BlockingIter<'_> { + BlockingIter { + connection: self.connection.as_ref(), + } + } + + /// Returns an iterator over notifications which blocks a limited amount of time. + /// + /// If there are no already buffered pending notifications, this iterator will block waiting on the PostgreSQL + /// backend server to send one up to the provided timeout. A return value of `None` either indicates that there are + /// no pending notifications or that the server has disconnected. + /// + /// # Note + /// + /// This iterator may start returning `Some` after previously returning `None` if more notifications are received. + pub fn timeout_iter(&mut self, timeout: Duration) -> TimeoutIter<'_> { + TimeoutIter { + delay: Box::pin(self.connection.enter(|| time::sleep(timeout))), + timeout, + connection: self.connection.as_ref(), + } + } +} + +/// A nonblocking iterator over pending notifications. +pub struct Iter<'a> { + connection: ConnectionRef<'a>, +} + +impl FallibleIterator for Iter<'_> { + type Item = Notification; + type Error = Error; + + fn next(&mut self) -> Result, Self::Error> { + if let Some(notification) = self.connection.notifications_mut().pop_front() { + return Ok(Some(notification)); + } + + self.connection + .poll_block_on(|_, notifications, _| Poll::Ready(Ok(notifications.pop_front()))) + } + + fn size_hint(&self) -> (usize, Option) { + (self.connection.notifications().len(), None) + } +} + +/// A blocking iterator over pending notifications. +pub struct BlockingIter<'a> { + connection: ConnectionRef<'a>, +} + +impl FallibleIterator for BlockingIter<'_> { + type Item = Notification; + type Error = Error; + + fn next(&mut self) -> Result, Self::Error> { + if let Some(notification) = self.connection.notifications_mut().pop_front() { + return Ok(Some(notification)); + } + + self.connection + .poll_block_on(|_, notifications, done| match notifications.pop_front() { + Some(notification) => Poll::Ready(Ok(Some(notification))), + None if done => Poll::Ready(Ok(None)), + None => Poll::Pending, + }) + } + + fn size_hint(&self) -> (usize, Option) { + (self.connection.notifications().len(), None) + } +} + +/// A time-limited blocking iterator over pending notifications. +pub struct TimeoutIter<'a> { + connection: ConnectionRef<'a>, + delay: Pin>, + timeout: Duration, +} + +impl FallibleIterator for TimeoutIter<'_> { + type Item = Notification; + type Error = Error; + + fn next(&mut self) -> Result, Self::Error> { + if let Some(notification) = self.connection.notifications_mut().pop_front() { + self.delay.as_mut().reset(Instant::now() + self.timeout); + return Ok(Some(notification)); + } + + let delay = &mut self.delay; + let timeout = self.timeout; + self.connection.poll_block_on(|cx, notifications, done| { + match notifications.pop_front() { + Some(notification) => { + delay.as_mut().reset(Instant::now() + timeout); + return Poll::Ready(Ok(Some(notification))); + } + None if done => return Poll::Ready(Ok(None)), + None => {} + } + + ready!(delay.poll_unpin(cx)); + Poll::Ready(Ok(None)) + }) + } + + fn size_hint(&self) -> (usize, Option) { + (self.connection.notifications().len(), None) + } +} diff --git a/postgres/src/params.rs b/postgres/src/params.rs deleted file mode 100644 index eb521f109..000000000 --- a/postgres/src/params.rs +++ /dev/null @@ -1,3 +0,0 @@ -//! Connection parameters - -pub use postgres_shared::params::{Builder, ConnectParams, User, Host, IntoConnectParams}; diff --git a/postgres/src/priv_io.rs b/postgres/src/priv_io.rs deleted file mode 100644 index b761910bc..000000000 --- a/postgres/src/priv_io.rs +++ /dev/null @@ -1,259 +0,0 @@ -use bytes::{BufMut, BytesMut}; -use postgres_protocol::message::backend; -use postgres_protocol::message::frontend; -use socket2::{Domain, SockAddr, Socket, Type}; -use std::io::{self, BufWriter, Read, Write}; -use std::net::{SocketAddr, ToSocketAddrs}; -#[cfg(unix)] -use std::os::unix::io::{AsRawFd, RawFd}; -#[cfg(windows)] -use std::os::windows::io::{AsRawSocket, RawSocket}; -use std::result; -use std::time::Duration; - -use error; -use params::{ConnectParams, Host}; -use tls::TlsStream; -use {Result, TlsMode}; - -const INITIAL_CAPACITY: usize = 8 * 1024; - -pub struct MessageStream { - stream: BufWriter>, - in_buf: BytesMut, - out_buf: Vec, -} - -impl MessageStream { - pub fn new(stream: Box) -> MessageStream { - MessageStream { - stream: BufWriter::new(stream), - in_buf: BytesMut::with_capacity(INITIAL_CAPACITY), - out_buf: vec![], - } - } - - pub fn get_ref(&self) -> &TlsStream { - &**self.stream.get_ref() - } - - pub fn write_message(&mut self, f: F) -> result::Result<(), E> - where - F: FnOnce(&mut Vec) -> result::Result<(), E>, - E: From, - { - self.out_buf.clear(); - f(&mut self.out_buf)?; - self.stream.write_all(&self.out_buf).map_err(From::from) - } - - pub fn read_message(&mut self) -> io::Result { - loop { - match backend::Message::parse(&mut self.in_buf) { - Ok(Some(message)) => return Ok(message), - Ok(None) => self.read_in()?, - Err(e) => return Err(e), - } - } - } - - fn read_in(&mut self) -> io::Result<()> { - self.in_buf.reserve(1); - match self.stream - .get_mut() - .read(unsafe { self.in_buf.bytes_mut() }) - { - Ok(0) => Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "unexpected EOF", - )), - Ok(n) => { - unsafe { self.in_buf.advance_mut(n) }; - Ok(()) - } - Err(e) => Err(e), - } - } - - pub fn read_message_timeout( - &mut self, - timeout: Duration, - ) -> io::Result> { - if self.in_buf.is_empty() { - self.set_read_timeout(Some(timeout))?; - let r = self.read_in(); - self.set_read_timeout(None)?; - - match r { - Ok(()) => {} - Err(ref e) - if e.kind() == io::ErrorKind::WouldBlock - || e.kind() == io::ErrorKind::TimedOut => - { - return Ok(None) - } - Err(e) => return Err(e), - } - } - - self.read_message().map(Some) - } - - pub fn read_message_nonblocking(&mut self) -> io::Result> { - if self.in_buf.is_empty() { - self.set_nonblocking(true)?; - let r = self.read_in(); - self.set_nonblocking(false)?; - - match r { - Ok(()) => {} - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(None), - Err(e) => return Err(e), - } - } - - self.read_message().map(Some) - } - - pub fn flush(&mut self) -> io::Result<()> { - self.stream.flush() - } - - fn set_read_timeout(&self, timeout: Option) -> io::Result<()> { - self.stream.get_ref().get_ref().0.set_read_timeout(timeout) - } - - fn set_nonblocking(&self, nonblock: bool) -> io::Result<()> { - self.stream.get_ref().get_ref().0.set_nonblocking(nonblock) - } -} - -/// A connection to the Postgres server. -/// -/// It implements `Read`, `Write` and `TlsStream`, as well as `AsRawFd` on -/// Unix platforms and `AsRawSocket` on Windows platforms. -#[derive(Debug)] -pub struct Stream(Socket); - -impl Read for Stream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) - } -} - -impl Write for Stream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.0.flush() - } -} - -impl TlsStream for Stream { - fn get_ref(&self) -> &Stream { - self - } - - fn get_mut(&mut self) -> &mut Stream { - self - } -} - -#[cfg(unix)] -impl AsRawFd for Stream { - fn as_raw_fd(&self) -> RawFd { - self.0.as_raw_fd() - } -} - -#[cfg(windows)] -impl AsRawSocket for Stream { - fn as_raw_socket(&self) -> RawSocket { - self.0.as_raw_socket() - } -} - -fn open_socket(params: &ConnectParams) -> Result { - let port = params.port(); - match *params.host() { - Host::Tcp(ref host) => { - let mut error = None; - for addr in (&**host, port).to_socket_addrs()? { - let domain = match addr { - SocketAddr::V4(_) => Domain::ipv4(), - SocketAddr::V6(_) => Domain::ipv6(), - }; - let socket = Socket::new(domain, Type::stream(), None)?; - if let Some(keepalive) = params.keepalive() { - socket.set_keepalive(Some(keepalive))?; - } - let addr = SockAddr::from(addr); - let r = match params.connect_timeout() { - Some(timeout) => socket.connect_timeout(&addr, timeout), - None => socket.connect(&addr), - }; - match r { - Ok(()) => return Ok(socket), - Err(e) => error = Some(e), - } - } - - Err(error - .unwrap_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve any addresses", - ) - }) - .into()) - } - #[cfg(unix)] - Host::Unix(ref path) => { - let path = path.join(&format!(".s.PGSQL.{}", port)); - let socket = Socket::new(Domain::unix(), Type::stream(), None)?; - let addr = SockAddr::unix(path)?; - socket.connect(&addr)?; - Ok(socket) - } - #[cfg(not(unix))] - Host::Unix(..) => Err(io::Error::new( - io::ErrorKind::InvalidInput, - "unix sockets are not supported on this system", - ).into()), - } -} - -pub fn initialize_stream(params: &ConnectParams, tls: TlsMode) -> Result> { - let mut socket = Stream(open_socket(params)?); - - let (tls_required, handshaker) = match tls { - TlsMode::None => return Ok(Box::new(socket)), - TlsMode::Prefer(handshaker) => (false, handshaker), - TlsMode::Require(handshaker) => (true, handshaker), - }; - - let mut buf = vec![]; - frontend::ssl_request(&mut buf); - socket.write_all(&buf)?; - socket.flush()?; - - let mut b = [0; 1]; - socket.read_exact(&mut b)?; - if b[0] == b'N' { - if tls_required { - return Err(error::tls("the server does not support TLS".into())); - } else { - return Ok(Box::new(socket)); - } - } - - let host = match *params.host() { - Host::Tcp(ref host) => host, - // Postgres doesn't support TLS over unix sockets - Host::Unix(_) => return Err(::bad_response().into()), - }; - - handshaker.tls_handshake(host, socket).map_err(error::tls) -} diff --git a/postgres/src/row_iter.rs b/postgres/src/row_iter.rs new file mode 100644 index 000000000..221fdfc68 --- /dev/null +++ b/postgres/src/row_iter.rs @@ -0,0 +1,38 @@ +use crate::connection::ConnectionRef; +use fallible_iterator::FallibleIterator; +use futures_util::StreamExt; +use std::pin::Pin; +use tokio_postgres::{Error, Row, RowStream}; + +/// The iterator returned by `query_raw`. +pub struct RowIter<'a> { + connection: ConnectionRef<'a>, + it: Pin>, +} + +impl<'a> RowIter<'a> { + pub(crate) fn new(connection: ConnectionRef<'a>, stream: RowStream) -> RowIter<'a> { + RowIter { + connection, + it: Box::pin(stream), + } + } + + /// Returns the number of rows affected by the query. + /// + /// This function will return `None` until the iterator has been exhausted. + pub fn rows_affected(&self) -> Option { + self.it.rows_affected() + } +} + +impl FallibleIterator for RowIter<'_> { + type Item = Row; + type Error = Error; + + fn next(&mut self) -> Result, Error> { + let it = &mut self.it; + self.connection + .block_on(async { it.next().await.transpose() }) + } +} diff --git a/postgres/src/rows.rs b/postgres/src/rows.rs deleted file mode 100644 index 25b006fd8..000000000 --- a/postgres/src/rows.rs +++ /dev/null @@ -1,342 +0,0 @@ -//! Query result rows. - -use fallible_iterator::FallibleIterator; -use postgres_protocol::message::frontend; -use postgres_shared::rows::RowData; -use std::collections::VecDeque; -use std::fmt; -use std::io; -use std::ops::Deref; -use std::slice; -use std::sync::Arc; - -#[doc(inline)] -pub use postgres_shared::rows::RowIndex; - -use error; -use stmt::{Column, Statement}; -use transaction::Transaction; -use types::{FromSql, WrongType}; -use {Error, Result, StatementInfo}; - -enum MaybeOwned<'a, T: 'a> { - Borrowed(&'a T), - Owned(T), -} - -impl<'a, T> Deref for MaybeOwned<'a, T> { - type Target = T; - - fn deref(&self) -> &T { - match *self { - MaybeOwned::Borrowed(s) => s, - MaybeOwned::Owned(ref s) => s, - } - } -} - -/// The resulting rows of a query. -pub struct Rows { - stmt_info: Arc, - data: Vec, -} - -impl fmt::Debug for Rows { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("Rows") - .field("columns", &self.columns()) - .field("rows", &self.data.len()) - .finish() - } -} - -impl Rows { - pub(crate) fn new(stmt: &Statement, data: Vec) -> Rows { - Rows { - stmt_info: stmt.info().clone(), - data: data, - } - } - - /// Returns a slice describing the columns of the `Rows`. - pub fn columns(&self) -> &[Column] { - &self.stmt_info.columns[..] - } - - /// Returns the number of rows present. - pub fn len(&self) -> usize { - self.data.len() - } - - /// Determines if there are any rows present. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns a specific `Row`. - /// - /// # Panics - /// - /// Panics if `idx` is out of bounds. - pub fn get<'a>(&'a self, idx: usize) -> Row<'a> { - Row { - stmt_info: &self.stmt_info, - data: MaybeOwned::Borrowed(&self.data[idx]), - } - } - - /// Returns an iterator over the `Row`s. - pub fn iter<'a>(&'a self) -> Iter<'a> { - Iter { - stmt_info: &self.stmt_info, - iter: self.data.iter(), - } - } -} - -impl<'a> IntoIterator for &'a Rows { - type Item = Row<'a>; - type IntoIter = Iter<'a>; - - fn into_iter(self) -> Iter<'a> { - self.iter() - } -} - -/// An iterator over `Row`s. -pub struct Iter<'a> { - stmt_info: &'a StatementInfo, - iter: slice::Iter<'a, RowData>, -} - -impl<'a> Iterator for Iter<'a> { - type Item = Row<'a>; - - fn next(&mut self) -> Option> { - self.iter.next().map(|row| Row { - stmt_info: self.stmt_info, - data: MaybeOwned::Borrowed(row), - }) - } - - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } -} - -impl<'a> DoubleEndedIterator for Iter<'a> { - fn next_back(&mut self) -> Option> { - self.iter.next_back().map(|row| Row { - stmt_info: self.stmt_info, - data: MaybeOwned::Borrowed(row), - }) - } -} - -impl<'a> ExactSizeIterator for Iter<'a> {} - -/// A single result row of a query. -pub struct Row<'a> { - stmt_info: &'a StatementInfo, - data: MaybeOwned<'a, RowData>, -} - -impl<'a> fmt::Debug for Row<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("Row") - .field("statement", self.stmt_info) - .finish() - } -} - -impl<'a> Row<'a> { - /// Returns the number of values in the row. - pub fn len(&self) -> usize { - self.data.len() - } - - /// Determines if there are any values in the row. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns a slice describing the columns of the `Row`. - pub fn columns(&self) -> &[Column] { - &self.stmt_info.columns[..] - } - - /// Retrieves the contents of a field of the row. - /// - /// A field can be accessed by the name or index of its column, though - /// access by index is more efficient. Rows are 0-indexed. - /// - /// # Panics - /// - /// Panics if the index does not reference a column or the return type is - /// not compatible with the Postgres type. - /// - /// # Example - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// let stmt = conn.prepare("SELECT foo, bar from BAZ").unwrap(); - /// for row in &stmt.query(&[]).unwrap() { - /// let foo: i32 = row.get(0); - /// let bar: String = row.get("bar"); - /// println!("{}: {}", foo, bar); - /// } - /// ``` - pub fn get<'b, I, T>(&'b self, idx: I) -> T - where - I: RowIndex + fmt::Debug, - T: FromSql<'b>, - { - match self.get_inner(&idx) { - Some(Ok(ok)) => ok, - Some(Err(err)) => panic!("error retrieving column {:?}: {:?}", idx, err), - None => panic!("no such column {:?}", idx), - } - } - - /// Retrieves the contents of a field of the row. - /// - /// A field can be accessed by the name or index of its column, though - /// access by index is more efficient. Rows are 0-indexed. - /// - /// Returns `None` if the index does not reference a column, `Some(Err(..))` - /// if there was an error converting the result value, and `Some(Ok(..))` - /// on success. - pub fn get_opt<'b, I, T>(&'b self, idx: I) -> Option> - where - I: RowIndex, - T: FromSql<'b>, - { - self.get_inner(&idx) - } - - fn get_inner<'b, I, T>(&'b self, idx: &I) -> Option> - where - I: RowIndex, - T: FromSql<'b>, - { - let idx = match idx.__idx(&self.stmt_info.columns) { - Some(idx) => idx, - None => return None, - }; - - let ty = self.stmt_info.columns[idx].type_(); - if !::accepts(ty) { - return Some(Err(error::conversion(Box::new(WrongType::new(ty.clone()))))); - } - let value = FromSql::from_sql_nullable(ty, self.data.get(idx)); - Some(value.map_err(error::conversion)) - } -} - -/// A lazily-loaded iterator over the resulting rows of a query. -pub struct LazyRows<'trans, 'stmt> { - stmt: &'stmt Statement<'stmt>, - data: VecDeque, - name: String, - row_limit: i32, - more_rows: bool, - finished: bool, - _trans: &'trans Transaction<'trans>, -} - -impl<'a, 'b> Drop for LazyRows<'a, 'b> { - fn drop(&mut self) { - if !self.finished { - let _ = self.finish_inner(); - } - } -} - -impl<'a, 'b> fmt::Debug for LazyRows<'a, 'b> { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("LazyRows") - .field("name", &self.name) - .field("row_limit", &self.row_limit) - .field("remaining_rows", &self.data.len()) - .field("more_rows", &self.more_rows) - .finish() - } -} - -impl<'trans, 'stmt> LazyRows<'trans, 'stmt> { - pub(crate) fn new( - stmt: &'stmt Statement<'stmt>, - data: VecDeque, - name: String, - row_limit: i32, - more_rows: bool, - finished: bool, - trans: &'trans Transaction<'trans>, - ) -> LazyRows<'trans, 'stmt> { - LazyRows { - stmt: stmt, - data: data, - name: name, - row_limit: row_limit, - more_rows: more_rows, - finished: finished, - _trans: trans, - } - } - - fn finish_inner(&mut self) -> Result<()> { - let mut conn = self.stmt.conn().0.borrow_mut(); - check_desync!(conn); - conn.close_statement(&self.name, b'P') - } - - fn execute(&mut self) -> Result<()> { - let mut conn = self.stmt.conn().0.borrow_mut(); - - conn.stream - .write_message(|buf| frontend::execute(&self.name, self.row_limit, buf))?; - conn.stream - .write_message(|buf| Ok::<(), io::Error>(frontend::sync(buf)))?; - conn.stream.flush()?; - conn.read_rows(|row| self.data.push_back(row)) - .map(|more_rows| self.more_rows = more_rows) - } - - /// Returns a slice describing the columns of the `LazyRows`. - pub fn columns(&self) -> &[Column] { - self.stmt.columns() - } - - /// Consumes the `LazyRows`, cleaning up associated state. - /// - /// Functionally identical to the `Drop` implementation on `LazyRows` - /// except that it returns any error to the caller. - pub fn finish(mut self) -> Result<()> { - self.finish_inner() - } -} - -impl<'trans, 'stmt> FallibleIterator for LazyRows<'trans, 'stmt> { - type Item = Row<'stmt>; - type Error = Error; - - fn next(&mut self) -> Result>> { - if self.data.is_empty() && self.more_rows { - self.execute()?; - } - - let row = self.data.pop_front().map(|r| Row { - stmt_info: &**self.stmt.info(), - data: MaybeOwned::Owned(r), - }); - - Ok(row) - } - - fn size_hint(&self) -> (usize, Option) { - let lower = self.data.len(); - let upper = if self.more_rows { None } else { Some(lower) }; - (lower, upper) - } -} diff --git a/postgres/src/stmt.rs b/postgres/src/stmt.rs deleted file mode 100644 index 500da8080..000000000 --- a/postgres/src/stmt.rs +++ /dev/null @@ -1,649 +0,0 @@ -//! Prepared statements - -use fallible_iterator::FallibleIterator; -use std::cell::Cell; -use std::collections::VecDeque; -use std::fmt; -use std::io::{self, Read, Write}; -use std::sync::Arc; -use postgres_protocol::message::{backend, frontend}; -use postgres_shared::rows::RowData; - -#[doc(inline)] -pub use postgres_shared::stmt::Column; - -use types::{Type, ToSql}; -use rows::{Rows, LazyRows}; -use transaction::Transaction; -use {bad_response, err, Connection, Result, StatementInfo}; - -/// A prepared statement. -pub struct Statement<'conn> { - conn: &'conn Connection, - info: Arc, - next_portal_id: Cell, - finished: bool, -} - -impl<'a> fmt::Debug for Statement<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt::Debug::fmt(&*self.info, fmt) - } -} - -impl<'conn> Drop for Statement<'conn> { - fn drop(&mut self) { - let _ = self.finish_inner(); - } -} - -impl<'conn> Statement<'conn> { - pub(crate) fn new( - conn: &'conn Connection, - info: Arc, - next_portal_id: Cell, - finished: bool, - ) -> Statement<'conn> { - Statement { - conn: conn, - info: info, - next_portal_id: next_portal_id, - finished: finished, - } - } - - pub(crate) fn info(&self) -> &Arc { - &self.info - } - - pub(crate) fn conn(&self) -> &'conn Connection { - self.conn - } - - pub(crate) fn into_query(self, params: &[&ToSql]) -> Result { - check_desync!(self.conn); - let mut rows = vec![]; - self.inner_query("", 0, params, |row| rows.push(row))?; - Ok(Rows::new(&self, rows)) - } - - fn finish_inner(&mut self) -> Result<()> { - if self.finished { - Ok(()) - } else { - self.finished = true; - let mut conn = self.conn.0.borrow_mut(); - check_desync!(conn); - conn.close_statement(&self.info.name, b'S') - } - } - - #[allow(type_complexity)] - fn inner_query( - &self, - portal_name: &str, - row_limit: i32, - params: &[&ToSql], - acceptor: F, - ) -> Result - where - F: FnMut(RowData), - { - let mut conn = self.conn.0.borrow_mut(); - - conn.raw_execute( - &self.info.name, - portal_name, - row_limit, - self.param_types(), - params, - )?; - - conn.read_rows(acceptor) - } - - /// Returns a slice containing the expected parameter types. - pub fn param_types(&self) -> &[Type] { - &self.info.param_types - } - - /// Returns a slice describing the columns of the result of the query. - pub fn columns(&self) -> &[Column] { - &self.info.columns - } - - /// Executes the prepared statement, returning the number of rows modified. - /// - /// If the statement does not modify any rows (e.g. SELECT), 0 is returned. - /// - /// # Panics - /// - /// Panics if the number of parameters provided does not match the number - /// expected. - /// - /// # Example - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// # let bar = 1i32; - /// # let baz = true; - /// let stmt = conn.prepare("UPDATE foo SET bar = $1 WHERE baz = $2").unwrap(); - /// let rows_updated = stmt.execute(&[&bar, &baz]).unwrap(); - /// println!("{} rows updated", rows_updated); - /// ``` - pub fn execute(&self, params: &[&ToSql]) -> Result { - let mut conn = self.conn.0.borrow_mut(); - check_desync!(conn); - conn.raw_execute( - &self.info.name, - "", - 0, - self.param_types(), - params, - )?; - - let num; - loop { - match conn.read_message()? { - backend::Message::DataRow(_) => {} - backend::Message::ErrorResponse(body) => { - conn.wait_for_ready()?; - return Err(err(&mut body.fields())); - } - backend::Message::CommandComplete(body) => { - num = parse_update_count(body.tag()?); - break; - } - backend::Message::EmptyQueryResponse => { - num = 0; - break; - } - backend::Message::CopyInResponse(_) => { - conn.stream.write_message(|buf| { - frontend::copy_fail("COPY queries cannot be directly executed", buf) - })?; - conn.stream.write_message( - |buf| Ok::<(), io::Error>(frontend::sync(buf)), - )?; - conn.stream.flush()?; - } - backend::Message::CopyOutResponse(_) => { - loop { - match conn.read_message()? { - backend::Message::CopyDone => break, - backend::Message::ErrorResponse(body) => { - conn.wait_for_ready()?; - return Err(err(&mut body.fields())); - } - _ => {} - } - } - num = 0; - break; - } - _ => { - conn.desynchronized = true; - return Err(bad_response().into()); - } - } - } - conn.wait_for_ready()?; - - Ok(num) - } - - /// Executes the prepared statement, returning the resulting rows. - /// - /// # Panics - /// - /// Panics if the number of parameters provided does not match the number - /// expected. - /// - /// # Example - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// let stmt = conn.prepare("SELECT foo FROM bar WHERE baz = $1").unwrap(); - /// # let baz = true; - /// for row in &stmt.query(&[&baz]).unwrap() { - /// let foo: i32 = row.get("foo"); - /// println!("foo: {}", foo); - /// } - /// ``` - pub fn query(&self, params: &[&ToSql]) -> Result { - check_desync!(self.conn); - let mut rows = vec![]; - self.inner_query("", 0, params, |row| rows.push(row))?; - Ok(Rows::new(self, rows)) - } - - /// Executes the prepared statement, returning a lazily loaded iterator - /// over the resulting rows. - /// - /// No more than `row_limit` rows will be stored in memory at a time. Rows - /// will be pulled from the database in batches of `row_limit` as needed. - /// If `row_limit` is less than or equal to 0, `lazy_query` is equivalent - /// to `query`. - /// - /// This can only be called inside of a transaction, and the `Transaction` - /// object representing the active transaction must be passed to - /// `lazy_query`. - /// - /// # Panics - /// - /// Panics if the provided `Transaction` is not associated with the same - /// `Connection` as this `Statement`, if the `Transaction` is not - /// active, or if the number of parameters provided does not match the - /// number of parameters expected. - /// - /// # Examples - /// - /// ```no_run - /// extern crate fallible_iterator; - /// extern crate postgres; - /// - /// use fallible_iterator::FallibleIterator; - /// # use postgres::{Connection, TlsMode}; - /// - /// # fn main() { - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// let stmt = conn.prepare("SELECT foo FROM bar WHERE baz = $1").unwrap(); - /// let trans = conn.transaction().unwrap(); - /// # let baz = true; - /// let mut rows = stmt.lazy_query(&trans, &[&baz], 100).unwrap(); - /// - /// while let Some(row) = rows.next().unwrap() { - /// let foo: i32 = row.get("foo"); - /// println!("foo: {}", foo); - /// } - /// # } - /// ``` - pub fn lazy_query<'trans, 'stmt>( - &'stmt self, - trans: &'trans Transaction, - params: &[&ToSql], - row_limit: i32, - ) -> Result> { - assert!( - self.conn as *const _ == trans.conn() as *const _, - "the `Transaction` passed to `lazy_query` must be associated with the same \ - `Connection` as the `Statement`" - ); - let conn = self.conn.0.borrow(); - check_desync!(conn); - assert!( - conn.trans_depth == trans.depth(), - "`lazy_query` must be passed the active transaction" - ); - drop(conn); - - let id = self.next_portal_id.get(); - self.next_portal_id.set(id + 1); - let portal_name = format!("{}p{}", self.info.name, id); - - let mut rows = VecDeque::new(); - let more_rows = self.inner_query( - &portal_name, - row_limit, - params, - |row| rows.push_back(row), - )?; - Ok(LazyRows::new( - self, - rows, - portal_name, - row_limit, - more_rows, - false, - trans, - )) - } - - /// Executes a `COPY FROM STDIN` statement, returning the number of rows - /// added. - /// - /// The contents of the provided reader are passed to the Postgres server - /// verbatim; it is the caller's responsibility to ensure it uses the - /// proper format. See the - /// [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html) - /// for details. - /// - /// If the statement is not a `COPY FROM STDIN` statement it will still be - /// executed and this method will return an error. - /// - /// # Examples - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// conn.batch_execute("CREATE TABLE people (id INT PRIMARY KEY, name VARCHAR)").unwrap(); - /// let stmt = conn.prepare("COPY people FROM STDIN").unwrap(); - /// stmt.copy_in(&[], &mut "1\tjohn\n2\tjane\n".as_bytes()).unwrap(); - /// ``` - pub fn copy_in(&self, params: &[&ToSql], r: &mut R) -> Result { - let mut conn = self.conn.0.borrow_mut(); - conn.raw_execute( - &self.info.name, - "", - 0, - self.param_types(), - params, - )?; - - let (format, column_formats) = match conn.read_message()? { - backend::Message::CopyInResponse(body) => { - let format = body.format(); - let column_formats = body.column_formats().map(|f| Format::from_u16(f)).collect()?; - (format, column_formats) - } - backend::Message::ErrorResponse(body) => { - conn.wait_for_ready()?; - return Err(err(&mut body.fields())); - } - _ => { - loop { - if let backend::Message::ReadyForQuery(_) = conn.read_message()? { - return Err( - io::Error::new( - io::ErrorKind::InvalidInput, - "called `copy_in` on a non-`COPY FROM STDIN` statement", - ).into(), - ); - } - } - } - }; - - let info = CopyInfo { - format: Format::from_u16(format as u16), - column_formats: column_formats, - }; - - let mut buf = [0; 16 * 1024]; - loop { - match fill_copy_buf(&mut buf, r, &info) { - Ok(0) => break, - Ok(len) => { - conn.stream.write_message( - |out| frontend::copy_data(&buf[..len], out), - )?; - } - Err(err) => { - conn.stream.write_message( - |buf| frontend::copy_fail("", buf), - )?; - conn.stream.write_message(|buf| { - Ok::<(), io::Error>(frontend::copy_done(buf)) - })?; - conn.stream.write_message( - |buf| Ok::<(), io::Error>(frontend::sync(buf)), - )?; - conn.stream.flush()?; - match conn.read_message()? { - backend::Message::ErrorResponse(_) => { - // expected from the CopyFail - } - _ => { - conn.desynchronized = true; - return Err(bad_response().into()); - } - } - conn.wait_for_ready()?; - return Err(err.into()); - } - } - } - - conn.stream.write_message(|buf| { - Ok::<(), io::Error>(frontend::copy_done(buf)) - })?; - conn.stream.write_message( - |buf| Ok::<(), io::Error>(frontend::sync(buf)), - )?; - conn.stream.flush()?; - - let num = match conn.read_message()? { - backend::Message::CommandComplete(body) => parse_update_count(body.tag()?), - backend::Message::ErrorResponse(body) => { - conn.wait_for_ready()?; - return Err(err(&mut body.fields())); - } - _ => { - conn.desynchronized = true; - return Err(bad_response().into()); - } - }; - - conn.wait_for_ready()?; - Ok(num) - } - - /// Executes a `COPY TO STDOUT` statement, passing the resulting data to - /// the provided writer and returning the number of rows received. - /// - /// See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/sql-copy.html) - /// for details on the data format. - /// - /// If the statement is not a `COPY TO STDOUT` statement it will still be - /// executed and this method will return an error. - /// - /// # Examples - /// - /// ```rust,no_run - /// # use postgres::{Connection, TlsMode}; - /// # let conn = Connection::connect("", TlsMode::None).unwrap(); - /// conn.batch_execute(" - /// CREATE TABLE people (id INT PRIMARY KEY, name VARCHAR); - /// INSERT INTO people (id, name) VALUES (1, 'john'), (2, 'jane');").unwrap(); - /// let stmt = conn.prepare("COPY people TO STDOUT").unwrap(); - /// let mut buf = vec![]; - /// stmt.copy_out(&[], &mut buf).unwrap(); - /// assert_eq!(buf, b"1\tjohn\n2\tjane\n"); - /// ``` - pub fn copy_out<'a, W: WriteWithInfo>(&'a self, params: &[&ToSql], w: &mut W) -> Result { - let mut conn = self.conn.0.borrow_mut(); - conn.raw_execute( - &self.info.name, - "", - 0, - self.param_types(), - params, - )?; - - let (format, column_formats) = match conn.read_message()? { - backend::Message::CopyOutResponse(body) => { - let format = body.format(); - let column_formats = body.column_formats().map(|f| Format::from_u16(f)).collect()?; - (format, column_formats) - } - backend::Message::CopyInResponse(_) => { - conn.stream.write_message( - |buf| frontend::copy_fail("", buf), - )?; - conn.stream.write_message(|buf| { - Ok::<(), io::Error>(frontend::copy_done(buf)) - })?; - conn.stream.write_message( - |buf| Ok::<(), io::Error>(frontend::sync(buf)), - )?; - conn.stream.flush()?; - match conn.read_message()? { - backend::Message::ErrorResponse(_) => { - // expected from the CopyFail - } - _ => { - conn.desynchronized = true; - return Err(bad_response().into()); - } - } - conn.wait_for_ready()?; - return Err( - io::Error::new( - io::ErrorKind::InvalidInput, - "called `copy_out` on a non-`COPY TO STDOUT` statement", - ).into(), - ); - } - backend::Message::ErrorResponse(body) => { - conn.wait_for_ready()?; - return Err(err(&mut body.fields())); - } - _ => { - loop { - if let backend::Message::ReadyForQuery(_) = conn.read_message()? { - return Err( - io::Error::new( - io::ErrorKind::InvalidInput, - "called `copy_out` on a non-`COPY TO STDOUT` statement", - ).into(), - ); - } - } - } - }; - - let info = CopyInfo { - format: Format::from_u16(format as u16), - column_formats: column_formats, - }; - - let count; - loop { - match conn.read_message()? { - backend::Message::CopyData(body) => { - let mut data = body.data(); - while !data.is_empty() { - match w.write_with_info(data, &info) { - Ok(n) => data = &data[n..], - Err(e) => { - loop { - if let backend::Message::ReadyForQuery(_) = - conn.read_message()? - { - return Err(e.into()); - } - } - } - } - } - } - backend::Message::CopyDone => {} - backend::Message::CommandComplete(body) => { - count = parse_update_count(body.tag()?); - break; - } - backend::Message::ErrorResponse(body) => { - loop { - if let backend::Message::ReadyForQuery(_) = conn.read_message()? { - return Err(err(&mut body.fields())); - } - } - } - _ => { - loop { - if let backend::Message::ReadyForQuery(_) = conn.read_message()? { - return Err(bad_response().into()); - } - } - } - } - } - - conn.wait_for_ready()?; - Ok(count) - } - - /// Consumes the statement, clearing it from the Postgres session. - /// - /// If this statement was created via the `prepare_cached` method, `finish` - /// does nothing. - /// - /// Functionally identical to the `Drop` implementation of the - /// `Statement` except that it returns any error to the caller. - pub fn finish(mut self) -> Result<()> { - self.finish_inner() - } -} - -fn fill_copy_buf(buf: &mut [u8], r: &mut R, info: &CopyInfo) -> io::Result { - let mut nread = 0; - while nread < buf.len() { - match r.read_with_info(&mut buf[nread..], info) { - Ok(0) => break, - Ok(n) => nread += n, - Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} - Err(e) => return Err(e), - } - } - Ok(nread) -} - -/// A struct containing information relevant for a `COPY` operation. -pub struct CopyInfo { - format: Format, - column_formats: Vec, -} - -impl CopyInfo { - /// Returns the format of the overall data. - pub fn format(&self) -> Format { - self.format - } - - /// Returns the format of the individual columns. - pub fn column_formats(&self) -> &[Format] { - &self.column_formats - } -} - -/// Like `Read` except that a `CopyInfo` object is provided as well. -/// -/// All types that implement `Read` also implement this trait. -pub trait ReadWithInfo { - /// Like `Read::read`. - fn read_with_info(&mut self, buf: &mut [u8], info: &CopyInfo) -> io::Result; -} - -impl ReadWithInfo for R { - fn read_with_info(&mut self, buf: &mut [u8], _: &CopyInfo) -> io::Result { - self.read(buf) - } -} - -/// Like `Write` except that a `CopyInfo` object is provided as well. -/// -/// All types that implement `Write` also implement this trait. -pub trait WriteWithInfo { - /// Like `Write::write`. - fn write_with_info(&mut self, buf: &[u8], info: &CopyInfo) -> io::Result; -} - -impl WriteWithInfo for W { - fn write_with_info(&mut self, buf: &[u8], _: &CopyInfo) -> io::Result { - self.write(buf) - } -} - -/// The format of a portion of COPY query data. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub enum Format { - /// A text based format. - Text, - /// A binary format. - Binary, -} - -impl Format { - fn from_u16(value: u16) -> Format { - match value { - 0 => Format::Text, - _ => Format::Binary, - } - } -} - -fn parse_update_count(tag: &str) -> u64 { - tag.split(' ').last().unwrap().parse().unwrap_or(0) -} diff --git a/postgres/src/test.rs b/postgres/src/test.rs new file mode 100644 index 000000000..0fd404574 --- /dev/null +++ b/postgres/src/test.rs @@ -0,0 +1,510 @@ +use std::io::{Read, Write}; +use std::str::FromStr; +use std::sync::mpsc; +use std::thread; +use std::time::Duration; +use tokio_postgres::error::SqlState; +use tokio_postgres::types::Type; +use tokio_postgres::NoTls; + +use super::*; +use crate::binary_copy::{BinaryCopyInWriter, BinaryCopyOutIter}; +use fallible_iterator::FallibleIterator; + +#[test] +fn prepare() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + let stmt = client.prepare("SELECT 1::INT, $1::TEXT").unwrap(); + assert_eq!(stmt.params(), &[Type::TEXT]); + assert_eq!(stmt.columns().len(), 2); + assert_eq!(stmt.columns()[0].type_(), &Type::INT4); + assert_eq!(stmt.columns()[1].type_(), &Type::TEXT); +} + +#[test] +fn query_prepared() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + let stmt = client.prepare("SELECT $1::TEXT").unwrap(); + let rows = client.query(&stmt, &[&"hello"]).unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "hello"); +} + +#[test] +fn query_unprepared() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + let rows = client.query("SELECT $1::TEXT", &[&"hello"]).unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "hello"); +} + +#[test] +fn transaction_commit() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .simple_query("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)") + .unwrap(); + + let mut transaction = client.transaction().unwrap(); + + transaction + .execute("INSERT INTO foo DEFAULT VALUES", &[]) + .unwrap(); + + transaction.commit().unwrap(); + + let rows = client.query("SELECT * FROM foo", &[]).unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, i32>(0), 1); +} + +#[test] +fn transaction_rollback() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .simple_query("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)") + .unwrap(); + + let mut transaction = client.transaction().unwrap(); + + transaction + .execute("INSERT INTO foo DEFAULT VALUES", &[]) + .unwrap(); + + transaction.rollback().unwrap(); + + let rows = client.query("SELECT * FROM foo", &[]).unwrap(); + assert_eq!(rows.len(), 0); +} + +#[test] +fn transaction_drop() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .simple_query("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)") + .unwrap(); + + let mut transaction = client.transaction().unwrap(); + + transaction + .execute("INSERT INTO foo DEFAULT VALUES", &[]) + .unwrap(); + + drop(transaction); + + let rows = client.query("SELECT * FROM foo", &[]).unwrap(); + assert_eq!(rows.len(), 0); +} + +#[test] +fn transaction_drop_immediate_rollback() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + let mut client2 = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .simple_query("CREATE TABLE IF NOT EXISTS foo (id SERIAL PRIMARY KEY)") + .unwrap(); + + client + .execute("INSERT INTO foo VALUES (1) ON CONFLICT DO NOTHING", &[]) + .unwrap(); + + let mut transaction = client.transaction().unwrap(); + + transaction + .execute("SELECT * FROM foo FOR UPDATE", &[]) + .unwrap(); + + drop(transaction); + + let rows = client2.query("SELECT * FROM foo FOR UPDATE", &[]).unwrap(); + assert_eq!(rows.len(), 1); +} + +#[test] +fn nested_transactions() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .batch_execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)") + .unwrap(); + + let mut transaction = client.transaction().unwrap(); + + transaction + .execute("INSERT INTO foo (id) VALUES (1)", &[]) + .unwrap(); + + let mut transaction2 = transaction.transaction().unwrap(); + + transaction2 + .execute("INSERT INTO foo (id) VALUES (2)", &[]) + .unwrap(); + + transaction2.rollback().unwrap(); + + let rows = transaction + .query("SELECT id FROM foo ORDER BY id", &[]) + .unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, i32>(0), 1); + + let mut transaction3 = transaction.transaction().unwrap(); + + transaction3 + .execute("INSERT INTO foo (id) VALUES(3)", &[]) + .unwrap(); + + let mut transaction4 = transaction3.transaction().unwrap(); + + transaction4 + .execute("INSERT INTO foo (id) VALUES(4)", &[]) + .unwrap(); + + transaction4.commit().unwrap(); + transaction3.commit().unwrap(); + transaction.commit().unwrap(); + + let rows = client.query("SELECT id FROM foo ORDER BY id", &[]).unwrap(); + assert_eq!(rows.len(), 3); + assert_eq!(rows[0].get::<_, i32>(0), 1); + assert_eq!(rows[1].get::<_, i32>(0), 3); + assert_eq!(rows[2].get::<_, i32>(0), 4); +} + +#[test] +fn savepoints() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .batch_execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)") + .unwrap(); + + let mut transaction = client.transaction().unwrap(); + + transaction + .execute("INSERT INTO foo (id) VALUES (1)", &[]) + .unwrap(); + + let mut savepoint1 = transaction.savepoint("savepoint1").unwrap(); + + savepoint1 + .execute("INSERT INTO foo (id) VALUES (2)", &[]) + .unwrap(); + + savepoint1.rollback().unwrap(); + + let rows = transaction + .query("SELECT id FROM foo ORDER BY id", &[]) + .unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, i32>(0), 1); + + let mut savepoint2 = transaction.savepoint("savepoint2").unwrap(); + + savepoint2 + .execute("INSERT INTO foo (id) VALUES(3)", &[]) + .unwrap(); + + let mut savepoint3 = savepoint2.savepoint("savepoint3").unwrap(); + + savepoint3 + .execute("INSERT INTO foo (id) VALUES(4)", &[]) + .unwrap(); + + savepoint3.commit().unwrap(); + savepoint2.commit().unwrap(); + transaction.commit().unwrap(); + + let rows = client.query("SELECT id FROM foo ORDER BY id", &[]).unwrap(); + assert_eq!(rows.len(), 3); + assert_eq!(rows[0].get::<_, i32>(0), 1); + assert_eq!(rows[1].get::<_, i32>(0), 3); + assert_eq!(rows[2].get::<_, i32>(0), 4); +} + +#[test] +fn copy_in() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)") + .unwrap(); + + let mut writer = client.copy_in("COPY foo FROM stdin").unwrap(); + writer.write_all(b"1\tsteven\n2\ttimothy").unwrap(); + writer.finish().unwrap(); + + let rows = client + .query("SELECT id, name FROM foo ORDER BY id", &[]) + .unwrap(); + + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::<_, i32>(0), 1); + assert_eq!(rows[0].get::<_, &str>(1), "steven"); + assert_eq!(rows[1].get::<_, i32>(0), 2); + assert_eq!(rows[1].get::<_, &str>(1), "timothy"); +} + +#[test] +fn copy_in_abort() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)") + .unwrap(); + + let mut writer = client.copy_in("COPY foo FROM stdin").unwrap(); + writer.write_all(b"1\tsteven\n2\ttimothy").unwrap(); + drop(writer); + + let rows = client + .query("SELECT id, name FROM foo ORDER BY id", &[]) + .unwrap(); + + assert_eq!(rows.len(), 0); +} + +#[test] +fn binary_copy_in() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .simple_query("CREATE TEMPORARY TABLE foo (id INT, name TEXT)") + .unwrap(); + + let writer = client.copy_in("COPY foo FROM stdin BINARY").unwrap(); + let mut writer = BinaryCopyInWriter::new(writer, &[Type::INT4, Type::TEXT]); + writer.write(&[&1i32, &"steven"]).unwrap(); + writer.write(&[&2i32, &"timothy"]).unwrap(); + writer.finish().unwrap(); + + let rows = client + .query("SELECT id, name FROM foo ORDER BY id", &[]) + .unwrap(); + + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::<_, i32>(0), 1); + assert_eq!(rows[0].get::<_, &str>(1), "steven"); + assert_eq!(rows[1].get::<_, i32>(0), 2); + assert_eq!(rows[1].get::<_, &str>(1), "timothy"); +} + +#[test] +fn copy_out() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .simple_query( + "CREATE TEMPORARY TABLE foo (id INT, name TEXT); + INSERT INTO foo (id, name) VALUES (1, 'steven'), (2, 'timothy');", + ) + .unwrap(); + + let mut reader = client.copy_out("COPY foo (id, name) TO STDOUT").unwrap(); + let mut s = String::new(); + reader.read_to_string(&mut s).unwrap(); + drop(reader); + + assert_eq!(s, "1\tsteven\n2\ttimothy\n"); + + client.simple_query("SELECT 1").unwrap(); +} + +#[test] +fn binary_copy_out() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .simple_query( + "CREATE TEMPORARY TABLE foo (id INT, name TEXT); + INSERT INTO foo (id, name) VALUES (1, 'steven'), (2, 'timothy');", + ) + .unwrap(); + + let reader = client + .copy_out("COPY foo (id, name) TO STDOUT BINARY") + .unwrap(); + let rows = BinaryCopyOutIter::new(reader, &[Type::INT4, Type::TEXT]) + .collect::>() + .unwrap(); + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::(0), 1); + assert_eq!(rows[0].get::<&str>(1), "steven"); + assert_eq!(rows[1].get::(0), 2); + assert_eq!(rows[1].get::<&str>(1), "timothy"); + + client.simple_query("SELECT 1").unwrap(); +} + +#[test] +fn portal() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .simple_query( + "CREATE TEMPORARY TABLE foo (id INT); + INSERT INTO foo (id) VALUES (1), (2), (3);", + ) + .unwrap(); + + let mut transaction = client.transaction().unwrap(); + + let portal = transaction + .bind("SELECT * FROM foo ORDER BY id", &[]) + .unwrap(); + + let rows = transaction.query_portal(&portal, 2).unwrap(); + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::<_, i32>(0), 1); + assert_eq!(rows[1].get::<_, i32>(0), 2); + + let rows = transaction.query_portal(&portal, 2).unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, i32>(0), 3); +} + +#[test] +fn cancel_query() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + let cancel_token = client.cancel_token(); + let cancel_thread = thread::spawn(move || { + thread::sleep(Duration::from_millis(100)); + cancel_token.cancel_query(NoTls).unwrap(); + }); + + match client.batch_execute("SELECT pg_sleep(100)") { + Err(e) if e.code() == Some(&SqlState::QUERY_CANCELED) => {} + t => panic!("unexpected return: {:?}", t), + } + + cancel_thread.join().unwrap(); +} + +#[test] +fn notifications_iter() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .batch_execute( + "\ + LISTEN notifications_iter; + NOTIFY notifications_iter, 'hello'; + NOTIFY notifications_iter, 'world'; + ", + ) + .unwrap(); + + let notifications = client.notifications().iter().collect::>().unwrap(); + assert_eq!(notifications.len(), 2); + assert_eq!(notifications[0].payload(), "hello"); + assert_eq!(notifications[1].payload(), "world"); +} + +#[test] +fn notifications_blocking_iter() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .batch_execute( + "\ + LISTEN notifications_blocking_iter; + NOTIFY notifications_blocking_iter, 'hello'; + ", + ) + .unwrap(); + + thread::spawn(|| { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + thread::sleep(Duration::from_secs(1)); + client + .batch_execute("NOTIFY notifications_blocking_iter, 'world'") + .unwrap(); + }); + + let notifications = client + .notifications() + .blocking_iter() + .take(2) + .collect::>() + .unwrap(); + assert_eq!(notifications.len(), 2); + assert_eq!(notifications[0].payload(), "hello"); + assert_eq!(notifications[1].payload(), "world"); +} + +#[test] +fn notifications_timeout_iter() { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + client + .batch_execute( + "\ + LISTEN notifications_timeout_iter; + NOTIFY notifications_timeout_iter, 'hello'; + ", + ) + .unwrap(); + + thread::spawn(|| { + let mut client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + + thread::sleep(Duration::from_secs(1)); + client + .batch_execute("NOTIFY notifications_timeout_iter, 'world'") + .unwrap(); + + thread::sleep(Duration::from_secs(10)); + client + .batch_execute("NOTIFY notifications_timeout_iter, '!'") + .unwrap(); + }); + + let notifications = client + .notifications() + .timeout_iter(Duration::from_secs(2)) + .collect::>() + .unwrap(); + assert_eq!(notifications.len(), 2); + assert_eq!(notifications[0].payload(), "hello"); + assert_eq!(notifications[1].payload(), "world"); +} + +#[test] +fn notice_callback() { + let (notice_tx, notice_rx) = mpsc::sync_channel(64); + let mut client = Config::from_str("host=localhost port=5433 user=postgres") + .unwrap() + .notice_callback(move |n| notice_tx.send(n).unwrap()) + .connect(NoTls) + .unwrap(); + + client + .batch_execute("DO $$BEGIN RAISE NOTICE 'custom'; END$$") + .unwrap(); + + assert_eq!(notice_rx.recv().unwrap().message(), "custom"); +} + +#[test] +fn explicit_close() { + let client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap(); + client.close().unwrap(); +} + +#[test] +fn check_send() { + fn is_send() {} + + is_send::(); + is_send::(); + is_send::>(); +} diff --git a/postgres/src/text_rows.rs b/postgres/src/text_rows.rs deleted file mode 100644 index fd3562c7b..000000000 --- a/postgres/src/text_rows.rs +++ /dev/null @@ -1,194 +0,0 @@ -//! Query result rows. - -use postgres_shared::rows::RowData; -use std::fmt; -use std::slice; -use std::str; - -#[doc(inline)] -pub use postgres_shared::rows::RowIndex; - -use stmt::{Column}; -use {Result, error}; - -/// The resulting rows of a query. -pub struct TextRows { - columns: Vec, - data: Vec, -} - -impl fmt::Debug for TextRows { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("TextRows") - .field("columns", &self.columns()) - .field("rows", &self.data.len()) - .finish() - } -} - -impl TextRows { - pub(crate) fn new(columns: Vec, data: Vec) -> TextRows { - TextRows { - columns: columns, - data: data, - } - } - - /// Returns a slice describing the columns of the `TextRows`. - pub fn columns(&self) -> &[Column] { - &self.columns[..] - } - - /// Returns the number of rows present. - pub fn len(&self) -> usize { - self.data.len() - } - - /// Determines if there are any rows present. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns a specific `TextRow`. - /// - /// # Panics - /// - /// Panics if `idx` is out of bounds. - pub fn get<'a>(&'a self, idx: usize) -> TextRow<'a> { - TextRow { - columns: &self.columns, - data: &self.data[idx], - } - } - - /// Returns an iterator over the `TextRow`s. - pub fn iter<'a>(&'a self) -> Iter<'a> { - Iter { - columns: self.columns(), - iter: self.data.iter(), - } - } -} - -impl<'a> IntoIterator for &'a TextRows { - type Item = TextRow<'a>; - type IntoIter = Iter<'a>; - - fn into_iter(self) -> Iter<'a> { - self.iter() - } -} - -/// An iterator over `TextRow`s. -pub struct Iter<'a> { - columns: &'a [Column], - iter: slice::Iter<'a, RowData>, -} - -impl<'a> Iterator for Iter<'a> { - type Item = TextRow<'a>; - - fn next(&mut self) -> Option> { - self.iter.next().map(|row| { - TextRow { - columns: self.columns, - data: row, - } - }) - } - - fn size_hint(&self) -> (usize, Option) { - self.iter.size_hint() - } -} - -impl<'a> DoubleEndedIterator for Iter<'a> { - fn next_back(&mut self) -> Option> { - self.iter.next_back().map(|row| { - TextRow { - columns: self.columns, - data: row, - } - }) - } -} - -impl<'a> ExactSizeIterator for Iter<'a> {} - -/// A single result row of a query. -pub struct TextRow<'a> { - columns: &'a [Column], - data: &'a RowData, -} - -impl<'a> fmt::Debug for TextRow<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("TextRow") - .field("columns", &self.columns) - .finish() - } -} - -impl<'a> TextRow<'a> { - /// Returns the number of values in the row. - pub fn len(&self) -> usize { - self.data.len() - } - - /// Determines if there are any values in the row. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns a slice describing the columns of the `TextRow`. - pub fn columns(&self) -> &[Column] { - self.columns - } - - /// Retrieve the contents of a field of a row - /// - /// A field can be accessed by the name or index of its column, though - /// access by index is more efficient. Rows are 0-indexed. - /// - /// # Panics - /// - /// Panics if the index does not reference a column - pub fn get(&self, idx: I) -> &str - where - I: RowIndex + fmt::Debug, - { - match self.get_inner(&idx) { - Some(Ok(value)) => value, - Some(Err(err)) => panic!("error retrieving column {:?}: {:?}", idx, err), - None => panic!("no such column {:?}", idx), - } - } - - /// Retrieves the contents of a field of the row. - /// - /// A field can be accessed by the name or index of its column, though - /// access by index is more efficient. Rows are 0-indexed. - /// - /// Returns None if the index does not reference a column, Some(Err(..)) if - /// there was an error parsing the result as UTF-8, and Some(Ok(..)) on - /// success. - pub fn get_opt(&self, idx: I) -> Option> - where - I: RowIndex, - { - self.get_inner(&idx) - } - - fn get_inner(&self, idx: &I) -> Option> - where - I: RowIndex, - { - let idx = match idx.__idx(self.columns) { - Some(idx) => idx, - None => return None, - }; - - self.data.get(idx) - .map(|s| str::from_utf8(s).map_err(|e| error::conversion(Box::new(e)))) - } -} diff --git a/postgres/src/tls.rs b/postgres/src/tls.rs deleted file mode 100644 index 2c051a61c..000000000 --- a/postgres/src/tls.rs +++ /dev/null @@ -1,60 +0,0 @@ -//! Types and traits for TLS support. -pub use priv_io::Stream; - -use std::error::Error; -use std::fmt; -use std::io::prelude::*; - -/// A trait implemented by TLS streams. -pub trait TlsStream: fmt::Debug + Read + Write + Send { - /// Returns a reference to the underlying `Stream`. - fn get_ref(&self) -> &Stream; - - /// Returns a mutable reference to the underlying `Stream`. - fn get_mut(&mut self) -> &mut Stream; - - /// Returns the data associated with the `tls-unique` channel binding type as described in - /// [RFC 5929], if supported. - /// - /// An implementation only needs to support one of this or `tls_server_end_point`. - /// - /// [RFC 5929]: https://tools.ietf.org/html/rfc5929 - fn tls_unique(&self) -> Option> { - None - } - - /// Returns the data associated with the `tls-server-end-point` channel binding type as - /// described in [RFC 5929], if supported. - /// - /// An implementation only needs to support one of this or `tls_unique`. - /// - /// [RFC 5929]: https://tools.ietf.org/html/rfc5929 - fn tls_server_end_point(&self) -> Option> { - None - } -} - -/// A trait implemented by types that can initiate a TLS session over a Postgres -/// stream. -pub trait TlsHandshake: fmt::Debug { - /// Performs a client-side TLS handshake, returning a wrapper around the - /// provided stream. - /// - /// The host portion of the connection parameters is provided for hostname - /// verification. - fn tls_handshake( - &self, - host: &str, - stream: Stream, - ) -> Result, Box>; -} - -impl TlsHandshake for Box { - fn tls_handshake( - &self, - host: &str, - stream: Stream, - ) -> Result, Box> { - (**self).tls_handshake(host, stream) - } -} diff --git a/postgres/src/transaction.rs b/postgres/src/transaction.rs index 60b1dcb06..8126b1dbe 100644 --- a/postgres/src/transaction.rs +++ b/postgres/src/transaction.rs @@ -1,321 +1,250 @@ -//! Transactions +use crate::connection::ConnectionRef; +use crate::{CancelToken, CopyInWriter, CopyOutReader, Portal, RowIter, Statement, ToStatement}; +use tokio_postgres::types::{BorrowToSql, ToSql, Type}; +use tokio_postgres::{Error, Row, SimpleQueryMessage}; -use std::cell::Cell; -use std::fmt; - -use rows::Rows; -use text_rows::TextRows; -use stmt::Statement; -use types::ToSql; -use {bad_response, Connection, Result}; - -/// An enumeration of transaction isolation levels. -/// -/// See the [Postgres documentation](http://www.postgresql.org/docs/9.4/static/transaction-iso.html) -/// for full details on the semantics of each level. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum IsolationLevel { - /// The "read uncommitted" level. - /// - /// In current versions of Postgres, this behaves identically to - /// `ReadCommitted`. - ReadUncommitted, - /// The "read committed" level. - /// - /// This is the default isolation level in Postgres. - ReadCommitted, - /// The "repeatable read" level. - RepeatableRead, - /// The "serializable" level. - Serializable, -} - -impl IsolationLevel { - pub(crate) fn new(raw: &str) -> Result { - if raw.eq_ignore_ascii_case("READ UNCOMMITTED") { - Ok(IsolationLevel::ReadUncommitted) - } else if raw.eq_ignore_ascii_case("READ COMMITTED") { - Ok(IsolationLevel::ReadCommitted) - } else if raw.eq_ignore_ascii_case("REPEATABLE READ") { - Ok(IsolationLevel::RepeatableRead) - } else if raw.eq_ignore_ascii_case("SERIALIZABLE") { - Ok(IsolationLevel::Serializable) - } else { - Err(bad_response().into()) - } - } - - fn to_sql(&self) -> &'static str { - match *self { - IsolationLevel::ReadUncommitted => "READ UNCOMMITTED", - IsolationLevel::ReadCommitted => "READ COMMITTED", - IsolationLevel::RepeatableRead => "REPEATABLE READ", - IsolationLevel::Serializable => "SERIALIZABLE", - } - } -} - -/// Configuration of a transaction. -#[derive(Debug)] -pub struct Config { - isolation_level: Option, - read_only: Option, - deferrable: Option, -} - -impl Default for Config { - fn default() -> Config { - Config { - isolation_level: None, - read_only: None, - deferrable: None, - } - } -} - -impl Config { - pub(crate) fn build_command(&self, s: &mut String) { - let mut first = true; - - if let Some(isolation_level) = self.isolation_level { - s.push_str(" ISOLATION LEVEL "); - s.push_str(isolation_level.to_sql()); - first = false; - } - - if let Some(read_only) = self.read_only { - if !first { - s.push(','); - } - if read_only { - s.push_str(" READ ONLY"); - } else { - s.push_str(" READ WRITE"); - } - first = false; - } - - if let Some(deferrable) = self.deferrable { - if !first { - s.push(','); - } - if deferrable { - s.push_str(" DEFERRABLE"); - } else { - s.push_str(" NOT DEFERRABLE"); - } - } - } - - /// Creates a new `Config` with no configuration overrides. - pub fn new() -> Config { - Config::default() - } - - /// Sets the isolation level of the configuration. - pub fn isolation_level(&mut self, isolation_level: IsolationLevel) -> &mut Config { - self.isolation_level = Some(isolation_level); - self - } - - /// Sets the read-only property of a transaction. - /// - /// If enabled, a transaction will be unable to modify any persistent - /// database state. - pub fn read_only(&mut self, read_only: bool) -> &mut Config { - self.read_only = Some(read_only); - self - } - - /// Sets the deferrable property of a transaction. - /// - /// If enabled in a read only, serializable transaction, the transaction may - /// block when created, after which it will run without the normal overhead - /// of a serializable transaction and will not be forced to roll back due - /// to serialization failures. - pub fn deferrable(&mut self, deferrable: bool) -> &mut Config { - self.deferrable = Some(deferrable); - self - } -} - -/// A transaction on a database connection. +/// A representation of a PostgreSQL database transaction. /// -/// The transaction will roll back by default. -pub struct Transaction<'conn> { - conn: &'conn Connection, - depth: u32, - savepoint_name: Option, - commit: Cell, - finished: bool, +/// Transactions will implicitly roll back by default when dropped. Use the `commit` method to commit the changes made +/// in the transaction. Transactions can be nested, with inner transactions implemented via savepoints. +pub struct Transaction<'a> { + connection: ConnectionRef<'a>, + transaction: Option>, } -impl<'a> fmt::Debug for Transaction<'a> { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.debug_struct("Transaction") - .field("commit", &self.commit.get()) - .field("depth", &self.depth) - .finish() - } -} - -impl<'conn> Drop for Transaction<'conn> { +impl Drop for Transaction<'_> { fn drop(&mut self) { - if !self.finished { - let _ = self.finish_inner(); + if let Some(transaction) = self.transaction.take() { + let _ = self.connection.block_on(transaction.rollback()); } } } -impl<'conn> Transaction<'conn> { - pub(crate) fn new(conn: &'conn Connection, depth: u32) -> Transaction<'conn> { +impl<'a> Transaction<'a> { + pub(crate) fn new( + connection: ConnectionRef<'a>, + transaction: tokio_postgres::Transaction<'a>, + ) -> Transaction<'a> { Transaction { - conn: conn, - depth: depth, - savepoint_name: None, - commit: Cell::new(false), - finished: false, + connection, + transaction: Some(transaction), } } - pub(crate) fn conn(&self) -> &'conn Connection { - self.conn - } - - pub(crate) fn depth(&self) -> u32 { - self.depth - } - - fn finish_inner(&mut self) -> Result<()> { - let mut conn = self.conn.0.borrow_mut(); - debug_assert!(self.depth == conn.trans_depth); - conn.trans_depth -= 1; - match (self.commit.get(), &self.savepoint_name) { - (false, &Some(ref sp)) => conn.quick_query(&format!("ROLLBACK TO {}", sp))?, - (false, &None) => conn.quick_query("ROLLBACK")?, - (true, &Some(ref sp)) => conn.quick_query(&format!("RELEASE {}", sp))?, - (true, &None) => conn.quick_query("COMMIT")?, - }; - - Ok(()) - } - - /// Like `Connection::prepare`. - pub fn prepare(&self, query: &str) -> Result> { - self.conn.prepare(query) + /// Consumes the transaction, committing all changes made within it. + pub fn commit(mut self) -> Result<(), Error> { + self.connection + .block_on(self.transaction.take().unwrap().commit()) } - /// Like `Connection::prepare_cached`. + /// Rolls the transaction back, discarding all changes made within it. /// - /// # Note - /// - /// The statement will be cached for the duration of the - /// connection, not just the duration of this transaction. - pub fn prepare_cached(&self, query: &str) -> Result> { - self.conn.prepare_cached(query) - } - - /// Like `Connection::execute`. - pub fn execute(&self, query: &str, params: &[&ToSql]) -> Result { - self.conn.execute(query, params) - } - - /// Like `Connection::query`. - pub fn query<'a>(&'a self, query: &str, params: &[&ToSql]) -> Result { - self.conn.query(query, params) - } - - /// Like `Connection::batch_execute`. - #[deprecated(since="0.15.3", note="please use `simple_query` instead")] - pub fn batch_execute(&self, query: &str) -> Result<()> { - self.simple_query(query) - .map(|_| ()) - } - - /// Like `Connection::simple_query`. - pub fn simple_query(&self, query: &str) -> Result> { - self.conn.simple_query(query) - } - - /// Like `Connection::transaction`, but creates a nested transaction via - /// a savepoint. + /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. + pub fn rollback(mut self) -> Result<(), Error> { + self.connection + .block_on(self.transaction.take().unwrap().rollback()) + } + + /// Like `Client::prepare`. + pub fn prepare(&mut self, query: &str) -> Result { + self.connection + .block_on(self.transaction.as_ref().unwrap().prepare(query)) + } + + /// Like `Client::prepare_typed`. + pub fn prepare_typed(&mut self, query: &str, types: &[Type]) -> Result { + self.connection.block_on( + self.transaction + .as_ref() + .unwrap() + .prepare_typed(query, types), + ) + } + + /// Like `Client::execute`. + pub fn execute(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement, + { + self.connection + .block_on(self.transaction.as_ref().unwrap().execute(query, params)) + } + + /// Like `Client::query`. + pub fn query(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.connection + .block_on(self.transaction.as_ref().unwrap().query(query, params)) + } + + /// Like `Client::query_one`. + pub fn query_one(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement, + { + self.connection + .block_on(self.transaction.as_ref().unwrap().query_one(query, params)) + } + + /// Like `Client::query_opt`. + pub fn query_opt( + &mut self, + query: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.connection + .block_on(self.transaction.as_ref().unwrap().query_opt(query, params)) + } + + /// Like `Client::query_raw`. + pub fn query_raw(&mut self, query: &T, params: I) -> Result, Error> + where + T: ?Sized + ToStatement, + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let stream = self + .connection + .block_on(self.transaction.as_ref().unwrap().query_raw(query, params))?; + Ok(RowIter::new(self.connection.as_ref(), stream)) + } + + /// Like `Client::query_typed`. + pub fn query_typed( + &mut self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.connection.block_on( + self.transaction + .as_ref() + .unwrap() + .query_typed(statement, params), + ) + } + + /// Like `Client::query_typed_raw`. + pub fn query_typed_raw(&mut self, query: &str, params: I) -> Result, Error> + where + P: BorrowToSql, + I: IntoIterator, + { + let stream = self.connection.block_on( + self.transaction + .as_ref() + .unwrap() + .query_typed_raw(query, params), + )?; + Ok(RowIter::new(self.connection.as_ref(), stream)) + } + + /// Binds parameters to a statement, creating a "portal". /// - /// # Panics + /// Portals can be used with the `query_portal` method to page through the results of a query without being forced + /// to consume them all immediately. /// - /// Panics if there is an active nested transaction. - pub fn transaction<'a>(&'a self) -> Result> { - self.savepoint("sp") - } - - /// Like `Connection::transaction`, but creates a nested transaction via - /// a savepoint with the specified name. + /// Portals are automatically closed when the transaction they were created in is closed. /// /// # Panics /// - /// Panics if there is an active nested transaction. - pub fn savepoint<'a>(&'a self, name: &str) -> Result> { - let mut conn = self.conn.0.borrow_mut(); - check_desync!(conn); - assert!( - conn.trans_depth == self.depth, - "`savepoint` may only be called on the active transaction" - ); - conn.quick_query(&format!("SAVEPOINT {}", name))?; - conn.trans_depth += 1; - Ok(Transaction { - conn: self.conn, - depth: self.depth + 1, - savepoint_name: Some(name.to_owned()), - commit: Cell::new(false), - finished: false, - }) - } - - /// Returns a reference to the `Transaction`'s `Connection`. - pub fn connection(&self) -> &'conn Connection { - self.conn - } - - /// Like `Connection::is_active`. - pub fn is_active(&self) -> bool { - self.conn.0.borrow().trans_depth == self.depth - } - - /// Alters the configuration of the active transaction. - pub fn set_config(&self, config: &Config) -> Result<()> { - let mut command = "SET TRANSACTION".to_owned(); - config.build_command(&mut command); - self.simple_query(&command) - .map(|_| ()) - } - - /// Determines if the transaction is currently set to commit or roll back. - pub fn will_commit(&self) -> bool { - self.commit.get() - } - - /// Sets the transaction to commit at its completion. - pub fn set_commit(&self) { - self.commit.set(true); - } - - /// Sets the transaction to roll back at its completion. - pub fn set_rollback(&self) { - self.commit.set(false); - } - - /// A convenience method which consumes and commits a transaction. - pub fn commit(self) -> Result<()> { - self.set_commit(); - self.finish() + /// Panics if the number of parameters provided does not match the number expected. + pub fn bind(&mut self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement, + { + self.connection + .block_on(self.transaction.as_ref().unwrap().bind(query, params)) } - /// Consumes the transaction, commiting or rolling it back as appropriate. + /// Continues execution of a portal, returning the next set of rows. /// - /// Functionally equivalent to the `Drop` implementation of `Transaction` - /// except that it returns any error to the caller. - pub fn finish(mut self) -> Result<()> { - self.finished = true; - self.finish_inner() + /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to + /// `query_portal`. If the requested number is negative or 0, all remaining rows will be returned. + pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> Result, Error> { + self.connection.block_on( + self.transaction + .as_ref() + .unwrap() + .query_portal(portal, max_rows), + ) + } + + /// The maximally flexible version of `query_portal`. + pub fn query_portal_raw( + &mut self, + portal: &Portal, + max_rows: i32, + ) -> Result, Error> { + let stream = self.connection.block_on( + self.transaction + .as_ref() + .unwrap() + .query_portal_raw(portal, max_rows), + )?; + Ok(RowIter::new(self.connection.as_ref(), stream)) + } + + /// Like `Client::copy_in`. + pub fn copy_in(&mut self, query: &T) -> Result, Error> + where + T: ?Sized + ToStatement, + { + let sink = self + .connection + .block_on(self.transaction.as_ref().unwrap().copy_in(query))?; + Ok(CopyInWriter::new(self.connection.as_ref(), sink)) + } + + /// Like `Client::copy_out`. + pub fn copy_out(&mut self, query: &T) -> Result, Error> + where + T: ?Sized + ToStatement, + { + let stream = self + .connection + .block_on(self.transaction.as_ref().unwrap().copy_out(query))?; + Ok(CopyOutReader::new(self.connection.as_ref(), stream)) + } + + /// Like `Client::simple_query`. + pub fn simple_query(&mut self, query: &str) -> Result, Error> { + self.connection + .block_on(self.transaction.as_ref().unwrap().simple_query(query)) + } + + /// Like `Client::batch_execute`. + pub fn batch_execute(&mut self, query: &str) -> Result<(), Error> { + self.connection + .block_on(self.transaction.as_ref().unwrap().batch_execute(query)) + } + + /// Like `Client::cancel_token`. + pub fn cancel_token(&self) -> CancelToken { + CancelToken::new(self.transaction.as_ref().unwrap().cancel_token()) + } + + /// Like `Client::transaction`, but creates a nested transaction via a savepoint. + pub fn transaction(&mut self) -> Result, Error> { + let transaction = self + .connection + .block_on(self.transaction.as_mut().unwrap().transaction())?; + Ok(Transaction::new(self.connection.as_ref(), transaction)) + } + + /// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name. + pub fn savepoint(&mut self, name: I) -> Result, Error> + where + I: Into, + { + let transaction = self + .connection + .block_on(self.transaction.as_mut().unwrap().savepoint(name))?; + Ok(Transaction::new(self.connection.as_ref(), transaction)) } } diff --git a/postgres/src/transaction_builder.rs b/postgres/src/transaction_builder.rs new file mode 100644 index 000000000..e0f8a56e8 --- /dev/null +++ b/postgres/src/transaction_builder.rs @@ -0,0 +1,50 @@ +use crate::connection::ConnectionRef; +use crate::{Error, IsolationLevel, Transaction}; + +/// A builder for database transactions. +pub struct TransactionBuilder<'a> { + connection: ConnectionRef<'a>, + builder: tokio_postgres::TransactionBuilder<'a>, +} + +impl<'a> TransactionBuilder<'a> { + pub(crate) fn new( + connection: ConnectionRef<'a>, + builder: tokio_postgres::TransactionBuilder<'a>, + ) -> TransactionBuilder<'a> { + TransactionBuilder { + connection, + builder, + } + } + + /// Sets the isolation level of the transaction. + pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self { + self.builder = self.builder.isolation_level(isolation_level); + self + } + + /// Sets the access mode of the transaction. + pub fn read_only(mut self, read_only: bool) -> Self { + self.builder = self.builder.read_only(read_only); + self + } + + /// Sets the deferrability of the transaction. + /// + /// If the transaction is also serializable and read only, creation of the transaction may block, but when it + /// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to + /// serialization failure. + pub fn deferrable(mut self, deferrable: bool) -> Self { + self.builder = self.builder.deferrable(deferrable); + self + } + + /// Begins the transaction. + /// + /// The transaction will roll back by default - use the `commit` method to commit it. + pub fn start(mut self) -> Result, Error> { + let transaction = self.connection.block_on(self.builder.start())?; + Ok(Transaction::new(self.connection, transaction)) + } +} diff --git a/postgres/tests/test.rs b/postgres/tests/test.rs deleted file mode 100644 index 13ef53e7b..000000000 --- a/postgres/tests/test.rs +++ /dev/null @@ -1,1446 +0,0 @@ -extern crate fallible_iterator; -extern crate postgres; -extern crate url; - -#[macro_use] -extern crate postgres_shared; - -use fallible_iterator::FallibleIterator; -use postgres::error::ErrorPosition::Normal; -use postgres::error::{DbError, SqlState}; -use postgres::notification::Notification; -use postgres::params::IntoConnectParams; -use postgres::transaction::{self, IsolationLevel}; -use postgres::types::{Kind, Oid, Type, WrongType}; -use postgres::{Connection, GenericConnection, HandleNotice, TlsMode}; -use std::io; -use std::thread; -use std::time::Duration; - -macro_rules! or_panic { - ($e:expr) => { - match $e { - Ok(ok) => ok, - Err(err) => panic!("{:#?}", err), - } - }; -} - -mod types; - -#[test] -fn test_non_default_database() { - or_panic!(Connection::connect( - "postgres://postgres@localhost:5433/postgres", - TlsMode::None, - )); -} - -#[test] -fn test_url_terminating_slash() { - or_panic!(Connection::connect( - "postgres://postgres@localhost:5433/", - TlsMode::None, - )); -} - -#[test] -fn test_prepare_err() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let err = conn.prepare("invalid sql database").unwrap_err(); - match err.as_db() { - Some(e) if e.code == SqlState::SYNTAX_ERROR && e.position == Some(Normal(1)) => {} - _ => panic!("Unexpected result {:?}", err), - } -} - -#[test] -fn test_unknown_database() { - match Connection::connect("postgres://postgres@localhost:5433/asdf", TlsMode::None) { - Err(ref e) if e.code() == Some(&SqlState::INVALID_CATALOG_NAME) => {} - Err(resp) => panic!("Unexpected result {:?}", resp), - _ => panic!("Unexpected result"), - } -} - -#[test] -fn test_connection_finish() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - assert!(conn.finish().is_ok()); -} - -#[test] -#[ignore] // doesn't work on our CI setup -fn test_unix_connection() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = or_panic!(conn.prepare("SHOW unix_socket_directories")); - let result = or_panic!(stmt.query(&[])); - let unix_socket_directories: String = result.iter().map(|row| row.get(0)).next().unwrap(); - - if unix_socket_directories.is_empty() { - panic!("can't test connect_unix; unix_socket_directories is empty"); - } - - let unix_socket_directory = unix_socket_directories.split(',').next().unwrap(); - - let path = url::percent_encoding::utf8_percent_encode( - unix_socket_directory, - url::percent_encoding::USERINFO_ENCODE_SET, - ); - let url = format!("postgres://postgres@{}", path); - let conn = or_panic!(Connection::connect(&url[..], TlsMode::None)); - assert!(conn.finish().is_ok()); -} - -#[test] -fn test_transaction_commit() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); - - let trans = or_panic!(conn.transaction()); - or_panic!(trans.execute("INSERT INTO foo (id) VALUES ($1)", &[&1i32])); - trans.set_commit(); - drop(trans); - - let stmt = or_panic!(conn.prepare("SELECT * FROM foo")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![1i32], - result.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -fn test_transaction_commit_finish() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); - - let trans = or_panic!(conn.transaction()); - or_panic!(trans.execute("INSERT INTO foo (id) VALUES ($1)", &[&1i32])); - trans.set_commit(); - assert!(trans.finish().is_ok()); - - let stmt = or_panic!(conn.prepare("SELECT * FROM foo")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![1i32], - result.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -fn test_transaction_commit_method() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); - - let trans = or_panic!(conn.transaction()); - or_panic!(trans.execute("INSERT INTO foo (id) VALUES ($1)", &[&1i32])); - assert!(trans.commit().is_ok()); - - let stmt = or_panic!(conn.prepare("SELECT * FROM foo")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![1i32], - result.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -fn test_transaction_rollback() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); - - or_panic!(conn.execute("INSERT INTO foo (id) VALUES ($1)", &[&1i32])); - - let trans = or_panic!(conn.transaction()); - or_panic!(trans.execute("INSERT INTO foo (id) VALUES ($1)", &[&2i32])); - drop(trans); - - let stmt = or_panic!(conn.prepare("SELECT * FROM foo")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![1i32], - result.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -fn test_transaction_rollback_finish() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); - - or_panic!(conn.execute("INSERT INTO foo (id) VALUES ($1)", &[&1i32])); - - let trans = or_panic!(conn.transaction()); - or_panic!(trans.execute("INSERT INTO foo (id) VALUES ($1)", &[&2i32])); - assert!(trans.finish().is_ok()); - - let stmt = or_panic!(conn.prepare("SELECT * FROM foo")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![1i32], - result.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -fn test_nested_transactions() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); - - or_panic!(conn.execute("INSERT INTO foo (id) VALUES (1)", &[])); - - { - let trans1 = or_panic!(conn.transaction()); - or_panic!(trans1.execute("INSERT INTO foo (id) VALUES (2)", &[])); - - { - let trans2 = or_panic!(trans1.transaction()); - or_panic!(trans2.execute("INSERT INTO foo (id) VALUES (3)", &[])); - } - - { - let trans2 = or_panic!(trans1.transaction()); - or_panic!(trans2.execute("INSERT INTO foo (id) VALUES (4)", &[])); - - { - let trans3 = or_panic!(trans2.transaction()); - or_panic!(trans3.execute("INSERT INTO foo (id) VALUES (5)", &[])); - } - - { - let sp = or_panic!(trans2.savepoint("custom")); - or_panic!(sp.execute("INSERT INTO foo (id) VALUES (6)", &[])); - assert!(sp.commit().is_ok()); - } - - assert!(trans2.commit().is_ok()); - } - - let stmt = or_panic!(trans1.prepare("SELECT * FROM foo ORDER BY id")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![1i32, 2, 4, 6], - result.iter().map(|row| row.get(0)).collect::>() - ); - } - - let stmt = or_panic!(conn.prepare("SELECT * FROM foo ORDER BY id")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![1i32], - result.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -fn test_nested_transactions_finish() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); - - or_panic!(conn.execute("INSERT INTO foo (id) VALUES (1)", &[])); - - { - let trans1 = or_panic!(conn.transaction()); - or_panic!(trans1.execute("INSERT INTO foo (id) VALUES (2)", &[])); - - { - let trans2 = or_panic!(trans1.transaction()); - or_panic!(trans2.execute("INSERT INTO foo (id) VALUES (3)", &[])); - assert!(trans2.finish().is_ok()); - } - - { - let trans2 = or_panic!(trans1.transaction()); - or_panic!(trans2.execute("INSERT INTO foo (id) VALUES (4)", &[])); - - { - let trans3 = or_panic!(trans2.transaction()); - or_panic!(trans3.execute("INSERT INTO foo (id) VALUES (5)", &[])); - assert!(trans3.finish().is_ok()); - } - - { - let sp = or_panic!(trans2.savepoint("custom")); - or_panic!(sp.execute("INSERT INTO foo (id) VALUES (6)", &[])); - sp.set_commit(); - assert!(sp.finish().is_ok()); - } - - trans2.set_commit(); - assert!(trans2.finish().is_ok()); - } - - // in a block to unborrow trans1 for the finish call - { - let stmt = or_panic!(trans1.prepare("SELECT * FROM foo ORDER BY id")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![1i32, 2, 4, 6], - result.iter().map(|row| row.get(0)).collect::>() - ); - } - - assert!(trans1.finish().is_ok()); - } - - let stmt = or_panic!(conn.prepare("SELECT * FROM foo ORDER BY id")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![1i32], - result.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -#[should_panic(expected = "active transaction")] -fn test_conn_trans_when_nested() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let _trans = or_panic!(conn.transaction()); - conn.transaction().unwrap(); -} - -#[test] -#[should_panic(expected = "active transaction")] -fn test_trans_with_nested_trans() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let trans = or_panic!(conn.transaction()); - let _trans2 = or_panic!(trans.transaction()); - trans.transaction().unwrap(); -} - -#[test] -#[should_panic(expected = "active transaction")] -fn test_trans_with_savepoints() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let trans = or_panic!(conn.transaction()); - let _sp = or_panic!(trans.savepoint("custom")); - trans.savepoint("custom2").unwrap(); -} - -#[test] -fn test_stmt_execute_after_transaction() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let trans = or_panic!(conn.transaction()); - let stmt = or_panic!(trans.prepare("SELECT 1")); - or_panic!(trans.finish()); - let result = or_panic!(stmt.query(&[])); - assert_eq!(1i32, result.iter().next().unwrap().get::<_, i32>(0)); -} - -#[test] -fn test_stmt_finish() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id BIGINT PRIMARY KEY)", &[])); - let stmt = or_panic!(conn.prepare("SELECT * FROM foo")); - assert!(stmt.finish().is_ok()); -} - -#[test] -#[allow(deprecated)] -fn test_batch_execute() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let query = "CREATE TEMPORARY TABLE foo (id BIGINT PRIMARY KEY); - INSERT INTO foo (id) VALUES (10);"; - or_panic!(conn.batch_execute(query)); - - let stmt = or_panic!(conn.prepare("SELECT * from foo ORDER BY id")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![10i64], - result.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -#[allow(deprecated)] -fn test_batch_execute_error() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let query = "CREATE TEMPORARY TABLE foo (id BIGINT PRIMARY KEY); - INSERT INTO foo (id) VALUES (10); - asdfa; - INSERT INTO foo (id) VALUES (11)"; - conn.batch_execute(query).err().unwrap(); - - let stmt = conn.prepare("SELECT * FROM foo ORDER BY id"); - match stmt { - Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => {} - Err(e) => panic!("unexpected error {:?}", e), - _ => panic!("unexpected success"), - } -} - -#[test] -#[allow(deprecated)] -fn test_transaction_batch_execute() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let trans = or_panic!(conn.transaction()); - let query = "CREATE TEMPORARY TABLE foo (id BIGINT PRIMARY KEY); - INSERT INTO foo (id) VALUES (10);"; - or_panic!(trans.batch_execute(query)); - - let stmt = or_panic!(trans.prepare("SELECT * from foo ORDER BY id")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![10i64], - result.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -fn test_query() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id BIGINT PRIMARY KEY)", &[])); - or_panic!(conn.execute("INSERT INTO foo (id) VALUES ($1), ($2)", &[&1i64, &2i64])); - let stmt = or_panic!(conn.prepare("SELECT * from foo ORDER BY id")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![1i64, 2], - result.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -fn test_error_after_datarow() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = or_panic!(conn.prepare( - " -SELECT - (SELECT generate_series(1, ss.i)) -FROM (SELECT gs.i - FROM generate_series(1, 2) gs(i) - ORDER BY gs.i - LIMIT 2) ss", - )); - match stmt.query(&[]) { - Err(ref e) if e.code() == Some(&SqlState::CARDINALITY_VIOLATION) => {} - Err(err) => panic!("Unexpected error {:?}", err), - Ok(_) => panic!("Expected failure"), - }; -} - -#[test] -fn test_lazy_query() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - - let trans = or_panic!(conn.transaction()); - or_panic!(trans.execute("CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY)", &[])); - let stmt = or_panic!(trans.prepare("INSERT INTO foo (id) VALUES ($1)")); - let values = vec![0i32, 1, 2, 3, 4, 5]; - for value in &values { - or_panic!(stmt.execute(&[value])); - } - let stmt = or_panic!(trans.prepare("SELECT id FROM foo ORDER BY id")); - let result = or_panic!(stmt.lazy_query(&trans, &[], 2)); - assert_eq!( - values, - result.map(|row| row.get(0)).collect::>().unwrap() - ); -} - -#[test] -#[should_panic(expected = "same `Connection` as")] -fn test_lazy_query_wrong_conn() { - let conn1 = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let conn2 = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - - let trans = or_panic!(conn1.transaction()); - let stmt = or_panic!(conn2.prepare("SELECT 1::INT")); - stmt.lazy_query(&trans, &[], 1).unwrap(); -} - -#[test] -fn test_param_types() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = or_panic!(conn.prepare("SELECT $1::INT, $2::VARCHAR")); - assert_eq!(stmt.param_types(), &[Type::INT4, Type::VARCHAR][..]); -} - -#[test] -fn test_columns() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = or_panic!(conn.prepare("SELECT 1::INT as a, 'hi'::VARCHAR as b")); - let cols = stmt.columns(); - assert_eq!(2, cols.len()); - assert_eq!(cols[0].name(), "a"); - assert_eq!(cols[0].type_(), &Type::INT4); - assert_eq!(cols[1].name(), "b"); - assert_eq!(cols[1].type_(), &Type::VARCHAR); -} - -#[test] -fn test_execute_counts() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - assert_eq!( - 0, - or_panic!(conn.execute( - "CREATE TEMPORARY TABLE foo ( - id SERIAL PRIMARY KEY, - b INT - )", - &[], - )) - ); - assert_eq!( - 3, - or_panic!(conn.execute( - "INSERT INTO foo (b) VALUES ($1), ($2), ($2)", - &[&1i32, &2i32], - )) - ); - assert_eq!( - 2, - or_panic!(conn.execute("UPDATE foo SET b = 0 WHERE b = 2", &[])) - ); - assert_eq!(3, or_panic!(conn.execute("SELECT * FROM foo", &[]))); -} - -#[test] -fn test_wrong_param_type() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let err = conn.execute("SELECT $1::VARCHAR", &[&1i32]).unwrap_err(); - match err.as_conversion() { - Some(e) if e.is::() => {} - _ => panic!("unexpected result {:?}", err), - } -} - -#[test] -#[should_panic(expected = "expected 2 parameters but got 1")] -fn test_too_few_params() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let _ = conn.execute("SELECT $1::INT, $2::INT", &[&1i32]); -} - -#[test] -#[should_panic(expected = "expected 2 parameters but got 3")] -fn test_too_many_params() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let _ = conn.execute("SELECT $1::INT, $2::INT", &[&1i32, &2i32, &3i32]); -} - -#[test] -fn test_index_named() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = or_panic!(conn.prepare("SELECT 10::INT as val")); - let result = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![10i32], - result - .iter() - .map(|row| row.get("val")) - .collect::>() - ); -} - -#[test] -#[should_panic] -fn test_index_named_fail() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = or_panic!(conn.prepare("SELECT 10::INT as id")); - let result = or_panic!(stmt.query(&[])); - - let _: i32 = result.iter().next().unwrap().get("asdf"); -} - -#[test] -fn test_get_named_err() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = or_panic!(conn.prepare("SELECT 10::INT as id")); - let result = or_panic!(stmt.query(&[])); - - match result.iter().next().unwrap().get_opt::<_, i32>("asdf") { - None => {} - res => panic!("unexpected result {:?}", res), - }; -} - -#[test] -fn test_get_was_null() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = or_panic!(conn.prepare("SELECT NULL::INT as id")); - let result = or_panic!(stmt.query(&[])); - - match result.iter().next().unwrap().get_opt::<_, i32>(0) { - Some(Err(ref e)) if e.as_conversion().is_some() => {} - res => panic!("unexpected result {:?}", res), - }; -} - -#[test] -fn test_get_off_by_one() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = or_panic!(conn.prepare("SELECT 10::INT as id")); - let result = or_panic!(stmt.query(&[])); - - match result.iter().next().unwrap().get_opt::<_, i32>(1) { - None => {} - res => panic!("unexpected result {:?}", res), - }; -} - -#[test] -fn test_custom_notice_handler() { - static mut COUNT: usize = 0; - struct Handler; - - impl HandleNotice for Handler { - fn handle_notice(&mut self, notice: DbError) { - assert_eq!("note", notice.message); - unsafe { - COUNT += 1; - } - } - } - - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433?client_min_messages=NOTICE", - TlsMode::None, - )); - conn.set_notice_handler(Box::new(Handler)); - or_panic!(conn.execute( - "CREATE FUNCTION pg_temp.note() RETURNS INT AS $$ - BEGIN - RAISE NOTICE 'note'; - RETURN 1; - END; $$ LANGUAGE plpgsql", - &[], - )); - or_panic!(conn.execute("SELECT pg_temp.note()", &[])); - - assert_eq!(unsafe { COUNT }, 1); -} - -#[test] -fn test_notification_iterator_none() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - assert!(conn.notifications().iter().next().unwrap().is_none()); -} - -fn check_notification(expected: Notification, actual: Notification) { - assert_eq!(&expected.channel, &actual.channel); - assert_eq!(&expected.payload, &actual.payload); -} - -#[test] -fn test_notification_iterator_some() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let notifications = conn.notifications(); - let mut it = notifications.iter(); - or_panic!(conn.execute("LISTEN test_notification_iterator_one_channel", &[])); - or_panic!(conn.execute("LISTEN test_notification_iterator_one_channel2", &[])); - or_panic!(conn.execute( - "NOTIFY test_notification_iterator_one_channel, 'hello'", - &[], - )); - or_panic!(conn.execute( - "NOTIFY test_notification_iterator_one_channel2, 'world'", - &[], - )); - - check_notification( - Notification { - process_id: 0, - channel: "test_notification_iterator_one_channel".to_string(), - payload: "hello".to_string(), - }, - it.next().unwrap().unwrap(), - ); - check_notification( - Notification { - process_id: 0, - channel: "test_notification_iterator_one_channel2".to_string(), - payload: "world".to_string(), - }, - it.next().unwrap().unwrap(), - ); - assert!(it.next().unwrap().is_none()); - - or_panic!(conn.execute("NOTIFY test_notification_iterator_one_channel, '!'", &[])); - check_notification( - Notification { - process_id: 0, - channel: "test_notification_iterator_one_channel".to_string(), - payload: "!".to_string(), - }, - it.next().unwrap().unwrap(), - ); - assert!(it.next().unwrap().is_none()); -} - -#[test] -fn test_notifications_next_block() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("LISTEN test_notifications_next_block", &[])); - - let _t = thread::spawn(|| { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - thread::sleep(Duration::from_millis(500)); - or_panic!(conn.execute("NOTIFY test_notifications_next_block, 'foo'", &[])); - }); - - let notifications = conn.notifications(); - check_notification( - Notification { - process_id: 0, - channel: "test_notifications_next_block".to_string(), - payload: "foo".to_string(), - }, - notifications.blocking_iter().next().unwrap().unwrap(), - ); -} - -#[test] -fn test_notification_next_timeout() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("LISTEN test_notifications_next_timeout", &[])); - - let _t = thread::spawn(|| { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - thread::sleep(Duration::from_millis(500)); - or_panic!(conn.execute("NOTIFY test_notifications_next_timeout, 'foo'", &[])); - thread::sleep(Duration::from_millis(1500)); - or_panic!(conn.execute("NOTIFY test_notifications_next_timeout, 'foo'", &[])); - }); - - let notifications = conn.notifications(); - let mut it = notifications.timeout_iter(Duration::from_secs(1)); - check_notification( - Notification { - process_id: 0, - channel: "test_notifications_next_timeout".to_string(), - payload: "foo".to_string(), - }, - it.next().unwrap().unwrap(), - ); - - assert!(it.next().unwrap().is_none()); -} - -#[test] -fn test_notification_disconnect() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("LISTEN test_notifications_disconnect", &[])); - - let _t = thread::spawn(|| { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - thread::sleep(Duration::from_millis(500)); - or_panic!(conn.execute( - "SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE query = 'LISTEN test_notifications_disconnect'", - &[], - )); - }); - - let notifications = conn.notifications(); - assert!(notifications.blocking_iter().next().is_err()); -} - -#[test] -// This test is pretty sad, but I don't think there's a better way :( -fn test_cancel_query() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let cancel_data = conn.cancel_data(); - - let t = thread::spawn(move || { - thread::sleep(Duration::from_millis(500)); - assert!( - postgres::cancel_query( - "postgres://postgres@localhost:5433", - TlsMode::None, - &cancel_data, - ).is_ok() - ); - }); - - match conn.execute("SELECT pg_sleep(10)", &[]) { - Err(ref e) if e.code() == Some(&SqlState::QUERY_CANCELED) => {} - Err(res) => panic!("Unexpected result {:?}", res), - _ => panic!("Unexpected result"), - } - - t.join().unwrap(); -} - -#[test] -fn test_plaintext_pass() { - or_panic!(Connection::connect( - "postgres://pass_user:password@localhost:5433/postgres", - TlsMode::None, - )); -} - -#[test] -fn test_plaintext_pass_no_pass() { - let ret = Connection::connect( - "postgres://pass_user@localhost:5433/postgres", - TlsMode::None, - ); - match ret { - Err(ref e) if e.as_connection().is_some() => (), - Err(err) => panic!("Unexpected error {:?}", err), - _ => panic!("Expected error"), - } -} - -#[test] -fn test_plaintext_pass_wrong_pass() { - let ret = Connection::connect( - "postgres://pass_user:asdf@localhost:5433/postgres", - TlsMode::None, - ); - match ret { - Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {} - Err(err) => panic!("Unexpected error {:?}", err), - _ => panic!("Expected error"), - } -} - -#[test] -fn test_md5_pass() { - or_panic!(Connection::connect( - "postgres://md5_user:password@localhost:5433/postgres", - TlsMode::None, - )); -} - -#[test] -fn test_md5_pass_no_pass() { - let ret = Connection::connect("postgres://md5_user@localhost:5433/postgres", TlsMode::None); - match ret { - Err(ref e) if e.as_connection().is_some() => (), - Err(err) => panic!("Unexpected error {:?}", err), - _ => panic!("Expected error"), - } -} - -#[test] -fn test_md5_pass_wrong_pass() { - let ret = Connection::connect( - "postgres://md5_user:asdf@localhost:5433/postgres", - TlsMode::None, - ); - match ret { - Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {} - Err(err) => panic!("Unexpected error {:?}", err), - _ => panic!("Expected error"), - } -} - -#[test] -fn test_scram_pass() { - or_panic!(Connection::connect( - "postgres://scram_user:password@localhost:5433/postgres", - TlsMode::None, - )); -} - -#[test] -fn test_scram_pass_no_pass() { - let ret = Connection::connect( - "postgres://scram_user@localhost:5433/postgres", - TlsMode::None, - ); - match ret { - Err(ref e) if e.as_connection().is_some() => (), - Err(err) => panic!("Unexpected error {:?}", err), - _ => panic!("Expected error"), - } -} - -#[test] -fn test_scram_pass_wrong_pass() { - let ret = Connection::connect( - "postgres://scram_user:asdf@localhost:5433/postgres", - TlsMode::None, - ); - match ret { - Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {} - Err(err) => panic!("Unexpected error {:?}", err), - _ => panic!("Expected error"), - } -} - -#[test] -fn test_execute_copy_from_err() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])); - let stmt = or_panic!(conn.prepare("COPY foo (id) FROM STDIN")); - - let err = stmt.execute(&[]).unwrap_err(); - match err.as_db() { - Some(err) if err.message.contains("COPY") => {} - _ => panic!("Unexpected error {:?}", err), - } - - let err = stmt.execute(&[]).unwrap_err(); - match err.as_db() { - Some(err) if err.message.contains("COPY") => {} - _ => panic!("Unexpected error {:?}", err), - } -} - -#[test] -#[allow(deprecated)] -fn test_batch_execute_copy_from_err() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])); - let err = conn.batch_execute("COPY foo (id) FROM STDIN").unwrap_err(); - match err.as_db() { - Some(err) if err.message.contains("COPY") => {} - _ => panic!("Unexpected error {:?}", err), - } -} - -#[test] -fn test_copy_io_error() { - struct ErrorReader; - - impl io::Read for ErrorReader { - fn read(&mut self, _: &mut [u8]) -> io::Result { - Err(io::Error::new(io::ErrorKind::AddrNotAvailable, "boom")) - } - } - - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])); - let stmt = or_panic!(conn.prepare("COPY foo (id) FROM STDIN")); - let err = stmt.copy_in(&[], &mut ErrorReader).unwrap_err(); - match err.as_io() { - Some(e) if e.kind() == io::ErrorKind::AddrNotAvailable => {} - _ => panic!("Unexpected error {:?}", err), - } - - or_panic!(conn.execute("SELECT 1", &[])); -} - -#[test] -fn test_copy() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])); - let stmt = or_panic!(conn.prepare("COPY foo (id) FROM STDIN")); - let mut data = &b"1\n2\n3\n5\n8\n"[..]; - assert_eq!(5, or_panic!(stmt.copy_in(&[], &mut data))); - let stmt = or_panic!(conn.prepare("SELECT id FROM foo ORDER BY id")); - assert_eq!( - vec![1i32, 2, 3, 5, 8], - stmt.query(&[]) - .unwrap() - .iter() - .map(|r| r.get(0)) - .collect::>() - ); -} - -#[test] -fn test_query_copy_out_err() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.simple_query( - " - CREATE TEMPORARY TABLE foo (id INT); - INSERT INTO foo (id) VALUES (0), (1), (2), (3)", - )); - let stmt = or_panic!(conn.prepare("COPY foo (id) TO STDOUT")); - let err = stmt.query(&[]).unwrap_err(); - match err.as_io() { - Some(e) if e.to_string().contains("COPY") => {} - _ => panic!("unexpected error {:?}", err), - }; -} - -#[test] -fn test_copy_out() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.simple_query( - " - CREATE TEMPORARY TABLE foo (id INT); - INSERT INTO foo (id) VALUES (0), (1), (2), (3)", - )); - let stmt = or_panic!(conn.prepare("COPY (SELECT id FROM foo ORDER BY id) TO STDOUT")); - let mut buf = vec![]; - let count = or_panic!(stmt.copy_out(&[], &mut buf)); - assert_eq!(count, 4); - assert_eq!(buf, b"0\n1\n2\n3\n"); - or_panic!(conn.simple_query("SELECT 1")); -} - -#[test] -fn test_copy_out_error() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.simple_query( - " - CREATE TEMPORARY TABLE foo (id INT); - INSERT INTO foo (id) VALUES (0), (1), (2), (3)", - )); - let stmt = or_panic!(conn.prepare("COPY (SELECT id FROM foo ORDER BY id) TO STDOUT (OIDS)")); - let mut buf = vec![]; - let err = stmt.copy_out(&[], &mut buf).unwrap_err(); - match err.as_db() { - Some(_) => {} - _ => panic!("unexpected error {}", err), - } -} - -#[test] -// Just make sure the impls don't infinite loop -fn test_generic_connection() { - fn f(t: &T) - where - T: GenericConnection, - { - or_panic!(t.execute("SELECT 1", &[])); - } - - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - f(&conn); - let trans = or_panic!(conn.transaction()); - f(&trans); -} - -#[test] -fn test_custom_range_element_type() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute( - "CREATE TYPE pg_temp.floatrange AS RANGE ( - subtype = float8, - subtype_diff = float8mi - )", - &[], - )); - let stmt = or_panic!(conn.prepare("SELECT $1::floatrange")); - let ty = &stmt.param_types()[0]; - assert_eq!("floatrange", ty.name()); - assert_eq!(&Kind::Range(Type::FLOAT8), ty.kind()); -} - -#[test] -fn test_prepare_cached() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute("CREATE TEMPORARY TABLE foo (id INT)", &[])); - or_panic!(conn.execute("INSERT INTO foo (id) VALUES (1), (2)", &[])); - - let stmt = or_panic!(conn.prepare_cached("SELECT id FROM foo ORDER BY id")); - assert_eq!( - vec![1, 2], - or_panic!(stmt.query(&[])) - .iter() - .map(|r| r.get(0)) - .collect::>() - ); - or_panic!(stmt.finish()); - - let stmt = or_panic!(conn.prepare_cached("SELECT id FROM foo ORDER BY id")); - assert_eq!( - vec![1, 2], - or_panic!(stmt.query(&[])) - .iter() - .map(|r| r.get(0)) - .collect::>() - ); - or_panic!(stmt.finish()); - - let stmt = or_panic!(conn.prepare_cached("SELECT id FROM foo ORDER BY id DESC")); - assert_eq!( - vec![2, 1], - or_panic!(stmt.query(&[])) - .iter() - .map(|r| r.get(0)) - .collect::>() - ); - or_panic!(stmt.finish()); -} - -#[test] -fn test_is_active() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - assert!(conn.is_active()); - let trans = or_panic!(conn.transaction()); - assert!(!conn.is_active()); - assert!(trans.is_active()); - { - let trans2 = or_panic!(trans.transaction()); - assert!(!conn.is_active()); - assert!(!trans.is_active()); - assert!(trans2.is_active()); - or_panic!(trans2.finish()); - } - assert!(!conn.is_active()); - assert!(trans.is_active()); - or_panic!(trans.finish()); - assert!(conn.is_active()); -} - -#[test] -fn test_parameter() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - assert_eq!(Some("UTF8".to_string()), conn.parameter("client_encoding")); - assert_eq!(None, conn.parameter("asdf")); -} - -#[test] -fn url_unencoded_password() { - assert!( - "postgresql://username:password%1*@localhost:5433" - .into_connect_params() - .is_err() - ) -} - -#[test] -fn url_encoded_password() { - let params = "postgresql://username%7b%7c:password%7b%7c@localhost:5433" - .into_connect_params() - .unwrap(); - assert_eq!("username{|", params.user().unwrap().name()); - assert_eq!("password{|", params.user().unwrap().password().unwrap()); -} - -#[test] -fn test_transaction_isolation_level() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - assert_eq!( - IsolationLevel::ReadCommitted, - or_panic!(conn.transaction_isolation()) - ); - or_panic!(conn.set_transaction_config( - transaction::Config::new().isolation_level(IsolationLevel::ReadUncommitted), - )); - assert_eq!( - IsolationLevel::ReadUncommitted, - or_panic!(conn.transaction_isolation()) - ); - or_panic!(conn.set_transaction_config( - transaction::Config::new().isolation_level(IsolationLevel::RepeatableRead), - )); - assert_eq!( - IsolationLevel::RepeatableRead, - or_panic!(conn.transaction_isolation()) - ); - or_panic!(conn.set_transaction_config( - transaction::Config::new().isolation_level(IsolationLevel::Serializable), - )); - assert_eq!( - IsolationLevel::Serializable, - or_panic!(conn.transaction_isolation()) - ); - or_panic!(conn.set_transaction_config( - transaction::Config::new().isolation_level(IsolationLevel::ReadCommitted), - )); - assert_eq!( - IsolationLevel::ReadCommitted, - or_panic!(conn.transaction_isolation()) - ); -} - -#[test] -fn test_rows_index() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.simple_query( - " - CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY); - INSERT INTO foo (id) VALUES (1), (2), (3); - ", - ).unwrap(); - let stmt = conn.prepare("SELECT id FROM foo ORDER BY id").unwrap(); - let rows = stmt.query(&[]).unwrap(); - assert_eq!(3, rows.len()); - assert_eq!(2i32, rows.get(1).get::<_, i32>(0)); -} - -#[test] -fn test_type_names() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - let stmt = - conn.prepare( - "SELECT t.oid, t.typname - FROM pg_catalog.pg_type t, pg_namespace n - WHERE n.oid = t.typnamespace - AND n.nspname = 'pg_catalog' - AND t.oid < 10000 - AND t.typtype != 'c'", - ).unwrap(); - for row in &stmt.query(&[]).unwrap() { - let id: Oid = row.get(0); - let name: String = row.get(1); - assert_eq!(Type::from_oid(id).unwrap().name(), name); - } -} - -#[test] -fn test_conn_query() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.simple_query( - " - CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY); - INSERT INTO foo (id) VALUES (1), (2), (3); - ", - ).unwrap(); - let ids = conn - .query("SELECT id FROM foo ORDER BY id", &[]) - .unwrap() - .iter() - .map(|r| r.get(0)) - .collect::>(); - assert_eq!(ids, [1, 2, 3]); -} - -#[test] -fn transaction_config() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - let mut config = transaction::Config::new(); - config - .isolation_level(IsolationLevel::Serializable) - .read_only(true) - .deferrable(true); - conn.set_transaction_config(&config).unwrap(); -} - -#[test] -fn transaction_config_one_setting() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.set_transaction_config(transaction::Config::new().read_only(true)) - .unwrap(); - conn.set_transaction_config(transaction::Config::new().deferrable(true)) - .unwrap(); -} - -#[test] -fn transaction_with() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - let mut config = transaction::Config::new(); - config - .isolation_level(IsolationLevel::Serializable) - .read_only(true) - .deferrable(true); - conn.transaction_with(&config).unwrap().finish().unwrap(); -} - -#[test] -fn transaction_set_config() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - let trans = conn.transaction().unwrap(); - let mut config = transaction::Config::new(); - config - .isolation_level(IsolationLevel::Serializable) - .read_only(true) - .deferrable(true); - trans.set_config(&config).unwrap(); - trans.finish().unwrap(); -} - -#[test] -fn keepalive() { - let params = "postgres://postgres@localhost:5433?keepalive=10" - .into_connect_params() - .unwrap(); - assert_eq!(params.keepalive(), Some(Duration::from_secs(10))); - - Connection::connect(params, TlsMode::None).unwrap(); -} - -#[test] -fn explicit_types() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - let stmt = conn - .prepare_typed("SELECT $1::INT4", &[Some(Type::INT8)]) - .unwrap(); - assert_eq!(stmt.param_types()[0], Type::INT8); -} - -#[test] -fn simple_query() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.simple_query( - " - CREATE TEMPORARY TABLE foo (id INT PRIMARY KEY); - INSERT INTO foo (id) VALUES (1), (2), (3); - ", - ).unwrap(); - let queries = "SELECT id FROM foo WHERE id = 1 ORDER BY id; \ - SELECT id FROM foo WHERE id != 1 ORDER BY id"; - - let results = conn.simple_query(queries).unwrap(); - assert_eq!(results[0].get(0).get("id"), "1"); - assert_eq!(results[1].get(0).get("id"), "2"); - assert_eq!(results[1].get(1).get("id"), "3"); -} diff --git a/postgres/tests/types/chrono.rs b/postgres/tests/types/chrono.rs deleted file mode 100644 index 22bd7eef3..000000000 --- a/postgres/tests/types/chrono.rs +++ /dev/null @@ -1,152 +0,0 @@ -extern crate chrono; - -use self::chrono::{TimeZone, NaiveDate, NaiveTime, NaiveDateTime, DateTime, Utc}; -use types::test_type; - -use postgres::types::{Date, Timestamp}; - -#[test] -fn test_naive_date_time_params() { - fn make_check<'a>(time: &'a str) -> (Option, &'a str) { - ( - Some( - NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), - ), - time, - ) - } - test_type( - "TIMESTAMP", - &[ - make_check("'1970-01-01 00:00:00.010000000'"), - make_check("'1965-09-25 11:19:33.100314000'"), - make_check("'2010-02-09 23:11:45.120200000'"), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_with_special_naive_date_time_params() { - fn make_check<'a>(time: &'a str) -> (Timestamp, &'a str) { - ( - Timestamp::Value( - NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), - ), - time, - ) - } - test_type( - "TIMESTAMP", - &[ - make_check("'1970-01-01 00:00:00.010000000'"), - make_check("'1965-09-25 11:19:33.100314000'"), - make_check("'2010-02-09 23:11:45.120200000'"), - (Timestamp::PosInfinity, "'infinity'"), - (Timestamp::NegInfinity, "'-infinity'"), - ], - ); -} - -#[test] -fn test_date_time_params() { - fn make_check<'a>(time: &'a str) -> (Option>, &'a str) { - ( - Some( - Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap(), - ), - time, - ) - } - test_type( - "TIMESTAMP WITH TIME ZONE", - &[ - make_check("'1970-01-01 00:00:00.010000000'"), - make_check("'1965-09-25 11:19:33.100314000'"), - make_check("'2010-02-09 23:11:45.120200000'"), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_with_special_date_time_params() { - fn make_check<'a>(time: &'a str) -> (Timestamp>, &'a str) { - ( - Timestamp::Value( - Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap(), - ), - time, - ) - } - test_type( - "TIMESTAMP WITH TIME ZONE", - &[ - make_check("'1970-01-01 00:00:00.010000000'"), - make_check("'1965-09-25 11:19:33.100314000'"), - make_check("'2010-02-09 23:11:45.120200000'"), - (Timestamp::PosInfinity, "'infinity'"), - (Timestamp::NegInfinity, "'-infinity'"), - ], - ); -} - -#[test] -fn test_date_params() { - fn make_check<'a>(time: &'a str) -> (Option, &'a str) { - ( - Some(NaiveDate::parse_from_str(time, "'%Y-%m-%d'").unwrap()), - time, - ) - } - test_type( - "DATE", - &[ - make_check("'1970-01-01'"), - make_check("'1965-09-25'"), - make_check("'2010-02-09'"), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_with_special_date_params() { - fn make_check<'a>(date: &'a str) -> (Date, &'a str) { - ( - Date::Value(NaiveDate::parse_from_str(date, "'%Y-%m-%d'").unwrap()), - date, - ) - } - test_type( - "DATE", - &[ - make_check("'1970-01-01'"), - make_check("'1965-09-25'"), - make_check("'2010-02-09'"), - (Date::PosInfinity, "'infinity'"), - (Date::NegInfinity, "'-infinity'"), - ], - ); -} - -#[test] -fn test_time_params() { - fn make_check<'a>(time: &'a str) -> (Option, &'a str) { - ( - Some(NaiveTime::parse_from_str(time, "'%H:%M:%S.%f'").unwrap()), - time, - ) - } - test_type( - "TIME", - &[ - make_check("'00:00:00.010000000'"), - make_check("'11:19:33.100314000'"), - make_check("'23:11:45.120200000'"), - (None, "NULL"), - ], - ); -} diff --git a/postgres/tests/types/eui48.rs b/postgres/tests/types/eui48.rs deleted file mode 100644 index dc77078e9..000000000 --- a/postgres/tests/types/eui48.rs +++ /dev/null @@ -1,17 +0,0 @@ -extern crate eui48; - -use types::test_type; - -#[test] -fn test_eui48_params() { - test_type( - "MACADDR", - &[ - ( - Some(eui48::MacAddress::parse_str("12-34-56-AB-CD-EF").unwrap()), - "'12-34-56-ab-cd-ef'", - ), - (None, "NULL"), - ], - ) -} diff --git a/postgres/tests/types/geo.rs b/postgres/tests/types/geo.rs deleted file mode 100644 index bcde561fc..000000000 --- a/postgres/tests/types/geo.rs +++ /dev/null @@ -1,52 +0,0 @@ -extern crate geo; - -use self::geo::{Coordinate, LineString, Point, Rect}; -use types::test_type; - -#[test] -fn test_point_params() { - test_type( - "POINT", - &[ - (Some(Point::new(0.0, 0.0)), "POINT(0, 0)"), - (Some(Point::new(-3.14, 1.618)), "POINT(-3.14, 1.618)"), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_box_params() { - test_type( - "BOX", - &[ - ( - Some(Rect { - min: Coordinate { x: -3.14, y: 1.618, }, - max: Coordinate { x: 160.0, y: 69701.5615, }, - }), - "BOX(POINT(160.0, 69701.5615), POINT(-3.14, 1.618))", - ), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_path_params() { - let points = vec![ - Point::new(0.0, 0.0), - Point::new(-3.14, 1.618), - Point::new(160.0, 69701.5615), - ]; - test_type( - "PATH", - &[ - ( - Some(LineString(points)), - "path '((0, 0), (-3.14, 1.618), (160.0, 69701.5615))'", - ), - (None, "NULL"), - ], - ); -} diff --git a/postgres/tests/types/mod.rs b/postgres/tests/types/mod.rs deleted file mode 100644 index e95945b36..000000000 --- a/postgres/tests/types/mod.rs +++ /dev/null @@ -1,528 +0,0 @@ -use std::collections::HashMap; -use std::error; -use std::f32; -use std::f64; -use std::fmt; -use std::result; -use std::time::{Duration, UNIX_EPOCH}; - -use postgres::types::{FromSql, FromSqlOwned, IsNull, Kind, ToSql, Type, WrongType}; -use postgres::{Connection, TlsMode}; - -#[cfg(feature = "with-bit-vec-0.5")] -mod bit_vec; -#[cfg(feature = "with-chrono-0.4")] -mod chrono; -#[cfg(feature = "with-eui48-0.3")] -mod eui48; -#[cfg(feature = "with-geo-0.10")] -mod geo; -#[cfg(feature = "with-serde_json-1")] -mod serde_json; -#[cfg(feature = "with-uuid-0.6")] -mod uuid; - -fn test_type(sql_type: &str, checks: &[(T, S)]) -where - T: PartialEq + for<'a> FromSqlOwned + ToSql, - S: fmt::Display, -{ - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - for &(ref val, ref repr) in checks.iter() { - let stmt = or_panic!(conn.prepare(&*format!("SELECT {}::{}", *repr, sql_type))); - let rows = or_panic!(stmt.query(&[])); - let row = rows.iter().next().unwrap(); - let result = row.get(0); - assert_eq!(val, &result); - - let stmt = or_panic!(conn.prepare(&*format!("SELECT $1::{}", sql_type))); - let rows = or_panic!(stmt.query(&[val])); - let row = rows.iter().next().unwrap(); - let result = row.get(0); - assert_eq!(val, &result); - } -} - -#[test] -fn test_ref_tosql() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = conn.prepare("SELECT $1::Int").unwrap(); - let num: &ToSql = &&7; - stmt.query(&[num]).unwrap(); -} - -#[test] -fn test_bool_params() { - test_type( - "BOOL", - &[(Some(true), "'t'"), (Some(false), "'f'"), (None, "NULL")], - ); -} - -#[test] -fn test_i8_params() { - test_type("\"char\"", &[(Some('a' as i8), "'a'"), (None, "NULL")]); -} - -#[test] -fn test_name_params() { - test_type( - "NAME", - &[ - (Some("hello world".to_owned()), "'hello world'"), - ( - Some("イロハニホヘト チリヌルヲ".to_owned()), - "'イロハニホヘト チリヌルヲ'", - ), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_i16_params() { - test_type( - "SMALLINT", - &[ - (Some(15001i16), "15001"), - (Some(-15001i16), "-15001"), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_i32_params() { - test_type( - "INT", - &[ - (Some(2147483548i32), "2147483548"), - (Some(-2147483548i32), "-2147483548"), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_oid_params() { - test_type( - "OID", - &[ - (Some(2147483548u32), "2147483548"), - (Some(4000000000), "4000000000"), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_i64_params() { - test_type( - "BIGINT", - &[ - (Some(9223372036854775708i64), "9223372036854775708"), - (Some(-9223372036854775708i64), "-9223372036854775708"), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_f32_params() { - test_type( - "REAL", - &[ - (Some(f32::INFINITY), "'infinity'"), - (Some(f32::NEG_INFINITY), "'-infinity'"), - (Some(1000.55), "1000.55"), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_f64_params() { - test_type( - "DOUBLE PRECISION", - &[ - (Some(f64::INFINITY), "'infinity'"), - (Some(f64::NEG_INFINITY), "'-infinity'"), - (Some(10000.55), "10000.55"), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_varchar_params() { - test_type( - "VARCHAR", - &[ - (Some("hello world".to_owned()), "'hello world'"), - ( - Some("イロハニホヘト チリヌルヲ".to_owned()), - "'イロハニホヘト チリヌルヲ'", - ), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_text_params() { - test_type( - "TEXT", - &[ - (Some("hello world".to_owned()), "'hello world'"), - ( - Some("イロハニホヘト チリヌルヲ".to_owned()), - "'イロハニホヘト チリヌルヲ'", - ), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_borrowed_text() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - - let rows = or_panic!(conn.query("SELECT 'foo'", &[])); - let row = rows.get(0); - let s: &str = row.get(0); - assert_eq!(s, "foo"); -} - -#[test] -fn test_bpchar_params() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute( - "CREATE TEMPORARY TABLE foo ( - id SERIAL PRIMARY KEY, - b CHAR(5) - )", - &[], - )); - or_panic!(conn.execute( - "INSERT INTO foo (b) VALUES ($1), ($2), ($3)", - &[&Some("12345"), &Some("123"), &None::<&'static str>], - )); - let stmt = or_panic!(conn.prepare("SELECT b FROM foo ORDER BY id")); - let res = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![Some("12345".to_owned()), Some("123 ".to_owned()), None], - res.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -fn test_citext_params() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - or_panic!(conn.execute( - "CREATE TEMPORARY TABLE foo ( - id SERIAL PRIMARY KEY, - b CITEXT - )", - &[], - )); - or_panic!(conn.execute( - "INSERT INTO foo (b) VALUES ($1), ($2), ($3)", - &[&Some("foobar"), &Some("FooBar"), &None::<&'static str>], - )); - let stmt = or_panic!(conn.prepare("SELECT id FROM foo WHERE b = 'FOOBAR' ORDER BY id",)); - let res = or_panic!(stmt.query(&[])); - - assert_eq!( - vec![Some(1i32), Some(2i32)], - res.iter().map(|row| row.get(0)).collect::>() - ); -} - -#[test] -fn test_bytea_params() { - test_type( - "BYTEA", - &[ - (Some(vec![0u8, 1, 2, 3, 254, 255]), "'\\x00010203feff'"), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_borrowed_bytea() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - - let rows = or_panic!(conn.query("SELECT 'foo'::BYTEA", &[])); - let row = rows.get(0); - let s: &[u8] = row.get(0); - assert_eq!(s, b"foo"); -} - -#[test] -fn test_hstore_params() { - macro_rules! make_map { - ($($k:expr => $v:expr),+) => ({ - let mut map = HashMap::new(); - $(map.insert($k, $v);)+ - map - }) - } - test_type( - "hstore", - &[ - ( - Some(make_map!("a".to_owned() => Some("1".to_owned()))), - "'a=>1'", - ), - ( - Some(make_map!("hello".to_owned() => Some("world!".to_owned()), - "hola".to_owned() => Some("mundo!".to_owned()), - "what".to_owned() => None)), - "'hello=>world!,hola=>mundo!,what=>NULL'", - ), - (None, "NULL"), - ], - ); -} - -#[test] -fn test_array_params() { - test_type( - "integer[]", - &[ - (Some(vec![1i32, 2i32]), "ARRAY[1,2]"), - (Some(vec![1i32]), "ARRAY[1]"), - (Some(vec![]), "ARRAY[]"), - (None, "NULL"), - ], - ); -} - -fn test_nan_param(sql_type: &str) -where - T: PartialEq + ToSql + FromSqlOwned, -{ - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = or_panic!(conn.prepare(&*format!("SELECT 'NaN'::{}", sql_type))); - let result = or_panic!(stmt.query(&[])); - let val: T = result.iter().next().unwrap().get(0); - assert!(val != val); -} - -#[test] -fn test_f32_nan_param() { - test_nan_param::("REAL"); -} - -#[test] -fn test_f64_nan_param() { - test_nan_param::("DOUBLE PRECISION"); -} - -#[test] -fn test_pg_database_datname() { - let conn = or_panic!(Connection::connect( - "postgres://postgres@localhost:5433", - TlsMode::None, - )); - let stmt = or_panic!(conn.prepare("SELECT datname FROM pg_database")); - let result = or_panic!(stmt.query(&[])); - - let next = result.iter().next().unwrap(); - or_panic!(next.get_opt::<_, String>(0).unwrap()); - or_panic!(next.get_opt::<_, String>("datname").unwrap()); -} - -#[test] -fn test_slice() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.simple_query( - "CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY, f VARCHAR); - INSERT INTO foo (f) VALUES ('a'), ('b'), ('c'), ('d');", - ).unwrap(); - - let stmt = conn.prepare("SELECT f FROM foo WHERE id = ANY($1)") - .unwrap(); - let result = stmt.query(&[&&[1i32, 3, 4][..]]).unwrap(); - assert_eq!( - vec!["a".to_owned(), "c".to_owned(), "d".to_owned()], - result - .iter() - .map(|r| r.get::<_, String>(0)) - .collect::>() - ); -} - -#[test] -fn test_slice_wrong_type() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.simple_query("CREATE TEMPORARY TABLE foo (id SERIAL PRIMARY KEY)") - .unwrap(); - - let stmt = conn.prepare("SELECT * FROM foo WHERE id = ANY($1)") - .unwrap(); - let err = stmt.query(&[&&["hi"][..]]).unwrap_err(); - match err.as_conversion() { - Some(e) if e.is::() => {} - _ => panic!("Unexpected error {:?}", err), - }; -} - -#[test] -fn test_slice_range() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - - let stmt = conn.prepare("SELECT $1::INT8RANGE").unwrap(); - let err = stmt.query(&[&&[1i64][..]]).unwrap_err(); - match err.as_conversion() { - Some(e) if e.is::() => {} - _ => panic!("Unexpected error {:?}", err), - }; -} - -#[test] -fn domain() { - #[derive(Debug, PartialEq)] - struct SessionId(Vec); - - impl ToSql for SessionId { - fn to_sql( - &self, - ty: &Type, - out: &mut Vec, - ) -> result::Result> { - let inner = match *ty.kind() { - Kind::Domain(ref inner) => inner, - _ => unreachable!(), - }; - self.0.to_sql(inner, out) - } - - fn accepts(ty: &Type) -> bool { - ty.name() == "session_id" && match *ty.kind() { - Kind::Domain(_) => true, - _ => false, - } - } - - to_sql_checked!(); - } - - impl<'a> FromSql<'a> for SessionId { - fn from_sql( - ty: &Type, - raw: &[u8], - ) -> result::Result> { - Vec::::from_sql(ty, raw).map(SessionId) - } - - fn accepts(ty: &Type) -> bool { - // This is super weird! - as FromSql>::accepts(ty) - } - } - - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.simple_query( - "CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16); - CREATE TABLE pg_temp.foo (id pg_temp.session_id);", - ).unwrap(); - - let id = SessionId(b"0123456789abcdef".to_vec()); - conn.execute("INSERT INTO pg_temp.foo (id) VALUES ($1)", &[&id]) - .unwrap(); - let rows = conn.query("SELECT id FROM pg_temp.foo", &[]).unwrap(); - assert_eq!(id, rows.get(0).get(0)); -} - -#[test] -fn composite() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.simple_query( - "CREATE TYPE pg_temp.inventory_item AS ( - name TEXT, - supplier INTEGER, - price NUMERIC - )", - ).unwrap(); - - let stmt = conn.prepare("SELECT $1::inventory_item").unwrap(); - let type_ = &stmt.param_types()[0]; - assert_eq!(type_.name(), "inventory_item"); - match *type_.kind() { - Kind::Composite(ref fields) => { - assert_eq!(fields[0].name(), "name"); - assert_eq!(fields[0].type_(), &Type::TEXT); - assert_eq!(fields[1].name(), "supplier"); - assert_eq!(fields[1].type_(), &Type::INT4); - assert_eq!(fields[2].name(), "price"); - assert_eq!(fields[2].type_(), &Type::NUMERIC); - } - ref t => panic!("bad type {:?}", t), - } -} - -#[test] -fn enum_() { - let conn = Connection::connect("postgres://postgres@localhost:5433", TlsMode::None).unwrap(); - conn.simple_query("CREATE TYPE pg_temp.mood AS ENUM ('sad', 'ok', 'happy');") - .unwrap(); - - let stmt = conn.prepare("SELECT $1::mood").unwrap(); - let type_ = &stmt.param_types()[0]; - assert_eq!(type_.name(), "mood"); - match *type_.kind() { - Kind::Enum(ref variants) => { - assert_eq!( - variants, - &["sad".to_owned(), "ok".to_owned(), "happy".to_owned()] - ); - } - _ => panic!("bad type"), - } -} - -#[test] -fn system_time() { - test_type( - "TIMESTAMP", - &[ - ( - Some(UNIX_EPOCH + Duration::from_millis(1_010)), - "'1970-01-01 00:00:01.01'", - ), - ( - Some(UNIX_EPOCH - Duration::from_millis(1_010)), - "'1969-12-31 23:59:58.99'", - ), - ( - Some(UNIX_EPOCH + Duration::from_millis(946684800 * 1000 + 1_010)), - "'2000-01-01 00:00:01.01'", - ), - (None, "NULL"), - ], - ); -} diff --git a/postgres/tests/types/serde_json.rs b/postgres/tests/types/serde_json.rs deleted file mode 100644 index bf62a8cbf..000000000 --- a/postgres/tests/types/serde_json.rs +++ /dev/null @@ -1,40 +0,0 @@ -extern crate serde_json; - -use self::serde_json::Value; -use types::test_type; - -#[test] -fn test_json_params() { - test_type( - "JSON", - &[ - ( - Some(serde_json::from_str::("[10, 11, 12]").unwrap()), - "'[10, 11, 12]'", - ), - ( - Some(serde_json::from_str::("{\"f\": \"asd\"}").unwrap()), - "'{\"f\": \"asd\"}'", - ), - (None, "NULL"), - ], - ) -} - -#[test] -fn test_jsonb_params() { - test_type( - "JSONB", - &[ - ( - Some(serde_json::from_str::("[10, 11, 12]").unwrap()), - "'[10, 11, 12]'", - ), - ( - Some(serde_json::from_str::("{\"f\": \"asd\"}").unwrap()), - "'{\"f\": \"asd\"}'", - ), - (None, "NULL"), - ], - ) -} diff --git a/postgres/tests/types/uuid.rs b/postgres/tests/types/uuid.rs deleted file mode 100644 index fd03ca828..000000000 --- a/postgres/tests/types/uuid.rs +++ /dev/null @@ -1,19 +0,0 @@ -extern crate uuid; - -use types::test_type; - -#[test] -fn test_uuid_params() { - test_type( - "UUID", - &[ - ( - Some( - uuid::Uuid::parse_str("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11").unwrap(), - ), - "'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'", - ), - (None, "NULL"), - ], - ) -} diff --git a/tokio-postgres-openssl/Cargo.toml b/tokio-postgres-openssl/Cargo.toml deleted file mode 100644 index 955349200..000000000 --- a/tokio-postgres-openssl/Cargo.toml +++ /dev/null @@ -1,15 +0,0 @@ -[package] -name = "tokio-postgres-openssl" -version = "0.1.0" -authors = ["Steven Fackler "] - -[dependencies] -bytes = "0.4" -futures = "0.1" -openssl = "0.10" -tokio-io = "0.1" -tokio-openssl = "0.2" -tokio-postgres = { version = "0.3", path = "../tokio-postgres" } - -[dev-dependencies] -tokio = "0.1.7" diff --git a/tokio-postgres-openssl/src/lib.rs b/tokio-postgres-openssl/src/lib.rs deleted file mode 100644 index 3a77218de..000000000 --- a/tokio-postgres-openssl/src/lib.rs +++ /dev/null @@ -1,141 +0,0 @@ -extern crate bytes; -extern crate futures; -extern crate openssl; -extern crate tokio_io; -extern crate tokio_openssl; -extern crate tokio_postgres; - -#[cfg(test)] -extern crate tokio; - -use bytes::{Buf, BufMut}; -use futures::{Future, IntoFuture, Poll}; -use openssl::error::ErrorStack; -use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod, SslRef}; -use std::error::Error; -use std::io::{self, Read, Write}; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_openssl::ConnectConfigurationExt; -use tokio_postgres::tls::{Socket, TlsConnect, TlsStream}; - -#[cfg(test)] -mod test; - -pub struct TlsConnector { - connector: SslConnector, - callback: Box Result<(), ErrorStack> + Sync + Send>, -} - -impl TlsConnector { - pub fn new() -> Result { - let connector = SslConnector::builder(SslMethod::tls())?.build(); - Ok(TlsConnector::with_connector(connector)) - } - - pub fn with_connector(connector: SslConnector) -> TlsConnector { - TlsConnector { - connector, - callback: Box::new(|_| Ok(())), - } - } - - pub fn set_callback(&mut self, f: F) - where - F: Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + 'static + Sync + Send, - { - self.callback = Box::new(f); - } -} - -impl TlsConnect for TlsConnector { - fn connect( - &self, - domain: &str, - socket: Socket, - ) -> Box, Error = Box> + Sync + Send> { - let f = self - .connector - .configure() - .and_then(|mut ssl| (self.callback)(&mut ssl).map(|_| ssl)) - .map_err(|e| { - let e: Box = Box::new(e); - e - }) - .into_future() - .and_then({ - let domain = domain.to_string(); - move |ssl| { - ssl.connect_async(&domain, socket) - .map(|s| { - let s: Box = Box::new(SslStream(s)); - s - }) - .map_err(|e| { - let e: Box = Box::new(e); - e - }) - } - }); - Box::new(f) - } -} - -struct SslStream(tokio_openssl::SslStream); - -impl Read for SslStream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) - } -} - -impl AsyncRead for SslStream { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.0.prepare_uninitialized_buffer(buf) - } - - fn read_buf(&mut self, buf: &mut B) -> Poll - where - B: BufMut, - { - self.0.read_buf(buf) - } -} - -impl Write for SslStream { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.0.flush() - } -} - -impl AsyncWrite for SslStream { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.0.shutdown() - } - - fn write_buf(&mut self, buf: &mut B) -> Poll - where - B: Buf, - { - self.0.write_buf(buf) - } -} - -impl TlsStream for SslStream { - fn tls_unique(&self) -> Option> { - let f = if self.0.get_ref().ssl().session_reused() { - SslRef::peer_finished - } else { - SslRef::finished - }; - - let len = f(self.0.get_ref().ssl(), &mut []); - let mut buf = vec![0; len]; - f(self.0.get_ref().ssl(), &mut buf); - - Some(buf) - } -} diff --git a/tokio-postgres-openssl/src/test.rs b/tokio-postgres-openssl/src/test.rs deleted file mode 100644 index f5999d148..000000000 --- a/tokio-postgres-openssl/src/test.rs +++ /dev/null @@ -1,60 +0,0 @@ -use futures::{Future, Stream}; -use openssl::ssl::{SslConnector, SslMethod}; -use tokio::runtime::current_thread::Runtime; -use tokio_postgres::{self, TlsMode}; - -use TlsConnector; - -fn smoke_test(url: &str, tls: TlsMode) { - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect(url.parse().unwrap(), tls); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - let prepare = client.prepare("SELECT 1::INT4"); - let statement = runtime.block_on(prepare).unwrap(); - let select = client.query(&statement, &[]).collect().map(|rows| { - assert_eq!(rows.len(), 1); - assert_eq!(rows[0].get::<_, i32>(0), 1); - }); - runtime.block_on(select).unwrap(); - - drop(statement); - drop(client); - runtime.run().unwrap(); -} - -#[test] -fn require() { - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_ca_file("../test/server.crt").unwrap(); - let connector = TlsConnector::with_connector(builder.build()); - smoke_test( - "postgres://ssl_user@localhost:5433/postgres", - TlsMode::Require(Box::new(connector)), - ); -} - -#[test] -fn prefer() { - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_ca_file("../test/server.crt").unwrap(); - let connector = TlsConnector::with_connector(builder.build()); - smoke_test( - "postgres://ssl_user@localhost:5433/postgres", - TlsMode::Prefer(Box::new(connector)), - ); -} - -#[test] -fn scram_user() { - let mut builder = SslConnector::builder(SslMethod::tls()).unwrap(); - builder.set_ca_file("../test/server.crt").unwrap(); - let connector = TlsConnector::with_connector(builder.build()); - smoke_test( - "postgres://scram_user:password@localhost:5433/postgres", - TlsMode::Require(Box::new(connector)), - ); -} diff --git a/tokio-postgres/CHANGELOG.md b/tokio-postgres/CHANGELOG.md new file mode 100644 index 000000000..a67f69ea7 --- /dev/null +++ b/tokio-postgres/CHANGELOG.md @@ -0,0 +1,314 @@ +# Change Log + +## Unreleased + +## v0.7.13 - 2025-02-02 + +### Added + +* Added support for direct TLS negotiation. +* Added support for `cidr` 0.3 via the `with-cidr-0_3` feature. + +### Fixes + +* Added `load_balance_hosts` to `Config`'s `Debug` implementation. + +### Changes + +* Upgraded `rand`. + +## v0.7.12 - 2024-09-15 + +### Fixed + +* Fixed `query_typed` queries that return no rows. + +### Added + +* Added support for `jiff` 0.1 via the `with-jiff-01` feature. +* Added support for TCP keepalive on AIX. + +## v0.7.11 - 2024-07-21 + +### Fixed + +* Fixed handling of non-UTF8 error fields which can be sent after failed handshakes. +* Fixed cancellation handling of `TransactionBuilder::start` futures. + +### Added + +* Added `table_oid` and `field_id` fields to `Columns` struct of prepared statements. +* Added `GenericClient::simple_query`. +* Added `#[track_caller]` to `Row::get` and `SimpleQueryRow::get`. +* Added `TargetSessionAttrs::ReadOnly`. +* Added `Debug` implementation for `Statement`. +* Added `Clone` implementation for `Row`. +* Added `SimpleQueryMessage::RowDescription`. +* Added `{Client, Transaction, GenericClient}::query_typed`. + +### Changed + +* Disable `rustc-serialize` compatibility of `eui48-1` dependency +* Config setters now take `impl Into`. + +## v0.7.10 - 2023-08-25 + +## Fixed + +* Defered default username lookup to avoid regressing `Config` behavior. + +## v0.7.9 - 2023-08-19 + +## Fixed + +* Fixed builds on OpenBSD. + +## Added + +* Added the `js` feature for WASM support. +* Added support for the `hostaddr` config option to bypass DNS lookups. +* Added support for the `load_balance_hosts` config option to randomize connection ordering. +* The `user` config option now defaults to the executing process's user. + +## v0.7.8 - 2023-05-27 + +## Added + +* Added `keepalives_interval` and `keepalives_retries` config options. +* Added new `SqlState` variants. +* Added more `Debug` impls. +* Added `GenericClient::batch_execute`. +* Added `RowStream::rows_affected`. +* Added the `tcp_user_timeout` config option. + +## Changed + +* Passing an incorrect number of parameters to a query method now returns an error instead of panicking. +* Upgraded `socket2`. + +## v0.7.7 - 2022-08-21 + +## Added + +* Added `ToSql` and `FromSql` implementations for `[u8; N]` via the `array-impls` feature. +* Added support for `smol_str` 0.1 via the `with-smol_str-01` feature. +* Added `ToSql::encode_format` to support text encodings of parameters. + +## v0.7.6 - 2022-04-30 + +### Added + +* Added support for `uuid` 1.0 via the `with-uuid-1` feature. + +### Changed + +* Upgraded to `tokio-util` 0.7. +* Upgraded to `parking_lot` 0.12. + +## v0.7.5 - 2021-10-29 + +### Fixed + +* Fixed a bug where the client could enter into a transaction if the `Client::transaction` future was dropped before completion. + +## v0.7.4 - 2021-10-19 + +### Fixed + +* Fixed reporting of commit-time errors triggered by deferred constraints. + +## v0.7.3 - 2021-09-29 + +### Fixed + +* Fixed a deadlock when pipelined requests concurrently prepare cached typeinfo queries. + +### Added + +* Added `SimpleQueryRow::columns`. +* Added support for `eui48` 1.0 via the `with-eui48-1` feature. +* Added `FromSql` and `ToSql` implementations for arrays via the `array-impls` feature. +* Added support for `time` 0.3 via the `with-time-0_3` feature. + +## v0.7.2 - 2021-04-25 + +### Fixed + +* `SqlState` constants can now be used in `match` patterns. + +## v0.7.1 - 2021-04-03 + +### Added + +* Added support for `geo-types` 0.7 via `with-geo-types-0_7` feature. +* Added `Client::clear_type_cache`. +* Added `Error::as_db_error` and `Error::is_closed`. + +## v0.7.0 - 2020-12-25 + +### Changed + +* Upgraded to `tokio` 1.0. +* Upgraded to `postgres-types` 0.2. + +### Added + +* Methods taking iterators of `ToSql` values can now take both `&dyn ToSql` and `T: ToSql` values. + +## v0.6.0 - 2020-10-17 + +### Changed + +* Upgraded to `tokio` 0.3. +* Added the detail and hint fields to `DbError`'s `Display` implementation. + +## v0.5.5 - 2020-07-03 + +### Added + +* Added support for `geo-types` 0.6. + +## v0.5.4 - 2020-05-01 + +### Added + +* Added `Transaction::savepoint`, which can be used to create a savepoint with a custom name. + +## v0.5.3 - 2020-03-05 + +### Added + +* Added `Debug` implementations for `Client`, `Row`, and `Column`. +* Added `time` 0.2 support. + +## v0.5.2 - 2020-01-31 + +### Fixed + +* Notice messages sent during the initial connection process are now collected and returned first from + `Connection::poll_message`. + +### Deprecated + +* Deprecated `Client::cancel_query` and `Client::cancel_query_raw` in favor of `Client::cancel_token`. + +### Added + +* Added `Client::build_transaction` to allow configuration of various transaction options. +* Added `Client::cancel_token`, which returns a separate owned object that can be used to cancel queries. +* Added accessors for `Config` fields. +* Added a `GenericClient` trait implemented for `Client` and `Transaction` and covering shared functionality. + +## v0.5.1 - 2019-12-25 + +### Fixed + +* Removed some stray `println!`s from `copy_out` internals. + +## v0.5.0 - 2019-12-23 + +### Changed + +* `Client::copy_in` now returns a `Sink` rather than taking in a `Stream`. +* `CopyStream` has been renamed to `CopyOutStream`. +* `Client::copy_in` and `Client::copy_out` no longer take query parameters as PostgreSQL doesn't support parameters in + COPY queries. +* `TargetSessionAttrs`, `SslMode`, and `ChannelBinding` are now true non-exhaustive enums. + +### Added + +* Added `Client::query_opt` for queries expected to return zero or one rows. +* Added binary copy format support to the `binary_copy` module. +* Added back query logging. + +### Removed + +* Removed `uuid` 0.7 support. + +## v0.5.0-alpha.2 - 2019-11-27 + +### Changed + +* Upgraded `bytes` to 0.5. +* Upgraded `tokio` to 0.2. +* The TLS interface uses a trait to obtain channel binding information rather than returning it after the handshake. +* Changed the value of the `timezone` property from `GMT` to `UTC`. +* Returned `Stream` implementations are now `!Unpin`. + +### Added + +* Added support for `uuid` 0.8. +* Added the column to `Row::try_get` errors. + +## v0.5.0-alpha.1 - 2019-10-14 + +### Changed + +* The library now uses `std::futures::Future` and async/await syntax. +* Most methods now take `&self` rather than `&mut self`. +* The transaction API has changed to more closely resemble the synchronous API and is significantly more ergonomic. +* Methods now take `&[&(dyn ToSql + Sync)]` rather than `&[&dyn ToSql]` to allow futures to be `Send`. +* Methods are now "normal" async functions that no longer do work up-front. +* Statements are no longer required to be prepared explicitly before use. Methods taking `&Statement` can now also take + `&str`, and will internally prepare the statement. +* `ToSql` now serializes its value into a `BytesMut` rather than `Vec`. +* Methods that previously returned `Stream`s now return `Vec`. New `*_raw` methods still provide a `Stream` + interface. + +### Added + +* Added the `channel_binding=disable/allow/require` configuration to control use of channel binding. +* Added the `Client::query_one` method to cover the common case of a query that returns exactly one row. + +## v0.4.0-rc.3 - 2019-06-29 + +### Fixed + +* Significantly improved the performance of `query` and `copy_in`. + +### Changed + +* The items of the stream passed to `copy_in` must be `'static`. + +## v0.4.0-rc.2 - 2019-03-05 + +### Fixed + +* Fixed Cargo features to actually enable the functionality they claim to. + +## v0.4.0-rc.1 - 2019-03-05 + +### Changed + +* The client API has been significantly overhauled. It now resembles `hyper`'s, with separate `Connection` and `Client` + objects. See the crate-level documentation for more details. +* The TLS connection mode (e.g. `prefer`) is now part of the connection configuration rather than being passed in + separately. +* The Cargo features enabling `ToSql` and `FromSql` implementations for external crates are now versioned. For example, + `with-uuid` is now `with-uuid-0_7`. This enables us to add support for new major versions of the crates in parallel + without breaking backwards compatibility. +* Upgraded from `tokio-core` to `tokio`. + +### Added + +* Connection string configuration now more fully mirrors libpq's syntax, and supports both URL-style and key-value style + strings. +* `FromSql` implementations can now borrow from the data buffer. In particular, this means that you can deserialize + values as `&str`. The `FromSqlOwned` trait can be used as a bound to restrict code to deserializing owned values. +* Added support for channel binding with SCRAM authentication. +* Added multi-host support in connection configuration. +* The client now supports query pipelining, which can be used as a latency hiding measure. +* While the crate uses `tokio` by default, the base API can be used with any asynchronous stream type on any reactor. +* Added support for simple query requests returning row data. + +### Removed + +* The `with-openssl` feature has been removed. Use the `tokio-postgres-openssl` crate instead. +* The `with-rustc_serialize` and `with-time` features have been removed. Use `serde` and `SystemTime` or `chrono` + instead. + +## Older + +Look at the [release tags] for information about older releases. + +[release tags]: https://github.com/sfackler/rust-postgres/releases diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index f41b69051..f969ae5b7 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -1,56 +1,96 @@ [package] name = "tokio-postgres" -version = "0.3.0" +version = "0.7.13" authors = ["Steven Fackler "] -license = "MIT" -description = "A native PostgreSQL driver using Tokio" +edition = "2018" +license = "MIT OR Apache-2.0" +description = "A native, asynchronous PostgreSQL client" repository = "https://github.com/sfackler/rust-postgres" readme = "../README.md" keywords = ["database", "postgres", "postgresql", "sql", "async"] categories = ["database"] +[lib] +test = false + +[[bench]] +name = "bench" +harness = false + [package.metadata.docs.rs] -features = [ - "with-bit-vec-0.5", - "with-chrono-0.4", - "with-eui48-0.3", - "with-geo-0.10", - "with-serde_json-1", - "with-uuid-0.6", - "with-openssl", -] +all-features = true [badges] circle-ci = { repository = "sfackler/rust-postgres" } [features] -"with-bit-vec-0.5" = ["postgres-shared/with-bit-vec-0.5"] -"with-chrono-0.4" = ["postgres-shared/with-chrono-0.4"] -"with-eui48-0.3" = ["postgres-shared/with-eui48-0.3"] -"with-geo-0.10" = ["postgres-shared/with-geo-0.10"] -"with-serde_json-1" = ["postgres-shared/with-serde_json-1"] -"with-uuid-0.6" = ["postgres-shared/with-uuid-0.6"] +default = ["runtime"] +runtime = ["tokio/net", "tokio/time"] + +array-impls = ["postgres-types/array-impls"] +with-bit-vec-0_6 = ["postgres-types/with-bit-vec-0_6"] +with-chrono-0_4 = ["postgres-types/with-chrono-0_4"] +with-cidr-0_2 = ["postgres-types/with-cidr-0_2"] +with-cidr-0_3 = ["postgres-types/with-cidr-0_3"] +with-eui48-0_4 = ["postgres-types/with-eui48-0_4"] +with-eui48-1 = ["postgres-types/with-eui48-1"] +with-geo-types-0_6 = ["postgres-types/with-geo-types-0_6"] +with-geo-types-0_7 = ["postgres-types/with-geo-types-0_7"] +with-jiff-0_1 = ["postgres-types/with-jiff-0_1"] +with-jiff-0_2 = ["postgres-types/with-jiff-0_2"] +with-serde_json-1 = ["postgres-types/with-serde_json-1"] +with-smol_str-01 = ["postgres-types/with-smol_str-01"] +with-uuid-0_8 = ["postgres-types/with-uuid-0_8"] +with-uuid-1 = ["postgres-types/with-uuid-1"] +with-time-0_2 = ["postgres-types/with-time-0_2"] +with-time-0_3 = ["postgres-types/with-time-0_3"] +js = ["postgres-protocol/js", "postgres-types/js"] [dependencies] -antidote = "1.0" -bytes = "0.4" -fallible-iterator = "0.1.3" -futures = "0.1.7" -futures-cpupool = "0.1" -lazy_static = "1.0" +async-trait = "0.1" +bytes = "1.0" +byteorder = "1.0" +fallible-iterator = "0.2" +futures-channel = { version = "0.3", features = ["sink"] } +futures-util = { version = "0.3", features = ["sink"] } log = "0.4" -phf = "=0.7.22" -postgres-protocol = { version = "0.3.0", path = "../postgres-protocol" } -postgres-shared = { version = "0.4.0", path = "../postgres-shared" } -state_machine_future = "0.1.7" -tokio-codec = "0.1" -tokio-io = "0.1" -tokio-tcp = "0.1" -tokio-timer = "0.2" - -[target.'cfg(unix)'.dependencies] -tokio-uds = "0.2.1" +parking_lot = "0.12" +percent-encoding = "2.0" +pin-project-lite = "0.2" +phf = "0.11" +postgres-protocol = { version = "0.6.8", path = "../postgres-protocol" } +postgres-types = { version = "0.2.9", path = "../postgres-types" } +tokio = { version = "1.27", features = ["io-util"] } +tokio-util = { version = "0.7", features = ["codec"] } +rand = "0.9.0" +whoami = "1.4.1" + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +socket2 = { version = "0.5", features = ["all"] } [dev-dependencies] -tokio = "0.1.7" -env_logger = "0.5" +futures-executor = "0.3" +criterion = "0.6" +env_logger = "0.11" +tokio = { version = "1.0", features = [ + "macros", + "net", + "rt", + "rt-multi-thread", + "time", +] } + +bit-vec-06 = { version = "0.6", package = "bit-vec" } +chrono-04 = { version = "0.4", package = "chrono", default-features = false } +eui48-1 = { version = "1.0", package = "eui48", default-features = false } +geo-types-06 = { version = "0.6", package = "geo-types" } +geo-types-07 = { version = "0.7", package = "geo-types" } +jiff-01 = { version = "0.1", package = "jiff" } +jiff-02 = { version = "0.2", package = "jiff" } +serde-1 = { version = "1.0", package = "serde" } +serde_json-1 = { version = "1.0", package = "serde_json" } +smol_str-01 = { version = "0.1", package = "smol_str" } +uuid-08 = { version = "0.8", package = "uuid" } +uuid-1 = { version = "1.0", package = "uuid" } +time-02 = { version = "0.2", package = "time" } +time-03 = { version = "0.3", package = "time", features = ["parsing"] } diff --git a/tokio-postgres/LICENSE-APACHE b/tokio-postgres/LICENSE-APACHE new file mode 120000 index 000000000..965b606f3 --- /dev/null +++ b/tokio-postgres/LICENSE-APACHE @@ -0,0 +1 @@ +../LICENSE-APACHE \ No newline at end of file diff --git a/tokio-postgres/LICENSE-MIT b/tokio-postgres/LICENSE-MIT new file mode 120000 index 000000000..76219eb72 --- /dev/null +++ b/tokio-postgres/LICENSE-MIT @@ -0,0 +1 @@ +../LICENSE-MIT \ No newline at end of file diff --git a/tokio-postgres/benches/bench.rs b/tokio-postgres/benches/bench.rs new file mode 100644 index 000000000..a8f9b5f1a --- /dev/null +++ b/tokio-postgres/benches/bench.rs @@ -0,0 +1,58 @@ +use criterion::{criterion_group, criterion_main, Criterion}; +use futures_channel::oneshot; +use std::sync::Arc; +use std::time::Instant; +use tokio::runtime::Runtime; +use tokio_postgres::{Client, NoTls}; + +fn setup() -> (Client, Runtime) { + let runtime = Runtime::new().unwrap(); + let (client, conn) = runtime + .block_on(tokio_postgres::connect( + "host=localhost port=5433 user=postgres", + NoTls, + )) + .unwrap(); + runtime.spawn(async { conn.await.unwrap() }); + (client, runtime) +} + +fn query_prepared(c: &mut Criterion) { + let (client, runtime) = setup(); + let statement = runtime.block_on(client.prepare("SELECT $1::INT8")).unwrap(); + c.bench_function("runtime_block_on", move |b| { + b.iter(|| { + runtime + .block_on(client.query(&statement, &[&1i64])) + .unwrap() + }) + }); + + let (client, runtime) = setup(); + let statement = runtime.block_on(client.prepare("SELECT $1::INT8")).unwrap(); + c.bench_function("executor_block_on", move |b| { + b.iter(|| futures_executor::block_on(client.query(&statement, &[&1i64])).unwrap()) + }); + + let (client, runtime) = setup(); + let client = Arc::new(client); + let statement = runtime.block_on(client.prepare("SELECT $1::INT8")).unwrap(); + c.bench_function("spawned", move |b| { + b.iter_custom(|iters| { + let (tx, rx) = oneshot::channel(); + let client = client.clone(); + let statement = statement.clone(); + runtime.spawn(async move { + let start = Instant::now(); + for _ in 0..iters { + client.query(&statement, &[&1i64]).await.unwrap(); + } + tx.send(start.elapsed()).unwrap(); + }); + futures_executor::block_on(rx).unwrap() + }) + }); +} + +criterion_group!(benches, query_prepared); +criterion_main!(benches); diff --git a/tokio-postgres/src/binary_copy.rs b/tokio-postgres/src/binary_copy.rs new file mode 100644 index 000000000..dab141663 --- /dev/null +++ b/tokio-postgres/src/binary_copy.rs @@ -0,0 +1,273 @@ +//! Utilities for working with the PostgreSQL binary copy format. + +use crate::types::{FromSql, IsNull, ToSql, Type, WrongType}; +use crate::{slice_iter, CopyInSink, CopyOutStream, Error}; +use byteorder::{BigEndian, ByteOrder}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use futures_util::{ready, SinkExt, Stream}; +use pin_project_lite::pin_project; +use postgres_types::BorrowToSql; +use std::convert::TryFrom; +use std::io; +use std::io::Cursor; +use std::ops::Range; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0"; +const HEADER_LEN: usize = MAGIC.len() + 4 + 4; + +pin_project! { + /// A type which serializes rows into the PostgreSQL binary copy format. + /// + /// The copy *must* be explicitly completed via the `finish` method. If it is not, the copy will be aborted. + pub struct BinaryCopyInWriter { + #[pin] + sink: CopyInSink, + types: Vec, + buf: BytesMut, + } +} + +impl BinaryCopyInWriter { + /// Creates a new writer which will write rows of the provided types to the provided sink. + pub fn new(sink: CopyInSink, types: &[Type]) -> BinaryCopyInWriter { + let mut buf = BytesMut::new(); + buf.put_slice(MAGIC); + buf.put_i32(0); // flags + buf.put_i32(0); // header extension + + BinaryCopyInWriter { + sink, + types: types.to_vec(), + buf, + } + } + + /// Writes a single row. + /// + /// # Panics + /// + /// Panics if the number of values provided does not match the number expected. + pub async fn write(self: Pin<&mut Self>, values: &[&(dyn ToSql + Sync)]) -> Result<(), Error> { + self.write_raw(slice_iter(values)).await + } + + /// A maximally-flexible version of `write`. + /// + /// # Panics + /// + /// Panics if the number of values provided does not match the number expected. + pub async fn write_raw(self: Pin<&mut Self>, values: I) -> Result<(), Error> + where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let mut this = self.project(); + + let values = values.into_iter(); + assert!( + values.len() == this.types.len(), + "expected {} values but got {}", + this.types.len(), + values.len(), + ); + + this.buf.put_i16(this.types.len() as i16); + + for (i, (value, type_)) in values.zip(this.types).enumerate() { + let idx = this.buf.len(); + this.buf.put_i32(0); + let len = match value + .borrow_to_sql() + .to_sql_checked(type_, this.buf) + .map_err(|e| Error::to_sql(e, i))? + { + IsNull::Yes => -1, + IsNull::No => i32::try_from(this.buf.len() - idx - 4) + .map_err(|e| Error::encode(io::Error::new(io::ErrorKind::InvalidInput, e)))?, + }; + BigEndian::write_i32(&mut this.buf[idx..], len); + } + + if this.buf.len() > 4096 { + this.sink.send(this.buf.split().freeze()).await?; + } + + Ok(()) + } + + /// Completes the copy, returning the number of rows added. + /// + /// This method *must* be used to complete the copy process. If it is not, the copy will be aborted. + pub async fn finish(self: Pin<&mut Self>) -> Result { + let mut this = self.project(); + + this.buf.put_i16(-1); + this.sink.send(this.buf.split().freeze()).await?; + this.sink.finish().await + } +} + +struct Header { + has_oids: bool, +} + +pin_project! { + /// A stream of rows deserialized from the PostgreSQL binary copy format. + pub struct BinaryCopyOutStream { + #[pin] + stream: CopyOutStream, + types: Arc>, + header: Option
, + } +} + +impl BinaryCopyOutStream { + /// Creates a stream from a raw copy out stream and the types of the columns being returned. + pub fn new(stream: CopyOutStream, types: &[Type]) -> BinaryCopyOutStream { + BinaryCopyOutStream { + stream, + types: Arc::new(types.to_vec()), + header: None, + } + } +} + +impl Stream for BinaryCopyOutStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + let chunk = match ready!(this.stream.poll_next(cx)) { + Some(Ok(chunk)) => chunk, + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => return Poll::Ready(Some(Err(Error::closed()))), + }; + let mut chunk = Cursor::new(chunk); + + let has_oids = match &this.header { + Some(header) => header.has_oids, + None => { + check_remaining(&chunk, HEADER_LEN)?; + if !chunk.chunk().starts_with(MAGIC) { + return Poll::Ready(Some(Err(Error::parse(io::Error::new( + io::ErrorKind::InvalidData, + "invalid magic value", + ))))); + } + chunk.advance(MAGIC.len()); + + let flags = chunk.get_i32(); + let has_oids = (flags & (1 << 16)) != 0; + + let header_extension = chunk.get_u32() as usize; + check_remaining(&chunk, header_extension)?; + chunk.advance(header_extension); + + *this.header = Some(Header { has_oids }); + has_oids + } + }; + + check_remaining(&chunk, 2)?; + let mut len = chunk.get_i16(); + if len == -1 { + return Poll::Ready(None); + } + + if has_oids { + len += 1; + } + if len as usize != this.types.len() { + return Poll::Ready(Some(Err(Error::parse(io::Error::new( + io::ErrorKind::InvalidInput, + format!("expected {} values but got {}", this.types.len(), len), + ))))); + } + + let mut ranges = vec![]; + for _ in 0..len { + check_remaining(&chunk, 4)?; + let len = chunk.get_i32(); + if len == -1 { + ranges.push(None); + } else { + let len = len as usize; + check_remaining(&chunk, len)?; + let start = chunk.position() as usize; + ranges.push(Some(start..start + len)); + chunk.advance(len); + } + } + + Poll::Ready(Some(Ok(BinaryCopyOutRow { + buf: chunk.into_inner(), + ranges, + types: this.types.clone(), + }))) + } +} + +fn check_remaining(buf: &Cursor, len: usize) -> Result<(), Error> { + if buf.remaining() < len { + Err(Error::parse(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + ))) + } else { + Ok(()) + } +} + +/// A row of data parsed from a binary copy out stream. +pub struct BinaryCopyOutRow { + buf: Bytes, + ranges: Vec>>, + types: Arc>, +} + +impl BinaryCopyOutRow { + /// Like `get`, but returns a `Result` rather than panicking. + pub fn try_get<'a, T>(&'a self, idx: usize) -> Result + where + T: FromSql<'a>, + { + let type_ = match self.types.get(idx) { + Some(type_) => type_, + None => return Err(Error::column(idx.to_string())), + }; + + if !T::accepts(type_) { + return Err(Error::from_sql( + Box::new(WrongType::new::(type_.clone())), + idx, + )); + } + + let r = match &self.ranges[idx] { + Some(range) => T::from_sql(type_, &self.buf[range.clone()]), + None => T::from_sql_null(type_), + }; + + r.map_err(|e| Error::from_sql(e, idx)) + } + + /// Deserializes a value from the row. + /// + /// # Panics + /// + /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + pub fn get<'a, T>(&'a self, idx: usize) -> T + where + T: FromSql<'a>, + { + match self.try_get(idx) { + Ok(value) => value, + Err(e) => panic!("error retrieving column {}: {}", idx, e), + } + } +} diff --git a/tokio-postgres/src/bind.rs b/tokio-postgres/src/bind.rs new file mode 100644 index 000000000..9c5c49218 --- /dev/null +++ b/tokio-postgres/src/bind.rs @@ -0,0 +1,38 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::types::BorrowToSql; +use crate::{query, Error, Portal, Statement}; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +static NEXT_ID: AtomicUsize = AtomicUsize::new(0); + +pub async fn bind( + client: &Arc, + statement: Statement, + params: I, +) -> Result +where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let name = format!("p{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); + let buf = client.with_buf(|buf| { + query::encode_bind(&statement, params, &name, buf)?; + frontend::sync(buf); + Ok(buf.split().freeze()) + })?; + + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(Portal::new(client, name, statement)) +} diff --git a/tokio-postgres/src/cancel_query.rs b/tokio-postgres/src/cancel_query.rs new file mode 100644 index 000000000..2dfd47c06 --- /dev/null +++ b/tokio-postgres/src/cancel_query.rs @@ -0,0 +1,52 @@ +use crate::client::SocketConfig; +use crate::config::{SslMode, SslNegotiation}; +use crate::tls::MakeTlsConnect; +use crate::{cancel_query_raw, connect_socket, Error, Socket}; +use std::io; + +pub(crate) async fn cancel_query( + config: Option, + ssl_mode: SslMode, + ssl_negotiation: SslNegotiation, + mut tls: T, + process_id: i32, + secret_key: i32, +) -> Result<(), Error> +where + T: MakeTlsConnect, +{ + let config = match config { + Some(config) => config, + None => { + return Err(Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "unknown host", + ))) + } + }; + + let tls = tls + .make_tls_connect(config.hostname.as_deref().unwrap_or("")) + .map_err(|e| Error::tls(e.into()))?; + let has_hostname = config.hostname.is_some(); + + let socket = connect_socket::connect_socket( + &config.addr, + config.port, + config.connect_timeout, + config.tcp_user_timeout, + config.keepalive.as_ref(), + ) + .await?; + + cancel_query_raw::cancel_query_raw( + socket, + ssl_mode, + ssl_negotiation, + tls, + has_hostname, + process_id, + secret_key, + ) + .await +} diff --git a/tokio-postgres/src/cancel_query_raw.rs b/tokio-postgres/src/cancel_query_raw.rs new file mode 100644 index 000000000..886606497 --- /dev/null +++ b/tokio-postgres/src/cancel_query_raw.rs @@ -0,0 +1,31 @@ +use crate::config::{SslMode, SslNegotiation}; +use crate::tls::TlsConnect; +use crate::{connect_tls, Error}; +use bytes::BytesMut; +use postgres_protocol::message::frontend; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; + +pub async fn cancel_query_raw( + stream: S, + mode: SslMode, + negotiation: SslNegotiation, + tls: T, + has_hostname: bool, + process_id: i32, + secret_key: i32, +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + let mut stream = connect_tls::connect_tls(stream, mode, negotiation, tls, has_hostname).await?; + + let mut buf = BytesMut::new(); + frontend::cancel_request(process_id, secret_key, &mut buf); + + stream.write_all(&buf).await.map_err(Error::io)?; + stream.flush().await.map_err(Error::io)?; + stream.shutdown().await.map_err(Error::io)?; + + Ok(()) +} diff --git a/tokio-postgres/src/cancel_token.rs b/tokio-postgres/src/cancel_token.rs new file mode 100644 index 000000000..1652bec72 --- /dev/null +++ b/tokio-postgres/src/cancel_token.rs @@ -0,0 +1,67 @@ +use crate::config::{SslMode, SslNegotiation}; +use crate::tls::TlsConnect; +#[cfg(feature = "runtime")] +use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect, Socket}; +use crate::{cancel_query_raw, Error}; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// The capability to request cancellation of in-progress queries on a +/// connection. +#[derive(Clone)] +pub struct CancelToken { + #[cfg(feature = "runtime")] + pub(crate) socket_config: Option, + pub(crate) ssl_mode: SslMode, + pub(crate) ssl_negotiation: SslNegotiation, + pub(crate) process_id: i32, + pub(crate) secret_key: i32, +} + +impl CancelToken { + /// Attempts to cancel the in-progress query on the connection associated + /// with this `CancelToken`. + /// + /// The server provides no information about whether a cancellation attempt was successful or not. An error will + /// only be returned if the client was unable to connect to the database. + /// + /// Cancellation is inherently racy. There is no guarantee that the + /// cancellation request will reach the server before the query terminates + /// normally, or that the connection associated with this token is still + /// active. + /// + /// Requires the `runtime` Cargo feature (enabled by default). + #[cfg(feature = "runtime")] + pub async fn cancel_query(&self, tls: T) -> Result<(), Error> + where + T: MakeTlsConnect, + { + cancel_query::cancel_query( + self.socket_config.clone(), + self.ssl_mode, + self.ssl_negotiation, + tls, + self.process_id, + self.secret_key, + ) + .await + } + + /// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new + /// connection itself. + pub async fn cancel_query_raw(&self, stream: S, tls: T) -> Result<(), Error> + where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, + { + cancel_query_raw::cancel_query_raw( + stream, + self.ssl_mode, + self.ssl_negotiation, + tls, + true, + self.process_id, + self.secret_key, + ) + .await + } +} diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs new file mode 100644 index 000000000..b38bbba37 --- /dev/null +++ b/tokio-postgres/src/client.rs @@ -0,0 +1,614 @@ +use crate::codec::BackendMessages; +use crate::config::{SslMode, SslNegotiation}; +use crate::connection::{Request, RequestMessages}; +use crate::copy_out::CopyOutStream; +#[cfg(feature = "runtime")] +use crate::keepalive::KeepaliveConfig; +use crate::query::RowStream; +use crate::simple_query::SimpleQueryStream; +#[cfg(feature = "runtime")] +use crate::tls::MakeTlsConnect; +use crate::tls::TlsConnect; +use crate::types::{Oid, ToSql, Type}; +#[cfg(feature = "runtime")] +use crate::Socket; +use crate::{ + copy_in, copy_out, prepare, query, simple_query, slice_iter, CancelToken, CopyInSink, Error, + Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, +}; +use bytes::{Buf, BytesMut}; +use fallible_iterator::FallibleIterator; +use futures_channel::mpsc; +use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; +use parking_lot::Mutex; +use postgres_protocol::message::backend::Message; +use postgres_types::BorrowToSql; +use std::collections::HashMap; +use std::fmt; +#[cfg(feature = "runtime")] +use std::net::IpAddr; +#[cfg(feature = "runtime")] +use std::path::PathBuf; +use std::sync::Arc; +use std::task::{Context, Poll}; +#[cfg(feature = "runtime")] +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncWrite}; + +pub struct Responses { + receiver: mpsc::Receiver, + cur: BackendMessages, +} + +impl Responses { + pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match self.cur.next().map_err(Error::parse)? { + Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))), + Some(message) => return Poll::Ready(Ok(message)), + None => {} + } + + match ready!(self.receiver.poll_next_unpin(cx)) { + Some(messages) => self.cur = messages, + None => return Poll::Ready(Err(Error::closed())), + } + } + } + + pub async fn next(&mut self) -> Result { + future::poll_fn(|cx| self.poll_next(cx)).await + } +} + +/// A cache of type info and prepared statements for fetching type info +/// (corresponding to the queries in the [prepare](prepare) module). +#[derive(Default)] +struct CachedTypeInfo { + /// A statement for basic information for a type from its + /// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its + /// fallback). + typeinfo: Option, + /// A statement for getting information for a composite type from its OID. + /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY). + typeinfo_composite: Option, + /// A statement for getting information for a composite type from its OID. + /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or + /// its fallback). + typeinfo_enum: Option, + + /// Cache of types already looked up. + types: HashMap, +} + +pub struct InnerClient { + sender: mpsc::UnboundedSender, + cached_typeinfo: Mutex, + + /// A buffer to use when writing out postgres commands. + buffer: Mutex, +} + +impl InnerClient { + pub fn send(&self, messages: RequestMessages) -> Result { + let (sender, receiver) = mpsc::channel(1); + let request = Request { messages, sender }; + self.sender + .unbounded_send(request) + .map_err(|_| Error::closed())?; + + Ok(Responses { + receiver, + cur: BackendMessages::empty(), + }) + } + + pub fn typeinfo(&self) -> Option { + self.cached_typeinfo.lock().typeinfo.clone() + } + + pub fn set_typeinfo(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo = Some(statement.clone()); + } + + pub fn typeinfo_composite(&self) -> Option { + self.cached_typeinfo.lock().typeinfo_composite.clone() + } + + pub fn set_typeinfo_composite(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone()); + } + + pub fn typeinfo_enum(&self) -> Option { + self.cached_typeinfo.lock().typeinfo_enum.clone() + } + + pub fn set_typeinfo_enum(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone()); + } + + pub fn type_(&self, oid: Oid) -> Option { + self.cached_typeinfo.lock().types.get(&oid).cloned() + } + + pub fn set_type(&self, oid: Oid, type_: &Type) { + self.cached_typeinfo.lock().types.insert(oid, type_.clone()); + } + + pub fn clear_type_cache(&self) { + self.cached_typeinfo.lock().types.clear(); + } + + /// Call the given function with a buffer to be used when writing out + /// postgres commands. + pub fn with_buf(&self, f: F) -> R + where + F: FnOnce(&mut BytesMut) -> R, + { + let mut buffer = self.buffer.lock(); + let r = f(&mut buffer); + buffer.clear(); + r + } +} + +#[cfg(feature = "runtime")] +#[derive(Clone)] +pub(crate) struct SocketConfig { + pub addr: Addr, + pub hostname: Option, + pub port: u16, + pub connect_timeout: Option, + pub tcp_user_timeout: Option, + pub keepalive: Option, +} + +#[cfg(feature = "runtime")] +#[derive(Clone)] +pub(crate) enum Addr { + Tcp(IpAddr), + #[cfg(unix)] + Unix(PathBuf), +} + +/// An asynchronous PostgreSQL client. +/// +/// The client is one half of what is returned when a connection is established. Users interact with the database +/// through this client object. +pub struct Client { + inner: Arc, + #[cfg(feature = "runtime")] + socket_config: Option, + ssl_mode: SslMode, + ssl_negotiation: SslNegotiation, + process_id: i32, + secret_key: i32, +} + +impl Client { + pub(crate) fn new( + sender: mpsc::UnboundedSender, + ssl_mode: SslMode, + ssl_negotiation: SslNegotiation, + process_id: i32, + secret_key: i32, + ) -> Client { + Client { + inner: Arc::new(InnerClient { + sender, + cached_typeinfo: Default::default(), + buffer: Default::default(), + }), + #[cfg(feature = "runtime")] + socket_config: None, + ssl_mode, + ssl_negotiation, + process_id, + secret_key, + } + } + + pub(crate) fn inner(&self) -> &Arc { + &self.inner + } + + #[cfg(feature = "runtime")] + pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) { + self.socket_config = Some(socket_config); + } + + /// Creates a new prepared statement. + /// + /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc), + /// which are set when executed. Prepared statements can only be used with the connection that created them. + pub async fn prepare(&self, query: &str) -> Result { + self.prepare_typed(query, &[]).await + } + + /// Like `prepare`, but allows the types of query parameters to be explicitly specified. + /// + /// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be + /// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`. + pub async fn prepare_typed( + &self, + query: &str, + parameter_types: &[Type], + ) -> Result { + prepare::prepare(&self.inner, query, parameter_types).await + } + + /// Executes a statement, returning a vector of the resulting rows. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + pub async fn query( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.query_raw(statement, slice_iter(params)) + .await? + .try_collect() + .await + } + + /// Executes a statement which returns a single row, returning it. + /// + /// Returns an error if the query does not return exactly one row. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + pub async fn query_one( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + ToStatement, + { + self.query_opt(statement, params) + .await + .and_then(|res| res.ok_or_else(Error::row_count)) + } + + /// Executes a statements which returns zero or one rows, returning it. + /// + /// Returns an error if the query returns more than one row. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + pub async fn query_opt( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement, + { + let stream = self.query_raw(statement, slice_iter(params)).await?; + pin_mut!(stream); + + let mut first = None; + + // Originally this was two calls to `try_next().await?`, + // once for the first element, and second to error if more than one. + // + // However, this new form with only one .await in a loop generates + // slightly smaller codegen/stack usage for the resulting future. + while let Some(row) = stream.try_next().await? { + if first.is_some() { + return Err(Error::row_count()); + } + + first = Some(row); + } + + Ok(first) + } + + /// The maximally flexible version of [`query`]. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// [`query`]: #method.query + /// + /// # Examples + /// + /// ```no_run + /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> { + /// use futures_util::{pin_mut, TryStreamExt}; + /// + /// let params: Vec = vec![ + /// "first param".into(), + /// "second param".into(), + /// ]; + /// let mut it = client.query_raw( + /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", + /// params, + /// ).await?; + /// + /// pin_mut!(it); + /// while let Some(row) = it.try_next().await? { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn query_raw(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let statement = statement.__convert().into_statement(self).await?; + query::query(&self.inner, statement, params).await + } + + /// Like `query`, but requires the types of query parameters to be explicitly specified. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. + pub async fn query_typed( + &self, + query: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed_raw(query, params.iter().map(|(v, t)| (*v, t.clone()))) + .await? + .try_collect() + .await + } + + /// The maximally flexible version of [`query_typed`]. + /// + /// Compared to `query`, this method allows performing queries without three round trips (for + /// prepare, execute, and close) by requiring the caller to specify parameter values along with + /// their Postgres type. Thus, this is suitable in environments where prepared statements aren't + /// supported (such as Cloudflare Workers with Hyperdrive). + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the + /// parameter of the list provided, 1-indexed. + /// + /// [`query_typed`]: #method.query_typed + /// + /// # Examples + /// + /// ```no_run + /// # async fn async_main(client: &tokio_postgres::Client) -> Result<(), tokio_postgres::Error> { + /// use futures_util::{pin_mut, TryStreamExt}; + /// use tokio_postgres::types::Type; + /// + /// let params: Vec<(String, Type)> = vec![ + /// ("first param".into(), Type::TEXT), + /// ("second param".into(), Type::TEXT), + /// ]; + /// let mut it = client.query_typed_raw( + /// "SELECT foo FROM bar WHERE biz = $1 AND baz = $2", + /// params, + /// ).await?; + /// + /// pin_mut!(it); + /// while let Some(row) = it.try_next().await? { + /// let foo: i32 = row.get("foo"); + /// println!("foo: {}", foo); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn query_typed_raw(&self, query: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator, + { + query::query_typed(&self.inner, query, params).await + } + + /// Executes a statement, returning the number of rows modified. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned. + pub async fn execute( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + ToStatement, + { + self.execute_raw(statement, slice_iter(params)).await + } + + /// The maximally flexible version of [`execute`]. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// [`execute`]: #method.execute + pub async fn execute_raw(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let statement = statement.__convert().into_statement(self).await?; + query::execute(self.inner(), statement, params).await + } + + /// Executes a `COPY FROM STDIN` statement, returning a sink used to write the copy data. + /// + /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. The copy *must* + /// be explicitly completed via the `Sink::close` or `finish` methods. If it is not, the copy will be aborted. + pub async fn copy_in(&self, statement: &T) -> Result, Error> + where + T: ?Sized + ToStatement, + U: Buf + 'static + Send, + { + let statement = statement.__convert().into_statement(self).await?; + copy_in::copy_in(self.inner(), statement).await + } + + /// Executes a `COPY TO STDOUT` statement, returning a stream of the resulting data. + /// + /// PostgreSQL does not support parameters in `COPY` statements, so this method does not take any. + pub async fn copy_out(&self, statement: &T) -> Result + where + T: ?Sized + ToStatement, + { + let statement = statement.__convert().into_statement(self).await?; + copy_out::copy_out(self.inner(), statement).await + } + + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. + /// + /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that + /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings, + /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the + /// rows, this method returns a list of an enum which indicates either the completion of one of the commands, + /// or a row of data. This preserves the framing between the separate statements in the request. + /// + /// # Warning + /// + /// Prepared statements should be use for any query which contains user-specified data, as they provided the + /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass + /// them to this method! + pub async fn simple_query(&self, query: &str) -> Result, Error> { + self.simple_query_raw(query).await?.try_collect().await + } + + pub(crate) async fn simple_query_raw(&self, query: &str) -> Result { + simple_query::simple_query(self.inner(), query).await + } + + /// Executes a sequence of SQL statements using the simple query protocol. + /// + /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that + /// point. This is intended for use when, for example, initializing a database schema. + /// + /// # Warning + /// + /// Prepared statements should be use for any query which contains user-specified data, as they provided the + /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass + /// them to this method! + pub async fn batch_execute(&self, query: &str) -> Result<(), Error> { + simple_query::batch_execute(self.inner(), query).await + } + + /// Begins a new database transaction. + /// + /// The transaction will roll back by default - use the `commit` method to commit it. + pub async fn transaction(&mut self) -> Result, Error> { + self.build_transaction().start().await + } + + /// Returns a builder for a transaction with custom settings. + /// + /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other + /// attributes. + pub fn build_transaction(&mut self) -> TransactionBuilder<'_> { + TransactionBuilder::new(self) + } + + /// Constructs a cancellation token that can later be used to request cancellation of a query running on the + /// connection associated with this client. + pub fn cancel_token(&self) -> CancelToken { + CancelToken { + #[cfg(feature = "runtime")] + socket_config: self.socket_config.clone(), + ssl_mode: self.ssl_mode, + ssl_negotiation: self.ssl_negotiation, + process_id: self.process_id, + secret_key: self.secret_key, + } + } + + /// Attempts to cancel an in-progress query. + /// + /// The server provides no information about whether a cancellation attempt was successful or not. An error will + /// only be returned if the client was unable to connect to the database. + /// + /// Requires the `runtime` Cargo feature (enabled by default). + #[cfg(feature = "runtime")] + #[deprecated(since = "0.6.0", note = "use Client::cancel_token() instead")] + pub async fn cancel_query(&self, tls: T) -> Result<(), Error> + where + T: MakeTlsConnect, + { + self.cancel_token().cancel_query(tls).await + } + + /// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new + /// connection itself. + #[deprecated(since = "0.6.0", note = "use Client::cancel_token() instead")] + pub async fn cancel_query_raw(&self, stream: S, tls: T) -> Result<(), Error> + where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, + { + self.cancel_token().cancel_query_raw(stream, tls).await + } + + /// Clears the client's type information cache. + /// + /// When user-defined types are used in a query, the client loads their definitions from the database and caches + /// them for the lifetime of the client. If those definitions are changed in the database, this method can be used + /// to flush the local cache and allow the new, updated definitions to be loaded. + pub fn clear_type_cache(&self) { + self.inner().clear_type_cache(); + } + + /// Determines if the connection to the server has already closed. + /// + /// In that case, all future queries will fail. + pub fn is_closed(&self) -> bool { + self.inner.sender.is_closed() + } + + #[doc(hidden)] + pub fn __private_api_close(&mut self) { + self.inner.sender.close_channel() + } +} + +impl fmt::Debug for Client { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Client").finish() + } +} diff --git a/tokio-postgres/src/codec.rs b/tokio-postgres/src/codec.rs new file mode 100644 index 000000000..9d078044b --- /dev/null +++ b/tokio-postgres/src/codec.rs @@ -0,0 +1,98 @@ +use bytes::{Buf, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use postgres_protocol::message::backend; +use postgres_protocol::message::frontend::CopyData; +use std::io; +use tokio_util::codec::{Decoder, Encoder}; + +pub enum FrontendMessage { + Raw(Bytes), + CopyData(CopyData>), +} + +pub enum BackendMessage { + Normal { + messages: BackendMessages, + request_complete: bool, + }, + Async(backend::Message), +} + +pub struct BackendMessages(BytesMut); + +impl BackendMessages { + pub fn empty() -> BackendMessages { + BackendMessages(BytesMut::new()) + } +} + +impl FallibleIterator for BackendMessages { + type Item = backend::Message; + type Error = io::Error; + + fn next(&mut self) -> io::Result> { + backend::Message::parse(&mut self.0) + } +} + +pub struct PostgresCodec; + +impl Encoder for PostgresCodec { + type Error = io::Error; + + fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> { + match item { + FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), + FrontendMessage::CopyData(data) => data.write(dst), + } + + Ok(()) + } +} + +impl Decoder for PostgresCodec { + type Item = BackendMessage; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { + let mut idx = 0; + let mut request_complete = false; + + while let Some(header) = backend::Header::parse(&src[idx..])? { + let len = header.len() as usize + 1; + if src[idx..].len() < len { + break; + } + + match header.tag() { + backend::NOTICE_RESPONSE_TAG + | backend::NOTIFICATION_RESPONSE_TAG + | backend::PARAMETER_STATUS_TAG => { + if idx == 0 { + let message = backend::Message::parse(src)?.unwrap(); + return Ok(Some(BackendMessage::Async(message))); + } else { + break; + } + } + _ => {} + } + + idx += len; + + if header.tag() == backend::READY_FOR_QUERY_TAG { + request_complete = true; + break; + } + } + + if idx == 0 { + Ok(None) + } else { + Ok(Some(BackendMessage::Normal { + messages: BackendMessages(src.split_to(idx)), + request_complete, + })) + } + } +} diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs new file mode 100644 index 000000000..59edd8fe2 --- /dev/null +++ b/tokio-postgres/src/config.rs @@ -0,0 +1,1211 @@ +//! Connection configuration. + +#![allow(clippy::doc_overindented_list_items)] + +#[cfg(feature = "runtime")] +use crate::connect::connect; +use crate::connect_raw::connect_raw; +#[cfg(not(target_arch = "wasm32"))] +use crate::keepalive::KeepaliveConfig; +#[cfg(feature = "runtime")] +use crate::tls::MakeTlsConnect; +use crate::tls::TlsConnect; +#[cfg(feature = "runtime")] +use crate::Socket; +use crate::{Client, Connection, Error}; +use std::borrow::Cow; +#[cfg(unix)] +use std::ffi::OsStr; +use std::net::IpAddr; +use std::ops::Deref; +#[cfg(unix)] +use std::os::unix::ffi::OsStrExt; +#[cfg(unix)] +use std::path::{Path, PathBuf}; +use std::str; +use std::str::FromStr; +use std::time::Duration; +use std::{error, fmt, iter, mem}; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// Properties required of a session. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum TargetSessionAttrs { + /// No special properties are required. + Any, + /// The session must allow writes. + ReadWrite, + /// The session allow only reads. + ReadOnly, +} + +/// TLS configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum SslMode { + /// Do not use TLS. + Disable, + /// Attempt to connect with TLS but allow sessions without. + Prefer, + /// Require the use of TLS. + Require, +} + +/// TLS negotiation configuration +/// +/// See more information at +/// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLNEGOTIATION +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] +#[non_exhaustive] +pub enum SslNegotiation { + /// Use PostgreSQL SslRequest for Ssl negotiation + #[default] + Postgres, + /// Start Ssl handshake without negotiation, only works for PostgreSQL 17+ + Direct, +} + +/// Channel binding configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ChannelBinding { + /// Do not use channel binding. + Disable, + /// Attempt to use channel binding but allow sessions without. + Prefer, + /// Require the use of channel binding. + Require, +} + +/// Load balancing configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum LoadBalanceHosts { + /// Make connection attempts to hosts in the order provided. + Disable, + /// Make connection attempts to hosts in a random order. + Random, +} + +/// A host specification. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Host { + /// A TCP hostname. + Tcp(String), + /// A path to a directory containing the server's Unix socket. + /// + /// This variant is only available on Unix platforms. + #[cfg(unix)] + Unix(PathBuf), +} + +/// Connection configuration. +/// +/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: +/// +/// # Key-Value +/// +/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain +/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped. +/// +/// ## Keys +/// +/// * `user` - The username to authenticate with. Defaults to the user executing this process. +/// * `password` - The password to authenticate with. +/// * `dbname` - The name of the database to connect to. Defaults to the username. +/// * `options` - Command line options used to configure the server. +/// * `application_name` - Sets the `application_name` parameter on the server. +/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used +/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`. +/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the +/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts +/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting +/// with the `connect` method. +/// * `sslnegotiation` - TLS negotiation method. If set to `direct`, the client +/// will perform direct TLS handshake, this only works for PostgreSQL 17 and +/// newer. +/// Note that you will need to setup ALPN of TLS client configuration to +/// `postgresql` when using direct TLS. If you are using postgres_openssl +/// as TLS backend, a `postgres_openssl::set_postgresql_alpn` helper is +/// provided for that. +/// If set to `postgres`, the default value, it follows original postgres +/// wire protocol to perform the negotiation. +/// * `hostaddr` - Numeric IP address of host to connect to. This should be in the standard IPv4 address format, +/// e.g., 172.28.40.9. If your machine supports IPv6, you can also use those addresses. +/// If this parameter is not specified, the value of `host` will be looked up to find the corresponding IP address, +/// or if host specifies an IP address, that value will be used directly. +/// Using `hostaddr` allows the application to avoid a host name look-up, which might be important in applications +/// with time constraints. However, a host name is required for TLS certificate verification. +/// Specifically: +/// * If `hostaddr` is specified without `host`, the value for `hostaddr` gives the server network address. +/// The connection attempt will fail if the authentication method requires a host name; +/// * If `host` is specified without `hostaddr`, a host name lookup occurs; +/// * If both `host` and `hostaddr` are specified, the value for `hostaddr` gives the server network address. +/// The value for `host` is ignored unless the authentication method requires it, +/// in which case it will be used as the host name. +/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be +/// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if +/// omitted or the empty string. +/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames +/// can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout. +/// * `tcp_user_timeout` - The time limit that transmitted data may remain unacknowledged before a connection is forcibly closed. +/// This is ignored for Unix domain socket connections. It is only supported on systems where TCP_USER_TIMEOUT is available +/// and will default to the system default if omitted or set to 0; on other systems, it has no effect. +/// * `keepalives` - Controls the use of TCP keepalive. A value of 0 disables keepalive and nonzero integers enable it. +/// This option is ignored when connecting with Unix sockets. Defaults to on. +/// * `keepalives_idle` - The number of seconds of inactivity after which a keepalive message is sent to the server. +/// This option is ignored when connecting with Unix sockets. Defaults to 2 hours. +/// * `keepalives_interval` - The time interval between TCP keepalive probes. +/// This option is ignored when connecting with Unix sockets. +/// * `keepalives_retries` - The maximum number of TCP keepalive probes that will be sent before dropping a connection. +/// This option is ignored when connecting with Unix sockets. +/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that +/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server +/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`. +/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel +/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. +/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// * `load_balance_hosts` - Controls the order in which the client tries to connect to the available hosts and +/// addresses. Once a connection attempt is successful no other hosts and addresses will be tried. This parameter +/// is typically used in combination with multiple host names or a DNS record that returns multiple IPs. If set to +/// `disable`, hosts and addresses will be tried in the order provided. If set to `random`, hosts will be tried +/// in a random order, and the IP addresses resolved from a hostname will also be tried in a random order. Defaults +/// to `disable`. +/// +/// ## Examples +/// +/// ```not_rust +/// host=localhost user=postgres connect_timeout=10 keepalives=0 +/// ``` +/// +/// ```not_rust +/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces' +/// ``` +/// +/// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 hostaddr=127.0.0.1,127.0.0.2,127.0.0.3 user=postgres target_session_attrs=read-write +/// ``` +/// +/// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write +/// ``` +/// +/// # Url +/// +/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional, +/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple +/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded, +/// as the path component of the URL specifies the database name. +/// +/// ## Examples +/// +/// ```not_rust +/// postgresql://user@localhost +/// ``` +/// +/// ```not_rust +/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10 +/// ``` +/// +/// ```not_rust +/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write +/// ``` +/// +/// ```not_rust +/// postgresql:///mydb?user=user&host=/var/lib/postgresql +/// ``` +#[derive(Clone, PartialEq, Eq)] +pub struct Config { + pub(crate) user: Option, + pub(crate) password: Option>, + pub(crate) dbname: Option, + pub(crate) options: Option, + pub(crate) application_name: Option, + pub(crate) ssl_mode: SslMode, + pub(crate) ssl_negotiation: SslNegotiation, + pub(crate) host: Vec, + pub(crate) hostaddr: Vec, + pub(crate) port: Vec, + pub(crate) connect_timeout: Option, + pub(crate) tcp_user_timeout: Option, + pub(crate) keepalives: bool, + #[cfg(not(target_arch = "wasm32"))] + pub(crate) keepalive_config: KeepaliveConfig, + pub(crate) target_session_attrs: TargetSessionAttrs, + pub(crate) channel_binding: ChannelBinding, + pub(crate) load_balance_hosts: LoadBalanceHosts, +} + +impl Default for Config { + fn default() -> Config { + Config::new() + } +} + +impl Config { + /// Creates a new configuration. + pub fn new() -> Config { + Config { + user: None, + password: None, + dbname: None, + options: None, + application_name: None, + ssl_mode: SslMode::Prefer, + ssl_negotiation: SslNegotiation::Postgres, + host: vec![], + hostaddr: vec![], + port: vec![], + connect_timeout: None, + tcp_user_timeout: None, + keepalives: true, + #[cfg(not(target_arch = "wasm32"))] + keepalive_config: KeepaliveConfig { + idle: Duration::from_secs(2 * 60 * 60), + interval: None, + retries: None, + }, + target_session_attrs: TargetSessionAttrs::Any, + channel_binding: ChannelBinding::Prefer, + load_balance_hosts: LoadBalanceHosts::Disable, + } + } + + /// Sets the user to authenticate with. + /// + /// Defaults to the user executing this process. + pub fn user(&mut self, user: impl Into) -> &mut Config { + self.user = Some(user.into()); + self + } + + /// Gets the user to authenticate with, if one has been configured with + /// the `user` method. + pub fn get_user(&self) -> Option<&str> { + self.user.as_deref() + } + + /// Sets the password to authenticate with. + pub fn password(&mut self, password: T) -> &mut Config + where + T: AsRef<[u8]>, + { + self.password = Some(password.as_ref().to_vec()); + self + } + + /// Gets the password to authenticate with, if one has been configured with + /// the `password` method. + pub fn get_password(&self) -> Option<&[u8]> { + self.password.as_deref() + } + + /// Sets the name of the database to connect to. + /// + /// Defaults to the user. + pub fn dbname(&mut self, dbname: impl Into) -> &mut Config { + self.dbname = Some(dbname.into()); + self + } + + /// Gets the name of the database to connect to, if one has been configured + /// with the `dbname` method. + pub fn get_dbname(&self) -> Option<&str> { + self.dbname.as_deref() + } + + /// Sets command line options used to configure the server. + pub fn options(&mut self, options: impl Into) -> &mut Config { + self.options = Some(options.into()); + self + } + + /// Gets the command line options used to configure the server, if the + /// options have been set with the `options` method. + pub fn get_options(&self) -> Option<&str> { + self.options.as_deref() + } + + /// Sets the value of the `application_name` runtime parameter. + pub fn application_name(&mut self, application_name: impl Into) -> &mut Config { + self.application_name = Some(application_name.into()); + self + } + + /// Gets the value of the `application_name` runtime parameter, if it has + /// been set with the `application_name` method. + pub fn get_application_name(&self) -> Option<&str> { + self.application_name.as_deref() + } + + /// Sets the SSL configuration. + /// + /// Defaults to `prefer`. + pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config { + self.ssl_mode = ssl_mode; + self + } + + /// Gets the SSL configuration. + pub fn get_ssl_mode(&self) -> SslMode { + self.ssl_mode + } + + /// Sets the SSL negotiation method. + /// + /// Defaults to `postgres`. + pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config { + self.ssl_negotiation = ssl_negotiation; + self + } + + /// Gets the SSL negotiation method. + pub fn get_ssl_negotiation(&self) -> SslNegotiation { + self.ssl_negotiation + } + + /// Adds a host to the configuration. + /// + /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix + /// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets. + /// There must be either no hosts, or the same number of hosts as hostaddrs. + pub fn host(&mut self, host: impl Into) -> &mut Config { + let host = host.into(); + + #[cfg(unix)] + { + if host.starts_with('/') { + return self.host_path(host); + } + } + + self.host.push(Host::Tcp(host)); + self + } + + /// Gets the hosts that have been added to the configuration with `host`. + pub fn get_hosts(&self) -> &[Host] { + &self.host + } + + /// Gets the hostaddrs that have been added to the configuration with `hostaddr`. + pub fn get_hostaddrs(&self) -> &[IpAddr] { + self.hostaddr.deref() + } + + /// Adds a Unix socket host to the configuration. + /// + /// Unlike `host`, this method allows non-UTF8 paths. + #[cfg(unix)] + pub fn host_path(&mut self, host: T) -> &mut Config + where + T: AsRef, + { + self.host.push(Host::Unix(host.as_ref().to_path_buf())); + self + } + + /// Adds a hostaddr to the configuration. + /// + /// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order. + /// There must be either no hostaddrs, or the same number of hostaddrs as hosts. + pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config { + self.hostaddr.push(hostaddr); + self + } + + /// Adds a port to the configuration. + /// + /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which + /// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports + /// as hosts. + pub fn port(&mut self, port: u16) -> &mut Config { + self.port.push(port); + self + } + + /// Gets the ports that have been added to the configuration with `port`. + pub fn get_ports(&self) -> &[u16] { + &self.port + } + + /// Sets the timeout applied to socket-level connection attempts. + /// + /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each + /// host separately. Defaults to no limit. + pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config { + self.connect_timeout = Some(connect_timeout); + self + } + + /// Gets the connection timeout, if one has been set with the + /// `connect_timeout` method. + pub fn get_connect_timeout(&self) -> Option<&Duration> { + self.connect_timeout.as_ref() + } + + /// Sets the TCP user timeout. + /// + /// This is ignored for Unix domain socket connections. It is only supported on systems where + /// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0; + /// on other systems, it has no effect. + pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config { + self.tcp_user_timeout = Some(tcp_user_timeout); + self + } + + /// Gets the TCP user timeout, if one has been set with the + /// `user_timeout` method. + pub fn get_tcp_user_timeout(&self) -> Option<&Duration> { + self.tcp_user_timeout.as_ref() + } + + /// Controls the use of TCP keepalive. + /// + /// This is ignored for Unix domain socket connections. Defaults to `true`. + pub fn keepalives(&mut self, keepalives: bool) -> &mut Config { + self.keepalives = keepalives; + self + } + + /// Reports whether TCP keepalives will be used. + pub fn get_keepalives(&self) -> bool { + self.keepalives + } + + /// Sets the amount of idle time before a keepalive packet is sent on the connection. + /// + /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours. + #[cfg(not(target_arch = "wasm32"))] + pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config { + self.keepalive_config.idle = keepalives_idle; + self + } + + /// Gets the configured amount of idle time before a keepalive packet will + /// be sent on the connection. + #[cfg(not(target_arch = "wasm32"))] + pub fn get_keepalives_idle(&self) -> Duration { + self.keepalive_config.idle + } + + /// Sets the time interval between TCP keepalive probes. + /// On Windows, this sets the value of the tcp_keepalive struct’s keepaliveinterval field. + /// + /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + #[cfg(not(target_arch = "wasm32"))] + pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config { + self.keepalive_config.interval = Some(keepalives_interval); + self + } + + /// Gets the time interval between TCP keepalive probes. + #[cfg(not(target_arch = "wasm32"))] + pub fn get_keepalives_interval(&self) -> Option { + self.keepalive_config.interval + } + + /// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection. + /// + /// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. + #[cfg(not(target_arch = "wasm32"))] + pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config { + self.keepalive_config.retries = Some(keepalives_retries); + self + } + + /// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection. + #[cfg(not(target_arch = "wasm32"))] + pub fn get_keepalives_retries(&self) -> Option { + self.keepalive_config.retries + } + + /// Sets the requirements of the session. + /// + /// This can be used to connect to the primary server in a clustered database rather than one of the read-only + /// secondary servers. Defaults to `Any`. + pub fn target_session_attrs( + &mut self, + target_session_attrs: TargetSessionAttrs, + ) -> &mut Config { + self.target_session_attrs = target_session_attrs; + self + } + + /// Gets the requirements of the session. + pub fn get_target_session_attrs(&self) -> TargetSessionAttrs { + self.target_session_attrs + } + + /// Sets the channel binding behavior. + /// + /// Defaults to `prefer`. + pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config { + self.channel_binding = channel_binding; + self + } + + /// Gets the channel binding behavior. + pub fn get_channel_binding(&self) -> ChannelBinding { + self.channel_binding + } + + /// Sets the host load balancing behavior. + /// + /// Defaults to `disable`. + pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config { + self.load_balance_hosts = load_balance_hosts; + self + } + + /// Gets the host load balancing behavior. + pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts { + self.load_balance_hosts + } + + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { + match key { + "user" => { + self.user(value); + } + "password" => { + self.password(value); + } + "dbname" => { + self.dbname(value); + } + "options" => { + self.options(value); + } + "application_name" => { + self.application_name(value); + } + "sslmode" => { + let mode = match value { + "disable" => SslMode::Disable, + "prefer" => SslMode::Prefer, + "require" => SslMode::Require, + _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))), + }; + self.ssl_mode(mode); + } + "sslnegotiation" => { + let mode = match value { + "postgres" => SslNegotiation::Postgres, + "direct" => SslNegotiation::Direct, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "sslnegotiation", + )))) + } + }; + self.ssl_negotiation(mode); + } + "host" => { + for host in value.split(',') { + self.host(host); + } + } + "hostaddr" => { + for hostaddr in value.split(',') { + let addr = hostaddr + .parse() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("hostaddr"))))?; + self.hostaddr(addr); + } + } + "port" => { + for port in value.split(',') { + let port = if port.is_empty() { + 5432 + } else { + port.parse() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))? + }; + self.port(port); + } + } + "connect_timeout" => { + let timeout = value + .parse::() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?; + if timeout > 0 { + self.connect_timeout(Duration::from_secs(timeout as u64)); + } + } + "tcp_user_timeout" => { + let timeout = value + .parse::() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("tcp_user_timeout"))))?; + if timeout > 0 { + self.tcp_user_timeout(Duration::from_secs(timeout as u64)); + } + } + #[cfg(not(target_arch = "wasm32"))] + "keepalives" => { + let keepalives = value + .parse::() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives"))))?; + self.keepalives(keepalives != 0); + } + #[cfg(not(target_arch = "wasm32"))] + "keepalives_idle" => { + let keepalives_idle = value + .parse::() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("keepalives_idle"))))?; + if keepalives_idle > 0 { + self.keepalives_idle(Duration::from_secs(keepalives_idle as u64)); + } + } + #[cfg(not(target_arch = "wasm32"))] + "keepalives_interval" => { + let keepalives_interval = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("keepalives_interval"))) + })?; + if keepalives_interval > 0 { + self.keepalives_interval(Duration::from_secs(keepalives_interval as u64)); + } + } + #[cfg(not(target_arch = "wasm32"))] + "keepalives_retries" => { + let keepalives_retries = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("keepalives_retries"))) + })?; + self.keepalives_retries(keepalives_retries); + } + "target_session_attrs" => { + let target_session_attrs = match value { + "any" => TargetSessionAttrs::Any, + "read-write" => TargetSessionAttrs::ReadWrite, + "read-only" => TargetSessionAttrs::ReadOnly, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "target_session_attrs", + )))); + } + }; + self.target_session_attrs(target_session_attrs); + } + "channel_binding" => { + let channel_binding = match value { + "disable" => ChannelBinding::Disable, + "prefer" => ChannelBinding::Prefer, + "require" => ChannelBinding::Require, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "channel_binding", + )))) + } + }; + self.channel_binding(channel_binding); + } + "load_balance_hosts" => { + let load_balance_hosts = match value { + "disable" => LoadBalanceHosts::Disable, + "random" => LoadBalanceHosts::Random, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "load_balance_hosts", + )))) + } + }; + self.load_balance_hosts(load_balance_hosts); + } + key => { + return Err(Error::config_parse(Box::new(UnknownOption( + key.to_string(), + )))); + } + } + + Ok(()) + } + + /// Opens a connection to a PostgreSQL database. + /// + /// Requires the `runtime` Cargo feature (enabled by default). + #[cfg(feature = "runtime")] + pub async fn connect(&self, tls: T) -> Result<(Client, Connection), Error> + where + T: MakeTlsConnect, + { + connect(tls, self).await + } + + /// Connects to a PostgreSQL database over an arbitrary stream. + /// + /// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored. + pub async fn connect_raw( + &self, + stream: S, + tls: T, + ) -> Result<(Client, Connection), Error> + where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, + { + connect_raw(stream, tls, true, self).await + } +} + +impl FromStr for Config { + type Err = Error; + + fn from_str(s: &str) -> Result { + match UrlParser::parse(s)? { + Some(config) => Ok(config), + None => Parser::parse(s), + } + } +} + +// Omit password from debug output +impl fmt::Debug for Config { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Redaction {} + impl fmt::Debug for Redaction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "_") + } + } + + let mut config_dbg = &mut f.debug_struct("Config"); + config_dbg = config_dbg + .field("user", &self.user) + .field("password", &self.password.as_ref().map(|_| Redaction {})) + .field("dbname", &self.dbname) + .field("options", &self.options) + .field("application_name", &self.application_name) + .field("ssl_mode", &self.ssl_mode) + .field("host", &self.host) + .field("hostaddr", &self.hostaddr) + .field("port", &self.port) + .field("connect_timeout", &self.connect_timeout) + .field("tcp_user_timeout", &self.tcp_user_timeout) + .field("keepalives", &self.keepalives); + + #[cfg(not(target_arch = "wasm32"))] + { + config_dbg = config_dbg + .field("keepalives_idle", &self.keepalive_config.idle) + .field("keepalives_interval", &self.keepalive_config.interval) + .field("keepalives_retries", &self.keepalive_config.retries); + } + + config_dbg + .field("target_session_attrs", &self.target_session_attrs) + .field("channel_binding", &self.channel_binding) + .field("load_balance_hosts", &self.load_balance_hosts) + .finish() + } +} + +#[derive(Debug)] +struct UnknownOption(String); + +impl fmt::Display for UnknownOption { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "unknown option `{}`", self.0) + } +} + +impl error::Error for UnknownOption {} + +#[derive(Debug)] +struct InvalidValue(&'static str); + +impl fmt::Display for InvalidValue { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "invalid value for option `{}`", self.0) + } +} + +impl error::Error for InvalidValue {} + +struct Parser<'a> { + s: &'a str, + it: iter::Peekable>, +} + +impl<'a> Parser<'a> { + fn parse(s: &'a str) -> Result { + let mut parser = Parser { + s, + it: s.char_indices().peekable(), + }; + + let mut config = Config::new(); + + while let Some((key, value)) = parser.parameter()? { + config.param(key, &value)?; + } + + Ok(config) + } + + fn skip_ws(&mut self) { + self.take_while(char::is_whitespace); + } + + fn take_while(&mut self, f: F) -> &'a str + where + F: Fn(char) -> bool, + { + let start = match self.it.peek() { + Some(&(i, _)) => i, + None => return "", + }; + + loop { + match self.it.peek() { + Some(&(_, c)) if f(c) => { + self.it.next(); + } + Some(&(i, _)) => return &self.s[start..i], + None => return &self.s[start..], + } + } + } + + fn eat(&mut self, target: char) -> Result<(), Error> { + match self.it.next() { + Some((_, c)) if c == target => Ok(()), + Some((i, c)) => { + let m = format!( + "unexpected character at byte {}: expected `{}` but got `{}`", + i, target, c + ); + Err(Error::config_parse(m.into())) + } + None => Err(Error::config_parse("unexpected EOF".into())), + } + } + + fn eat_if(&mut self, target: char) -> bool { + match self.it.peek() { + Some(&(_, c)) if c == target => { + self.it.next(); + true + } + _ => false, + } + } + + fn keyword(&mut self) -> Option<&'a str> { + let s = self.take_while(|c| match c { + c if c.is_whitespace() => false, + '=' => false, + _ => true, + }); + + if s.is_empty() { + None + } else { + Some(s) + } + } + + fn value(&mut self) -> Result { + let value = if self.eat_if('\'') { + let value = self.quoted_value()?; + self.eat('\'')?; + value + } else { + self.simple_value()? + }; + + Ok(value) + } + + fn simple_value(&mut self) -> Result { + let mut value = String::new(); + + while let Some(&(_, c)) = self.it.peek() { + if c.is_whitespace() { + break; + } + + self.it.next(); + if c == '\\' { + if let Some((_, c2)) = self.it.next() { + value.push(c2); + } + } else { + value.push(c); + } + } + + if value.is_empty() { + return Err(Error::config_parse("unexpected EOF".into())); + } + + Ok(value) + } + + fn quoted_value(&mut self) -> Result { + let mut value = String::new(); + + while let Some(&(_, c)) = self.it.peek() { + if c == '\'' { + return Ok(value); + } + + self.it.next(); + if c == '\\' { + if let Some((_, c2)) = self.it.next() { + value.push(c2); + } + } else { + value.push(c); + } + } + + Err(Error::config_parse( + "unterminated quoted connection parameter value".into(), + )) + } + + fn parameter(&mut self) -> Result, Error> { + self.skip_ws(); + let keyword = match self.keyword() { + Some(keyword) => keyword, + None => return Ok(None), + }; + self.skip_ws(); + self.eat('=')?; + self.skip_ws(); + let value = self.value()?; + + Ok(Some((keyword, value))) + } +} + +// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict +struct UrlParser<'a> { + s: &'a str, + config: Config, +} + +impl<'a> UrlParser<'a> { + fn parse(s: &'a str) -> Result, Error> { + let s = match Self::remove_url_prefix(s) { + Some(s) => s, + None => return Ok(None), + }; + + let mut parser = UrlParser { + s, + config: Config::new(), + }; + + parser.parse_credentials()?; + parser.parse_host()?; + parser.parse_path()?; + parser.parse_params()?; + + Ok(Some(parser.config)) + } + + fn remove_url_prefix(s: &str) -> Option<&str> { + for prefix in &["postgres://", "postgresql://"] { + if let Some(stripped) = s.strip_prefix(prefix) { + return Some(stripped); + } + } + + None + } + + fn take_until(&mut self, end: &[char]) -> Option<&'a str> { + match self.s.find(end) { + Some(pos) => { + let (head, tail) = self.s.split_at(pos); + self.s = tail; + Some(head) + } + None => None, + } + } + + fn take_all(&mut self) -> &'a str { + mem::take(&mut self.s) + } + + fn eat_byte(&mut self) { + self.s = &self.s[1..]; + } + + fn parse_credentials(&mut self) -> Result<(), Error> { + let creds = match self.take_until(&['@']) { + Some(creds) => creds, + None => return Ok(()), + }; + self.eat_byte(); + + let mut it = creds.splitn(2, ':'); + let user = self.decode(it.next().unwrap())?; + self.config.user(user); + + if let Some(password) = it.next() { + let password = Cow::from(percent_encoding::percent_decode(password.as_bytes())); + self.config.password(password); + } + + Ok(()) + } + + fn parse_host(&mut self) -> Result<(), Error> { + let host = match self.take_until(&['/', '?']) { + Some(host) => host, + None => self.take_all(), + }; + + if host.is_empty() { + return Ok(()); + } + + for chunk in host.split(',') { + let (host, port) = if chunk.starts_with('[') { + let idx = match chunk.find(']') { + Some(idx) => idx, + None => return Err(Error::config_parse(InvalidValue("host").into())), + }; + + let host = &chunk[1..idx]; + let remaining = &chunk[idx + 1..]; + let port = if let Some(port) = remaining.strip_prefix(':') { + Some(port) + } else if remaining.is_empty() { + None + } else { + return Err(Error::config_parse(InvalidValue("host").into())); + }; + + (host, port) + } else { + let mut it = chunk.splitn(2, ':'); + (it.next().unwrap(), it.next()) + }; + + self.host_param(host)?; + let port = self.decode(port.unwrap_or("5432"))?; + self.config.param("port", &port)?; + } + + Ok(()) + } + + fn parse_path(&mut self) -> Result<(), Error> { + if !self.s.starts_with('/') { + return Ok(()); + } + self.eat_byte(); + + let dbname = match self.take_until(&['?']) { + Some(dbname) => dbname, + None => self.take_all(), + }; + + if !dbname.is_empty() { + self.config.dbname(self.decode(dbname)?); + } + + Ok(()) + } + + fn parse_params(&mut self) -> Result<(), Error> { + if !self.s.starts_with('?') { + return Ok(()); + } + self.eat_byte(); + + while !self.s.is_empty() { + let key = match self.take_until(&['=']) { + Some(key) => self.decode(key)?, + None => return Err(Error::config_parse("unterminated parameter".into())), + }; + self.eat_byte(); + + let value = match self.take_until(&['&']) { + Some(value) => { + self.eat_byte(); + value + } + None => self.take_all(), + }; + + if key == "host" { + self.host_param(value)?; + } else { + let value = self.decode(value)?; + self.config.param(&key, &value)?; + } + } + + Ok(()) + } + + #[cfg(unix)] + fn host_param(&mut self, s: &str) -> Result<(), Error> { + let decoded = Cow::from(percent_encoding::percent_decode(s.as_bytes())); + if decoded.first() == Some(&b'/') { + self.config.host_path(OsStr::from_bytes(&decoded)); + } else { + let decoded = str::from_utf8(&decoded).map_err(|e| Error::config_parse(Box::new(e)))?; + self.config.host(decoded); + } + + Ok(()) + } + + #[cfg(not(unix))] + fn host_param(&mut self, s: &str) -> Result<(), Error> { + let s = self.decode(s)?; + self.config.param("host", &s) + } + + fn decode(&self, s: &'a str) -> Result, Error> { + percent_encoding::percent_decode(s.as_bytes()) + .decode_utf8() + .map_err(|e| Error::config_parse(e.into())) + } +} + +#[cfg(test)] +mod tests { + use std::net::IpAddr; + + use crate::{config::Host, Config}; + + #[test] + fn test_simple_parsing() { + let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257"; + let config = s.parse::().unwrap(); + assert_eq!(Some("pass_user"), config.get_user()); + assert_eq!(Some("postgres"), config.get_dbname()); + assert_eq!( + [ + Host::Tcp("host1".to_string()), + Host::Tcp("host2".to_string()) + ], + config.get_hosts(), + ); + + assert_eq!( + [ + "127.0.0.1".parse::().unwrap(), + "127.0.0.2".parse::().unwrap() + ], + config.get_hostaddrs(), + ); + + assert_eq!(1, 1); + } + + #[test] + fn test_invalid_hostaddr_parsing() { + let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257"; + s.parse::().err().unwrap(); + } +} diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs new file mode 100644 index 000000000..e97a7a2a3 --- /dev/null +++ b/tokio-postgres/src/connect.rs @@ -0,0 +1,227 @@ +use crate::client::{Addr, SocketConfig}; +use crate::config::{Host, LoadBalanceHosts, TargetSessionAttrs}; +use crate::connect_raw::connect_raw; +use crate::connect_socket::connect_socket; +use crate::tls::MakeTlsConnect; +use crate::{Client, Config, Connection, Error, SimpleQueryMessage, Socket}; +use futures_util::{future, pin_mut, Future, FutureExt, Stream}; +use rand::seq::SliceRandom; +use std::task::Poll; +use std::{cmp, io}; +use tokio::net; + +pub async fn connect( + mut tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + if config.host.is_empty() && config.hostaddr.is_empty() { + return Err(Error::config("both host and hostaddr are missing".into())); + } + + if !config.host.is_empty() + && !config.hostaddr.is_empty() + && config.host.len() != config.hostaddr.len() + { + let msg = format!( + "number of hosts ({}) is different from number of hostaddrs ({})", + config.host.len(), + config.hostaddr.len(), + ); + return Err(Error::config(msg.into())); + } + + // At this point, either one of the following two scenarios could happen: + // (1) either config.host or config.hostaddr must be empty; + // (2) if both config.host and config.hostaddr are NOT empty; their lengths must be equal. + let num_hosts = cmp::max(config.host.len(), config.hostaddr.len()); + + if config.port.len() > 1 && config.port.len() != num_hosts { + return Err(Error::config("invalid number of ports".into())); + } + + let mut indices = (0..num_hosts).collect::>(); + if config.load_balance_hosts == LoadBalanceHosts::Random { + indices.shuffle(&mut rand::rng()); + } + + let mut error = None; + for i in indices { + let host = config.host.get(i); + let hostaddr = config.hostaddr.get(i); + let port = config + .port + .get(i) + .or_else(|| config.port.first()) + .copied() + .unwrap_or(5432); + + // The value of host is used as the hostname for TLS validation, + let hostname = match host { + Some(Host::Tcp(host)) => Some(host.clone()), + // postgres doesn't support TLS over unix sockets, so the choice here doesn't matter + #[cfg(unix)] + Some(Host::Unix(_)) => None, + None => None, + }; + + // Try to use the value of hostaddr to establish the TCP connection, + // fallback to host if hostaddr is not present. + let addr = match hostaddr { + Some(ipaddr) => Host::Tcp(ipaddr.to_string()), + None => host.cloned().unwrap(), + }; + + match connect_host(addr, hostname, port, &mut tls, config).await { + Ok((client, connection)) => return Ok((client, connection)), + Err(e) => error = Some(e), + } + } + + Err(error.unwrap()) +} + +async fn connect_host( + host: Host, + hostname: Option, + port: u16, + tls: &mut T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + match host { + Host::Tcp(host) => { + let mut addrs = net::lookup_host((&*host, port)) + .await + .map_err(Error::connect)? + .collect::>(); + + if config.load_balance_hosts == LoadBalanceHosts::Random { + addrs.shuffle(&mut rand::rng()); + } + + let mut last_err = None; + for addr in addrs { + match connect_once(Addr::Tcp(addr.ip()), hostname.as_deref(), port, tls, config) + .await + { + Ok(stream) => return Ok(stream), + Err(e) => { + last_err = Some(e); + continue; + } + }; + } + + Err(last_err.unwrap_or_else(|| { + Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve any addresses", + )) + })) + } + #[cfg(unix)] + Host::Unix(path) => { + connect_once(Addr::Unix(path), hostname.as_deref(), port, tls, config).await + } + } +} + +async fn connect_once( + addr: Addr, + hostname: Option<&str>, + port: u16, + tls: &mut T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + let socket = connect_socket( + &addr, + port, + config.connect_timeout, + config.tcp_user_timeout, + if config.keepalives { + Some(&config.keepalive_config) + } else { + None + }, + ) + .await?; + + let tls = tls + .make_tls_connect(hostname.unwrap_or("")) + .map_err(|e| Error::tls(e.into()))?; + let has_hostname = hostname.is_some(); + let (mut client, mut connection) = connect_raw(socket, tls, has_hostname, config).await?; + + if config.target_session_attrs != TargetSessionAttrs::Any { + let rows = client.simple_query_raw("SHOW transaction_read_only"); + pin_mut!(rows); + + let rows = future::poll_fn(|cx| { + if connection.poll_unpin(cx)?.is_ready() { + return Poll::Ready(Err(Error::closed())); + } + + rows.as_mut().poll(cx) + }) + .await?; + pin_mut!(rows); + + loop { + let next = future::poll_fn(|cx| { + if connection.poll_unpin(cx)?.is_ready() { + return Poll::Ready(Some(Err(Error::closed()))); + } + + rows.as_mut().poll_next(cx) + }); + + match next.await.transpose()? { + Some(SimpleQueryMessage::Row(row)) => { + let read_only_result = row.try_get(0)?; + if read_only_result == Some("on") + && config.target_session_attrs == TargetSessionAttrs::ReadWrite + { + return Err(Error::connect(io::Error::new( + io::ErrorKind::PermissionDenied, + "database does not allow writes", + ))); + } else if read_only_result == Some("off") + && config.target_session_attrs == TargetSessionAttrs::ReadOnly + { + return Err(Error::connect(io::Error::new( + io::ErrorKind::PermissionDenied, + "database is not read only", + ))); + } else { + break; + } + } + Some(_) => {} + None => return Err(Error::unexpected_message()), + } + } + } + + client.set_socket_config(SocketConfig { + addr, + hostname: hostname.map(|s| s.to_string()), + port, + connect_timeout: config.connect_timeout, + tcp_user_timeout: config.tcp_user_timeout, + keepalive: if config.keepalives { + Some(config.keepalive_config.clone()) + } else { + None + }, + }); + + Ok((client, connection)) +} diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs new file mode 100644 index 000000000..cf7476cab --- /dev/null +++ b/tokio-postgres/src/connect_raw.rs @@ -0,0 +1,368 @@ +use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::config::{self, Config}; +use crate::connect_tls::connect_tls; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::tls::{TlsConnect, TlsStream}; +use crate::{Client, Connection, Error}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures_channel::mpsc; +use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt}; +use postgres_protocol::authentication; +use postgres_protocol::authentication::sasl; +use postgres_protocol::authentication::sasl::ScramSha256; +use postgres_protocol::message::backend::{AuthenticationSaslBody, Message}; +use postgres_protocol::message::frontend; +use std::borrow::Cow; +use std::collections::{HashMap, VecDeque}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; + +pub struct StartupStream { + inner: Framed, PostgresCodec>, + buf: BackendMessages, + delayed: VecDeque, +} + +impl Sink for StartupStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Error = io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> { + Pin::new(&mut self.inner).start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_close(cx) + } +} + +impl Stream for StartupStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Item = io::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + match self.buf.next() { + Ok(Some(message)) => return Poll::Ready(Some(Ok(message))), + Ok(None) => {} + Err(e) => return Poll::Ready(Some(Err(e))), + } + + match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages, + Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => return Poll::Ready(None), + } + } + } +} + +pub async fn connect_raw( + stream: S, + tls: T, + has_hostname: bool, + config: &Config, +) -> Result<(Client, Connection), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + let stream = connect_tls( + stream, + config.ssl_mode, + config.ssl_negotiation, + tls, + has_hostname, + ) + .await?; + + let mut stream = StartupStream { + inner: Framed::new(stream, PostgresCodec), + buf: BackendMessages::empty(), + delayed: VecDeque::new(), + }; + + let user = config + .user + .as_deref() + .map_or_else(|| Cow::Owned(whoami::username()), Cow::Borrowed); + + startup(&mut stream, config, &user).await?; + authenticate(&mut stream, config, &user).await?; + let (process_id, secret_key, parameters) = read_info(&mut stream).await?; + + let (sender, receiver) = mpsc::unbounded(); + let client = Client::new( + sender, + config.ssl_mode, + config.ssl_negotiation, + process_id, + secret_key, + ); + let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver); + + Ok((client, connection)) +} + +async fn startup( + stream: &mut StartupStream, + config: &Config, + user: &str, +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut params = vec![("client_encoding", "UTF8")]; + params.push(("user", user)); + if let Some(dbname) = &config.dbname { + params.push(("database", &**dbname)); + } + if let Some(options) = &config.options { + params.push(("options", &**options)); + } + if let Some(application_name) = &config.application_name { + params.push(("application_name", &**application_name)); + } + + let mut buf = BytesMut::new(); + frontend::startup_message(params, &mut buf).map_err(Error::encode)?; + + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io) +} + +async fn authenticate( + stream: &mut StartupStream, + config: &Config, + user: &str, +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationOk) => { + can_skip_channel_binding(config)?; + return Ok(()); + } + Some(Message::AuthenticationCleartextPassword) => { + can_skip_channel_binding(config)?; + + let pass = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; + + authenticate_password(stream, pass).await?; + } + Some(Message::AuthenticationMd5Password(body)) => { + can_skip_channel_binding(config)?; + + let pass = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; + + let output = authentication::md5_hash(user.as_bytes(), pass, body.salt()); + authenticate_password(stream, output.as_bytes()).await?; + } + Some(Message::AuthenticationSasl(body)) => { + authenticate_sasl(stream, body, config).await?; + } + Some(Message::AuthenticationKerberosV5) + | Some(Message::AuthenticationScmCredential) + | Some(Message::AuthenticationGss) + | Some(Message::AuthenticationSspi) => { + return Err(Error::authentication( + "unsupported authentication method".into(), + )) + } + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + } + + match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationOk) => Ok(()), + Some(Message::ErrorResponse(body)) => Err(Error::db(body)), + Some(_) => Err(Error::unexpected_message()), + None => Err(Error::closed()), + } +} + +fn can_skip_channel_binding(config: &Config) -> Result<(), Error> { + match config.channel_binding { + config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()), + config::ChannelBinding::Require => Err(Error::authentication( + "server did not use channel binding".into(), + )), + } +} + +async fn authenticate_password( + stream: &mut StartupStream, + password: &[u8], +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut buf = BytesMut::new(); + frontend::password_message(password, &mut buf).map_err(Error::encode)?; + + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io) +} + +async fn authenticate_sasl( + stream: &mut StartupStream, + body: AuthenticationSaslBody, + config: &Config, +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + let password = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; + + let mut has_scram = false; + let mut has_scram_plus = false; + let mut mechanisms = body.mechanisms(); + while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? { + match mechanism { + sasl::SCRAM_SHA_256 => has_scram = true, + sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true, + _ => {} + } + } + + let channel_binding = stream + .inner + .get_ref() + .channel_binding() + .tls_server_end_point + .filter(|_| config.channel_binding != config::ChannelBinding::Disable) + .map(sasl::ChannelBinding::tls_server_end_point); + + let (channel_binding, mechanism) = if has_scram_plus { + match channel_binding { + Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS), + None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), + } + } else if has_scram { + match channel_binding { + Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256), + None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), + } + } else { + return Err(Error::authentication("unsupported SASL mechanism".into())); + }; + + if mechanism != sasl::SCRAM_SHA_256_PLUS { + can_skip_channel_binding(config)?; + } + + let mut scram = ScramSha256::new(password, channel_binding); + + let mut buf = BytesMut::new(); + frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io)?; + + let body = match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationSaslContinue(body)) => body, + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + }; + + scram + .update(body.data()) + .map_err(|e| Error::authentication(e.into()))?; + + let mut buf = BytesMut::new(); + frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?; + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io)?; + + let body = match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationSaslFinal(body)) => body, + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + }; + + scram + .finish(body.data()) + .map_err(|e| Error::authentication(e.into()))?; + + Ok(()) +} + +async fn read_info( + stream: &mut StartupStream, +) -> Result<(i32, i32, HashMap), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut process_id = 0; + let mut secret_key = 0; + let mut parameters = HashMap::new(); + + loop { + match stream.try_next().await.map_err(Error::io)? { + Some(Message::BackendKeyData(body)) => { + process_id = body.process_id(); + secret_key = body.secret_key(); + } + Some(Message::ParameterStatus(body)) => { + parameters.insert( + body.name().map_err(Error::parse)?.to_string(), + body.value().map_err(Error::parse)?.to_string(), + ); + } + Some(msg @ Message::NoticeResponse(_)) => { + stream.delayed.push_back(BackendMessage::Async(msg)) + } + Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)), + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + } + } +} diff --git a/tokio-postgres/src/connect_socket.rs b/tokio-postgres/src/connect_socket.rs new file mode 100644 index 000000000..26184701f --- /dev/null +++ b/tokio-postgres/src/connect_socket.rs @@ -0,0 +1,73 @@ +use crate::client::Addr; +use crate::keepalive::KeepaliveConfig; +use crate::{Error, Socket}; +use socket2::{SockRef, TcpKeepalive}; +use std::future::Future; +use std::io; +use std::time::Duration; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; +use tokio::time; + +pub(crate) async fn connect_socket( + addr: &Addr, + port: u16, + connect_timeout: Option, + #[cfg_attr(not(target_os = "linux"), allow(unused_variables))] tcp_user_timeout: Option< + Duration, + >, + keepalive_config: Option<&KeepaliveConfig>, +) -> Result { + match addr { + Addr::Tcp(ip) => { + let stream = + connect_with_timeout(TcpStream::connect((*ip, port)), connect_timeout).await?; + + stream.set_nodelay(true).map_err(Error::connect)?; + + let sock_ref = SockRef::from(&stream); + + #[cfg(target_os = "linux")] + if let Some(tcp_user_timeout) = tcp_user_timeout { + sock_ref + .set_tcp_user_timeout(Some(tcp_user_timeout)) + .map_err(Error::connect)?; + } + + if let Some(keepalive_config) = keepalive_config { + sock_ref + .set_tcp_keepalive(&TcpKeepalive::from(keepalive_config)) + .map_err(Error::connect)?; + } + + Ok(Socket::new_tcp(stream)) + } + #[cfg(unix)] + Addr::Unix(dir) => { + let path = dir.join(format!(".s.PGSQL.{}", port)); + let socket = connect_with_timeout(UnixStream::connect(path), connect_timeout).await?; + Ok(Socket::new_unix(socket)) + } + } +} + +async fn connect_with_timeout(connect: F, timeout: Option) -> Result +where + F: Future>, +{ + match timeout { + Some(timeout) => match time::timeout(timeout, connect).await { + Ok(Ok(socket)) => Ok(socket), + Ok(Err(e)) => Err(Error::connect(e)), + Err(_) => Err(Error::connect(io::Error::new( + io::ErrorKind::TimedOut, + "connection timed out", + ))), + }, + None => match connect.await { + Ok(socket) => Ok(socket), + Err(e) => Err(Error::connect(e)), + }, + } +} diff --git a/tokio-postgres/src/connect_tls.rs b/tokio-postgres/src/connect_tls.rs new file mode 100644 index 000000000..d220cd3b5 --- /dev/null +++ b/tokio-postgres/src/connect_tls.rs @@ -0,0 +1,59 @@ +use crate::config::{SslMode, SslNegotiation}; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::tls::private::ForcePrivateApi; +use crate::tls::TlsConnect; +use crate::Error; +use bytes::BytesMut; +use postgres_protocol::message::frontend; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +pub async fn connect_tls( + mut stream: S, + mode: SslMode, + negotiation: SslNegotiation, + tls: T, + has_hostname: bool, +) -> Result, Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + match mode { + SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), + SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { + return Ok(MaybeTlsStream::Raw(stream)) + } + SslMode::Prefer if negotiation == SslNegotiation::Direct => { + return Err(Error::tls("weak sslmode \"prefer\" may not be used with sslnegotiation=direct (use \"require\", \"verify-ca\", or \"verify-full\")".into())) + } + SslMode::Prefer | SslMode::Require => {} + } + + if negotiation == SslNegotiation::Postgres { + let mut buf = BytesMut::new(); + frontend::ssl_request(&mut buf); + stream.write_all(&buf).await.map_err(Error::io)?; + + let mut buf = [0]; + stream.read_exact(&mut buf).await.map_err(Error::io)?; + + if buf[0] != b'S' { + if SslMode::Require == mode { + return Err(Error::tls("server does not support TLS".into())); + } else { + return Ok(MaybeTlsStream::Raw(stream)); + } + } + } + + if !has_hostname { + return Err(Error::tls("no hostname provided for TLS handshake".into())); + } + + let stream = tls + .connect(stream) + .await + .map_err(|e| Error::tls(e.into()))?; + + Ok(MaybeTlsStream::Tls(stream)) +} diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs new file mode 100644 index 000000000..414335955 --- /dev/null +++ b/tokio-postgres/src/connection.rs @@ -0,0 +1,343 @@ +use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::copy_in::CopyInReceiver; +use crate::error::DbError; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::{AsyncMessage, Error, Notification}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures_channel::mpsc; +use futures_util::{ready, stream::FusedStream, Sink, Stream, StreamExt}; +use log::{info, trace}; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use std::collections::{HashMap, VecDeque}; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_util::codec::Framed; + +pub enum RequestMessages { + Single(FrontendMessage), + CopyIn(CopyInReceiver), +} + +pub struct Request { + pub messages: RequestMessages, + pub sender: mpsc::Sender, +} + +pub struct Response { + sender: mpsc::Sender, +} + +#[derive(PartialEq, Debug)] +enum State { + Active, + Terminating, + Closing, +} + +/// A connection to a PostgreSQL database. +/// +/// This is one half of what is returned when a new connection is established. It performs the actual IO with the +/// server, and should generally be spawned off onto an executor to run in the background. +/// +/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has +/// occurred, or because its associated `Client` has dropped and all outstanding work has completed. +#[must_use = "futures do nothing unless polled"] +pub struct Connection { + stream: Framed, PostgresCodec>, + parameters: HashMap, + receiver: mpsc::UnboundedReceiver, + pending_request: Option, + pending_responses: VecDeque, + responses: VecDeque, + state: State, +} + +impl Connection +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) fn new( + stream: Framed, PostgresCodec>, + pending_responses: VecDeque, + parameters: HashMap, + receiver: mpsc::UnboundedReceiver, + ) -> Connection { + Connection { + stream, + parameters, + receiver, + pending_request: None, + pending_responses, + responses: VecDeque::new(), + state: State::Active, + } + } + + fn poll_response( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + if let Some(message) = self.pending_responses.pop_front() { + trace!("retrying pending response"); + return Poll::Ready(Some(Ok(message))); + } + + Pin::new(&mut self.stream) + .poll_next(cx) + .map(|o| o.map(|r| r.map_err(Error::io))) + } + + fn poll_read(&mut self, cx: &mut Context<'_>) -> Result, Error> { + if self.state != State::Active { + trace!("poll_read: done"); + return Ok(None); + } + + loop { + let message = match self.poll_response(cx)? { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => return Err(Error::closed()), + Poll::Pending => { + trace!("poll_read: waiting on response"); + return Ok(None); + } + }; + + let (mut messages, request_complete) = match message { + BackendMessage::Async(Message::NoticeResponse(body)) => { + let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?; + return Ok(Some(AsyncMessage::Notice(error))); + } + BackendMessage::Async(Message::NotificationResponse(body)) => { + let notification = Notification { + process_id: body.process_id(), + channel: body.channel().map_err(Error::parse)?.to_string(), + payload: body.message().map_err(Error::parse)?.to_string(), + }; + return Ok(Some(AsyncMessage::Notification(notification))); + } + BackendMessage::Async(Message::ParameterStatus(body)) => { + self.parameters.insert( + body.name().map_err(Error::parse)?.to_string(), + body.value().map_err(Error::parse)?.to_string(), + ); + continue; + } + BackendMessage::Async(_) => unreachable!(), + BackendMessage::Normal { + messages, + request_complete, + } => (messages, request_complete), + }; + + let mut response = match self.responses.pop_front() { + Some(response) => response, + None => match messages.next().map_err(Error::parse)? { + Some(Message::ErrorResponse(error)) => return Err(Error::db(error)), + _ => return Err(Error::unexpected_message()), + }, + }; + + match response.sender.poll_ready(cx) { + Poll::Ready(Ok(())) => { + let _ = response.sender.start_send(messages); + if !request_complete { + self.responses.push_front(response); + } + } + Poll::Ready(Err(_)) => { + // we need to keep paging through the rest of the messages even if the receiver's hung up + if !request_complete { + self.responses.push_front(response); + } + } + Poll::Pending => { + self.responses.push_front(response); + self.pending_responses.push_back(BackendMessage::Normal { + messages, + request_complete, + }); + trace!("poll_read: waiting on sender"); + return Ok(None); + } + } + } + } + + fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(messages) = self.pending_request.take() { + trace!("retrying pending request"); + return Poll::Ready(Some(messages)); + } + + if self.receiver.is_terminated() { + return Poll::Ready(None); + } + + match self.receiver.poll_next_unpin(cx) { + Poll::Ready(Some(request)) => { + trace!("polled new request"); + self.responses.push_back(Response { + sender: request.sender, + }); + Poll::Ready(Some(request.messages)) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + + fn poll_write(&mut self, cx: &mut Context<'_>) -> Result { + loop { + if self.state == State::Closing { + trace!("poll_write: done"); + return Ok(false); + } + + if Pin::new(&mut self.stream) + .poll_ready(cx) + .map_err(Error::io)? + .is_pending() + { + trace!("poll_write: waiting on socket"); + return Ok(false); + } + + let request = match self.poll_request(cx) { + Poll::Ready(Some(request)) => request, + Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => { + trace!("poll_write: at eof, terminating"); + self.state = State::Terminating; + let mut request = BytesMut::new(); + frontend::terminate(&mut request); + RequestMessages::Single(FrontendMessage::Raw(request.freeze())) + } + Poll::Ready(None) => { + trace!( + "poll_write: at eof, pending responses {}", + self.responses.len() + ); + return Ok(true); + } + Poll::Pending => { + trace!("poll_write: waiting on request"); + return Ok(true); + } + }; + + match request { + RequestMessages::Single(request) => { + Pin::new(&mut self.stream) + .start_send(request) + .map_err(Error::io)?; + if self.state == State::Terminating { + trace!("poll_write: sent eof, closing"); + self.state = State::Closing; + } + } + RequestMessages::CopyIn(mut receiver) => { + let message = match receiver.poll_next_unpin(cx) { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => { + trace!("poll_write: finished copy_in request"); + continue; + } + Poll::Pending => { + trace!("poll_write: waiting on copy_in stream"); + self.pending_request = Some(RequestMessages::CopyIn(receiver)); + return Ok(true); + } + }; + Pin::new(&mut self.stream) + .start_send(message) + .map_err(Error::io)?; + self.pending_request = Some(RequestMessages::CopyIn(receiver)); + } + } + } + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> { + match Pin::new(&mut self.stream) + .poll_flush(cx) + .map_err(Error::io)? + { + Poll::Ready(()) => trace!("poll_flush: flushed"), + Poll::Pending => trace!("poll_flush: waiting on socket"), + } + Ok(()) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.state != State::Closing { + return Poll::Pending; + } + + match Pin::new(&mut self.stream) + .poll_close(cx) + .map_err(Error::io)? + { + Poll::Ready(()) => { + trace!("poll_shutdown: complete"); + Poll::Ready(Ok(())) + } + Poll::Pending => { + trace!("poll_shutdown: waiting on socket"); + Poll::Pending + } + } + } + + /// Returns the value of a runtime parameter for this connection. + pub fn parameter(&self, name: &str) -> Option<&str> { + self.parameters.get(name).map(|s| &**s) + } + + /// Polls for asynchronous messages from the server. + /// + /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to + /// examine those messages should use this method to drive the connection rather than its `Future` implementation. + /// + /// Return values of `None` or `Some(Err(_))` are "terminal"; callers should not invoke this method again after + /// receiving one of those values. + pub fn poll_message( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + let message = self.poll_read(cx)?; + let want_flush = self.poll_write(cx)?; + if want_flush { + self.poll_flush(cx)?; + } + match message { + Some(message) => Poll::Ready(Some(Ok(message))), + None => match self.poll_shutdown(cx) { + Poll::Ready(Ok(())) => Poll::Ready(None), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + }, + } + } +} + +impl Future for Connection +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result<(), Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while let Some(message) = ready!(self.poll_message(cx)?) { + if let AsyncMessage::Notice(notice) = message { + info!("{}: {}", notice.severity(), notice.message()); + } + } + Poll::Ready(Ok(())) + } +} diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs new file mode 100644 index 000000000..59e31fea6 --- /dev/null +++ b/tokio-postgres/src/copy_in.rs @@ -0,0 +1,226 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::query::extract_row_affected; +use crate::{query, slice_iter, Error, Statement}; +use bytes::{Buf, BufMut, BytesMut}; +use futures_channel::mpsc; +use futures_util::{future, ready, Sink, SinkExt, Stream, StreamExt}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use postgres_protocol::message::frontend::CopyData; +use std::marker::{PhantomData, PhantomPinned}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +enum CopyInMessage { + Message(FrontendMessage), + Done, +} + +pub struct CopyInReceiver { + receiver: mpsc::Receiver, + done: bool, +} + +impl CopyInReceiver { + fn new(receiver: mpsc::Receiver) -> CopyInReceiver { + CopyInReceiver { + receiver, + done: false, + } + } +} + +impl Stream for CopyInReceiver { + type Item = FrontendMessage; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.done { + return Poll::Ready(None); + } + + match ready!(self.receiver.poll_next_unpin(cx)) { + Some(CopyInMessage::Message(message)) => Poll::Ready(Some(message)), + Some(CopyInMessage::Done) => { + self.done = true; + let mut buf = BytesMut::new(); + frontend::copy_done(&mut buf); + frontend::sync(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + None => { + self.done = true; + let mut buf = BytesMut::new(); + frontend::copy_fail("", &mut buf).unwrap(); + frontend::sync(&mut buf); + Poll::Ready(Some(FrontendMessage::Raw(buf.freeze()))) + } + } + } +} + +enum SinkState { + Active, + Closing, + Reading, +} + +pin_project! { + /// A sink for `COPY ... FROM STDIN` query data. + /// + /// The copy *must* be explicitly completed via the `Sink::close` or `finish` methods. If it is + /// not, the copy will be aborted. + pub struct CopyInSink { + #[pin] + sender: mpsc::Sender, + responses: Responses, + buf: BytesMut, + state: SinkState, + #[pin] + _p: PhantomPinned, + _p2: PhantomData, + } +} + +impl CopyInSink +where + T: Buf + 'static + Send, +{ + /// A poll-based version of `finish`. + pub fn poll_finish(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.state { + SinkState::Active => { + ready!(self.as_mut().poll_flush(cx))?; + let mut this = self.as_mut().project(); + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + this.sender + .start_send(CopyInMessage::Done) + .map_err(|_| Error::closed())?; + *this.state = SinkState::Closing; + } + SinkState::Closing => { + let this = self.as_mut().project(); + ready!(this.sender.poll_close(cx)).map_err(|_| Error::closed())?; + *this.state = SinkState::Reading; + } + SinkState::Reading => { + let this = self.as_mut().project(); + match ready!(this.responses.poll_next(cx))? { + Message::CommandComplete(body) => { + let rows = extract_row_affected(&body)?; + return Poll::Ready(Ok(rows)); + } + _ => return Poll::Ready(Err(Error::unexpected_message())), + } + } + } + } + } + + /// Completes the copy, returning the number of rows inserted. + /// + /// The `Sink::close` method is equivalent to `finish`, except that it does not return the + /// number of rows. + pub async fn finish(mut self: Pin<&mut Self>) -> Result { + future::poll_fn(|cx| self.as_mut().poll_finish(cx)).await + } +} + +impl Sink for CopyInSink +where + T: Buf + 'static + Send, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project() + .sender + .poll_ready(cx) + .map_err(|_| Error::closed()) + } + + fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Error> { + let this = self.project(); + + let data: Box = if item.remaining() > 4096 { + if this.buf.is_empty() { + Box::new(item) + } else { + Box::new(this.buf.split().freeze().chain(item)) + } + } else { + this.buf.put(item); + if this.buf.len() > 4096 { + Box::new(this.buf.split().freeze()) + } else { + return Ok(()); + } + }; + + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + + if !this.buf.is_empty() { + ready!(this.sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?; + let data: Box = Box::new(this.buf.split().freeze()); + let data = CopyData::new(data).map_err(Error::encode)?; + this.sender + .as_mut() + .start_send(CopyInMessage::Message(FrontendMessage::CopyData(data))) + .map_err(|_| Error::closed())?; + } + + this.sender.poll_flush(cx).map_err(|_| Error::closed()) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_finish(cx).map_ok(|_| ()) + } +} + +pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result, Error> +where + T: Buf + 'static + Send, +{ + debug!("executing copy in statement {}", statement.name()); + + let buf = query::encode(client, &statement, slice_iter(&[]))?; + + let (mut sender, receiver) = mpsc::channel(1); + let receiver = CopyInReceiver::new(receiver); + let mut responses = client.send(RequestMessages::CopyIn(receiver))?; + + sender + .send(CopyInMessage::Message(FrontendMessage::Raw(buf))) + .await + .map_err(|_| Error::closed())?; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + match responses.next().await? { + Message::CopyInResponse(_) => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(CopyInSink { + sender, + responses, + buf: BytesMut::new(), + state: SinkState::Active, + _p: PhantomPinned, + _p2: PhantomData, + }) +} diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs new file mode 100644 index 000000000..1e6949252 --- /dev/null +++ b/tokio-postgres/src/copy_out.rs @@ -0,0 +1,62 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::{query, slice_iter, Error, Statement}; +use bytes::Bytes; +use futures_util::{ready, Stream}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub async fn copy_out(client: &InnerClient, statement: Statement) -> Result { + debug!("executing copy out statement {}", statement.name()); + + let buf = query::encode(client, &statement, slice_iter(&[]))?; + let responses = start(client, buf).await?; + Ok(CopyOutStream { + responses, + _p: PhantomPinned, + }) +} + +async fn start(client: &InnerClient, buf: Bytes) -> Result { + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + match responses.next().await? { + Message::CopyOutResponse(_) => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(responses) +} + +pin_project! { + /// A stream of `COPY ... TO STDOUT` query data. + pub struct CopyOutStream { + responses: Responses, + #[pin] + _p: PhantomPinned, + } +} + +impl Stream for CopyOutStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + match ready!(this.responses.poll_next(cx)?) { + Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), + Message::CopyDone => Poll::Ready(None), + _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } +} diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index de35c9b1a..75664d258 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -2,13 +2,13 @@ use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody}; -use std::error; +use std::error::{self, Error as _Error}; use std::fmt; use std::io; -use tokio_timer; pub use self::sqlstate::*; +#[allow(clippy::unreadable_literal)] mod sqlstate; /// The severity of a Postgres error or notice. @@ -33,7 +33,7 @@ pub enum Severity { } impl fmt::Display for Severity { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { let s = match *self { Severity::Panic => "PANIC", Severity::Fatal => "FATAL", @@ -86,7 +86,7 @@ pub struct DbError { } impl DbError { - pub(crate) fn new(fields: &mut ErrorFields) -> io::Result { + pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result { let mut severity = None; let mut parsed_severity = None; let mut code = None; @@ -107,14 +107,15 @@ impl DbError { let mut routine = None; while let Some(field) = fields.next()? { + let value = String::from_utf8_lossy(field.value_bytes()); match field.type_() { - b'S' => severity = Some(field.value().to_owned()), - b'C' => code = Some(SqlState::from_code(field.value())), - b'M' => message = Some(field.value().to_owned()), - b'D' => detail = Some(field.value().to_owned()), - b'H' => hint = Some(field.value().to_owned()), + b'S' => severity = Some(value.into_owned()), + b'C' => code = Some(SqlState::from_code(&value)), + b'M' => message = Some(value.into_owned()), + b'D' => detail = Some(value.into_owned()), + b'H' => hint = Some(value.into_owned()), b'P' => { - normal_position = Some(field.value().parse::().map_err(|_| { + normal_position = Some(value.parse::().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "`P` field did not contain an integer", @@ -122,32 +123,32 @@ impl DbError { })?); } b'p' => { - internal_position = Some(field.value().parse::().map_err(|_| { + internal_position = Some(value.parse::().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "`p` field did not contain an integer", ) })?); } - b'q' => internal_query = Some(field.value().to_owned()), - b'W' => where_ = Some(field.value().to_owned()), - b's' => schema = Some(field.value().to_owned()), - b't' => table = Some(field.value().to_owned()), - b'c' => column = Some(field.value().to_owned()), - b'd' => datatype = Some(field.value().to_owned()), - b'n' => constraint = Some(field.value().to_owned()), - b'F' => file = Some(field.value().to_owned()), + b'q' => internal_query = Some(value.into_owned()), + b'W' => where_ = Some(value.into_owned()), + b's' => schema = Some(value.into_owned()), + b't' => table = Some(value.into_owned()), + b'c' => column = Some(value.into_owned()), + b'd' => datatype = Some(value.into_owned()), + b'n' => constraint = Some(value.into_owned()), + b'F' => file = Some(value.into_owned()), b'L' => { - line = Some(field.value().parse::().map_err(|_| { + line = Some(value.parse::().map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "`L` field did not contain an integer", ) })?); } - b'R' => routine = Some(field.value().to_owned()), + b'R' => routine = Some(value.into_owned()), b'V' => { - parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| { + parsed_severity = Some(Severity::from_str(&value).ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "`V` field contained an invalid value", @@ -161,18 +162,18 @@ impl DbError { Ok(DbError { severity: severity .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?, - parsed_severity: parsed_severity, + parsed_severity, code: code .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?, message: message .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?, - detail: detail, - hint: hint, + detail, + hint, position: match normal_position { Some(position) => Some(ErrorPosition::Original(position)), None => match internal_position { Some(position) => Some(ErrorPosition::Internal { - position: position, + position, query: internal_query.ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, @@ -183,15 +184,15 @@ impl DbError { None => None, }, }, - where_: where_, - schema: schema, - table: table, - column: column, - datatype: datatype, - constraint: constraint, - file: file, - line: line, - routine: routine, + where_, + schema, + table, + column, + datatype, + constraint, + file, + line, + routine, }) } @@ -224,7 +225,7 @@ impl DbError { /// /// Might run to multiple lines. pub fn detail(&self) -> Option<&str> { - self.detail.as_ref().map(|s| &**s) + self.detail.as_deref() } /// An optional suggestion what to do about the problem. @@ -233,7 +234,7 @@ impl DbError { /// (potentially inappropriate) rather than hard facts. Might run to /// multiple lines. pub fn hint(&self) -> Option<&str> { - self.hint.as_ref().map(|s| &**s) + self.hint.as_deref() } /// An optional error cursor position into either the original query string @@ -248,20 +249,20 @@ impl DbError { /// language functions and internally-generated queries. The trace is one /// entry per line, most recent first. pub fn where_(&self) -> Option<&str> { - self.where_.as_ref().map(|s| &**s) + self.where_.as_deref() } /// If the error was associated with a specific database object, the name /// of the schema containing that object, if any. (PostgreSQL 9.3+) pub fn schema(&self) -> Option<&str> { - self.schema.as_ref().map(|s| &**s) + self.schema.as_deref() } /// If the error was associated with a specific table, the name of the /// table. (Refer to the schema name field for the name of the table's /// schema.) (PostgreSQL 9.3+) pub fn table(&self) -> Option<&str> { - self.table.as_ref().map(|s| &**s) + self.table.as_deref() } /// If the error was associated with a specific table column, the name of @@ -270,14 +271,14 @@ impl DbError { /// (Refer to the schema and table name fields to identify the table.) /// (PostgreSQL 9.3+) pub fn column(&self) -> Option<&str> { - self.column.as_ref().map(|s| &**s) + self.column.as_deref() } /// If the error was associated with a specific data type, the name of the /// data type. (Refer to the schema name field for the name of the data /// type's schema.) (PostgreSQL 9.3+) pub fn datatype(&self) -> Option<&str> { - self.datatype.as_ref().map(|s| &**s) + self.datatype.as_deref() } /// If the error was associated with a specific constraint, the name of the @@ -287,12 +288,12 @@ impl DbError { /// (For this purpose, indexes are treated as constraints, even if they /// weren't created with constraint syntax.) (PostgreSQL 9.3+) pub fn constraint(&self) -> Option<&str> { - self.constraint.as_ref().map(|s| &**s) + self.constraint.as_deref() } /// The file name of the source-code location where the error was reported. pub fn file(&self) -> Option<&str> { - self.file.as_ref().map(|s| &**s) + self.file.as_deref() } /// The line number of the source-code location where the error was @@ -303,21 +304,24 @@ impl DbError { /// The name of the source-code routine reporting the error. pub fn routine(&self) -> Option<&str> { - self.routine.as_ref().map(|s| &**s) + self.routine.as_deref() } } impl fmt::Display for DbError { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "{}: {}", self.severity, self.message) + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{}: {}", self.severity, self.message)?; + if let Some(detail) = &self.detail { + write!(fmt, "\nDETAIL: {}", detail)?; + } + if let Some(hint) = &self.hint { + write!(fmt, "\nHINT: {}", hint)?; + } + Ok(()) } } -impl error::Error for DbError { - fn description(&self) -> &str { - &self.message - } -} +impl error::Error for DbError {} /// Represents the position of an error in a query. #[derive(Clone, PartialEq, Eq, Debug)] @@ -338,31 +342,33 @@ enum Kind { Io, UnexpectedMessage, Tls, - ToSql, - FromSql, - CopyInStream, + ToSql(usize), + FromSql(usize), + Column(String), + Parameters(usize, usize), Closed, Db, Parse, Encode, - MissingUser, - MissingPassword, - UnsupportedAuthentication, - Connect, - Timer, Authentication, + ConfigParse, + Config, + RowCount, + #[cfg(feature = "runtime")] + Connect, + Timeout, } struct ErrorInner { kind: Kind, - cause: Option>, + cause: Option>, } /// An error communicating with the Postgres server. pub struct Error(Box); impl fmt::Debug for Error { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Error") .field("kind", &self.0.kind) .field("cause", &self.0.cause) @@ -371,8 +377,29 @@ impl fmt::Debug for Error { } impl fmt::Display for Error { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - fmt.write_str(error::Error::description(self))?; + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0.kind { + Kind::Io => fmt.write_str("error communicating with the server")?, + Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?, + Kind::Tls => fmt.write_str("error performing TLS handshake")?, + Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?, + Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?, + Kind::Column(column) => write!(fmt, "invalid column `{}`", column)?, + Kind::Parameters(real, expected) => { + write!(fmt, "expected {expected} parameters but got {real}")? + } + Kind::Closed => fmt.write_str("connection closed")?, + Kind::Db => fmt.write_str("db error")?, + Kind::Parse => fmt.write_str("error parsing response from server")?, + Kind::Encode => fmt.write_str("error encoding message to server")?, + Kind::Authentication => fmt.write_str("authentication error")?, + Kind::ConfigParse => fmt.write_str("invalid connection string")?, + Kind::Config => fmt.write_str("invalid configuration")?, + Kind::RowCount => fmt.write_str("query returned an unexpected number of rows")?, + #[cfg(feature = "runtime")] + Kind::Connect => fmt.write_str("error connecting to server")?, + Kind::Timeout => fmt.write_str("timeout waiting for server")?, + }; if let Some(ref cause) = self.0.cause { write!(fmt, ": {}", cause)?; } @@ -381,57 +408,37 @@ impl fmt::Display for Error { } impl error::Error for Error { - fn description(&self) -> &str { - match self.0.kind { - Kind::Io => "error communicating with the server", - Kind::UnexpectedMessage => "unexpected message from server", - Kind::Tls => "error performing TLS handshake", - Kind::ToSql => "error serializing a value", - Kind::FromSql => "error deserializing a value", - Kind::CopyInStream => "error from a copy_in stream", - Kind::Closed => "connection closed", - Kind::Db => "db error", - Kind::Parse => "error parsing response from server", - Kind::Encode => "error encoding message to server", - Kind::MissingUser => "username not provided", - Kind::MissingPassword => "password not provided", - Kind::UnsupportedAuthentication => "unsupported authentication method requested", - Kind::Connect => "error connecting to server", - Kind::Timer => "timer error", - Kind::Authentication => "authentication error", - } - } - - fn cause(&self) -> Option<&error::Error> { - self.0.cause.as_ref().map(|e| &**e as &error::Error) + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + self.0.cause.as_ref().map(|e| &**e as _) } } impl Error { - /// Returns the error's cause. + /// Consumes the error, returning its cause. + pub fn into_source(self) -> Option> { + self.0.cause + } + + /// Returns the source of this error if it was a `DbError`. /// - /// This is the same as `Error::cause` except that it provides extra bounds - /// required to be able to downcast the error. - pub fn cause2(&self) -> Option<&(error::Error + 'static + Sync + Send)> { - self.0.cause.as_ref().map(|e| &**e) + /// This is a simple convenience method. + pub fn as_db_error(&self) -> Option<&DbError> { + self.source().and_then(|e| e.downcast_ref::()) } - /// Consumes the error, returning its cause. - pub fn into_cause(self) -> Option> { - self.0.cause + /// Determines if the error was associated with closed connection. + pub fn is_closed(&self) -> bool { + self.0.kind == Kind::Closed } /// Returns the SQLSTATE error code associated with the error. /// - /// This is a convenience method that downcasts the cause to a `DbError` - /// and returns its code. + /// This is a convenience method that downcasts the cause to a `DbError` and returns its code. pub fn code(&self) -> Option<&SqlState> { - self.cause2() - .and_then(|e| e.downcast_ref::()) - .map(|e| e.code()) + self.as_db_error().map(DbError::code) } - fn new(kind: Kind, cause: Option>) -> Error { + fn new(kind: Kind, cause: Option>) -> Error { Error(Box::new(ErrorInner { kind, cause })) } @@ -443,8 +450,9 @@ impl Error { Error::new(Kind::UnexpectedMessage, None) } + #[allow(clippy::needless_pass_by_value)] pub(crate) fn db(error: ErrorResponseBody) -> Error { - match DbError::new(&mut error.fields()) { + match DbError::parse(&mut error.fields()) { Ok(e) => Error::new(Kind::Db, Some(Box::new(e))), Err(e) => Error::new(Kind::Parse, Some(Box::new(e))), } @@ -458,50 +466,54 @@ impl Error { Error::new(Kind::Encode, Some(Box::new(e))) } - pub(crate) fn to_sql(e: Box) -> Error { - Error::new(Kind::ToSql, Some(e)) + #[allow(clippy::wrong_self_convention)] + pub(crate) fn to_sql(e: Box, idx: usize) -> Error { + Error::new(Kind::ToSql(idx), Some(e)) } - pub(crate) fn from_sql(e: Box) -> Error { - Error::new(Kind::FromSql, Some(e)) + pub(crate) fn from_sql(e: Box, idx: usize) -> Error { + Error::new(Kind::FromSql(idx), Some(e)) } - pub(crate) fn copy_in_stream(e: E) -> Error - where - E: Into>, - { - Error::new(Kind::CopyInStream, Some(e.into())) + pub(crate) fn column(column: String) -> Error { + Error::new(Kind::Column(column), None) } - pub(crate) fn missing_user() -> Error { - Error::new(Kind::MissingUser, None) + pub(crate) fn parameters(real: usize, expected: usize) -> Error { + Error::new(Kind::Parameters(real, expected), None) } - pub(crate) fn missing_password() -> Error { - Error::new(Kind::MissingPassword, None) + pub(crate) fn tls(e: Box) -> Error { + Error::new(Kind::Tls, Some(e)) } - pub(crate) fn unsupported_authentication() -> Error { - Error::new(Kind::UnsupportedAuthentication, None) + pub(crate) fn io(e: io::Error) -> Error { + Error::new(Kind::Io, Some(Box::new(e))) } - pub(crate) fn tls(e: Box) -> Error { - Error::new(Kind::Tls, Some(e)) + pub(crate) fn authentication(e: Box) -> Error { + Error::new(Kind::Authentication, Some(e)) } - pub(crate) fn connect(e: io::Error) -> Error { - Error::new(Kind::Connect, Some(Box::new(e))) + pub(crate) fn config_parse(e: Box) -> Error { + Error::new(Kind::ConfigParse, Some(e)) } - pub(crate) fn timer(e: tokio_timer::Error) -> Error { - Error::new(Kind::Timer, Some(Box::new(e))) + pub(crate) fn config(e: Box) -> Error { + Error::new(Kind::Config, Some(e)) } - pub(crate) fn io(e: io::Error) -> Error { - Error::new(Kind::Io, Some(Box::new(e))) + pub(crate) fn row_count() -> Error { + Error::new(Kind::RowCount, None) + } + + #[cfg(feature = "runtime")] + pub(crate) fn connect(e: io::Error) -> Error { + Error::new(Kind::Connect, Some(Box::new(e))) } - pub(crate) fn authentication(e: io::Error) -> Error { - Error::new(Kind::Authentication, Some(Box::new(e))) + #[doc(hidden)] + pub fn __private_api_timeout() -> Error { + Error::new(Kind::Timeout, None) } } diff --git a/tokio-postgres/src/error/sqlstate.rs b/tokio-postgres/src/error/sqlstate.rs index c8e3ec2eb..13a1d75f9 100644 --- a/tokio-postgres/src/error/sqlstate.rs +++ b/tokio-postgres/src/error/sqlstate.rs @@ -1,1061 +1,1670 @@ // Autogenerated file - DO NOT EDIT -use phf; -use std::borrow::Cow; /// A SQLSTATE error code #[derive(PartialEq, Eq, Clone, Debug)] -pub struct SqlState(Cow<'static, str>); +pub struct SqlState(Inner); impl SqlState { /// Creates a `SqlState` from its error code. pub fn from_code(s: &str) -> SqlState { match SQLSTATE_MAP.get(s) { Some(state) => state.clone(), - None => SqlState(Cow::Owned(s.to_string())), + None => SqlState(Inner::Other(s.into())), } } /// Returns the error code corresponding to the `SqlState`. pub fn code(&self) -> &str { - &self.0 + match &self.0 { + Inner::E00000 => "00000", + Inner::E01000 => "01000", + Inner::E0100C => "0100C", + Inner::E01008 => "01008", + Inner::E01003 => "01003", + Inner::E01007 => "01007", + Inner::E01006 => "01006", + Inner::E01004 => "01004", + Inner::E01P01 => "01P01", + Inner::E02000 => "02000", + Inner::E02001 => "02001", + Inner::E03000 => "03000", + Inner::E08000 => "08000", + Inner::E08003 => "08003", + Inner::E08006 => "08006", + Inner::E08001 => "08001", + Inner::E08004 => "08004", + Inner::E08007 => "08007", + Inner::E08P01 => "08P01", + Inner::E09000 => "09000", + Inner::E0A000 => "0A000", + Inner::E0B000 => "0B000", + Inner::E0F000 => "0F000", + Inner::E0F001 => "0F001", + Inner::E0L000 => "0L000", + Inner::E0LP01 => "0LP01", + Inner::E0P000 => "0P000", + Inner::E0Z000 => "0Z000", + Inner::E0Z002 => "0Z002", + Inner::E20000 => "20000", + Inner::E21000 => "21000", + Inner::E22000 => "22000", + Inner::E2202E => "2202E", + Inner::E22021 => "22021", + Inner::E22008 => "22008", + Inner::E22012 => "22012", + Inner::E22005 => "22005", + Inner::E2200B => "2200B", + Inner::E22022 => "22022", + Inner::E22015 => "22015", + Inner::E2201E => "2201E", + Inner::E22014 => "22014", + Inner::E22016 => "22016", + Inner::E2201F => "2201F", + Inner::E2201G => "2201G", + Inner::E22018 => "22018", + Inner::E22007 => "22007", + Inner::E22019 => "22019", + Inner::E2200D => "2200D", + Inner::E22025 => "22025", + Inner::E22P06 => "22P06", + Inner::E22010 => "22010", + Inner::E22023 => "22023", + Inner::E22013 => "22013", + Inner::E2201B => "2201B", + Inner::E2201W => "2201W", + Inner::E2201X => "2201X", + Inner::E2202H => "2202H", + Inner::E2202G => "2202G", + Inner::E22009 => "22009", + Inner::E2200C => "2200C", + Inner::E2200G => "2200G", + Inner::E22004 => "22004", + Inner::E22002 => "22002", + Inner::E22003 => "22003", + Inner::E2200H => "2200H", + Inner::E22026 => "22026", + Inner::E22001 => "22001", + Inner::E22011 => "22011", + Inner::E22027 => "22027", + Inner::E22024 => "22024", + Inner::E2200F => "2200F", + Inner::E22P01 => "22P01", + Inner::E22P02 => "22P02", + Inner::E22P03 => "22P03", + Inner::E22P04 => "22P04", + Inner::E22P05 => "22P05", + Inner::E2200L => "2200L", + Inner::E2200M => "2200M", + Inner::E2200N => "2200N", + Inner::E2200S => "2200S", + Inner::E2200T => "2200T", + Inner::E22030 => "22030", + Inner::E22031 => "22031", + Inner::E22032 => "22032", + Inner::E22033 => "22033", + Inner::E22034 => "22034", + Inner::E22035 => "22035", + Inner::E22036 => "22036", + Inner::E22037 => "22037", + Inner::E22038 => "22038", + Inner::E22039 => "22039", + Inner::E2203A => "2203A", + Inner::E2203B => "2203B", + Inner::E2203C => "2203C", + Inner::E2203D => "2203D", + Inner::E2203E => "2203E", + Inner::E2203F => "2203F", + Inner::E2203G => "2203G", + Inner::E23000 => "23000", + Inner::E23001 => "23001", + Inner::E23502 => "23502", + Inner::E23503 => "23503", + Inner::E23505 => "23505", + Inner::E23514 => "23514", + Inner::E23P01 => "23P01", + Inner::E24000 => "24000", + Inner::E25000 => "25000", + Inner::E25001 => "25001", + Inner::E25002 => "25002", + Inner::E25008 => "25008", + Inner::E25003 => "25003", + Inner::E25004 => "25004", + Inner::E25005 => "25005", + Inner::E25006 => "25006", + Inner::E25007 => "25007", + Inner::E25P01 => "25P01", + Inner::E25P02 => "25P02", + Inner::E25P03 => "25P03", + Inner::E26000 => "26000", + Inner::E27000 => "27000", + Inner::E28000 => "28000", + Inner::E28P01 => "28P01", + Inner::E2B000 => "2B000", + Inner::E2BP01 => "2BP01", + Inner::E2D000 => "2D000", + Inner::E2F000 => "2F000", + Inner::E2F005 => "2F005", + Inner::E2F002 => "2F002", + Inner::E2F003 => "2F003", + Inner::E2F004 => "2F004", + Inner::E34000 => "34000", + Inner::E38000 => "38000", + Inner::E38001 => "38001", + Inner::E38002 => "38002", + Inner::E38003 => "38003", + Inner::E38004 => "38004", + Inner::E39000 => "39000", + Inner::E39001 => "39001", + Inner::E39004 => "39004", + Inner::E39P01 => "39P01", + Inner::E39P02 => "39P02", + Inner::E39P03 => "39P03", + Inner::E3B000 => "3B000", + Inner::E3B001 => "3B001", + Inner::E3D000 => "3D000", + Inner::E3F000 => "3F000", + Inner::E40000 => "40000", + Inner::E40002 => "40002", + Inner::E40001 => "40001", + Inner::E40003 => "40003", + Inner::E40P01 => "40P01", + Inner::E42000 => "42000", + Inner::E42601 => "42601", + Inner::E42501 => "42501", + Inner::E42846 => "42846", + Inner::E42803 => "42803", + Inner::E42P20 => "42P20", + Inner::E42P19 => "42P19", + Inner::E42830 => "42830", + Inner::E42602 => "42602", + Inner::E42622 => "42622", + Inner::E42939 => "42939", + Inner::E42804 => "42804", + Inner::E42P18 => "42P18", + Inner::E42P21 => "42P21", + Inner::E42P22 => "42P22", + Inner::E42809 => "42809", + Inner::E428C9 => "428C9", + Inner::E42703 => "42703", + Inner::E42883 => "42883", + Inner::E42P01 => "42P01", + Inner::E42P02 => "42P02", + Inner::E42704 => "42704", + Inner::E42701 => "42701", + Inner::E42P03 => "42P03", + Inner::E42P04 => "42P04", + Inner::E42723 => "42723", + Inner::E42P05 => "42P05", + Inner::E42P06 => "42P06", + Inner::E42P07 => "42P07", + Inner::E42712 => "42712", + Inner::E42710 => "42710", + Inner::E42702 => "42702", + Inner::E42725 => "42725", + Inner::E42P08 => "42P08", + Inner::E42P09 => "42P09", + Inner::E42P10 => "42P10", + Inner::E42611 => "42611", + Inner::E42P11 => "42P11", + Inner::E42P12 => "42P12", + Inner::E42P13 => "42P13", + Inner::E42P14 => "42P14", + Inner::E42P15 => "42P15", + Inner::E42P16 => "42P16", + Inner::E42P17 => "42P17", + Inner::E44000 => "44000", + Inner::E53000 => "53000", + Inner::E53100 => "53100", + Inner::E53200 => "53200", + Inner::E53300 => "53300", + Inner::E53400 => "53400", + Inner::E54000 => "54000", + Inner::E54001 => "54001", + Inner::E54011 => "54011", + Inner::E54023 => "54023", + Inner::E55000 => "55000", + Inner::E55006 => "55006", + Inner::E55P02 => "55P02", + Inner::E55P03 => "55P03", + Inner::E55P04 => "55P04", + Inner::E57000 => "57000", + Inner::E57014 => "57014", + Inner::E57P01 => "57P01", + Inner::E57P02 => "57P02", + Inner::E57P03 => "57P03", + Inner::E57P04 => "57P04", + Inner::E57P05 => "57P05", + Inner::E58000 => "58000", + Inner::E58030 => "58030", + Inner::E58P01 => "58P01", + Inner::E58P02 => "58P02", + Inner::E72000 => "72000", + Inner::EF0000 => "F0000", + Inner::EF0001 => "F0001", + Inner::EHV000 => "HV000", + Inner::EHV005 => "HV005", + Inner::EHV002 => "HV002", + Inner::EHV010 => "HV010", + Inner::EHV021 => "HV021", + Inner::EHV024 => "HV024", + Inner::EHV007 => "HV007", + Inner::EHV008 => "HV008", + Inner::EHV004 => "HV004", + Inner::EHV006 => "HV006", + Inner::EHV091 => "HV091", + Inner::EHV00B => "HV00B", + Inner::EHV00C => "HV00C", + Inner::EHV00D => "HV00D", + Inner::EHV090 => "HV090", + Inner::EHV00A => "HV00A", + Inner::EHV009 => "HV009", + Inner::EHV014 => "HV014", + Inner::EHV001 => "HV001", + Inner::EHV00P => "HV00P", + Inner::EHV00J => "HV00J", + Inner::EHV00K => "HV00K", + Inner::EHV00Q => "HV00Q", + Inner::EHV00R => "HV00R", + Inner::EHV00L => "HV00L", + Inner::EHV00M => "HV00M", + Inner::EHV00N => "HV00N", + Inner::EP0000 => "P0000", + Inner::EP0001 => "P0001", + Inner::EP0002 => "P0002", + Inner::EP0003 => "P0003", + Inner::EP0004 => "P0004", + Inner::EXX000 => "XX000", + Inner::EXX001 => "XX001", + Inner::EXX002 => "XX002", + Inner::Other(code) => code, + } } /// 00000 - pub const SUCCESSFUL_COMPLETION: SqlState = SqlState(Cow::Borrowed("00000")); + pub const SUCCESSFUL_COMPLETION: SqlState = SqlState(Inner::E00000); /// 01000 - pub const WARNING: SqlState = SqlState(Cow::Borrowed("01000")); + pub const WARNING: SqlState = SqlState(Inner::E01000); /// 0100C - pub const WARNING_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Cow::Borrowed("0100C")); + pub const WARNING_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E0100C); /// 01008 - pub const WARNING_IMPLICIT_ZERO_BIT_PADDING: SqlState = SqlState(Cow::Borrowed("01008")); + pub const WARNING_IMPLICIT_ZERO_BIT_PADDING: SqlState = SqlState(Inner::E01008); /// 01003 - pub const WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: SqlState = SqlState(Cow::Borrowed("01003")); + pub const WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: SqlState = SqlState(Inner::E01003); /// 01007 - pub const WARNING_PRIVILEGE_NOT_GRANTED: SqlState = SqlState(Cow::Borrowed("01007")); + pub const WARNING_PRIVILEGE_NOT_GRANTED: SqlState = SqlState(Inner::E01007); /// 01006 - pub const WARNING_PRIVILEGE_NOT_REVOKED: SqlState = SqlState(Cow::Borrowed("01006")); + pub const WARNING_PRIVILEGE_NOT_REVOKED: SqlState = SqlState(Inner::E01006); /// 01004 - pub const WARNING_STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Cow::Borrowed("01004")); + pub const WARNING_STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E01004); /// 01P01 - pub const WARNING_DEPRECATED_FEATURE: SqlState = SqlState(Cow::Borrowed("01P01")); + pub const WARNING_DEPRECATED_FEATURE: SqlState = SqlState(Inner::E01P01); /// 02000 - pub const NO_DATA: SqlState = SqlState(Cow::Borrowed("02000")); + pub const NO_DATA: SqlState = SqlState(Inner::E02000); /// 02001 - pub const NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Cow::Borrowed("02001")); + pub const NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E02001); /// 03000 - pub const SQL_STATEMENT_NOT_YET_COMPLETE: SqlState = SqlState(Cow::Borrowed("03000")); + pub const SQL_STATEMENT_NOT_YET_COMPLETE: SqlState = SqlState(Inner::E03000); /// 08000 - pub const CONNECTION_EXCEPTION: SqlState = SqlState(Cow::Borrowed("08000")); + pub const CONNECTION_EXCEPTION: SqlState = SqlState(Inner::E08000); /// 08003 - pub const CONNECTION_DOES_NOT_EXIST: SqlState = SqlState(Cow::Borrowed("08003")); + pub const CONNECTION_DOES_NOT_EXIST: SqlState = SqlState(Inner::E08003); /// 08006 - pub const CONNECTION_FAILURE: SqlState = SqlState(Cow::Borrowed("08006")); + pub const CONNECTION_FAILURE: SqlState = SqlState(Inner::E08006); /// 08001 - pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: SqlState = SqlState(Cow::Borrowed("08001")); + pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: SqlState = SqlState(Inner::E08001); /// 08004 - pub const SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: SqlState = SqlState(Cow::Borrowed("08004")); + pub const SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: SqlState = SqlState(Inner::E08004); /// 08007 - pub const TRANSACTION_RESOLUTION_UNKNOWN: SqlState = SqlState(Cow::Borrowed("08007")); + pub const TRANSACTION_RESOLUTION_UNKNOWN: SqlState = SqlState(Inner::E08007); /// 08P01 - pub const PROTOCOL_VIOLATION: SqlState = SqlState(Cow::Borrowed("08P01")); + pub const PROTOCOL_VIOLATION: SqlState = SqlState(Inner::E08P01); /// 09000 - pub const TRIGGERED_ACTION_EXCEPTION: SqlState = SqlState(Cow::Borrowed("09000")); + pub const TRIGGERED_ACTION_EXCEPTION: SqlState = SqlState(Inner::E09000); /// 0A000 - pub const FEATURE_NOT_SUPPORTED: SqlState = SqlState(Cow::Borrowed("0A000")); + pub const FEATURE_NOT_SUPPORTED: SqlState = SqlState(Inner::E0A000); /// 0B000 - pub const INVALID_TRANSACTION_INITIATION: SqlState = SqlState(Cow::Borrowed("0B000")); + pub const INVALID_TRANSACTION_INITIATION: SqlState = SqlState(Inner::E0B000); /// 0F000 - pub const LOCATOR_EXCEPTION: SqlState = SqlState(Cow::Borrowed("0F000")); + pub const LOCATOR_EXCEPTION: SqlState = SqlState(Inner::E0F000); /// 0F001 - pub const L_E_INVALID_SPECIFICATION: SqlState = SqlState(Cow::Borrowed("0F001")); + pub const L_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E0F001); /// 0L000 - pub const INVALID_GRANTOR: SqlState = SqlState(Cow::Borrowed("0L000")); + pub const INVALID_GRANTOR: SqlState = SqlState(Inner::E0L000); /// 0LP01 - pub const INVALID_GRANT_OPERATION: SqlState = SqlState(Cow::Borrowed("0LP01")); + pub const INVALID_GRANT_OPERATION: SqlState = SqlState(Inner::E0LP01); /// 0P000 - pub const INVALID_ROLE_SPECIFICATION: SqlState = SqlState(Cow::Borrowed("0P000")); + pub const INVALID_ROLE_SPECIFICATION: SqlState = SqlState(Inner::E0P000); /// 0Z000 - pub const DIAGNOSTICS_EXCEPTION: SqlState = SqlState(Cow::Borrowed("0Z000")); + pub const DIAGNOSTICS_EXCEPTION: SqlState = SqlState(Inner::E0Z000); /// 0Z002 - pub const STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: SqlState = SqlState(Cow::Borrowed("0Z002")); + pub const STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: SqlState = + SqlState(Inner::E0Z002); /// 20000 - pub const CASE_NOT_FOUND: SqlState = SqlState(Cow::Borrowed("20000")); + pub const CASE_NOT_FOUND: SqlState = SqlState(Inner::E20000); /// 21000 - pub const CARDINALITY_VIOLATION: SqlState = SqlState(Cow::Borrowed("21000")); + pub const CARDINALITY_VIOLATION: SqlState = SqlState(Inner::E21000); /// 22000 - pub const DATA_EXCEPTION: SqlState = SqlState(Cow::Borrowed("22000")); + pub const DATA_EXCEPTION: SqlState = SqlState(Inner::E22000); /// 2202E - pub const ARRAY_ELEMENT_ERROR: SqlState = SqlState(Cow::Borrowed("2202E")); + pub const ARRAY_ELEMENT_ERROR: SqlState = SqlState(Inner::E2202E); /// 2202E - pub const ARRAY_SUBSCRIPT_ERROR: SqlState = SqlState(Cow::Borrowed("2202E")); + pub const ARRAY_SUBSCRIPT_ERROR: SqlState = SqlState(Inner::E2202E); /// 22021 - pub const CHARACTER_NOT_IN_REPERTOIRE: SqlState = SqlState(Cow::Borrowed("22021")); + pub const CHARACTER_NOT_IN_REPERTOIRE: SqlState = SqlState(Inner::E22021); /// 22008 - pub const DATETIME_FIELD_OVERFLOW: SqlState = SqlState(Cow::Borrowed("22008")); + pub const DATETIME_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22008); /// 22008 - pub const DATETIME_VALUE_OUT_OF_RANGE: SqlState = SqlState(Cow::Borrowed("22008")); + pub const DATETIME_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22008); /// 22012 - pub const DIVISION_BY_ZERO: SqlState = SqlState(Cow::Borrowed("22012")); + pub const DIVISION_BY_ZERO: SqlState = SqlState(Inner::E22012); /// 22005 - pub const ERROR_IN_ASSIGNMENT: SqlState = SqlState(Cow::Borrowed("22005")); + pub const ERROR_IN_ASSIGNMENT: SqlState = SqlState(Inner::E22005); /// 2200B - pub const ESCAPE_CHARACTER_CONFLICT: SqlState = SqlState(Cow::Borrowed("2200B")); + pub const ESCAPE_CHARACTER_CONFLICT: SqlState = SqlState(Inner::E2200B); /// 22022 - pub const INDICATOR_OVERFLOW: SqlState = SqlState(Cow::Borrowed("22022")); + pub const INDICATOR_OVERFLOW: SqlState = SqlState(Inner::E22022); /// 22015 - pub const INTERVAL_FIELD_OVERFLOW: SqlState = SqlState(Cow::Borrowed("22015")); + pub const INTERVAL_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22015); /// 2201E - pub const INVALID_ARGUMENT_FOR_LOG: SqlState = SqlState(Cow::Borrowed("2201E")); + pub const INVALID_ARGUMENT_FOR_LOG: SqlState = SqlState(Inner::E2201E); /// 22014 - pub const INVALID_ARGUMENT_FOR_NTILE: SqlState = SqlState(Cow::Borrowed("22014")); + pub const INVALID_ARGUMENT_FOR_NTILE: SqlState = SqlState(Inner::E22014); /// 22016 - pub const INVALID_ARGUMENT_FOR_NTH_VALUE: SqlState = SqlState(Cow::Borrowed("22016")); + pub const INVALID_ARGUMENT_FOR_NTH_VALUE: SqlState = SqlState(Inner::E22016); /// 2201F - pub const INVALID_ARGUMENT_FOR_POWER_FUNCTION: SqlState = SqlState(Cow::Borrowed("2201F")); + pub const INVALID_ARGUMENT_FOR_POWER_FUNCTION: SqlState = SqlState(Inner::E2201F); /// 2201G - pub const INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: SqlState = SqlState(Cow::Borrowed("2201G")); + pub const INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: SqlState = SqlState(Inner::E2201G); /// 22018 - pub const INVALID_CHARACTER_VALUE_FOR_CAST: SqlState = SqlState(Cow::Borrowed("22018")); + pub const INVALID_CHARACTER_VALUE_FOR_CAST: SqlState = SqlState(Inner::E22018); /// 22007 - pub const INVALID_DATETIME_FORMAT: SqlState = SqlState(Cow::Borrowed("22007")); + pub const INVALID_DATETIME_FORMAT: SqlState = SqlState(Inner::E22007); /// 22019 - pub const INVALID_ESCAPE_CHARACTER: SqlState = SqlState(Cow::Borrowed("22019")); + pub const INVALID_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22019); /// 2200D - pub const INVALID_ESCAPE_OCTET: SqlState = SqlState(Cow::Borrowed("2200D")); + pub const INVALID_ESCAPE_OCTET: SqlState = SqlState(Inner::E2200D); /// 22025 - pub const INVALID_ESCAPE_SEQUENCE: SqlState = SqlState(Cow::Borrowed("22025")); + pub const INVALID_ESCAPE_SEQUENCE: SqlState = SqlState(Inner::E22025); /// 22P06 - pub const NONSTANDARD_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Cow::Borrowed("22P06")); + pub const NONSTANDARD_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22P06); /// 22010 - pub const INVALID_INDICATOR_PARAMETER_VALUE: SqlState = SqlState(Cow::Borrowed("22010")); + pub const INVALID_INDICATOR_PARAMETER_VALUE: SqlState = SqlState(Inner::E22010); /// 22023 - pub const INVALID_PARAMETER_VALUE: SqlState = SqlState(Cow::Borrowed("22023")); + pub const INVALID_PARAMETER_VALUE: SqlState = SqlState(Inner::E22023); + + /// 22013 + pub const INVALID_PRECEDING_OR_FOLLOWING_SIZE: SqlState = SqlState(Inner::E22013); /// 2201B - pub const INVALID_REGULAR_EXPRESSION: SqlState = SqlState(Cow::Borrowed("2201B")); + pub const INVALID_REGULAR_EXPRESSION: SqlState = SqlState(Inner::E2201B); /// 2201W - pub const INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: SqlState = SqlState(Cow::Borrowed("2201W")); + pub const INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: SqlState = SqlState(Inner::E2201W); /// 2201X - pub const INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: SqlState = SqlState(Cow::Borrowed("2201X")); + pub const INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: SqlState = SqlState(Inner::E2201X); /// 2202H - pub const INVALID_TABLESAMPLE_ARGUMENT: SqlState = SqlState(Cow::Borrowed("2202H")); + pub const INVALID_TABLESAMPLE_ARGUMENT: SqlState = SqlState(Inner::E2202H); /// 2202G - pub const INVALID_TABLESAMPLE_REPEAT: SqlState = SqlState(Cow::Borrowed("2202G")); + pub const INVALID_TABLESAMPLE_REPEAT: SqlState = SqlState(Inner::E2202G); /// 22009 - pub const INVALID_TIME_ZONE_DISPLACEMENT_VALUE: SqlState = SqlState(Cow::Borrowed("22009")); + pub const INVALID_TIME_ZONE_DISPLACEMENT_VALUE: SqlState = SqlState(Inner::E22009); /// 2200C - pub const INVALID_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Cow::Borrowed("2200C")); + pub const INVALID_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E2200C); /// 2200G - pub const MOST_SPECIFIC_TYPE_MISMATCH: SqlState = SqlState(Cow::Borrowed("2200G")); + pub const MOST_SPECIFIC_TYPE_MISMATCH: SqlState = SqlState(Inner::E2200G); /// 22004 - pub const NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Cow::Borrowed("22004")); + pub const NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E22004); /// 22002 - pub const NULL_VALUE_NO_INDICATOR_PARAMETER: SqlState = SqlState(Cow::Borrowed("22002")); + pub const NULL_VALUE_NO_INDICATOR_PARAMETER: SqlState = SqlState(Inner::E22002); /// 22003 - pub const NUMERIC_VALUE_OUT_OF_RANGE: SqlState = SqlState(Cow::Borrowed("22003")); + pub const NUMERIC_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22003); /// 2200H - pub const SEQUENCE_GENERATOR_LIMIT_EXCEEDED: SqlState = SqlState(Cow::Borrowed("2200H")); + pub const SEQUENCE_GENERATOR_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E2200H); /// 22026 - pub const STRING_DATA_LENGTH_MISMATCH: SqlState = SqlState(Cow::Borrowed("22026")); + pub const STRING_DATA_LENGTH_MISMATCH: SqlState = SqlState(Inner::E22026); /// 22001 - pub const STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Cow::Borrowed("22001")); + pub const STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E22001); /// 22011 - pub const SUBSTRING_ERROR: SqlState = SqlState(Cow::Borrowed("22011")); + pub const SUBSTRING_ERROR: SqlState = SqlState(Inner::E22011); /// 22027 - pub const TRIM_ERROR: SqlState = SqlState(Cow::Borrowed("22027")); + pub const TRIM_ERROR: SqlState = SqlState(Inner::E22027); /// 22024 - pub const UNTERMINATED_C_STRING: SqlState = SqlState(Cow::Borrowed("22024")); + pub const UNTERMINATED_C_STRING: SqlState = SqlState(Inner::E22024); /// 2200F - pub const ZERO_LENGTH_CHARACTER_STRING: SqlState = SqlState(Cow::Borrowed("2200F")); + pub const ZERO_LENGTH_CHARACTER_STRING: SqlState = SqlState(Inner::E2200F); /// 22P01 - pub const FLOATING_POINT_EXCEPTION: SqlState = SqlState(Cow::Borrowed("22P01")); + pub const FLOATING_POINT_EXCEPTION: SqlState = SqlState(Inner::E22P01); /// 22P02 - pub const INVALID_TEXT_REPRESENTATION: SqlState = SqlState(Cow::Borrowed("22P02")); + pub const INVALID_TEXT_REPRESENTATION: SqlState = SqlState(Inner::E22P02); /// 22P03 - pub const INVALID_BINARY_REPRESENTATION: SqlState = SqlState(Cow::Borrowed("22P03")); + pub const INVALID_BINARY_REPRESENTATION: SqlState = SqlState(Inner::E22P03); /// 22P04 - pub const BAD_COPY_FILE_FORMAT: SqlState = SqlState(Cow::Borrowed("22P04")); + pub const BAD_COPY_FILE_FORMAT: SqlState = SqlState(Inner::E22P04); /// 22P05 - pub const UNTRANSLATABLE_CHARACTER: SqlState = SqlState(Cow::Borrowed("22P05")); + pub const UNTRANSLATABLE_CHARACTER: SqlState = SqlState(Inner::E22P05); /// 2200L - pub const NOT_AN_XML_DOCUMENT: SqlState = SqlState(Cow::Borrowed("2200L")); + pub const NOT_AN_XML_DOCUMENT: SqlState = SqlState(Inner::E2200L); /// 2200M - pub const INVALID_XML_DOCUMENT: SqlState = SqlState(Cow::Borrowed("2200M")); + pub const INVALID_XML_DOCUMENT: SqlState = SqlState(Inner::E2200M); /// 2200N - pub const INVALID_XML_CONTENT: SqlState = SqlState(Cow::Borrowed("2200N")); + pub const INVALID_XML_CONTENT: SqlState = SqlState(Inner::E2200N); /// 2200S - pub const INVALID_XML_COMMENT: SqlState = SqlState(Cow::Borrowed("2200S")); + pub const INVALID_XML_COMMENT: SqlState = SqlState(Inner::E2200S); /// 2200T - pub const INVALID_XML_PROCESSING_INSTRUCTION: SqlState = SqlState(Cow::Borrowed("2200T")); + pub const INVALID_XML_PROCESSING_INSTRUCTION: SqlState = SqlState(Inner::E2200T); + + /// 22030 + pub const DUPLICATE_JSON_OBJECT_KEY_VALUE: SqlState = SqlState(Inner::E22030); + + /// 22031 + pub const INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION: SqlState = SqlState(Inner::E22031); + + /// 22032 + pub const INVALID_JSON_TEXT: SqlState = SqlState(Inner::E22032); + + /// 22033 + pub const INVALID_SQL_JSON_SUBSCRIPT: SqlState = SqlState(Inner::E22033); + + /// 22034 + pub const MORE_THAN_ONE_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22034); + + /// 22035 + pub const NO_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22035); + + /// 22036 + pub const NON_NUMERIC_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22036); + + /// 22037 + pub const NON_UNIQUE_KEYS_IN_A_JSON_OBJECT: SqlState = SqlState(Inner::E22037); + + /// 22038 + pub const SINGLETON_SQL_JSON_ITEM_REQUIRED: SqlState = SqlState(Inner::E22038); + + /// 22039 + pub const SQL_JSON_ARRAY_NOT_FOUND: SqlState = SqlState(Inner::E22039); + + /// 2203A + pub const SQL_JSON_MEMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203A); + + /// 2203B + pub const SQL_JSON_NUMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203B); + + /// 2203C + pub const SQL_JSON_OBJECT_NOT_FOUND: SqlState = SqlState(Inner::E2203C); + + /// 2203D + pub const TOO_MANY_JSON_ARRAY_ELEMENTS: SqlState = SqlState(Inner::E2203D); + + /// 2203E + pub const TOO_MANY_JSON_OBJECT_MEMBERS: SqlState = SqlState(Inner::E2203E); + + /// 2203F + pub const SQL_JSON_SCALAR_REQUIRED: SqlState = SqlState(Inner::E2203F); + + /// 2203G + pub const SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE: SqlState = SqlState(Inner::E2203G); /// 23000 - pub const INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Cow::Borrowed("23000")); + pub const INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E23000); /// 23001 - pub const RESTRICT_VIOLATION: SqlState = SqlState(Cow::Borrowed("23001")); + pub const RESTRICT_VIOLATION: SqlState = SqlState(Inner::E23001); /// 23502 - pub const NOT_NULL_VIOLATION: SqlState = SqlState(Cow::Borrowed("23502")); + pub const NOT_NULL_VIOLATION: SqlState = SqlState(Inner::E23502); /// 23503 - pub const FOREIGN_KEY_VIOLATION: SqlState = SqlState(Cow::Borrowed("23503")); + pub const FOREIGN_KEY_VIOLATION: SqlState = SqlState(Inner::E23503); /// 23505 - pub const UNIQUE_VIOLATION: SqlState = SqlState(Cow::Borrowed("23505")); + pub const UNIQUE_VIOLATION: SqlState = SqlState(Inner::E23505); /// 23514 - pub const CHECK_VIOLATION: SqlState = SqlState(Cow::Borrowed("23514")); + pub const CHECK_VIOLATION: SqlState = SqlState(Inner::E23514); /// 23P01 - pub const EXCLUSION_VIOLATION: SqlState = SqlState(Cow::Borrowed("23P01")); + pub const EXCLUSION_VIOLATION: SqlState = SqlState(Inner::E23P01); /// 24000 - pub const INVALID_CURSOR_STATE: SqlState = SqlState(Cow::Borrowed("24000")); + pub const INVALID_CURSOR_STATE: SqlState = SqlState(Inner::E24000); /// 25000 - pub const INVALID_TRANSACTION_STATE: SqlState = SqlState(Cow::Borrowed("25000")); + pub const INVALID_TRANSACTION_STATE: SqlState = SqlState(Inner::E25000); /// 25001 - pub const ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25001")); + pub const ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25001); /// 25002 - pub const BRANCH_TRANSACTION_ALREADY_ACTIVE: SqlState = SqlState(Cow::Borrowed("25002")); + pub const BRANCH_TRANSACTION_ALREADY_ACTIVE: SqlState = SqlState(Inner::E25002); /// 25008 - pub const HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: SqlState = SqlState(Cow::Borrowed("25008")); + pub const HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: SqlState = SqlState(Inner::E25008); /// 25003 - pub const INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25003")); + pub const INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25003); /// 25004 - pub const INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25004")); + pub const INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: SqlState = + SqlState(Inner::E25004); /// 25005 - pub const NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25005")); + pub const NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25005); /// 25006 - pub const READ_ONLY_SQL_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25006")); + pub const READ_ONLY_SQL_TRANSACTION: SqlState = SqlState(Inner::E25006); /// 25007 - pub const SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: SqlState = SqlState(Cow::Borrowed("25007")); + pub const SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: SqlState = SqlState(Inner::E25007); /// 25P01 - pub const NO_ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25P01")); + pub const NO_ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P01); /// 25P02 - pub const IN_FAILED_SQL_TRANSACTION: SqlState = SqlState(Cow::Borrowed("25P02")); + pub const IN_FAILED_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P02); /// 25P03 - pub const IDLE_IN_TRANSACTION_SESSION_TIMEOUT: SqlState = SqlState(Cow::Borrowed("25P03")); + pub const IDLE_IN_TRANSACTION_SESSION_TIMEOUT: SqlState = SqlState(Inner::E25P03); /// 26000 - pub const INVALID_SQL_STATEMENT_NAME: SqlState = SqlState(Cow::Borrowed("26000")); + pub const INVALID_SQL_STATEMENT_NAME: SqlState = SqlState(Inner::E26000); /// 26000 - pub const UNDEFINED_PSTATEMENT: SqlState = SqlState(Cow::Borrowed("26000")); + pub const UNDEFINED_PSTATEMENT: SqlState = SqlState(Inner::E26000); /// 27000 - pub const TRIGGERED_DATA_CHANGE_VIOLATION: SqlState = SqlState(Cow::Borrowed("27000")); + pub const TRIGGERED_DATA_CHANGE_VIOLATION: SqlState = SqlState(Inner::E27000); /// 28000 - pub const INVALID_AUTHORIZATION_SPECIFICATION: SqlState = SqlState(Cow::Borrowed("28000")); + pub const INVALID_AUTHORIZATION_SPECIFICATION: SqlState = SqlState(Inner::E28000); /// 28P01 - pub const INVALID_PASSWORD: SqlState = SqlState(Cow::Borrowed("28P01")); + pub const INVALID_PASSWORD: SqlState = SqlState(Inner::E28P01); /// 2B000 - pub const DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: SqlState = SqlState(Cow::Borrowed("2B000")); + pub const DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: SqlState = SqlState(Inner::E2B000); /// 2BP01 - pub const DEPENDENT_OBJECTS_STILL_EXIST: SqlState = SqlState(Cow::Borrowed("2BP01")); + pub const DEPENDENT_OBJECTS_STILL_EXIST: SqlState = SqlState(Inner::E2BP01); /// 2D000 - pub const INVALID_TRANSACTION_TERMINATION: SqlState = SqlState(Cow::Borrowed("2D000")); + pub const INVALID_TRANSACTION_TERMINATION: SqlState = SqlState(Inner::E2D000); /// 2F000 - pub const SQL_ROUTINE_EXCEPTION: SqlState = SqlState(Cow::Borrowed("2F000")); + pub const SQL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E2F000); /// 2F005 - pub const S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT: SqlState = SqlState(Cow::Borrowed("2F005")); + pub const S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT: SqlState = SqlState(Inner::E2F005); /// 2F002 - pub const S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Cow::Borrowed("2F002")); + pub const S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F002); /// 2F003 - pub const S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Cow::Borrowed("2F003")); + pub const S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E2F003); /// 2F004 - pub const S_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Cow::Borrowed("2F004")); + pub const S_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F004); /// 34000 - pub const INVALID_CURSOR_NAME: SqlState = SqlState(Cow::Borrowed("34000")); + pub const INVALID_CURSOR_NAME: SqlState = SqlState(Inner::E34000); /// 34000 - pub const UNDEFINED_CURSOR: SqlState = SqlState(Cow::Borrowed("34000")); + pub const UNDEFINED_CURSOR: SqlState = SqlState(Inner::E34000); /// 38000 - pub const EXTERNAL_ROUTINE_EXCEPTION: SqlState = SqlState(Cow::Borrowed("38000")); + pub const EXTERNAL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E38000); /// 38001 - pub const E_R_E_CONTAINING_SQL_NOT_PERMITTED: SqlState = SqlState(Cow::Borrowed("38001")); + pub const E_R_E_CONTAINING_SQL_NOT_PERMITTED: SqlState = SqlState(Inner::E38001); /// 38002 - pub const E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Cow::Borrowed("38002")); + pub const E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38002); /// 38003 - pub const E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Cow::Borrowed("38003")); + pub const E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E38003); /// 38004 - pub const E_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Cow::Borrowed("38004")); + pub const E_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38004); /// 39000 - pub const EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: SqlState = SqlState(Cow::Borrowed("39000")); + pub const EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: SqlState = SqlState(Inner::E39000); /// 39001 - pub const E_R_I_E_INVALID_SQLSTATE_RETURNED: SqlState = SqlState(Cow::Borrowed("39001")); + pub const E_R_I_E_INVALID_SQLSTATE_RETURNED: SqlState = SqlState(Inner::E39001); /// 39004 - pub const E_R_I_E_NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Cow::Borrowed("39004")); + pub const E_R_I_E_NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E39004); /// 39P01 - pub const E_R_I_E_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Cow::Borrowed("39P01")); + pub const E_R_I_E_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P01); /// 39P02 - pub const E_R_I_E_SRF_PROTOCOL_VIOLATED: SqlState = SqlState(Cow::Borrowed("39P02")); + pub const E_R_I_E_SRF_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P02); /// 39P03 - pub const E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Cow::Borrowed("39P03")); + pub const E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P03); /// 3B000 - pub const SAVEPOINT_EXCEPTION: SqlState = SqlState(Cow::Borrowed("3B000")); + pub const SAVEPOINT_EXCEPTION: SqlState = SqlState(Inner::E3B000); /// 3B001 - pub const S_E_INVALID_SPECIFICATION: SqlState = SqlState(Cow::Borrowed("3B001")); + pub const S_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E3B001); /// 3D000 - pub const INVALID_CATALOG_NAME: SqlState = SqlState(Cow::Borrowed("3D000")); + pub const INVALID_CATALOG_NAME: SqlState = SqlState(Inner::E3D000); /// 3D000 - pub const UNDEFINED_DATABASE: SqlState = SqlState(Cow::Borrowed("3D000")); + pub const UNDEFINED_DATABASE: SqlState = SqlState(Inner::E3D000); /// 3F000 - pub const INVALID_SCHEMA_NAME: SqlState = SqlState(Cow::Borrowed("3F000")); + pub const INVALID_SCHEMA_NAME: SqlState = SqlState(Inner::E3F000); /// 3F000 - pub const UNDEFINED_SCHEMA: SqlState = SqlState(Cow::Borrowed("3F000")); + pub const UNDEFINED_SCHEMA: SqlState = SqlState(Inner::E3F000); /// 40000 - pub const TRANSACTION_ROLLBACK: SqlState = SqlState(Cow::Borrowed("40000")); + pub const TRANSACTION_ROLLBACK: SqlState = SqlState(Inner::E40000); /// 40002 - pub const T_R_INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Cow::Borrowed("40002")); + pub const T_R_INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E40002); /// 40001 - pub const T_R_SERIALIZATION_FAILURE: SqlState = SqlState(Cow::Borrowed("40001")); + pub const T_R_SERIALIZATION_FAILURE: SqlState = SqlState(Inner::E40001); /// 40003 - pub const T_R_STATEMENT_COMPLETION_UNKNOWN: SqlState = SqlState(Cow::Borrowed("40003")); + pub const T_R_STATEMENT_COMPLETION_UNKNOWN: SqlState = SqlState(Inner::E40003); /// 40P01 - pub const T_R_DEADLOCK_DETECTED: SqlState = SqlState(Cow::Borrowed("40P01")); + pub const T_R_DEADLOCK_DETECTED: SqlState = SqlState(Inner::E40P01); /// 42000 - pub const SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: SqlState = SqlState(Cow::Borrowed("42000")); + pub const SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: SqlState = SqlState(Inner::E42000); /// 42601 - pub const SYNTAX_ERROR: SqlState = SqlState(Cow::Borrowed("42601")); + pub const SYNTAX_ERROR: SqlState = SqlState(Inner::E42601); /// 42501 - pub const INSUFFICIENT_PRIVILEGE: SqlState = SqlState(Cow::Borrowed("42501")); + pub const INSUFFICIENT_PRIVILEGE: SqlState = SqlState(Inner::E42501); /// 42846 - pub const CANNOT_COERCE: SqlState = SqlState(Cow::Borrowed("42846")); + pub const CANNOT_COERCE: SqlState = SqlState(Inner::E42846); /// 42803 - pub const GROUPING_ERROR: SqlState = SqlState(Cow::Borrowed("42803")); + pub const GROUPING_ERROR: SqlState = SqlState(Inner::E42803); /// 42P20 - pub const WINDOWING_ERROR: SqlState = SqlState(Cow::Borrowed("42P20")); + pub const WINDOWING_ERROR: SqlState = SqlState(Inner::E42P20); /// 42P19 - pub const INVALID_RECURSION: SqlState = SqlState(Cow::Borrowed("42P19")); + pub const INVALID_RECURSION: SqlState = SqlState(Inner::E42P19); /// 42830 - pub const INVALID_FOREIGN_KEY: SqlState = SqlState(Cow::Borrowed("42830")); + pub const INVALID_FOREIGN_KEY: SqlState = SqlState(Inner::E42830); /// 42602 - pub const INVALID_NAME: SqlState = SqlState(Cow::Borrowed("42602")); + pub const INVALID_NAME: SqlState = SqlState(Inner::E42602); /// 42622 - pub const NAME_TOO_LONG: SqlState = SqlState(Cow::Borrowed("42622")); + pub const NAME_TOO_LONG: SqlState = SqlState(Inner::E42622); /// 42939 - pub const RESERVED_NAME: SqlState = SqlState(Cow::Borrowed("42939")); + pub const RESERVED_NAME: SqlState = SqlState(Inner::E42939); /// 42804 - pub const DATATYPE_MISMATCH: SqlState = SqlState(Cow::Borrowed("42804")); + pub const DATATYPE_MISMATCH: SqlState = SqlState(Inner::E42804); /// 42P18 - pub const INDETERMINATE_DATATYPE: SqlState = SqlState(Cow::Borrowed("42P18")); + pub const INDETERMINATE_DATATYPE: SqlState = SqlState(Inner::E42P18); /// 42P21 - pub const COLLATION_MISMATCH: SqlState = SqlState(Cow::Borrowed("42P21")); + pub const COLLATION_MISMATCH: SqlState = SqlState(Inner::E42P21); /// 42P22 - pub const INDETERMINATE_COLLATION: SqlState = SqlState(Cow::Borrowed("42P22")); + pub const INDETERMINATE_COLLATION: SqlState = SqlState(Inner::E42P22); /// 42809 - pub const WRONG_OBJECT_TYPE: SqlState = SqlState(Cow::Borrowed("42809")); + pub const WRONG_OBJECT_TYPE: SqlState = SqlState(Inner::E42809); /// 428C9 - pub const GENERATED_ALWAYS: SqlState = SqlState(Cow::Borrowed("428C9")); + pub const GENERATED_ALWAYS: SqlState = SqlState(Inner::E428C9); /// 42703 - pub const UNDEFINED_COLUMN: SqlState = SqlState(Cow::Borrowed("42703")); + pub const UNDEFINED_COLUMN: SqlState = SqlState(Inner::E42703); /// 42883 - pub const UNDEFINED_FUNCTION: SqlState = SqlState(Cow::Borrowed("42883")); + pub const UNDEFINED_FUNCTION: SqlState = SqlState(Inner::E42883); /// 42P01 - pub const UNDEFINED_TABLE: SqlState = SqlState(Cow::Borrowed("42P01")); + pub const UNDEFINED_TABLE: SqlState = SqlState(Inner::E42P01); /// 42P02 - pub const UNDEFINED_PARAMETER: SqlState = SqlState(Cow::Borrowed("42P02")); + pub const UNDEFINED_PARAMETER: SqlState = SqlState(Inner::E42P02); /// 42704 - pub const UNDEFINED_OBJECT: SqlState = SqlState(Cow::Borrowed("42704")); + pub const UNDEFINED_OBJECT: SqlState = SqlState(Inner::E42704); /// 42701 - pub const DUPLICATE_COLUMN: SqlState = SqlState(Cow::Borrowed("42701")); + pub const DUPLICATE_COLUMN: SqlState = SqlState(Inner::E42701); /// 42P03 - pub const DUPLICATE_CURSOR: SqlState = SqlState(Cow::Borrowed("42P03")); + pub const DUPLICATE_CURSOR: SqlState = SqlState(Inner::E42P03); /// 42P04 - pub const DUPLICATE_DATABASE: SqlState = SqlState(Cow::Borrowed("42P04")); + pub const DUPLICATE_DATABASE: SqlState = SqlState(Inner::E42P04); /// 42723 - pub const DUPLICATE_FUNCTION: SqlState = SqlState(Cow::Borrowed("42723")); + pub const DUPLICATE_FUNCTION: SqlState = SqlState(Inner::E42723); /// 42P05 - pub const DUPLICATE_PSTATEMENT: SqlState = SqlState(Cow::Borrowed("42P05")); + pub const DUPLICATE_PSTATEMENT: SqlState = SqlState(Inner::E42P05); /// 42P06 - pub const DUPLICATE_SCHEMA: SqlState = SqlState(Cow::Borrowed("42P06")); + pub const DUPLICATE_SCHEMA: SqlState = SqlState(Inner::E42P06); /// 42P07 - pub const DUPLICATE_TABLE: SqlState = SqlState(Cow::Borrowed("42P07")); + pub const DUPLICATE_TABLE: SqlState = SqlState(Inner::E42P07); /// 42712 - pub const DUPLICATE_ALIAS: SqlState = SqlState(Cow::Borrowed("42712")); + pub const DUPLICATE_ALIAS: SqlState = SqlState(Inner::E42712); /// 42710 - pub const DUPLICATE_OBJECT: SqlState = SqlState(Cow::Borrowed("42710")); + pub const DUPLICATE_OBJECT: SqlState = SqlState(Inner::E42710); /// 42702 - pub const AMBIGUOUS_COLUMN: SqlState = SqlState(Cow::Borrowed("42702")); + pub const AMBIGUOUS_COLUMN: SqlState = SqlState(Inner::E42702); /// 42725 - pub const AMBIGUOUS_FUNCTION: SqlState = SqlState(Cow::Borrowed("42725")); + pub const AMBIGUOUS_FUNCTION: SqlState = SqlState(Inner::E42725); /// 42P08 - pub const AMBIGUOUS_PARAMETER: SqlState = SqlState(Cow::Borrowed("42P08")); + pub const AMBIGUOUS_PARAMETER: SqlState = SqlState(Inner::E42P08); /// 42P09 - pub const AMBIGUOUS_ALIAS: SqlState = SqlState(Cow::Borrowed("42P09")); + pub const AMBIGUOUS_ALIAS: SqlState = SqlState(Inner::E42P09); /// 42P10 - pub const INVALID_COLUMN_REFERENCE: SqlState = SqlState(Cow::Borrowed("42P10")); + pub const INVALID_COLUMN_REFERENCE: SqlState = SqlState(Inner::E42P10); /// 42611 - pub const INVALID_COLUMN_DEFINITION: SqlState = SqlState(Cow::Borrowed("42611")); + pub const INVALID_COLUMN_DEFINITION: SqlState = SqlState(Inner::E42611); /// 42P11 - pub const INVALID_CURSOR_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P11")); + pub const INVALID_CURSOR_DEFINITION: SqlState = SqlState(Inner::E42P11); /// 42P12 - pub const INVALID_DATABASE_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P12")); + pub const INVALID_DATABASE_DEFINITION: SqlState = SqlState(Inner::E42P12); /// 42P13 - pub const INVALID_FUNCTION_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P13")); + pub const INVALID_FUNCTION_DEFINITION: SqlState = SqlState(Inner::E42P13); /// 42P14 - pub const INVALID_PSTATEMENT_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P14")); + pub const INVALID_PSTATEMENT_DEFINITION: SqlState = SqlState(Inner::E42P14); /// 42P15 - pub const INVALID_SCHEMA_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P15")); + pub const INVALID_SCHEMA_DEFINITION: SqlState = SqlState(Inner::E42P15); /// 42P16 - pub const INVALID_TABLE_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P16")); + pub const INVALID_TABLE_DEFINITION: SqlState = SqlState(Inner::E42P16); /// 42P17 - pub const INVALID_OBJECT_DEFINITION: SqlState = SqlState(Cow::Borrowed("42P17")); + pub const INVALID_OBJECT_DEFINITION: SqlState = SqlState(Inner::E42P17); /// 44000 - pub const WITH_CHECK_OPTION_VIOLATION: SqlState = SqlState(Cow::Borrowed("44000")); + pub const WITH_CHECK_OPTION_VIOLATION: SqlState = SqlState(Inner::E44000); /// 53000 - pub const INSUFFICIENT_RESOURCES: SqlState = SqlState(Cow::Borrowed("53000")); + pub const INSUFFICIENT_RESOURCES: SqlState = SqlState(Inner::E53000); /// 53100 - pub const DISK_FULL: SqlState = SqlState(Cow::Borrowed("53100")); + pub const DISK_FULL: SqlState = SqlState(Inner::E53100); /// 53200 - pub const OUT_OF_MEMORY: SqlState = SqlState(Cow::Borrowed("53200")); + pub const OUT_OF_MEMORY: SqlState = SqlState(Inner::E53200); /// 53300 - pub const TOO_MANY_CONNECTIONS: SqlState = SqlState(Cow::Borrowed("53300")); + pub const TOO_MANY_CONNECTIONS: SqlState = SqlState(Inner::E53300); /// 53400 - pub const CONFIGURATION_LIMIT_EXCEEDED: SqlState = SqlState(Cow::Borrowed("53400")); + pub const CONFIGURATION_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E53400); /// 54000 - pub const PROGRAM_LIMIT_EXCEEDED: SqlState = SqlState(Cow::Borrowed("54000")); + pub const PROGRAM_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E54000); /// 54001 - pub const STATEMENT_TOO_COMPLEX: SqlState = SqlState(Cow::Borrowed("54001")); + pub const STATEMENT_TOO_COMPLEX: SqlState = SqlState(Inner::E54001); /// 54011 - pub const TOO_MANY_COLUMNS: SqlState = SqlState(Cow::Borrowed("54011")); + pub const TOO_MANY_COLUMNS: SqlState = SqlState(Inner::E54011); /// 54023 - pub const TOO_MANY_ARGUMENTS: SqlState = SqlState(Cow::Borrowed("54023")); + pub const TOO_MANY_ARGUMENTS: SqlState = SqlState(Inner::E54023); /// 55000 - pub const OBJECT_NOT_IN_PREREQUISITE_STATE: SqlState = SqlState(Cow::Borrowed("55000")); + pub const OBJECT_NOT_IN_PREREQUISITE_STATE: SqlState = SqlState(Inner::E55000); /// 55006 - pub const OBJECT_IN_USE: SqlState = SqlState(Cow::Borrowed("55006")); + pub const OBJECT_IN_USE: SqlState = SqlState(Inner::E55006); /// 55P02 - pub const CANT_CHANGE_RUNTIME_PARAM: SqlState = SqlState(Cow::Borrowed("55P02")); + pub const CANT_CHANGE_RUNTIME_PARAM: SqlState = SqlState(Inner::E55P02); /// 55P03 - pub const LOCK_NOT_AVAILABLE: SqlState = SqlState(Cow::Borrowed("55P03")); + pub const LOCK_NOT_AVAILABLE: SqlState = SqlState(Inner::E55P03); /// 55P04 - pub const UNSAFE_NEW_ENUM_VALUE_USAGE: SqlState = SqlState(Cow::Borrowed("55P04")); + pub const UNSAFE_NEW_ENUM_VALUE_USAGE: SqlState = SqlState(Inner::E55P04); /// 57000 - pub const OPERATOR_INTERVENTION: SqlState = SqlState(Cow::Borrowed("57000")); + pub const OPERATOR_INTERVENTION: SqlState = SqlState(Inner::E57000); /// 57014 - pub const QUERY_CANCELED: SqlState = SqlState(Cow::Borrowed("57014")); + pub const QUERY_CANCELED: SqlState = SqlState(Inner::E57014); /// 57P01 - pub const ADMIN_SHUTDOWN: SqlState = SqlState(Cow::Borrowed("57P01")); + pub const ADMIN_SHUTDOWN: SqlState = SqlState(Inner::E57P01); /// 57P02 - pub const CRASH_SHUTDOWN: SqlState = SqlState(Cow::Borrowed("57P02")); + pub const CRASH_SHUTDOWN: SqlState = SqlState(Inner::E57P02); /// 57P03 - pub const CANNOT_CONNECT_NOW: SqlState = SqlState(Cow::Borrowed("57P03")); + pub const CANNOT_CONNECT_NOW: SqlState = SqlState(Inner::E57P03); /// 57P04 - pub const DATABASE_DROPPED: SqlState = SqlState(Cow::Borrowed("57P04")); + pub const DATABASE_DROPPED: SqlState = SqlState(Inner::E57P04); + + /// 57P05 + pub const IDLE_SESSION_TIMEOUT: SqlState = SqlState(Inner::E57P05); /// 58000 - pub const SYSTEM_ERROR: SqlState = SqlState(Cow::Borrowed("58000")); + pub const SYSTEM_ERROR: SqlState = SqlState(Inner::E58000); /// 58030 - pub const IO_ERROR: SqlState = SqlState(Cow::Borrowed("58030")); + pub const IO_ERROR: SqlState = SqlState(Inner::E58030); /// 58P01 - pub const UNDEFINED_FILE: SqlState = SqlState(Cow::Borrowed("58P01")); + pub const UNDEFINED_FILE: SqlState = SqlState(Inner::E58P01); /// 58P02 - pub const DUPLICATE_FILE: SqlState = SqlState(Cow::Borrowed("58P02")); + pub const DUPLICATE_FILE: SqlState = SqlState(Inner::E58P02); /// 72000 - pub const SNAPSHOT_TOO_OLD: SqlState = SqlState(Cow::Borrowed("72000")); + pub const SNAPSHOT_TOO_OLD: SqlState = SqlState(Inner::E72000); /// F0000 - pub const CONFIG_FILE_ERROR: SqlState = SqlState(Cow::Borrowed("F0000")); + pub const CONFIG_FILE_ERROR: SqlState = SqlState(Inner::EF0000); /// F0001 - pub const LOCK_FILE_EXISTS: SqlState = SqlState(Cow::Borrowed("F0001")); + pub const LOCK_FILE_EXISTS: SqlState = SqlState(Inner::EF0001); /// HV000 - pub const FDW_ERROR: SqlState = SqlState(Cow::Borrowed("HV000")); + pub const FDW_ERROR: SqlState = SqlState(Inner::EHV000); /// HV005 - pub const FDW_COLUMN_NAME_NOT_FOUND: SqlState = SqlState(Cow::Borrowed("HV005")); + pub const FDW_COLUMN_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV005); /// HV002 - pub const FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: SqlState = SqlState(Cow::Borrowed("HV002")); + pub const FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: SqlState = SqlState(Inner::EHV002); /// HV010 - pub const FDW_FUNCTION_SEQUENCE_ERROR: SqlState = SqlState(Cow::Borrowed("HV010")); + pub const FDW_FUNCTION_SEQUENCE_ERROR: SqlState = SqlState(Inner::EHV010); /// HV021 - pub const FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: SqlState = SqlState(Cow::Borrowed("HV021")); + pub const FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: SqlState = SqlState(Inner::EHV021); /// HV024 - pub const FDW_INVALID_ATTRIBUTE_VALUE: SqlState = SqlState(Cow::Borrowed("HV024")); + pub const FDW_INVALID_ATTRIBUTE_VALUE: SqlState = SqlState(Inner::EHV024); /// HV007 - pub const FDW_INVALID_COLUMN_NAME: SqlState = SqlState(Cow::Borrowed("HV007")); + pub const FDW_INVALID_COLUMN_NAME: SqlState = SqlState(Inner::EHV007); /// HV008 - pub const FDW_INVALID_COLUMN_NUMBER: SqlState = SqlState(Cow::Borrowed("HV008")); + pub const FDW_INVALID_COLUMN_NUMBER: SqlState = SqlState(Inner::EHV008); /// HV004 - pub const FDW_INVALID_DATA_TYPE: SqlState = SqlState(Cow::Borrowed("HV004")); + pub const FDW_INVALID_DATA_TYPE: SqlState = SqlState(Inner::EHV004); /// HV006 - pub const FDW_INVALID_DATA_TYPE_DESCRIPTORS: SqlState = SqlState(Cow::Borrowed("HV006")); + pub const FDW_INVALID_DATA_TYPE_DESCRIPTORS: SqlState = SqlState(Inner::EHV006); /// HV091 - pub const FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: SqlState = SqlState(Cow::Borrowed("HV091")); + pub const FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: SqlState = SqlState(Inner::EHV091); /// HV00B - pub const FDW_INVALID_HANDLE: SqlState = SqlState(Cow::Borrowed("HV00B")); + pub const FDW_INVALID_HANDLE: SqlState = SqlState(Inner::EHV00B); /// HV00C - pub const FDW_INVALID_OPTION_INDEX: SqlState = SqlState(Cow::Borrowed("HV00C")); + pub const FDW_INVALID_OPTION_INDEX: SqlState = SqlState(Inner::EHV00C); /// HV00D - pub const FDW_INVALID_OPTION_NAME: SqlState = SqlState(Cow::Borrowed("HV00D")); + pub const FDW_INVALID_OPTION_NAME: SqlState = SqlState(Inner::EHV00D); /// HV090 - pub const FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: SqlState = SqlState(Cow::Borrowed("HV090")); + pub const FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: SqlState = SqlState(Inner::EHV090); /// HV00A - pub const FDW_INVALID_STRING_FORMAT: SqlState = SqlState(Cow::Borrowed("HV00A")); + pub const FDW_INVALID_STRING_FORMAT: SqlState = SqlState(Inner::EHV00A); /// HV009 - pub const FDW_INVALID_USE_OF_NULL_POINTER: SqlState = SqlState(Cow::Borrowed("HV009")); + pub const FDW_INVALID_USE_OF_NULL_POINTER: SqlState = SqlState(Inner::EHV009); /// HV014 - pub const FDW_TOO_MANY_HANDLES: SqlState = SqlState(Cow::Borrowed("HV014")); + pub const FDW_TOO_MANY_HANDLES: SqlState = SqlState(Inner::EHV014); /// HV001 - pub const FDW_OUT_OF_MEMORY: SqlState = SqlState(Cow::Borrowed("HV001")); + pub const FDW_OUT_OF_MEMORY: SqlState = SqlState(Inner::EHV001); /// HV00P - pub const FDW_NO_SCHEMAS: SqlState = SqlState(Cow::Borrowed("HV00P")); + pub const FDW_NO_SCHEMAS: SqlState = SqlState(Inner::EHV00P); /// HV00J - pub const FDW_OPTION_NAME_NOT_FOUND: SqlState = SqlState(Cow::Borrowed("HV00J")); + pub const FDW_OPTION_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV00J); /// HV00K - pub const FDW_REPLY_HANDLE: SqlState = SqlState(Cow::Borrowed("HV00K")); + pub const FDW_REPLY_HANDLE: SqlState = SqlState(Inner::EHV00K); /// HV00Q - pub const FDW_SCHEMA_NOT_FOUND: SqlState = SqlState(Cow::Borrowed("HV00Q")); + pub const FDW_SCHEMA_NOT_FOUND: SqlState = SqlState(Inner::EHV00Q); /// HV00R - pub const FDW_TABLE_NOT_FOUND: SqlState = SqlState(Cow::Borrowed("HV00R")); + pub const FDW_TABLE_NOT_FOUND: SqlState = SqlState(Inner::EHV00R); /// HV00L - pub const FDW_UNABLE_TO_CREATE_EXECUTION: SqlState = SqlState(Cow::Borrowed("HV00L")); + pub const FDW_UNABLE_TO_CREATE_EXECUTION: SqlState = SqlState(Inner::EHV00L); /// HV00M - pub const FDW_UNABLE_TO_CREATE_REPLY: SqlState = SqlState(Cow::Borrowed("HV00M")); + pub const FDW_UNABLE_TO_CREATE_REPLY: SqlState = SqlState(Inner::EHV00M); /// HV00N - pub const FDW_UNABLE_TO_ESTABLISH_CONNECTION: SqlState = SqlState(Cow::Borrowed("HV00N")); + pub const FDW_UNABLE_TO_ESTABLISH_CONNECTION: SqlState = SqlState(Inner::EHV00N); /// P0000 - pub const PLPGSQL_ERROR: SqlState = SqlState(Cow::Borrowed("P0000")); + pub const PLPGSQL_ERROR: SqlState = SqlState(Inner::EP0000); /// P0001 - pub const RAISE_EXCEPTION: SqlState = SqlState(Cow::Borrowed("P0001")); + pub const RAISE_EXCEPTION: SqlState = SqlState(Inner::EP0001); /// P0002 - pub const NO_DATA_FOUND: SqlState = SqlState(Cow::Borrowed("P0002")); + pub const NO_DATA_FOUND: SqlState = SqlState(Inner::EP0002); /// P0003 - pub const TOO_MANY_ROWS: SqlState = SqlState(Cow::Borrowed("P0003")); + pub const TOO_MANY_ROWS: SqlState = SqlState(Inner::EP0003); /// P0004 - pub const ASSERT_FAILURE: SqlState = SqlState(Cow::Borrowed("P0004")); + pub const ASSERT_FAILURE: SqlState = SqlState(Inner::EP0004); /// XX000 - pub const INTERNAL_ERROR: SqlState = SqlState(Cow::Borrowed("XX000")); + pub const INTERNAL_ERROR: SqlState = SqlState(Inner::EXX000); /// XX001 - pub const DATA_CORRUPTED: SqlState = SqlState(Cow::Borrowed("XX001")); + pub const DATA_CORRUPTED: SqlState = SqlState(Inner::EXX001); /// XX002 - pub const INDEX_CORRUPTED: SqlState = SqlState(Cow::Borrowed("XX002")); + pub const INDEX_CORRUPTED: SqlState = SqlState(Inner::EXX002); +} + +#[derive(PartialEq, Eq, Clone, Debug)] +#[allow(clippy::upper_case_acronyms)] +enum Inner { + E00000, + E01000, + E0100C, + E01008, + E01003, + E01007, + E01006, + E01004, + E01P01, + E02000, + E02001, + E03000, + E08000, + E08003, + E08006, + E08001, + E08004, + E08007, + E08P01, + E09000, + E0A000, + E0B000, + E0F000, + E0F001, + E0L000, + E0LP01, + E0P000, + E0Z000, + E0Z002, + E20000, + E21000, + E22000, + E2202E, + E22021, + E22008, + E22012, + E22005, + E2200B, + E22022, + E22015, + E2201E, + E22014, + E22016, + E2201F, + E2201G, + E22018, + E22007, + E22019, + E2200D, + E22025, + E22P06, + E22010, + E22023, + E22013, + E2201B, + E2201W, + E2201X, + E2202H, + E2202G, + E22009, + E2200C, + E2200G, + E22004, + E22002, + E22003, + E2200H, + E22026, + E22001, + E22011, + E22027, + E22024, + E2200F, + E22P01, + E22P02, + E22P03, + E22P04, + E22P05, + E2200L, + E2200M, + E2200N, + E2200S, + E2200T, + E22030, + E22031, + E22032, + E22033, + E22034, + E22035, + E22036, + E22037, + E22038, + E22039, + E2203A, + E2203B, + E2203C, + E2203D, + E2203E, + E2203F, + E2203G, + E23000, + E23001, + E23502, + E23503, + E23505, + E23514, + E23P01, + E24000, + E25000, + E25001, + E25002, + E25008, + E25003, + E25004, + E25005, + E25006, + E25007, + E25P01, + E25P02, + E25P03, + E26000, + E27000, + E28000, + E28P01, + E2B000, + E2BP01, + E2D000, + E2F000, + E2F005, + E2F002, + E2F003, + E2F004, + E34000, + E38000, + E38001, + E38002, + E38003, + E38004, + E39000, + E39001, + E39004, + E39P01, + E39P02, + E39P03, + E3B000, + E3B001, + E3D000, + E3F000, + E40000, + E40002, + E40001, + E40003, + E40P01, + E42000, + E42601, + E42501, + E42846, + E42803, + E42P20, + E42P19, + E42830, + E42602, + E42622, + E42939, + E42804, + E42P18, + E42P21, + E42P22, + E42809, + E428C9, + E42703, + E42883, + E42P01, + E42P02, + E42704, + E42701, + E42P03, + E42P04, + E42723, + E42P05, + E42P06, + E42P07, + E42712, + E42710, + E42702, + E42725, + E42P08, + E42P09, + E42P10, + E42611, + E42P11, + E42P12, + E42P13, + E42P14, + E42P15, + E42P16, + E42P17, + E44000, + E53000, + E53100, + E53200, + E53300, + E53400, + E54000, + E54001, + E54011, + E54023, + E55000, + E55006, + E55P02, + E55P03, + E55P04, + E57000, + E57014, + E57P01, + E57P02, + E57P03, + E57P04, + E57P05, + E58000, + E58030, + E58P01, + E58P02, + E72000, + EF0000, + EF0001, + EHV000, + EHV005, + EHV002, + EHV010, + EHV021, + EHV024, + EHV007, + EHV008, + EHV004, + EHV006, + EHV091, + EHV00B, + EHV00C, + EHV00D, + EHV090, + EHV00A, + EHV009, + EHV014, + EHV001, + EHV00P, + EHV00J, + EHV00K, + EHV00Q, + EHV00R, + EHV00L, + EHV00M, + EHV00N, + EP0000, + EP0001, + EP0002, + EP0003, + EP0004, + EXX000, + EXX001, + EXX002, + Other(Box), } -#[cfg_attr(rustfmt, rustfmt_skip)] -static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = ::phf::Map { - key: 1897749892740154578, - disps: ::phf::Slice::Static(&[ - (1, 99), + +#[rustfmt::skip] +static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = +::phf::Map { + key: 12913932095322966823, + disps: &[ + (0, 24), + (0, 12), + (0, 74), + (0, 109), + (0, 11), + (0, 9), (0, 0), - (1, 5), - (0, 3), - (0, 110), - (0, 54), - (0, 3), + (4, 38), + (3, 155), + (0, 6), + (1, 242), + (0, 66), + (0, 53), + (5, 180), + (3, 221), + (7, 230), + (0, 125), + (1, 46), + (0, 11), + (1, 2), + (0, 5), (0, 13), - (0, 0), - (0, 24), - (0, 214), - (0, 52), - (1, 34), - (0, 33), - (0, 44), - (0, 130), - (0, 16), - (0, 187), - (0, 3), - (13, 168), + (0, 171), + (0, 15), (0, 4), - (0, 19), - (0, 13), - (0, 87), + (0, 22), + (1, 85), + (0, 75), + (2, 0), + (1, 25), + (7, 47), + (0, 45), + (0, 35), + (0, 7), + (7, 124), (0, 0), - (0, 108), - (0, 123), - (7, 181), - (0, 109), - (0, 32), + (14, 104), + (1, 183), + (61, 50), + (3, 76), + (0, 12), + (0, 7), + (4, 189), + (0, 1), + (64, 102), (0, 0), - (1, 69), - (1, 81), - (1, 219), - (0, 157), - (2, 41), - (8, 141), + (16, 192), + (24, 19), (0, 5), - (0, 0), - (1, 6), - (0, 3), - (1, 146), - (1, 227), - (9, 94), - (10, 158), - (29, 65), - (3, 2), - (0, 33), - (1, 94), - ]), - entries: ::phf::Slice::Static(&[ - ("23001", SqlState::RESTRICT_VIOLATION), - ("42830", SqlState::INVALID_FOREIGN_KEY), - ("P0000", SqlState::PLPGSQL_ERROR), - ("58000", SqlState::SYSTEM_ERROR), - ("57P01", SqlState::ADMIN_SHUTDOWN), - ("22P04", SqlState::BAD_COPY_FILE_FORMAT), - ("42P05", SqlState::DUPLICATE_PSTATEMENT), - ("28000", SqlState::INVALID_AUTHORIZATION_SPECIFICATION), - ("2202E", SqlState::ARRAY_ELEMENT_ERROR), - ("2F005", SqlState::S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT), - ("53400", SqlState::CONFIGURATION_LIMIT_EXCEEDED), - ("20000", SqlState::CASE_NOT_FOUND), - ("25004", SqlState::INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION), - ("09000", SqlState::TRIGGERED_ACTION_EXCEPTION), - ("42P10", SqlState::INVALID_COLUMN_REFERENCE), - ("39P03", SqlState::E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED), - ("08000", SqlState::CONNECTION_EXCEPTION), - ("08006", SqlState::CONNECTION_FAILURE), - ("2201W", SqlState::INVALID_ROW_COUNT_IN_LIMIT_CLAUSE), - ("03000", SqlState::SQL_STATEMENT_NOT_YET_COMPLETE), - ("22014", SqlState::INVALID_ARGUMENT_FOR_NTILE), - ("42611", SqlState::INVALID_COLUMN_DEFINITION), - ("42P11", SqlState::INVALID_CURSOR_DEFINITION), + (0, 87), + (0, 89), + (0, 14), + ], + entries: &[ + ("2F000", SqlState::SQL_ROUTINE_EXCEPTION), + ("01008", SqlState::WARNING_IMPLICIT_ZERO_BIT_PADDING), + ("42501", SqlState::INSUFFICIENT_PRIVILEGE), + ("22000", SqlState::DATA_EXCEPTION), + ("0100C", SqlState::WARNING_DYNAMIC_RESULT_SETS_RETURNED), ("2200N", SqlState::INVALID_XML_CONTENT), - ("57014", SqlState::QUERY_CANCELED), - ("01003", SqlState::WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION), - ("01000", SqlState::WARNING), - ("55P04", SqlState::UNSAFE_NEW_ENUM_VALUE_USAGE), - ("25003", SqlState::INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION), - ("2200L", SqlState::NOT_AN_XML_DOCUMENT), - ("42846", SqlState::CANNOT_COERCE), - ("55P03", SqlState::LOCK_NOT_AVAILABLE), - ("08007", SqlState::TRANSACTION_RESOLUTION_UNKNOWN), - ("XX000", SqlState::INTERNAL_ERROR), - ("22005", SqlState::ERROR_IN_ASSIGNMENT), - ("22P03", SqlState::INVALID_BINARY_REPRESENTATION), - ("2201X", SqlState::INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE), - ("54011", SqlState::TOO_MANY_COLUMNS), - ("HV008", SqlState::FDW_INVALID_COLUMN_NUMBER), - ("HV009", SqlState::FDW_INVALID_USE_OF_NULL_POINTER), - ("0LP01", SqlState::INVALID_GRANT_OPERATION), - ("42704", SqlState::UNDEFINED_OBJECT), - ("25005", SqlState::NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION), - ("25P03", SqlState::IDLE_IN_TRANSACTION_SESSION_TIMEOUT), - ("44000", SqlState::WITH_CHECK_OPTION_VIOLATION), - ("22024", SqlState::UNTERMINATED_C_STRING), - ("0L000", SqlState::INVALID_GRANTOR), - ("40000", SqlState::TRANSACTION_ROLLBACK), - ("42P08", SqlState::AMBIGUOUS_PARAMETER), - ("38000", SqlState::EXTERNAL_ROUTINE_EXCEPTION), - ("42939", SqlState::RESERVED_NAME), ("40001", SqlState::T_R_SERIALIZATION_FAILURE), - ("HV00K", SqlState::FDW_REPLY_HANDLE), - ("2F002", SqlState::S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), - ("HV001", SqlState::FDW_OUT_OF_MEMORY), - ("42P19", SqlState::INVALID_RECURSION), - ("HV002", SqlState::FDW_DYNAMIC_PARAMETER_VALUE_NEEDED), - ("0A000", SqlState::FEATURE_NOT_SUPPORTED), - ("58P02", SqlState::DUPLICATE_FILE), + ("28P01", SqlState::INVALID_PASSWORD), + ("38000", SqlState::EXTERNAL_ROUTINE_EXCEPTION), ("25006", SqlState::READ_ONLY_SQL_TRANSACTION), + ("2203D", SqlState::TOO_MANY_JSON_ARRAY_ELEMENTS), + ("42P09", SqlState::AMBIGUOUS_ALIAS), + ("F0000", SqlState::CONFIG_FILE_ERROR), + ("42P18", SqlState::INDETERMINATE_DATATYPE), + ("40002", SqlState::T_R_INTEGRITY_CONSTRAINT_VIOLATION), ("22009", SqlState::INVALID_TIME_ZONE_DISPLACEMENT_VALUE), - ("0F001", SqlState::L_E_INVALID_SPECIFICATION), + ("42P08", SqlState::AMBIGUOUS_PARAMETER), + ("08000", SqlState::CONNECTION_EXCEPTION), + ("25P01", SqlState::NO_ACTIVE_SQL_TRANSACTION), + ("22024", SqlState::UNTERMINATED_C_STRING), + ("55000", SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE), + ("25001", SqlState::ACTIVE_SQL_TRANSACTION), + ("03000", SqlState::SQL_STATEMENT_NOT_YET_COMPLETE), + ("42710", SqlState::DUPLICATE_OBJECT), + ("2D000", SqlState::INVALID_TRANSACTION_TERMINATION), + ("2200G", SqlState::MOST_SPECIFIC_TYPE_MISMATCH), + ("22022", SqlState::INDICATOR_OVERFLOW), + ("55006", SqlState::OBJECT_IN_USE), + ("53200", SqlState::OUT_OF_MEMORY), + ("22012", SqlState::DIVISION_BY_ZERO), ("P0002", SqlState::NO_DATA_FOUND), - ("2F000", SqlState::SQL_ROUTINE_EXCEPTION), - ("01006", SqlState::WARNING_PRIVILEGE_NOT_REVOKED), - ("22025", SqlState::INVALID_ESCAPE_SEQUENCE), - ("22027", SqlState::TRIM_ERROR), - ("54001", SqlState::STATEMENT_TOO_COMPLEX), - ("42602", SqlState::INVALID_NAME), - ("54023", SqlState::TOO_MANY_ARGUMENTS), - ("2200T", SqlState::INVALID_XML_PROCESSING_INSTRUCTION), - ("01007", SqlState::WARNING_PRIVILEGE_NOT_GRANTED), - ("22000", SqlState::DATA_EXCEPTION), - ("28P01", SqlState::INVALID_PASSWORD), - ("23514", SqlState::CHECK_VIOLATION), - ("39P02", SqlState::E_R_I_E_SRF_PROTOCOL_VIOLATED), - ("57P02", SqlState::CRASH_SHUTDOWN), - ("42P03", SqlState::DUPLICATE_CURSOR), + ("XX001", SqlState::DATA_CORRUPTED), + ("22P05", SqlState::UNTRANSLATABLE_CHARACTER), + ("40003", SqlState::T_R_STATEMENT_COMPLETION_UNKNOWN), ("22021", SqlState::CHARACTER_NOT_IN_REPERTOIRE), - ("HV00P", SqlState::FDW_NO_SCHEMAS), - ("42701", SqlState::DUPLICATE_COLUMN), + ("25000", SqlState::INVALID_TRANSACTION_STATE), ("42P15", SqlState::INVALID_SCHEMA_DEFINITION), - ("HV00B", SqlState::FDW_INVALID_HANDLE), - ("34000", SqlState::INVALID_CURSOR_NAME), - ("22P06", SqlState::NONSTANDARD_USE_OF_ESCAPE_CHARACTER), - ("P0001", SqlState::RAISE_EXCEPTION), - ("08P01", SqlState::PROTOCOL_VIOLATION), + ("0B000", SqlState::INVALID_TRANSACTION_INITIATION), + ("22004", SqlState::NULL_VALUE_NOT_ALLOWED), + ("42804", SqlState::DATATYPE_MISMATCH), + ("42803", SqlState::GROUPING_ERROR), + ("02001", SqlState::NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED), + ("25002", SqlState::BRANCH_TRANSACTION_ALREADY_ACTIVE), + ("28000", SqlState::INVALID_AUTHORIZATION_SPECIFICATION), + ("HV009", SqlState::FDW_INVALID_USE_OF_NULL_POINTER), + ("22P01", SqlState::FLOATING_POINT_EXCEPTION), + ("2B000", SqlState::DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST), ("42723", SqlState::DUPLICATE_FUNCTION), - ("08001", SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), - ("HV006", SqlState::FDW_INVALID_DATA_TYPE_DESCRIPTORS), - ("23000", SqlState::INTEGRITY_CONSTRAINT_VIOLATION), + ("21000", SqlState::CARDINALITY_VIOLATION), + ("0Z002", SqlState::STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER), + ("23505", SqlState::UNIQUE_VIOLATION), + ("HV00J", SqlState::FDW_OPTION_NAME_NOT_FOUND), + ("23P01", SqlState::EXCLUSION_VIOLATION), + ("39P03", SqlState::E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED), + ("42P10", SqlState::INVALID_COLUMN_REFERENCE), + ("2202H", SqlState::INVALID_TABLESAMPLE_ARGUMENT), + ("55P04", SqlState::UNSAFE_NEW_ENUM_VALUE_USAGE), + ("P0000", SqlState::PLPGSQL_ERROR), + ("2F005", SqlState::S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT), + ("HV00M", SqlState::FDW_UNABLE_TO_CREATE_REPLY), + ("0A000", SqlState::FEATURE_NOT_SUPPORTED), + ("24000", SqlState::INVALID_CURSOR_STATE), + ("25008", SqlState::HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL), + ("01003", SqlState::WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION), ("42712", SqlState::DUPLICATE_ALIAS), - ("2201G", SqlState::INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION), - ("2200F", SqlState::ZERO_LENGTH_CHARACTER_STRING), - ("XX002", SqlState::INDEX_CORRUPTED), - ("53300", SqlState::TOO_MANY_CONNECTIONS), - ("38002", SqlState::E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), - ("22015", SqlState::INTERVAL_FIELD_OVERFLOW), - ("22P01", SqlState::FLOATING_POINT_EXCEPTION), - ("22012", SqlState::DIVISION_BY_ZERO), - ("XX001", SqlState::DATA_CORRUPTED), - ("0100C", SqlState::WARNING_DYNAMIC_RESULT_SETS_RETURNED), - ("42P01", SqlState::UNDEFINED_TABLE), - ("25002", SqlState::BRANCH_TRANSACTION_ALREADY_ACTIVE), - ("2D000", SqlState::INVALID_TRANSACTION_TERMINATION), + ("HV014", SqlState::FDW_TOO_MANY_HANDLES), + ("58030", SqlState::IO_ERROR), + ("2201W", SqlState::INVALID_ROW_COUNT_IN_LIMIT_CLAUSE), + ("22033", SqlState::INVALID_SQL_JSON_SUBSCRIPT), + ("2BP01", SqlState::DEPENDENT_OBJECTS_STILL_EXIST), + ("HV005", SqlState::FDW_COLUMN_NAME_NOT_FOUND), + ("25004", SqlState::INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION), + ("54000", SqlState::PROGRAM_LIMIT_EXCEEDED), + ("20000", SqlState::CASE_NOT_FOUND), + ("2203G", SqlState::SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE), + ("22038", SqlState::SINGLETON_SQL_JSON_ITEM_REQUIRED), + ("22007", SqlState::INVALID_DATETIME_FORMAT), + ("08004", SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION), + ("2200H", SqlState::SEQUENCE_GENERATOR_LIMIT_EXCEEDED), + ("HV00D", SqlState::FDW_INVALID_OPTION_NAME), ("P0004", SqlState::ASSERT_FAILURE), - ("2200C", SqlState::INVALID_USE_OF_ESCAPE_CHARACTER), - ("HV00R", SqlState::FDW_TABLE_NOT_FOUND), - ("22016", SqlState::INVALID_ARGUMENT_FOR_NTH_VALUE), + ("22018", SqlState::INVALID_CHARACTER_VALUE_FOR_CAST), + ("0L000", SqlState::INVALID_GRANTOR), + ("22P04", SqlState::BAD_COPY_FILE_FORMAT), + ("22031", SqlState::INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION), ("01P01", SqlState::WARNING_DEPRECATED_FEATURE), - ("F0000", SqlState::CONFIG_FILE_ERROR), - ("0Z000", SqlState::DIAGNOSTICS_EXCEPTION), - ("42P02", SqlState::UNDEFINED_PARAMETER), - ("2200S", SqlState::INVALID_XML_COMMENT), - ("2200H", SqlState::SEQUENCE_GENERATOR_LIMIT_EXCEEDED), + ("0LP01", SqlState::INVALID_GRANT_OPERATION), + ("58P02", SqlState::DUPLICATE_FILE), + ("26000", SqlState::INVALID_SQL_STATEMENT_NAME), + ("54001", SqlState::STATEMENT_TOO_COMPLEX), + ("22010", SqlState::INVALID_INDICATOR_PARAMETER_VALUE), ("HV00C", SqlState::FDW_INVALID_OPTION_INDEX), - ("38004", SqlState::E_R_E_READING_SQL_DATA_NOT_PERMITTED), - ("42703", SqlState::UNDEFINED_COLUMN), - ("23503", SqlState::FOREIGN_KEY_VIOLATION), - ("42000", SqlState::SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION), - ("22004", SqlState::NULL_VALUE_NOT_ALLOWED), - ("25008", SqlState::HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL), - ("22018", SqlState::INVALID_CHARACTER_VALUE_FOR_CAST), - ("22023", SqlState::INVALID_PARAMETER_VALUE), - ("22011", SqlState::SUBSTRING_ERROR), - ("40002", SqlState::T_R_INTEGRITY_CONSTRAINT_VIOLATION), - ("42803", SqlState::GROUPING_ERROR), - ("72000", SqlState::SNAPSHOT_TOO_OLD), - ("HV010", SqlState::FDW_FUNCTION_SEQUENCE_ERROR), - ("42809", SqlState::WRONG_OBJECT_TYPE), - ("42P16", SqlState::INVALID_TABLE_DEFINITION), - ("HV00D", SqlState::FDW_INVALID_OPTION_NAME), - ("39000", SqlState::EXTERNAL_ROUTINE_INVOCATION_EXCEPTION), - ("2202G", SqlState::INVALID_TABLESAMPLE_REPEAT), - ("42601", SqlState::SYNTAX_ERROR), - ("42622", SqlState::NAME_TOO_LONG), - ("HV00L", SqlState::FDW_UNABLE_TO_CREATE_EXECUTION), - ("25000", SqlState::INVALID_TRANSACTION_STATE), - ("3B000", SqlState::SAVEPOINT_EXCEPTION), - ("42P21", SqlState::COLLATION_MISMATCH), - ("23505", SqlState::UNIQUE_VIOLATION), - ("22001", SqlState::STRING_DATA_RIGHT_TRUNCATION), - ("02001", SqlState::NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED), - ("21000", SqlState::CARDINALITY_VIOLATION), - ("58P01", SqlState::UNDEFINED_FILE), + ("22008", SqlState::DATETIME_FIELD_OVERFLOW), + ("42P06", SqlState::DUPLICATE_SCHEMA), + ("25007", SqlState::SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED), + ("42P20", SqlState::WINDOWING_ERROR), ("HV091", SqlState::FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER), - ("25P01", SqlState::NO_ACTIVE_SQL_TRANSACTION), - ("40P01", SqlState::T_R_DEADLOCK_DETECTED), ("HV021", SqlState::FDW_INCONSISTENT_DESCRIPTOR_INFORMATION), - ("42P09", SqlState::AMBIGUOUS_ALIAS), - ("25007", SqlState::SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED), - ("23P01", SqlState::EXCLUSION_VIOLATION), - ("HV00J", SqlState::FDW_OPTION_NAME_NOT_FOUND), - ("58030", SqlState::IO_ERROR), + ("42702", SqlState::AMBIGUOUS_COLUMN), + ("02000", SqlState::NO_DATA), + ("54011", SqlState::TOO_MANY_COLUMNS), ("HV004", SqlState::FDW_INVALID_DATA_TYPE), - ("42710", SqlState::DUPLICATE_OBJECT), - ("HV090", SqlState::FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH), - ("42P18", SqlState::INDETERMINATE_DATATYPE), - ("HV00M", SqlState::FDW_UNABLE_TO_CREATE_REPLY), - ("42804", SqlState::DATATYPE_MISMATCH), - ("24000", SqlState::INVALID_CURSOR_STATE), - ("HV007", SqlState::FDW_INVALID_COLUMN_NAME), - ("2201E", SqlState::INVALID_ARGUMENT_FOR_LOG), - ("42P22", SqlState::INDETERMINATE_COLLATION), - ("22P05", SqlState::UNTRANSLATABLE_CHARACTER), - ("42P07", SqlState::DUPLICATE_TABLE), - ("2F004", SqlState::S_R_E_READING_SQL_DATA_NOT_PERMITTED), - ("23502", SqlState::NOT_NULL_VIOLATION), + ("01006", SqlState::WARNING_PRIVILEGE_NOT_REVOKED), + ("42701", SqlState::DUPLICATE_COLUMN), + ("08P01", SqlState::PROTOCOL_VIOLATION), + ("42622", SqlState::NAME_TOO_LONG), + ("P0003", SqlState::TOO_MANY_ROWS), + ("22003", SqlState::NUMERIC_VALUE_OUT_OF_RANGE), + ("42P03", SqlState::DUPLICATE_CURSOR), + ("23001", SqlState::RESTRICT_VIOLATION), ("57000", SqlState::OPERATOR_INTERVENTION), - ("HV000", SqlState::FDW_ERROR), - ("42883", SqlState::UNDEFINED_FUNCTION), + ("22027", SqlState::TRIM_ERROR), + ("42P12", SqlState::INVALID_DATABASE_DEFINITION), + ("3B000", SqlState::SAVEPOINT_EXCEPTION), ("2201B", SqlState::INVALID_REGULAR_EXPRESSION), - ("2200D", SqlState::INVALID_ESCAPE_OCTET), - ("42P06", SqlState::DUPLICATE_SCHEMA), - ("38003", SqlState::E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), - ("22026", SqlState::STRING_DATA_LENGTH_MISMATCH), - ("P0003", SqlState::TOO_MANY_ROWS), - ("3D000", SqlState::INVALID_CATALOG_NAME), - ("0B000", SqlState::INVALID_TRANSACTION_INITIATION), - ("55006", SqlState::OBJECT_IN_USE), - ("53200", SqlState::OUT_OF_MEMORY), - ("3F000", SqlState::INVALID_SCHEMA_NAME), - ("53100", SqlState::DISK_FULL), - ("2F003", SqlState::S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), - ("55P02", SqlState::CANT_CHANGE_RUNTIME_PARAM), - ("01004", SqlState::WARNING_STRING_DATA_RIGHT_TRUNCATION), - ("3B001", SqlState::S_E_INVALID_SPECIFICATION), - ("2200G", SqlState::MOST_SPECIFIC_TYPE_MISMATCH), + ("22030", SqlState::DUPLICATE_JSON_OBJECT_KEY_VALUE), + ("2F004", SqlState::S_R_E_READING_SQL_DATA_NOT_PERMITTED), ("428C9", SqlState::GENERATED_ALWAYS), - ("HV005", SqlState::FDW_COLUMN_NAME_NOT_FOUND), - ("2201F", SqlState::INVALID_ARGUMENT_FOR_POWER_FUNCTION), - ("22022", SqlState::INDICATOR_OVERFLOW), - ("HV00Q", SqlState::FDW_SCHEMA_NOT_FOUND), - ("0F000", SqlState::LOCATOR_EXCEPTION), - ("22002", SqlState::NULL_VALUE_NO_INDICATOR_PARAMETER), - ("02000", SqlState::NO_DATA), - ("2202H", SqlState::INVALID_TABLESAMPLE_ARGUMENT), - ("27000", SqlState::TRIGGERED_DATA_CHANGE_VIOLATION), - ("2BP01", SqlState::DEPENDENT_OBJECTS_STILL_EXIST), - ("55000", SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE), + ("2200S", SqlState::INVALID_XML_COMMENT), + ("22039", SqlState::SQL_JSON_ARRAY_NOT_FOUND), + ("42809", SqlState::WRONG_OBJECT_TYPE), + ("2201X", SqlState::INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE), ("39001", SqlState::E_R_I_E_INVALID_SQLSTATE_RETURNED), - ("08004", SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION), - ("42P13", SqlState::INVALID_FUNCTION_DEFINITION), - ("HV024", SqlState::FDW_INVALID_ATTRIBUTE_VALUE), - ("22019", SqlState::INVALID_ESCAPE_CHARACTER), - ("54000", SqlState::PROGRAM_LIMIT_EXCEEDED), - ("42501", SqlState::INSUFFICIENT_PRIVILEGE), - ("HV00A", SqlState::FDW_INVALID_STRING_FORMAT), - ("42702", SqlState::AMBIGUOUS_COLUMN), - ("53000", SqlState::INSUFFICIENT_RESOURCES), ("25P02", SqlState::IN_FAILED_SQL_TRANSACTION), - ("22010", SqlState::INVALID_INDICATOR_PARAMETER_VALUE), - ("01008", SqlState::WARNING_IMPLICIT_ZERO_BIT_PADDING), - ("HV014", SqlState::FDW_TOO_MANY_HANDLES), - ("42P20", SqlState::WINDOWING_ERROR), - ("42725", SqlState::AMBIGUOUS_FUNCTION), - ("F0001", SqlState::LOCK_FILE_EXISTS), - ("08003", SqlState::CONNECTION_DOES_NOT_EXIST), - ("2200M", SqlState::INVALID_XML_DOCUMENT), - ("22003", SqlState::NUMERIC_VALUE_OUT_OF_RANGE), - ("39004", SqlState::E_R_I_E_NULL_VALUE_NOT_ALLOWED), - ("2200B", SqlState::ESCAPE_CHARACTER_CONFLICT), ("0P000", SqlState::INVALID_ROLE_SPECIFICATION), + ("HV00N", SqlState::FDW_UNABLE_TO_ESTABLISH_CONNECTION), + ("53100", SqlState::DISK_FULL), + ("42601", SqlState::SYNTAX_ERROR), + ("23000", SqlState::INTEGRITY_CONSTRAINT_VIOLATION), + ("HV006", SqlState::FDW_INVALID_DATA_TYPE_DESCRIPTORS), + ("HV00B", SqlState::FDW_INVALID_HANDLE), + ("HV00Q", SqlState::FDW_SCHEMA_NOT_FOUND), + ("01000", SqlState::WARNING), + ("42883", SqlState::UNDEFINED_FUNCTION), + ("57P01", SqlState::ADMIN_SHUTDOWN), + ("22037", SqlState::NON_UNIQUE_KEYS_IN_A_JSON_OBJECT), ("00000", SqlState::SUCCESSFUL_COMPLETION), + ("55P03", SqlState::LOCK_NOT_AVAILABLE), + ("42P01", SqlState::UNDEFINED_TABLE), + ("42830", SqlState::INVALID_FOREIGN_KEY), + ("22005", SqlState::ERROR_IN_ASSIGNMENT), + ("22025", SqlState::INVALID_ESCAPE_SEQUENCE), + ("XX002", SqlState::INDEX_CORRUPTED), + ("42P16", SqlState::INVALID_TABLE_DEFINITION), + ("55P02", SqlState::CANT_CHANGE_RUNTIME_PARAM), + ("22019", SqlState::INVALID_ESCAPE_CHARACTER), + ("P0001", SqlState::RAISE_EXCEPTION), + ("72000", SqlState::SNAPSHOT_TOO_OLD), + ("42P11", SqlState::INVALID_CURSOR_DEFINITION), + ("40P01", SqlState::T_R_DEADLOCK_DETECTED), + ("57P02", SqlState::CRASH_SHUTDOWN), + ("HV00A", SqlState::FDW_INVALID_STRING_FORMAT), + ("2F002", SqlState::S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), + ("23503", SqlState::FOREIGN_KEY_VIOLATION), + ("40000", SqlState::TRANSACTION_ROLLBACK), + ("22032", SqlState::INVALID_JSON_TEXT), + ("2202E", SqlState::ARRAY_ELEMENT_ERROR), + ("42P19", SqlState::INVALID_RECURSION), + ("42611", SqlState::INVALID_COLUMN_DEFINITION), + ("42P13", SqlState::INVALID_FUNCTION_DEFINITION), + ("25003", SqlState::INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION), + ("39P02", SqlState::E_R_I_E_SRF_PROTOCOL_VIOLATED), + ("XX000", SqlState::INTERNAL_ERROR), + ("08006", SqlState::CONNECTION_FAILURE), + ("57P04", SqlState::DATABASE_DROPPED), + ("42P07", SqlState::DUPLICATE_TABLE), + ("22P03", SqlState::INVALID_BINARY_REPRESENTATION), + ("22035", SqlState::NO_SQL_JSON_ITEM), + ("42P14", SqlState::INVALID_PSTATEMENT_DEFINITION), + ("01007", SqlState::WARNING_PRIVILEGE_NOT_GRANTED), + ("38004", SqlState::E_R_E_READING_SQL_DATA_NOT_PERMITTED), + ("42P21", SqlState::COLLATION_MISMATCH), + ("0Z000", SqlState::DIAGNOSTICS_EXCEPTION), + ("HV001", SqlState::FDW_OUT_OF_MEMORY), + ("0F000", SqlState::LOCATOR_EXCEPTION), + ("22013", SqlState::INVALID_PRECEDING_OR_FOLLOWING_SIZE), + ("2201E", SqlState::INVALID_ARGUMENT_FOR_LOG), + ("22011", SqlState::SUBSTRING_ERROR), + ("42602", SqlState::INVALID_NAME), + ("01004", SqlState::WARNING_STRING_DATA_RIGHT_TRUNCATION), + ("42P02", SqlState::UNDEFINED_PARAMETER), + ("2203C", SqlState::SQL_JSON_OBJECT_NOT_FOUND), + ("HV002", SqlState::FDW_DYNAMIC_PARAMETER_VALUE_NEEDED), + ("0F001", SqlState::L_E_INVALID_SPECIFICATION), + ("58P01", SqlState::UNDEFINED_FILE), + ("38001", SqlState::E_R_E_CONTAINING_SQL_NOT_PERMITTED), + ("42703", SqlState::UNDEFINED_COLUMN), + ("57P05", SqlState::IDLE_SESSION_TIMEOUT), + ("57P03", SqlState::CANNOT_CONNECT_NOW), + ("HV007", SqlState::FDW_INVALID_COLUMN_NAME), + ("22014", SqlState::INVALID_ARGUMENT_FOR_NTILE), + ("22P06", SqlState::NONSTANDARD_USE_OF_ESCAPE_CHARACTER), + ("2203F", SqlState::SQL_JSON_SCALAR_REQUIRED), + ("2200F", SqlState::ZERO_LENGTH_CHARACTER_STRING), + ("09000", SqlState::TRIGGERED_ACTION_EXCEPTION), + ("2201F", SqlState::INVALID_ARGUMENT_FOR_POWER_FUNCTION), + ("08003", SqlState::CONNECTION_DOES_NOT_EXIST), + ("38002", SqlState::E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), + ("F0001", SqlState::LOCK_FILE_EXISTS), + ("42P22", SqlState::INDETERMINATE_COLLATION), + ("2200C", SqlState::INVALID_USE_OF_ESCAPE_CHARACTER), + ("2203E", SqlState::TOO_MANY_JSON_OBJECT_MEMBERS), + ("23514", SqlState::CHECK_VIOLATION), ("22P02", SqlState::INVALID_TEXT_REPRESENTATION), - ("25001", SqlState::ACTIVE_SQL_TRANSACTION), - ("HV00N", SqlState::FDW_UNABLE_TO_ESTABLISH_CONNECTION), + ("54023", SqlState::TOO_MANY_ARGUMENTS), + ("2200T", SqlState::INVALID_XML_PROCESSING_INSTRUCTION), + ("22016", SqlState::INVALID_ARGUMENT_FOR_NTH_VALUE), + ("25P03", SqlState::IDLE_IN_TRANSACTION_SESSION_TIMEOUT), + ("3B001", SqlState::S_E_INVALID_SPECIFICATION), + ("08001", SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), + ("22036", SqlState::NON_NUMERIC_SQL_JSON_ITEM), + ("3F000", SqlState::INVALID_SCHEMA_NAME), ("39P01", SqlState::E_R_I_E_TRIGGER_PROTOCOL_VIOLATED), - ("2B000", SqlState::DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST), - ("22008", SqlState::DATETIME_FIELD_OVERFLOW), - ("42P14", SqlState::INVALID_PSTATEMENT_DEFINITION), - ("57P04", SqlState::DATABASE_DROPPED), - ("26000", SqlState::INVALID_SQL_STATEMENT_NAME), + ("22026", SqlState::STRING_DATA_LENGTH_MISMATCH), ("42P17", SqlState::INVALID_OBJECT_DEFINITION), + ("22034", SqlState::MORE_THAN_ONE_SQL_JSON_ITEM), + ("HV000", SqlState::FDW_ERROR), + ("2200B", SqlState::ESCAPE_CHARACTER_CONFLICT), + ("HV008", SqlState::FDW_INVALID_COLUMN_NUMBER), + ("34000", SqlState::INVALID_CURSOR_NAME), + ("2201G", SqlState::INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION), + ("44000", SqlState::WITH_CHECK_OPTION_VIOLATION), + ("HV010", SqlState::FDW_FUNCTION_SEQUENCE_ERROR), + ("39004", SqlState::E_R_I_E_NULL_VALUE_NOT_ALLOWED), + ("22001", SqlState::STRING_DATA_RIGHT_TRUNCATION), + ("3D000", SqlState::INVALID_CATALOG_NAME), + ("25005", SqlState::NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION), + ("2200L", SqlState::NOT_AN_XML_DOCUMENT), + ("27000", SqlState::TRIGGERED_DATA_CHANGE_VIOLATION), + ("HV090", SqlState::FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH), + ("42939", SqlState::RESERVED_NAME), + ("58000", SqlState::SYSTEM_ERROR), + ("2200M", SqlState::INVALID_XML_DOCUMENT), + ("HV00L", SqlState::FDW_UNABLE_TO_CREATE_EXECUTION), + ("57014", SqlState::QUERY_CANCELED), + ("23502", SqlState::NOT_NULL_VIOLATION), + ("22002", SqlState::NULL_VALUE_NO_INDICATOR_PARAMETER), + ("HV00R", SqlState::FDW_TABLE_NOT_FOUND), + ("HV00P", SqlState::FDW_NO_SCHEMAS), + ("38003", SqlState::E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), + ("39000", SqlState::EXTERNAL_ROUTINE_INVOCATION_EXCEPTION), + ("22015", SqlState::INTERVAL_FIELD_OVERFLOW), + ("HV00K", SqlState::FDW_REPLY_HANDLE), + ("HV024", SqlState::FDW_INVALID_ATTRIBUTE_VALUE), + ("2200D", SqlState::INVALID_ESCAPE_OCTET), + ("08007", SqlState::TRANSACTION_RESOLUTION_UNKNOWN), + ("2F003", SqlState::S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), + ("42725", SqlState::AMBIGUOUS_FUNCTION), + ("2203A", SqlState::SQL_JSON_MEMBER_NOT_FOUND), + ("42846", SqlState::CANNOT_COERCE), ("42P04", SqlState::DUPLICATE_DATABASE), - ("38001", SqlState::E_R_E_CONTAINING_SQL_NOT_PERMITTED), - ("0Z002", SqlState::STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER), - ("22007", SqlState::INVALID_DATETIME_FORMAT), - ("40003", SqlState::T_R_STATEMENT_COMPLETION_UNKNOWN), - ("42P12", SqlState::INVALID_DATABASE_DEFINITION), - ("57P03", SqlState::CANNOT_CONNECT_NOW), - ]), + ("42000", SqlState::SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION), + ("2203B", SqlState::SQL_JSON_NUMBER_NOT_FOUND), + ("42P05", SqlState::DUPLICATE_PSTATEMENT), + ("53300", SqlState::TOO_MANY_CONNECTIONS), + ("53400", SqlState::CONFIGURATION_LIMIT_EXCEEDED), + ("42704", SqlState::UNDEFINED_OBJECT), + ("2202G", SqlState::INVALID_TABLESAMPLE_REPEAT), + ("22023", SqlState::INVALID_PARAMETER_VALUE), + ("53000", SqlState::INSUFFICIENT_RESOURCES), + ], }; diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs new file mode 100644 index 000000000..dcda147b5 --- /dev/null +++ b/tokio-postgres/src/generic_client.rs @@ -0,0 +1,305 @@ +use crate::query::RowStream; +use crate::types::{BorrowToSql, ToSql, Type}; +use crate::{Client, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction}; +use async_trait::async_trait; + +mod private { + pub trait Sealed {} +} + +/// A trait allowing abstraction over connections and transactions. +/// +/// This trait is "sealed", and cannot be implemented outside of this crate. +#[async_trait] +pub trait GenericClient: private::Sealed { + /// Like [`Client::execute`]. + async fn execute(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement + Sync + Send; + + /// Like [`Client::execute_raw`]. + async fn execute_raw(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + P: BorrowToSql, + I: IntoIterator + Sync + Send, + I::IntoIter: ExactSizeIterator; + + /// Like [`Client::query`]. + async fn query(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> + where + T: ?Sized + ToStatement + Sync + Send; + + /// Like [`Client::query_one`]. + async fn query_one( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + ToStatement + Sync + Send; + + /// Like [`Client::query_opt`]. + async fn query_opt( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement + Sync + Send; + + /// Like [`Client::query_raw`]. + async fn query_raw(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + P: BorrowToSql, + I: IntoIterator + Sync + Send, + I::IntoIter: ExactSizeIterator; + + /// Like [`Client::query_typed`] + async fn query_typed( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error>; + + /// Like [`Client::query_typed_raw`] + async fn query_typed_raw(&self, statement: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator + Sync + Send; + + /// Like [`Client::prepare`]. + async fn prepare(&self, query: &str) -> Result; + + /// Like [`Client::prepare_typed`]. + async fn prepare_typed( + &self, + query: &str, + parameter_types: &[Type], + ) -> Result; + + /// Like [`Client::transaction`]. + async fn transaction<'a>(&'a mut self) -> Result, Error>; + + /// Like [`Client::batch_execute`]. + async fn batch_execute(&self, query: &str) -> Result<(), Error>; + + /// Like [`Client::simple_query`]. + async fn simple_query(&self, query: &str) -> Result, Error>; + + /// Returns a reference to the underlying [`Client`]. + fn client(&self) -> &Client; +} + +impl private::Sealed for Client {} + +#[async_trait] +impl GenericClient for Client { + async fn execute(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + { + self.execute(query, params).await + } + + async fn execute_raw(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + P: BorrowToSql, + I: IntoIterator + Sync + Send, + I::IntoIter: ExactSizeIterator, + { + self.execute_raw(statement, params).await + } + + async fn query(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> + where + T: ?Sized + ToStatement + Sync + Send, + { + self.query(query, params).await + } + + async fn query_one( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + { + self.query_one(statement, params).await + } + + async fn query_opt( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement + Sync + Send, + { + self.query_opt(statement, params).await + } + + async fn query_raw(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + P: BorrowToSql, + I: IntoIterator + Sync + Send, + I::IntoIter: ExactSizeIterator, + { + self.query_raw(statement, params).await + } + + async fn query_typed( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed(statement, params).await + } + + async fn query_typed_raw(&self, statement: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params).await + } + + async fn prepare(&self, query: &str) -> Result { + self.prepare(query).await + } + + async fn prepare_typed( + &self, + query: &str, + parameter_types: &[Type], + ) -> Result { + self.prepare_typed(query, parameter_types).await + } + + async fn transaction<'a>(&'a mut self) -> Result, Error> { + self.transaction().await + } + + async fn batch_execute(&self, query: &str) -> Result<(), Error> { + self.batch_execute(query).await + } + + async fn simple_query(&self, query: &str) -> Result, Error> { + self.simple_query(query).await + } + + fn client(&self) -> &Client { + self + } +} + +impl private::Sealed for Transaction<'_> {} + +#[async_trait] +#[allow(clippy::needless_lifetimes)] +impl GenericClient for Transaction<'_> { + async fn execute(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + { + self.execute(query, params).await + } + + async fn execute_raw(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + P: BorrowToSql, + I: IntoIterator + Sync + Send, + I::IntoIter: ExactSizeIterator, + { + self.execute_raw(statement, params).await + } + + async fn query(&self, query: &T, params: &[&(dyn ToSql + Sync)]) -> Result, Error> + where + T: ?Sized + ToStatement + Sync + Send, + { + self.query(query, params).await + } + + async fn query_one( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + { + self.query_one(statement, params).await + } + + async fn query_opt( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement + Sync + Send, + { + self.query_opt(statement, params).await + } + + async fn query_raw(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + P: BorrowToSql, + I: IntoIterator + Sync + Send, + I::IntoIter: ExactSizeIterator, + { + self.query_raw(statement, params).await + } + + async fn query_typed( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.query_typed(statement, params).await + } + + async fn query_typed_raw(&self, statement: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator + Sync + Send, + { + self.query_typed_raw(statement, params).await + } + + async fn prepare(&self, query: &str) -> Result { + self.prepare(query).await + } + + async fn prepare_typed( + &self, + query: &str, + parameter_types: &[Type], + ) -> Result { + self.prepare_typed(query, parameter_types).await + } + + #[allow(clippy::needless_lifetimes)] + async fn transaction<'a>(&'a mut self) -> Result, Error> { + self.transaction().await + } + + async fn batch_execute(&self, query: &str) -> Result<(), Error> { + self.batch_execute(query).await + } + + async fn simple_query(&self, query: &str) -> Result, Error> { + self.simple_query(query).await + } + + fn client(&self) -> &Client { + self.client() + } +} diff --git a/tokio-postgres/src/keepalive.rs b/tokio-postgres/src/keepalive.rs new file mode 100644 index 000000000..7bdd76341 --- /dev/null +++ b/tokio-postgres/src/keepalive.rs @@ -0,0 +1,38 @@ +use socket2::TcpKeepalive; +use std::time::Duration; + +#[derive(Clone, PartialEq, Eq)] +pub(crate) struct KeepaliveConfig { + pub idle: Duration, + pub interval: Option, + pub retries: Option, +} + +impl From<&KeepaliveConfig> for TcpKeepalive { + fn from(keepalive_config: &KeepaliveConfig) -> Self { + let mut tcp_keepalive = Self::new().with_time(keepalive_config.idle); + + #[cfg(not(any( + target_os = "aix", + target_os = "redox", + target_os = "solaris", + target_os = "openbsd" + )))] + if let Some(interval) = keepalive_config.interval { + tcp_keepalive = tcp_keepalive.with_interval(interval); + } + + #[cfg(not(any( + target_os = "aix", + target_os = "redox", + target_os = "solaris", + target_os = "windows", + target_os = "openbsd" + )))] + if let Some(retries) = keepalive_config.retries { + tcp_keepalive = tcp_keepalive.with_retries(retries); + } + + tcp_keepalive + } +} diff --git a/tokio-postgres/src/lib.rs b/tokio-postgres/src/lib.rs index a16dd15e6..ec843d511 100644 --- a/tokio-postgres/src/lib.rs +++ b/tokio-postgres/src/lib.rs @@ -1,365 +1,261 @@ -extern crate antidote; -extern crate bytes; -extern crate fallible_iterator; -extern crate futures_cpupool; -extern crate phf; -extern crate postgres_protocol; -extern crate postgres_shared; -extern crate tokio_codec; -extern crate tokio_io; -extern crate tokio_tcp; -extern crate tokio_timer; - -#[macro_use] -extern crate futures; -#[macro_use] -extern crate lazy_static; -#[macro_use] -extern crate log; -#[macro_use] -extern crate state_machine_future; - -#[cfg(unix)] -extern crate tokio_uds; - -use bytes::Bytes; -use futures::{Async, Future, Poll, Stream}; -use postgres_shared::rows::RowIndex; -use std::error::Error as StdError; -use std::fmt; -use std::sync::atomic::{AtomicUsize, Ordering}; - -#[doc(inline)] -pub use postgres_shared::stmt::Column; -#[doc(inline)] -pub use postgres_shared::{params, types}; -#[doc(inline)] -pub use postgres_shared::{CancelData, Notification}; - -use error::{DbError, Error}; -use params::ConnectParams; -use tls::TlsConnect; -use types::{FromSql, ToSql, Type}; - +//! An asynchronous, pipelined, PostgreSQL client. +//! +//! # Example +//! +//! ```no_run +//! use tokio_postgres::{NoTls, Error}; +//! +//! # #[cfg(not(feature = "runtime"))] fn main() {} +//! # #[cfg(feature = "runtime")] +//! #[tokio::main] // By default, tokio_postgres uses the tokio crate as its runtime. +//! async fn main() -> Result<(), Error> { +//! // Connect to the database. +//! let (client, connection) = +//! tokio_postgres::connect("host=localhost user=postgres", NoTls).await?; +//! +//! // The connection object performs the actual communication with the database, +//! // so spawn it off to run on its own. +//! tokio::spawn(async move { +//! if let Err(e) = connection.await { +//! eprintln!("connection error: {}", e); +//! } +//! }); +//! +//! // Now we can execute a simple statement that just returns its parameter. +//! let rows = client +//! .query("SELECT $1::TEXT", &[&"hello world"]) +//! .await?; +//! +//! // And then check that we got back the same string we sent over. +//! let value: &str = rows[0].get(0); +//! assert_eq!(value, "hello world"); +//! +//! Ok(()) +//! } +//! ``` +//! +//! # Behavior +//! +//! Calling a method like `Client::query` on its own does nothing. The associated request is not sent to the database +//! until the future returned by the method is first polled. Requests are executed in the order that they are first +//! polled, not in the order that their futures are created. +//! +//! # Pipelining +//! +//! The client supports *pipelined* requests. Pipelining can improve performance in use cases in which multiple, +//! independent queries need to be executed. In a traditional workflow, each query is sent to the server after the +//! previous query completes. In contrast, pipelining allows the client to send all of the queries to the server up +//! front, minimizing time spent by one side waiting for the other to finish sending data: +//! +//! ```not_rust +//! Sequential Pipelined +//! | Client | Server | | Client | Server | +//! |----------------|-----------------| |----------------|-----------------| +//! | send query 1 | | | send query 1 | | +//! | | process query 1 | | send query 2 | process query 1 | +//! | receive rows 1 | | | send query 3 | process query 2 | +//! | send query 2 | | | receive rows 1 | process query 3 | +//! | | process query 2 | | receive rows 2 | | +//! | receive rows 2 | | | receive rows 3 | | +//! | send query 3 | | +//! | | process query 3 | +//! | receive rows 3 | | +//! ``` +//! +//! In both cases, the PostgreSQL server is executing the queries sequentially - pipelining just allows both sides of +//! the connection to work concurrently when possible. +//! +//! Pipelining happens automatically when futures are polled concurrently (for example, by using the futures `join` +//! combinator): +//! +//! ```rust +//! use futures_util::future; +//! use std::future::Future; +//! use tokio_postgres::{Client, Error, Statement}; +//! +//! async fn pipelined_prepare( +//! client: &Client, +//! ) -> Result<(Statement, Statement), Error> +//! { +//! future::try_join( +//! client.prepare("SELECT * FROM foo"), +//! client.prepare("INSERT INTO bar (id, name) VALUES ($1, $2)") +//! ).await +//! } +//! ``` +//! +//! # Runtime +//! +//! The client works with arbitrary `AsyncRead + AsyncWrite` streams. Convenience APIs are provided to handle the +//! connection process, but these are gated by the `runtime` Cargo feature, which is enabled by default. If disabled, +//! all dependence on the tokio runtime is removed. +//! +//! # SSL/TLS support +//! +//! TLS support is implemented via external libraries. `Client::connect` and `Config::connect` take a TLS implementation +//! as an argument. The `NoTls` type in this crate can be used when TLS is not required. Otherwise, the +//! `postgres-openssl` and `postgres-native-tls` crates provide implementations backed by the `openssl` and `native-tls` +//! crates, respectively. +//! +//! # Features +//! +//! The following features can be enabled from `Cargo.toml`: +//! +//! | Feature | Description | Extra dependencies | Default | +//! | ------- | ----------- | ------------------ | ------- | +//! | `runtime` | Enable convenience API for the connection process based on the `tokio` crate. | [tokio](https://crates.io/crates/tokio) 1.0 with the features `net` and `time` | yes | +//! | `array-impls` | Enables `ToSql` and `FromSql` trait impls for arrays | - | no | +//! | `with-bit-vec-0_6` | Enable support for the `bit-vec` crate. | [bit-vec](https://crates.io/crates/bit-vec) 0.6 | no | +//! | `with-chrono-0_4` | Enable support for the `chrono` crate. | [chrono](https://crates.io/crates/chrono) 0.4 | no | +//! | `with-eui48-0_4` | Enable support for the 0.4 version of the `eui48` crate. This is deprecated and will be removed. | [eui48](https://crates.io/crates/eui48) 0.4 | no | +//! | `with-eui48-1` | Enable support for the 1.0 version of the `eui48` crate. | [eui48](https://crates.io/crates/eui48) 1.0 | no | +//! | `with-geo-types-0_6` | Enable support for the 0.6 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.6.0) 0.6 | no | +//! | `with-geo-types-0_7` | Enable support for the 0.7 version of the `geo-types` crate. | [geo-types](https://crates.io/crates/geo-types/0.7.0) 0.7 | no | +//! | `with-jiff-0_1` | Enable support for the 0.1 version of the `jiff` crate. | [jiff](https://crates.io/crates/jiff/0.1.0) 0.1 | no | +//! | `with-serde_json-1` | Enable support for the `serde_json` crate. | [serde_json](https://crates.io/crates/serde_json) 1.0 | no | +//! | `with-uuid-0_8` | Enable support for the `uuid` crate. | [uuid](https://crates.io/crates/uuid) 0.8 | no | +//! | `with-uuid-1` | Enable support for the `uuid` crate. | [uuid](https://crates.io/crates/uuid) 1.0 | no | +//! | `with-time-0_2` | Enable support for the 0.2 version of the `time` crate. | [time](https://crates.io/crates/time/0.2.0) 0.2 | no | +//! | `with-time-0_3` | Enable support for the 0.3 version of the `time` crate. | [time](https://crates.io/crates/time/0.3.0) 0.3 | no | +#![warn(rust_2018_idioms, clippy::all, missing_docs)] + +pub use crate::cancel_token::CancelToken; +pub use crate::client::Client; +pub use crate::config::Config; +pub use crate::connection::Connection; +pub use crate::copy_in::CopyInSink; +pub use crate::copy_out::CopyOutStream; +use crate::error::DbError; +pub use crate::error::Error; +pub use crate::generic_client::GenericClient; +pub use crate::portal::Portal; +pub use crate::query::RowStream; +pub use crate::row::{Row, SimpleQueryRow}; +pub use crate::simple_query::{SimpleColumn, SimpleQueryStream}; +#[cfg(feature = "runtime")] +pub use crate::socket::Socket; +pub use crate::statement::{Column, Statement}; +#[cfg(feature = "runtime")] +use crate::tls::MakeTlsConnect; +pub use crate::tls::NoTls; +pub use crate::to_statement::ToStatement; +pub use crate::transaction::Transaction; +pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; +use crate::types::ToSql; +use std::sync::Arc; + +pub mod binary_copy; +mod bind; +#[cfg(feature = "runtime")] +mod cancel_query; +mod cancel_query_raw; +mod cancel_token; +mod client; +mod codec; +pub mod config; +#[cfg(feature = "runtime")] +mod connect; +mod connect_raw; +#[cfg(feature = "runtime")] +mod connect_socket; +mod connect_tls; +mod connection; +mod copy_in; +mod copy_out; pub mod error; -mod proto; +mod generic_client; +#[cfg(not(target_arch = "wasm32"))] +mod keepalive; +mod maybe_tls_stream; +mod portal; +mod prepare; +mod query; +pub mod row; +mod simple_query; +#[cfg(feature = "runtime")] +mod socket; +mod statement; pub mod tls; - -fn next_statement() -> String { - static ID: AtomicUsize = AtomicUsize::new(0); - format!("s{}", ID.fetch_add(1, Ordering::SeqCst)) -} - -fn next_portal() -> String { - static ID: AtomicUsize = AtomicUsize::new(0); - format!("p{}", ID.fetch_add(1, Ordering::SeqCst)) -} - -pub enum TlsMode { - None, - Prefer(Box), - Require(Box), -} - -pub fn cancel_query(params: ConnectParams, tls: TlsMode, cancel_data: CancelData) -> CancelQuery { - CancelQuery(proto::CancelFuture::new(params, tls, cancel_data)) +mod to_statement; +mod transaction; +mod transaction_builder; +pub mod types; + +/// A convenience function which parses a connection string and connects to the database. +/// +/// See the documentation for [`Config`] for details on the connection string format. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +/// +/// [`Config`]: config/struct.Config.html +#[cfg(feature = "runtime")] +pub async fn connect( + config: &str, + tls: T, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + let config = config.parse::()?; + config.connect(tls).await } -pub fn connect(params: ConnectParams, tls: TlsMode) -> Handshake { - Handshake(proto::HandshakeFuture::new(params, tls)) +/// An asynchronous notification. +#[derive(Clone, Debug)] +pub struct Notification { + process_id: i32, + channel: String, + payload: String, } -pub struct Client(proto::Client); - -impl Client { - pub fn prepare(&mut self, query: &str) -> Prepare { - self.prepare_typed(query, &[]) - } - - pub fn prepare_typed(&mut self, query: &str, param_types: &[Type]) -> Prepare { - Prepare(self.0.prepare(next_statement(), query, param_types)) - } - - pub fn execute(&mut self, statement: &Statement, params: &[&ToSql]) -> Execute { - Execute(self.0.execute(&statement.0, params)) - } - - pub fn query(&mut self, statement: &Statement, params: &[&ToSql]) -> Query { - Query(self.0.query(&statement.0, params)) - } - - pub fn bind(&mut self, statement: &Statement, params: &[&ToSql]) -> Bind { - Bind(self.0.bind(&statement.0, next_portal(), params)) - } - - pub fn query_portal(&mut self, portal: &Portal, max_rows: i32) -> QueryPortal { - QueryPortal(self.0.query_portal(&portal.0, max_rows)) - } - - pub fn copy_in(&mut self, statement: &Statement, params: &[&ToSql], stream: S) -> CopyIn - where - S: Stream, - S::Item: AsRef<[u8]>, - // FIXME error type? - S::Error: Into>, - { - CopyIn(self.0.copy_in(&statement.0, params, stream)) - } - - pub fn copy_out(&mut self, statement: &Statement, params: &[&ToSql]) -> CopyOut { - CopyOut(self.0.copy_out(&statement.0, params)) +impl Notification { + /// The process ID of the notifying backend process. + pub fn process_id(&self) -> i32 { + self.process_id } - pub fn transaction(&mut self, future: T) -> Transaction - where - T: Future, - // FIXME error type? - T::Error: From, - { - Transaction(proto::TransactionFuture::new(self.0.clone(), future)) + /// The name of the channel that the notify has been raised on. + pub fn channel(&self) -> &str { + &self.channel } - pub fn batch_execute(&mut self, query: &str) -> BatchExecute { - BatchExecute(self.0.batch_execute(query)) - } -} - -#[must_use = "futures do nothing unless polled"] -pub struct Connection(proto::Connection); - -impl Connection { - pub fn cancel_data(&self) -> CancelData { - self.0.cancel_data() - } - - pub fn parameter(&self, name: &str) -> Option<&str> { - self.0.parameter(name) - } - - pub fn poll_message(&mut self) -> Poll, Error> { - self.0.poll_message() - } -} - -impl Future for Connection { - type Item = (); - type Error = Error; - - fn poll(&mut self) -> Poll<(), Error> { - self.0.poll() + /// The "payload" string passed from the notifying process. + pub fn payload(&self) -> &str { + &self.payload } } +/// An asynchronous message from the server. +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone)] +#[non_exhaustive] pub enum AsyncMessage { + /// A notice. + /// + /// Notices use the same format as errors, but aren't "errors" per-se. Notice(DbError), + /// A notification. + /// + /// Connections can subscribe to notifications with the `LISTEN` command. Notification(Notification), - #[doc(hidden)] - __NonExhaustive, -} - -#[must_use = "futures do nothing unless polled"] -pub struct CancelQuery(proto::CancelFuture); - -impl Future for CancelQuery { - type Item = (); - type Error = Error; - - fn poll(&mut self) -> Poll<(), Error> { - self.0.poll() - } -} - -#[must_use = "futures do nothing unless polled"] -pub struct Handshake(proto::HandshakeFuture); - -impl Future for Handshake { - type Item = (Client, Connection); - type Error = Error; - - fn poll(&mut self) -> Poll<(Client, Connection), Error> { - let (client, connection) = try_ready!(self.0.poll()); - - Ok(Async::Ready((Client(client), Connection(connection)))) - } -} - -#[must_use = "futures do nothing unless polled"] -pub struct Prepare(proto::PrepareFuture); - -impl Future for Prepare { - type Item = Statement; - type Error = Error; - - fn poll(&mut self) -> Poll { - let statement = try_ready!(self.0.poll()); - - Ok(Async::Ready(Statement(statement))) - } } -pub struct Statement(proto::Statement); - -impl Statement { - pub fn params(&self) -> &[Type] { - self.0.params() - } - - pub fn columns(&self) -> &[Column] { - self.0.columns() - } -} - -#[must_use = "futures do nothing unless polled"] -pub struct Execute(proto::ExecuteFuture); - -impl Future for Execute { - type Item = u64; - type Error = Error; - - fn poll(&mut self) -> Poll { - self.0.poll() - } -} - -#[must_use = "streams do nothing unless polled"] -pub struct Query(proto::QueryStream); - -impl Stream for Query { - type Item = Row; - type Error = Error; - - fn poll(&mut self) -> Poll, Error> { - match self.0.poll() { - Ok(Async::Ready(Some(row))) => Ok(Async::Ready(Some(Row(row)))), - Ok(Async::Ready(None)) => Ok(Async::Ready(None)), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => Err(e), - } - } -} - -#[must_use = "futures do nothing unless polled"] -pub struct Bind(proto::BindFuture); - -impl Future for Bind { - type Item = Portal; - type Error = Error; - - fn poll(&mut self) -> Poll { - match self.0.poll() { - Ok(Async::Ready(portal)) => Ok(Async::Ready(Portal(portal))), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => Err(e), - } - } +/// Message returned by the `SimpleQuery` stream. +#[derive(Debug)] +#[non_exhaustive] +pub enum SimpleQueryMessage { + /// A row of data. + Row(SimpleQueryRow), + /// A statement in the query has completed. + /// + /// The number of rows modified or selected is returned. + CommandComplete(u64), + /// Column values of the proceeding row values + RowDescription(Arc<[SimpleColumn]>), } -#[must_use = "streams do nothing unless polled"] -pub struct QueryPortal(proto::QueryStream); - -impl Stream for QueryPortal { - type Item = Row; - type Error = Error; - - fn poll(&mut self) -> Poll, Error> { - match self.0.poll() { - Ok(Async::Ready(Some(row))) => Ok(Async::Ready(Some(Row(row)))), - Ok(Async::Ready(None)) => Ok(Async::Ready(None)), - Ok(Async::NotReady) => Ok(Async::NotReady), - Err(e) => Err(e), - } - } -} - -pub struct Portal(proto::Portal); - -#[must_use = "futures do nothing unless polled"] -pub struct CopyIn(proto::CopyInFuture) -where - S: Stream, - S::Item: AsRef<[u8]>, - S::Error: Into>; - -impl Future for CopyIn -where - S: Stream>, - S::Error: Into>, -{ - type Item = u64; - type Error = Error; - - fn poll(&mut self) -> Poll { - self.0.poll() - } -} - -#[must_use = "streams do nothing unless polled"] -pub struct CopyOut(proto::CopyOutStream); - -impl Stream for CopyOut { - type Item = Bytes; - type Error = Error; - - fn poll(&mut self) -> Poll, Error> { - self.0.poll() - } -} - -pub struct Row(proto::Row); - -impl Row { - pub fn columns(&self) -> &[Column] { - self.0.columns() - } - - pub fn len(&self) -> usize { - self.0.len() - } - - pub fn get<'a, I, T>(&'a self, idx: I) -> T - where - I: RowIndex + fmt::Debug, - T: FromSql<'a>, - { - self.0.get(idx) - } - - pub fn try_get<'a, I, T>(&'a self, idx: I) -> Result, Error> - where - I: RowIndex, - T: FromSql<'a>, - { - self.0.try_get(idx) - } -} - -#[must_use = "futures do nothing unless polled"] -pub struct Transaction(proto::TransactionFuture) -where - T: Future, - T::Error: From; - -impl Future for Transaction -where - T: Future, - T::Error: From, -{ - type Item = T::Item; - type Error = T::Error; - - fn poll(&mut self) -> Poll { - self.0.poll() - } -} - -#[must_use = "futures do nothing unless polled"] -pub struct BatchExecute(proto::SimpleQueryFuture); - -impl Future for BatchExecute { - type Item = (); - type Error = Error; - - fn poll(&mut self) -> Poll<(), Error> { - self.0.poll() - } +fn slice_iter<'a>( + s: &'a [&'a (dyn ToSql + Sync)], +) -> impl ExactSizeIterator + 'a { + s.iter().map(|s| *s as _) } diff --git a/tokio-postgres/src/maybe_tls_stream.rs b/tokio-postgres/src/maybe_tls_stream.rs new file mode 100644 index 000000000..73b0c4721 --- /dev/null +++ b/tokio-postgres/src/maybe_tls_stream.rs @@ -0,0 +1,71 @@ +use crate::tls::{ChannelBinding, TlsStream}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub enum MaybeTlsStream { + Raw(S), + Tls(T), +} + +impl AsyncRead for MaybeTlsStream +where + S: AsyncRead + Unpin, + T: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + S: AsyncWrite + Unpin, + T: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx), + } + } +} + +impl TlsStream for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + fn channel_binding(&self) -> ChannelBinding { + match self { + MaybeTlsStream::Raw(_) => ChannelBinding::none(), + MaybeTlsStream::Tls(s) => s.channel_binding(), + } + } +} diff --git a/tokio-postgres/src/portal.rs b/tokio-postgres/src/portal.rs new file mode 100644 index 000000000..464d175da --- /dev/null +++ b/tokio-postgres/src/portal.rs @@ -0,0 +1,50 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::Statement; +use postgres_protocol::message::frontend; +use std::sync::{Arc, Weak}; + +struct Inner { + client: Weak, + name: String, + statement: Statement, +} + +impl Drop for Inner { + fn drop(&mut self) { + if let Some(client) = self.client.upgrade() { + let buf = client.with_buf(|buf| { + frontend::close(b'P', &self.name, buf).unwrap(); + frontend::sync(buf); + buf.split().freeze() + }); + let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } + } +} + +/// A portal. +/// +/// Portals can only be used with the connection that created them, and only exist for the duration of the transaction +/// in which they were created. +#[derive(Clone)] +pub struct Portal(Arc); + +impl Portal { + pub(crate) fn new(client: &Arc, name: String, statement: Statement) -> Portal { + Portal(Arc::new(Inner { + client: Arc::downgrade(client), + name, + statement, + })) + } + + pub(crate) fn name(&self) -> &str { + &self.0.name + } + + pub(crate) fn statement(&self) -> &Statement { + &self.0.statement + } +} diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs new file mode 100644 index 000000000..1d9bacb16 --- /dev/null +++ b/tokio-postgres/src/prepare.rs @@ -0,0 +1,267 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::error::SqlState; +use crate::types::{Field, Kind, Oid, Type}; +use crate::{query, slice_iter}; +use crate::{Column, Error, Statement}; +use bytes::Bytes; +use fallible_iterator::FallibleIterator; +use futures_util::{pin_mut, TryStreamExt}; +use log::debug; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +const TYPEINFO_QUERY: &str = "\ +SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid +FROM pg_catalog.pg_type t +LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid +INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid +WHERE t.oid = $1 +"; + +// Range types weren't added until Postgres 9.2, so pg_range may not exist +const TYPEINFO_FALLBACK_QUERY: &str = "\ +SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid +FROM pg_catalog.pg_type t +INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid +WHERE t.oid = $1 +"; + +const TYPEINFO_ENUM_QUERY: &str = "\ +SELECT enumlabel +FROM pg_catalog.pg_enum +WHERE enumtypid = $1 +ORDER BY enumsortorder +"; + +// Postgres 9.0 didn't have enumsortorder +const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\ +SELECT enumlabel +FROM pg_catalog.pg_enum +WHERE enumtypid = $1 +ORDER BY oid +"; + +const TYPEINFO_COMPOSITE_QUERY: &str = "\ +SELECT attname, atttypid +FROM pg_catalog.pg_attribute +WHERE attrelid = $1 +AND NOT attisdropped +AND attnum > 0 +ORDER BY attnum +"; + +static NEXT_ID: AtomicUsize = AtomicUsize::new(0); + +pub async fn prepare( + client: &Arc, + query: &str, + types: &[Type], +) -> Result { + let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); + let buf = encode(client, &name, query, types)?; + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let parameter_description = match responses.next().await? { + Message::ParameterDescription(body) => body, + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + let mut parameters = vec![]; + let mut it = parameter_description.parameters(); + while let Some(oid) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, oid).await?; + parameters.push(type_); + } + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, field.type_oid()).await?; + let column = Column { + name: field.name().to_string(), + table_oid: Some(field.table_oid()).filter(|n| *n != 0), + column_id: Some(field.column_id()).filter(|n| *n != 0), + r#type: type_, + }; + columns.push(column); + } + } + + Ok(Statement::new(client, name, parameters, columns)) +} + +fn prepare_rec<'a>( + client: &'a Arc, + query: &'a str, + types: &'a [Type], +) -> Pin> + 'a + Send>> { + Box::pin(prepare(client, query, types)) +} + +fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { + if types.is_empty() { + debug!("preparing query {}: {}", name, query); + } else { + debug!("preparing query {} with types {:?}: {}", name, types, query); + } + + client.with_buf(|buf| { + frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?; + frontend::describe(b'S', name, buf).map_err(Error::encode)?; + frontend::sync(buf); + Ok(buf.split().freeze()) + }) +} + +pub(crate) async fn get_type(client: &Arc, oid: Oid) -> Result { + if let Some(type_) = Type::from_oid(oid) { + return Ok(type_); + } + + if let Some(type_) = client.type_(oid) { + return Ok(type_); + } + + let stmt = typeinfo_statement(client).await?; + + let rows = query::query(client, stmt, slice_iter(&[&oid])).await?; + pin_mut!(rows); + + let row = match rows.try_next().await? { + Some(row) => row, + None => return Err(Error::unexpected_message()), + }; + + let name: String = row.try_get(0)?; + let type_: i8 = row.try_get(1)?; + let elem_oid: Oid = row.try_get(2)?; + let rngsubtype: Option = row.try_get(3)?; + let basetype: Oid = row.try_get(4)?; + let schema: String = row.try_get(5)?; + let relid: Oid = row.try_get(6)?; + + let kind = if type_ == b'e' as i8 { + let variants = get_enum_variants(client, oid).await?; + Kind::Enum(variants) + } else if type_ == b'p' as i8 { + Kind::Pseudo + } else if basetype != 0 { + let type_ = get_type_rec(client, basetype).await?; + Kind::Domain(type_) + } else if elem_oid != 0 { + let type_ = get_type_rec(client, elem_oid).await?; + Kind::Array(type_) + } else if relid != 0 { + let fields = get_composite_fields(client, relid).await?; + Kind::Composite(fields) + } else if let Some(rngsubtype) = rngsubtype { + let type_ = get_type_rec(client, rngsubtype).await?; + Kind::Range(type_) + } else { + Kind::Simple + }; + + let type_ = Type::new(name, oid, kind, schema); + client.set_type(oid, &type_); + + Ok(type_) +} + +fn get_type_rec<'a>( + client: &'a Arc, + oid: Oid, +) -> Pin> + Send + 'a>> { + Box::pin(get_type(client, oid)) +} + +async fn typeinfo_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo() { + return Ok(stmt); + } + + let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await { + Ok(stmt) => stmt, + Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => { + prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await? + } + Err(e) => return Err(e), + }; + + client.set_typeinfo(&stmt); + Ok(stmt) +} + +async fn get_enum_variants(client: &Arc, oid: Oid) -> Result, Error> { + let stmt = typeinfo_enum_statement(client).await?; + + query::query(client, stmt, slice_iter(&[&oid])) + .await? + .and_then(|row| async move { row.try_get(0) }) + .try_collect() + .await +} + +async fn typeinfo_enum_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo_enum() { + return Ok(stmt); + } + + let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await { + Ok(stmt) => stmt, + Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => { + prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await? + } + Err(e) => return Err(e), + }; + + client.set_typeinfo_enum(&stmt); + Ok(stmt) +} + +async fn get_composite_fields(client: &Arc, oid: Oid) -> Result, Error> { + let stmt = typeinfo_composite_statement(client).await?; + + let rows = query::query(client, stmt, slice_iter(&[&oid])) + .await? + .try_collect::>() + .await?; + + let mut fields = vec![]; + for row in rows { + let name = row.try_get(0)?; + let oid = row.try_get(1)?; + let type_ = get_type_rec(client, oid).await?; + fields.push(Field::new(name, type_)); + } + + Ok(fields) +} + +async fn typeinfo_composite_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo_composite() { + return Ok(stmt); + } + + let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?; + + client.set_typeinfo_composite(&stmt); + Ok(stmt) +} diff --git a/tokio-postgres/src/proto/bind.rs b/tokio-postgres/src/proto/bind.rs deleted file mode 100644 index 00f78a0f4..000000000 --- a/tokio-postgres/src/proto/bind.rs +++ /dev/null @@ -1,72 +0,0 @@ -use futures::sync::mpsc; -use futures::{Poll, Stream}; -use postgres_protocol::message::backend::Message; -use proto::client::{Client, PendingRequest}; -use proto::portal::Portal; -use proto::statement::Statement; -use state_machine_future::RentToOwn; -use Error; - -#[derive(StateMachineFuture)] -pub enum Bind { - #[state_machine_future(start, transitions(ReadBindComplete))] - Start { - client: Client, - request: PendingRequest, - name: String, - statement: Statement, - }, - #[state_machine_future(transitions(Finished))] - ReadBindComplete { - receiver: mpsc::Receiver, - client: Client, - name: String, - statement: Statement, - }, - #[state_machine_future(ready)] - Finished(Portal), - #[state_machine_future(error)] - Failed(Error), -} - -impl PollBind for Bind { - fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let state = state.take(); - let receiver = state.client.send(state.request)?; - - transition!(ReadBindComplete { - receiver, - client: state.client, - name: state.name, - statement: state.statement, - }) - } - - fn poll_read_bind_complete<'a>( - state: &'a mut RentToOwn<'a, ReadBindComplete>, - ) -> Poll { - let message = try_ready_receive!(state.receiver.poll()); - let state = state.take(); - - match message { - Some(Message::BindComplete) => transition!(Finished(Portal::new( - state.client.downgrade(), - state.name, - state.statement, - ))), - Some(_) => Err(Error::unexpected_message()), - None => Err(Error::closed()), - } - } -} - -impl BindFuture { - pub fn new( - client: Client, - request: PendingRequest, - name: String, - statement: Statement, - ) -> BindFuture { - Bind::start(client, request, name, statement) - } -} diff --git a/tokio-postgres/src/proto/cancel.rs b/tokio-postgres/src/proto/cancel.rs deleted file mode 100644 index 138fb9bb6..000000000 --- a/tokio-postgres/src/proto/cancel.rs +++ /dev/null @@ -1,69 +0,0 @@ -use futures::{Future, Poll}; -use postgres_protocol::message::frontend; -use state_machine_future::RentToOwn; -use tokio_io::io::{self, Flush, WriteAll}; - -use error::Error; -use params::ConnectParams; -use proto::connect::ConnectFuture; -use tls::TlsStream; -use {CancelData, TlsMode}; - -#[derive(StateMachineFuture)] -pub enum Cancel { - #[state_machine_future(start, transitions(SendingCancel))] - Start { - future: ConnectFuture, - cancel_data: CancelData, - }, - #[state_machine_future(transitions(FlushingCancel))] - SendingCancel { - future: WriteAll, Vec>, - }, - #[state_machine_future(transitions(Finished))] - FlushingCancel { future: Flush> }, - #[state_machine_future(ready)] - Finished(()), - #[state_machine_future(error)] - Failed(Error), -} - -impl PollCancel for Cancel { - fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let stream = try_ready!(state.future.poll()); - - let mut buf = vec![]; - frontend::cancel_request( - state.cancel_data.process_id, - state.cancel_data.secret_key, - &mut buf, - ); - - transition!(SendingCancel { - future: io::write_all(stream, buf), - }) - } - - fn poll_sending_cancel<'a>( - state: &'a mut RentToOwn<'a, SendingCancel>, - ) -> Poll { - let (stream, _) = try_ready_closed!(state.future.poll()); - - transition!(FlushingCancel { - future: io::flush(stream), - }) - } - - fn poll_flushing_cancel<'a>( - state: &'a mut RentToOwn<'a, FlushingCancel>, - ) -> Poll { - try_ready_closed!(state.future.poll()); - transition!(Finished(())) - } -} - -impl CancelFuture { - pub fn new(params: ConnectParams, mode: TlsMode, cancel_data: CancelData) -> CancelFuture { - Cancel::start(ConnectFuture::new(params, mode), cancel_data) - } -} diff --git a/tokio-postgres/src/proto/client.rs b/tokio-postgres/src/proto/client.rs deleted file mode 100644 index b7060128f..000000000 --- a/tokio-postgres/src/proto/client.rs +++ /dev/null @@ -1,251 +0,0 @@ -use antidote::Mutex; -use futures::sync::mpsc; -use futures::{AsyncSink, Sink, Stream}; -use postgres_protocol; -use postgres_protocol::message::backend::Message; -use postgres_protocol::message::frontend; -use std::collections::HashMap; -use std::error::Error as StdError; -use std::sync::{Arc, Weak}; - -use proto::bind::BindFuture; -use proto::connection::{Request, RequestMessages}; -use proto::copy_in::{CopyInFuture, CopyInReceiver, CopyMessage}; -use proto::copy_out::CopyOutStream; -use proto::execute::ExecuteFuture; -use proto::portal::Portal; -use proto::prepare::PrepareFuture; -use proto::query::QueryStream; -use proto::simple_query::SimpleQueryFuture; -use proto::statement::Statement; -use types::{IsNull, Oid, ToSql, Type}; -use Error; - -pub struct PendingRequest(Result); - -pub struct WeakClient(Weak); - -impl WeakClient { - pub fn upgrade(&self) -> Option { - self.0.upgrade().map(Client) - } -} - -struct State { - types: HashMap, - typeinfo_query: Option, - typeinfo_enum_query: Option, - typeinfo_composite_query: Option, -} - -struct Inner { - state: Mutex, - sender: mpsc::UnboundedSender, -} - -#[derive(Clone)] -pub struct Client(Arc); - -impl Client { - pub fn new(sender: mpsc::UnboundedSender) -> Client { - Client(Arc::new(Inner { - state: Mutex::new(State { - types: HashMap::new(), - typeinfo_query: None, - typeinfo_enum_query: None, - typeinfo_composite_query: None, - }), - sender, - })) - } - - pub fn downgrade(&self) -> WeakClient { - WeakClient(Arc::downgrade(&self.0)) - } - - pub fn cached_type(&self, oid: Oid) -> Option { - self.0.state.lock().types.get(&oid).cloned() - } - - pub fn cache_type(&self, ty: &Type) { - self.0.state.lock().types.insert(ty.oid(), ty.clone()); - } - - pub fn typeinfo_query(&self) -> Option { - self.0.state.lock().typeinfo_query.clone() - } - - pub fn set_typeinfo_query(&self, statement: &Statement) { - self.0.state.lock().typeinfo_query = Some(statement.clone()); - } - - pub fn typeinfo_enum_query(&self) -> Option { - self.0.state.lock().typeinfo_enum_query.clone() - } - - pub fn set_typeinfo_enum_query(&self, statement: &Statement) { - self.0.state.lock().typeinfo_enum_query = Some(statement.clone()); - } - - pub fn typeinfo_composite_query(&self) -> Option { - self.0.state.lock().typeinfo_composite_query.clone() - } - - pub fn set_typeinfo_composite_query(&self, statement: &Statement) { - self.0.state.lock().typeinfo_composite_query = Some(statement.clone()); - } - - pub fn send(&self, request: PendingRequest) -> Result, Error> { - let messages = request.0?; - let (sender, receiver) = mpsc::channel(0); - self.0 - .sender - .unbounded_send(Request { messages, sender }) - .map(|_| receiver) - .map_err(|_| Error::closed()) - } - - pub fn batch_execute(&self, query: &str) -> SimpleQueryFuture { - let pending = self.pending(|buf| { - frontend::query(query, buf).map_err(Error::parse)?; - Ok(()) - }); - - SimpleQueryFuture::new(self.clone(), pending) - } - - pub fn prepare(&self, name: String, query: &str, param_types: &[Type]) -> PrepareFuture { - let pending = self.pending(|buf| { - frontend::parse(&name, query, param_types.iter().map(|t| t.oid()), buf) - .map_err(Error::parse)?; - frontend::describe(b'S', &name, buf).map_err(Error::parse)?; - frontend::sync(buf); - Ok(()) - }); - - PrepareFuture::new(self.clone(), pending, name) - } - - pub fn execute(&self, statement: &Statement, params: &[&ToSql]) -> ExecuteFuture { - let pending = PendingRequest( - self.excecute_message(statement, params) - .map(RequestMessages::Single), - ); - ExecuteFuture::new(self.clone(), pending, statement.clone()) - } - - pub fn query(&self, statement: &Statement, params: &[&ToSql]) -> QueryStream { - let pending = PendingRequest( - self.excecute_message(statement, params) - .map(RequestMessages::Single), - ); - QueryStream::new(self.clone(), pending, statement.clone()) - } - - pub fn bind(&self, statement: &Statement, name: String, params: &[&ToSql]) -> BindFuture { - let mut buf = self.bind_message(statement, &name, params); - if let Ok(ref mut buf) = buf { - frontend::sync(buf); - } - let pending = PendingRequest(buf.map(RequestMessages::Single)); - BindFuture::new(self.clone(), pending, name, statement.clone()) - } - - pub fn query_portal(&self, portal: &Portal, rows: i32) -> QueryStream { - let pending = self.pending(|buf| { - frontend::execute(portal.name(), rows, buf).map_err(Error::parse)?; - frontend::sync(buf); - Ok(()) - }); - QueryStream::new(self.clone(), pending, portal.clone()) - } - - pub fn copy_in(&self, statement: &Statement, params: &[&ToSql], stream: S) -> CopyInFuture - where - S: Stream, - S::Item: AsRef<[u8]>, - S::Error: Into>, - { - let (mut sender, receiver) = mpsc::channel(0); - let pending = PendingRequest(self.excecute_message(statement, params).map(|buf| { - match sender.start_send(CopyMessage::Data(buf)) { - Ok(AsyncSink::Ready) => {} - _ => unreachable!("channel should have capacity"), - } - RequestMessages::CopyIn { - receiver: CopyInReceiver::new(receiver), - pending_message: None, - } - })); - CopyInFuture::new(self.clone(), pending, statement.clone(), stream, sender) - } - - pub fn copy_out(&self, statement: &Statement, params: &[&ToSql]) -> CopyOutStream { - let pending = PendingRequest( - self.excecute_message(statement, params) - .map(RequestMessages::Single), - ); - CopyOutStream::new(self.clone(), pending, statement.clone()) - } - - pub fn close_statement(&self, name: &str) { - self.close(b'S', name) - } - - pub fn close_portal(&self, name: &str) { - self.close(b'P', name) - } - - fn close(&self, ty: u8, name: &str) { - let mut buf = vec![]; - frontend::close(ty, name, &mut buf).expect("statement name not valid"); - frontend::sync(&mut buf); - let (sender, _) = mpsc::channel(0); - let _ = self.0.sender.unbounded_send(Request { - messages: RequestMessages::Single(buf), - sender, - }); - } - - fn bind_message( - &self, - statement: &Statement, - name: &str, - params: &[&ToSql], - ) -> Result, Error> { - let mut buf = vec![]; - let r = frontend::bind( - name, - statement.name(), - Some(1), - params.iter().zip(statement.params()), - |(param, ty), buf| match param.to_sql_checked(ty, buf) { - Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), - Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), - Err(e) => Err(e), - }, - Some(1), - &mut buf, - ); - match r { - Ok(()) => Ok(buf), - Err(frontend::BindError::Conversion(e)) => return Err(Error::to_sql(e)), - Err(frontend::BindError::Serialization(e)) => return Err(Error::encode(e)), - } - } - - fn excecute_message(&self, statement: &Statement, params: &[&ToSql]) -> Result, Error> { - let mut buf = self.bind_message(statement, "", params)?; - frontend::execute("", 0, &mut buf).map_err(Error::parse)?; - frontend::sync(&mut buf); - Ok(buf) - } - - fn pending(&self, messages: F) -> PendingRequest - where - F: FnOnce(&mut Vec) -> Result<(), Error>, - { - let mut buf = vec![]; - PendingRequest(messages(&mut buf).map(|()| RequestMessages::Single(buf))) - } -} diff --git a/tokio-postgres/src/proto/codec.rs b/tokio-postgres/src/proto/codec.rs deleted file mode 100644 index 4e37ab603..000000000 --- a/tokio-postgres/src/proto/codec.rs +++ /dev/null @@ -1,25 +0,0 @@ -use bytes::BytesMut; -use postgres_protocol::message::backend; -use std::io; -use tokio_codec::{Decoder, Encoder}; - -pub struct PostgresCodec; - -impl Encoder for PostgresCodec { - type Item = Vec; - type Error = io::Error; - - fn encode(&mut self, item: Vec, dst: &mut BytesMut) -> Result<(), io::Error> { - dst.extend_from_slice(&item); - Ok(()) - } -} - -impl Decoder for PostgresCodec { - type Item = backend::Message; - type Error = io::Error; - - fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { - backend::Message::parse(src) - } -} diff --git a/tokio-postgres/src/proto/connect.rs b/tokio-postgres/src/proto/connect.rs deleted file mode 100644 index 00da8117e..000000000 --- a/tokio-postgres/src/proto/connect.rs +++ /dev/null @@ -1,288 +0,0 @@ -use futures::{Async, Future, Poll}; -use futures_cpupool::{CpuFuture, CpuPool}; -use postgres_protocol::message::frontend; -use state_machine_future::RentToOwn; -use std::error::Error as StdError; -use std::io; -use std::net::{SocketAddr, ToSocketAddrs}; -use std::time::{Duration, Instant}; -use std::vec; -use tokio_io::io::{read_exact, write_all, ReadExact, WriteAll}; -use tokio_tcp::{self, TcpStream}; -use tokio_timer::Delay; - -#[cfg(unix)] -use tokio_uds::{self, UnixStream}; - -use params::{ConnectParams, Host}; -use proto::socket::Socket; -use tls::{self, TlsConnect, TlsStream}; -use {Error, TlsMode}; - -lazy_static! { - static ref DNS_POOL: CpuPool = CpuPool::new(2); -} - -#[derive(StateMachineFuture)] -pub enum Connect { - #[state_machine_future(start)] - #[cfg_attr( - unix, - state_machine_future(transitions(ResolvingDns, ConnectingUnix)) - )] - #[cfg_attr(not(unix), state_machine_future(transitions(ResolvingDns)))] - Start { params: ConnectParams, tls: TlsMode }, - #[state_machine_future(transitions(ConnectingTcp))] - ResolvingDns { - future: CpuFuture, io::Error>, - timeout: Option, - params: ConnectParams, - tls: TlsMode, - }, - #[state_machine_future(transitions(PreparingSsl))] - ConnectingTcp { - addrs: vec::IntoIter, - future: tokio_tcp::ConnectFuture, - timeout: Option<(Duration, Delay)>, - params: ConnectParams, - tls: TlsMode, - }, - #[cfg(unix)] - #[state_machine_future(transitions(PreparingSsl))] - ConnectingUnix { - future: tokio_uds::ConnectFuture, - timeout: Option, - params: ConnectParams, - tls: TlsMode, - }, - #[state_machine_future(transitions(Ready, SendingSsl))] - PreparingSsl { - socket: Socket, - params: ConnectParams, - tls: TlsMode, - }, - #[state_machine_future(transitions(ReadingSsl))] - SendingSsl { - future: WriteAll>, - params: ConnectParams, - connector: Box, - required: bool, - }, - #[state_machine_future(transitions(ConnectingTls, Ready))] - ReadingSsl { - future: ReadExact, - params: ConnectParams, - connector: Box, - required: bool, - }, - #[state_machine_future(transitions(Ready))] - ConnectingTls { - future: - Box, Error = Box> + Sync + Send>, - params: ConnectParams, - }, - #[state_machine_future(ready)] - Ready(Box), - #[state_machine_future(error)] - Failed(Error), -} - -impl PollConnect for Connect { - fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let state = state.take(); - - let timeout = state.params.connect_timeout(); - let port = state.params.port(); - - match state.params.host().clone() { - Host::Tcp(host) => transition!(ResolvingDns { - future: DNS_POOL.spawn_fn(move || (&*host, port).to_socket_addrs()), - params: state.params, - tls: state.tls, - timeout, - }), - #[cfg(unix)] - Host::Unix(mut path) => { - path.push(format!(".s.PGSQL.{}", port)); - transition!(ConnectingUnix { - future: UnixStream::connect(path), - timeout: timeout.map(|t| Delay::new(Instant::now() + t)), - params: state.params, - tls: state.tls, - }) - }, - #[cfg(not(unix))] - Host::Unix(_) => { - Err(Error::connect(io::Error::new( - io::ErrorKind::Other, - "unix sockets are not supported on this platform", - ))) - }, - } - } - - fn poll_resolving_dns<'a>( - state: &'a mut RentToOwn<'a, ResolvingDns>, - ) -> Poll { - let mut addrs = try_ready!(state.future.poll().map_err(Error::connect)); - let state = state.take(); - - let addr = match addrs.next() { - Some(addr) => addr, - None => { - return Err(Error::connect(io::Error::new( - io::ErrorKind::Other, - "resolved to 0 addresses", - ))) - } - }; - - transition!(ConnectingTcp { - addrs, - future: TcpStream::connect(&addr), - timeout: state.timeout.map(|t| (t, Delay::new(Instant::now() + t))), - params: state.params, - tls: state.tls, - }) - } - - fn poll_connecting_tcp<'a>( - state: &'a mut RentToOwn<'a, ConnectingTcp>, - ) -> Poll { - let socket = loop { - let error = match state.future.poll() { - Ok(Async::Ready(socket)) => break socket, - Ok(Async::NotReady) => match state.timeout { - Some((_, ref mut delay)) => { - try_ready!(delay.poll().map_err(Error::timer)); - io::Error::new(io::ErrorKind::TimedOut, "connection timed out") - } - None => return Ok(Async::NotReady), - }, - Err(e) => e, - }; - - let addr = match state.addrs.next() { - Some(addr) => addr, - None => return Err(Error::connect(error)), - }; - - state.future = TcpStream::connect(&addr); - if let Some((timeout, ref mut delay)) = state.timeout { - delay.reset(Instant::now() + timeout); - } - }; - - // Our read/write patterns may trigger Nagle's algorithm since we're pipelining which - // we don't want. Each individual write should be a full command we want the backend to - // see immediately. - socket.set_nodelay(true).map_err(Error::connect)?; - - let state = state.take(); - transition!(PreparingSsl { - socket: Socket::Tcp(socket), - params: state.params, - tls: state.tls, - }) - } - - #[cfg(unix)] - fn poll_connecting_unix<'a>( - state: &'a mut RentToOwn<'a, ConnectingUnix>, - ) -> Poll { - match state.future.poll().map_err(Error::connect)? { - Async::Ready(socket) => { - let state = state.take(); - transition!(PreparingSsl { - socket: Socket::Unix(socket), - params: state.params, - tls: state.tls, - }) - } - Async::NotReady => match state.timeout { - Some(ref mut delay) => { - try_ready!(delay.poll().map_err(Error::timer)); - Err(Error::connect(io::Error::new( - io::ErrorKind::TimedOut, - "connection timed out", - ))) - } - None => Ok(Async::NotReady), - }, - } - } - - fn poll_preparing_ssl<'a>( - state: &'a mut RentToOwn<'a, PreparingSsl>, - ) -> Poll { - let state = state.take(); - - let (connector, required) = match state.tls { - TlsMode::None => { - transition!(Ready(Box::new(state.socket))); - } - TlsMode::Prefer(connector) => (connector, false), - TlsMode::Require(connector) => (connector, true), - }; - - let mut buf = vec![]; - frontend::ssl_request(&mut buf); - transition!(SendingSsl { - future: write_all(state.socket, buf), - params: state.params, - connector, - required, - }) - } - - fn poll_sending_ssl<'a>( - state: &'a mut RentToOwn<'a, SendingSsl>, - ) -> Poll { - let (stream, _) = try_ready_closed!(state.future.poll()); - let state = state.take(); - transition!(ReadingSsl { - future: read_exact(stream, [0]), - params: state.params, - connector: state.connector, - required: state.required, - }) - } - - fn poll_reading_ssl<'a>( - state: &'a mut RentToOwn<'a, ReadingSsl>, - ) -> Poll { - let (stream, buf) = try_ready_closed!(state.future.poll()); - let state = state.take(); - - match buf[0] { - b'S' => { - let future = match state.params.host() { - Host::Tcp(domain) => state.connector.connect(domain, tls::Socket(stream)), - Host::Unix(_) => { - return Err(Error::tls("TLS over unix sockets not supported".into())) - } - }; - transition!(ConnectingTls { - future, - params: state.params, - }) - } - b'N' if !state.required => transition!(Ready(Box::new(stream))), - b'N' => Err(Error::tls("TLS was required but not supported".into())), - _ => Err(Error::unexpected_message()), - } - } - - fn poll_connecting_tls<'a>( - state: &'a mut RentToOwn<'a, ConnectingTls>, - ) -> Poll { - let stream = try_ready!(state.future.poll().map_err(Error::tls)); - transition!(Ready(stream)) - } -} - -impl ConnectFuture { - pub fn new(params: ConnectParams, tls: TlsMode) -> ConnectFuture { - Connect::start(params, tls) - } -} diff --git a/tokio-postgres/src/proto/connection.rs b/tokio-postgres/src/proto/connection.rs deleted file mode 100644 index 562ec6ee3..000000000 --- a/tokio-postgres/src/proto/connection.rs +++ /dev/null @@ -1,306 +0,0 @@ -use futures::sync::mpsc; -use futures::{Async, AsyncSink, Future, Poll, Sink, Stream}; -use postgres_protocol::message::backend::Message; -use postgres_protocol::message::frontend; -use std::collections::{HashMap, VecDeque}; -use std::io; -use tokio_codec::Framed; - -use proto::codec::PostgresCodec; -use proto::copy_in::CopyInReceiver; -use tls::TlsStream; -use {AsyncMessage, CancelData, Notification}; -use {DbError, Error}; - -pub enum RequestMessages { - Single(Vec), - CopyIn { - receiver: CopyInReceiver, - pending_message: Option>, - }, -} - -pub struct Request { - pub messages: RequestMessages, - pub sender: mpsc::Sender, -} - -#[derive(PartialEq, Debug)] -enum State { - Active, - Terminating, - Closing, -} - -pub struct Connection { - stream: Framed, PostgresCodec>, - cancel_data: CancelData, - parameters: HashMap, - receiver: mpsc::UnboundedReceiver, - pending_request: Option, - pending_response: Option, - responses: VecDeque>, - state: State, -} - -impl Connection { - pub fn new( - stream: Framed, PostgresCodec>, - cancel_data: CancelData, - parameters: HashMap, - receiver: mpsc::UnboundedReceiver, - ) -> Connection { - Connection { - stream, - cancel_data, - parameters, - receiver, - pending_request: None, - pending_response: None, - responses: VecDeque::new(), - state: State::Active, - } - } - - pub fn cancel_data(&self) -> CancelData { - self.cancel_data - } - - pub fn parameter(&self, name: &str) -> Option<&str> { - self.parameters.get(name).map(|s| &**s) - } - - fn poll_response(&mut self) -> Poll, io::Error> { - if let Some(message) = self.pending_response.take() { - trace!("retrying pending response"); - return Ok(Async::Ready(Some(message))); - } - - self.stream.poll() - } - - fn poll_read(&mut self) -> Result, Error> { - if self.state != State::Active { - trace!("poll_read: done"); - return Ok(None); - } - - loop { - let message = match self.poll_response().map_err(Error::io)? { - Async::Ready(Some(message)) => message, - Async::Ready(None) => { - return Err(Error::closed()); - } - Async::NotReady => { - trace!("poll_read: waiting on response"); - return Ok(None); - } - }; - - let message = match message { - Message::NoticeResponse(body) => { - let error = DbError::new(&mut body.fields()).map_err(Error::parse)?; - return Ok(Some(AsyncMessage::Notice(error))); - } - Message::NotificationResponse(body) => { - let notification = Notification { - process_id: body.process_id(), - channel: body.channel().map_err(Error::parse)?.to_string(), - payload: body.message().map_err(Error::parse)?.to_string(), - }; - return Ok(Some(AsyncMessage::Notification(notification))); - } - Message::ParameterStatus(body) => { - self.parameters.insert( - body.name().map_err(Error::parse)?.to_string(), - body.value().map_err(Error::parse)?.to_string(), - ); - continue; - } - m => m, - }; - - let mut sender = match self.responses.pop_front() { - Some(sender) => sender, - None => match message { - Message::ErrorResponse(error) => return Err(Error::db(error)), - _ => return Err(Error::unexpected_message()), - }, - }; - - let request_complete = match message { - Message::ReadyForQuery(_) => true, - _ => false, - }; - - match sender.start_send(message) { - // if the receiver's hung up we still need to page through the rest of the messages - // designated to it - Ok(AsyncSink::Ready) | Err(_) => { - if !request_complete { - self.responses.push_front(sender); - } - } - Ok(AsyncSink::NotReady(message)) => { - self.responses.push_front(sender); - self.pending_response = Some(message); - trace!("poll_read: waiting on sender"); - return Ok(None); - } - } - } - } - - fn poll_request(&mut self) -> Poll, Error> { - if let Some(message) = self.pending_request.take() { - trace!("retrying pending request"); - return Ok(Async::Ready(Some(message))); - } - - match try_ready_receive!(self.receiver.poll()) { - Some(request) => { - trace!("polled new request"); - self.responses.push_back(request.sender); - Ok(Async::Ready(Some(request.messages))) - } - None => Ok(Async::Ready(None)), - } - } - - fn poll_write(&mut self) -> Result { - loop { - if self.state == State::Closing { - trace!("poll_write: done"); - return Ok(false); - } - - let request = match self.poll_request()? { - Async::Ready(Some(request)) => request, - Async::Ready(None) if self.responses.is_empty() && self.state == State::Active => { - trace!("poll_write: at eof, terminating"); - self.state = State::Terminating; - let mut request = vec![]; - frontend::terminate(&mut request); - RequestMessages::Single(request) - } - Async::Ready(None) => { - trace!( - "poll_write: at eof, pending responses {}", - self.responses.len(), - ); - return Ok(true); - } - Async::NotReady => { - trace!("poll_write: waiting on request"); - return Ok(true); - } - }; - - match request { - RequestMessages::Single(request) => { - match self.stream.start_send(request).map_err(Error::io)? { - AsyncSink::Ready => { - if self.state == State::Terminating { - trace!("poll_write: sent eof, closing"); - self.state = State::Closing; - } - } - AsyncSink::NotReady(request) => { - trace!("poll_write: waiting on socket"); - self.pending_request = Some(RequestMessages::Single(request)); - return Ok(false); - } - } - } - RequestMessages::CopyIn { - mut receiver, - mut pending_message, - } => { - let message = match pending_message.take() { - Some(message) => message, - None => match receiver.poll() { - Ok(Async::Ready(Some(message))) => message, - Ok(Async::Ready(None)) => { - trace!("poll_write: finished copy_in request"); - continue; - } - Ok(Async::NotReady) => { - trace!("poll_write: waiting on copy_in stream"); - self.pending_request = Some(RequestMessages::CopyIn { - receiver, - pending_message, - }); - return Ok(true); - } - Err(()) => unreachable!("mpsc::Receiver doesn't return errors"), - }, - }; - - match self.stream.start_send(message).map_err(Error::io)? { - AsyncSink::Ready => { - self.pending_request = Some(RequestMessages::CopyIn { - receiver, - pending_message: None, - }); - } - AsyncSink::NotReady(message) => { - trace!("poll_write: waiting on socket"); - self.pending_request = Some(RequestMessages::CopyIn { - receiver, - pending_message: Some(message), - }); - return Ok(false); - } - }; - } - } - } - } - - fn poll_flush(&mut self) -> Result<(), Error> { - match self.stream.poll_complete().map_err(Error::io)? { - Async::Ready(()) => trace!("poll_flush: flushed"), - Async::NotReady => trace!("poll_flush: waiting on socket"), - } - Ok(()) - } - - fn poll_shutdown(&mut self) -> Poll<(), Error> { - if self.state != State::Closing { - return Ok(Async::NotReady); - } - - match self.stream.close().map_err(Error::io)? { - Async::Ready(()) => { - trace!("poll_shutdown: complete"); - Ok(Async::Ready(())) - } - Async::NotReady => { - trace!("poll_shutdown: waiting on socket"); - Ok(Async::NotReady) - } - } - } - - pub fn poll_message(&mut self) -> Poll, Error> { - let message = self.poll_read()?; - let want_flush = self.poll_write()?; - if want_flush { - self.poll_flush()?; - } - match message { - Some(message) => Ok(Async::Ready(Some(message))), - None => self.poll_shutdown().map(|r| r.map(|()| None)), - } - } -} - -impl Future for Connection { - type Item = (); - type Error = Error; - - fn poll(&mut self) -> Poll<(), Error> { - while let Some(_) = try_ready!(self.poll_message()) {} - Ok(Async::Ready(())) - } -} diff --git a/tokio-postgres/src/proto/copy_in.rs b/tokio-postgres/src/proto/copy_in.rs deleted file mode 100644 index 8f22dbca0..000000000 --- a/tokio-postgres/src/proto/copy_in.rs +++ /dev/null @@ -1,228 +0,0 @@ -use futures::sink; -use futures::sync::mpsc; -use futures::{Async, AsyncSink, Future, Poll, Sink, Stream}; -use postgres_protocol::message::backend::Message; -use postgres_protocol::message::frontend; -use state_machine_future::RentToOwn; -use std::error::Error as StdError; - -use proto::client::{Client, PendingRequest}; -use proto::statement::Statement; -use Error; - -pub enum CopyMessage { - Data(Vec), - Done, -} - -pub struct CopyInReceiver { - receiver: mpsc::Receiver, - done: bool, -} - -impl CopyInReceiver { - pub fn new(receiver: mpsc::Receiver) -> CopyInReceiver { - CopyInReceiver { - receiver, - done: false, - } - } -} - -impl Stream for CopyInReceiver { - type Item = Vec; - type Error = (); - - fn poll(&mut self) -> Poll>, ()> { - if self.done { - return Ok(Async::Ready(None)); - } - - match self.receiver.poll()? { - Async::Ready(Some(CopyMessage::Data(buf))) => Ok(Async::Ready(Some(buf))), - Async::Ready(Some(CopyMessage::Done)) => { - self.done = true; - let mut buf = vec![]; - frontend::copy_done(&mut buf); - frontend::sync(&mut buf); - Ok(Async::Ready(Some(buf))) - } - Async::Ready(None) => { - self.done = true; - let mut buf = vec![]; - frontend::copy_fail("", &mut buf).unwrap(); - frontend::sync(&mut buf); - Ok(Async::Ready(Some(buf))) - } - Async::NotReady => Ok(Async::NotReady), - } - } -} - -#[derive(StateMachineFuture)] -pub enum CopyIn -where - S: Stream, - S::Item: AsRef<[u8]>, - S::Error: Into>, -{ - #[state_machine_future(start, transitions(ReadCopyInResponse))] - Start { - client: Client, - request: PendingRequest, - statement: Statement, - stream: S, - sender: mpsc::Sender, - }, - #[state_machine_future(transitions(WriteCopyData))] - ReadCopyInResponse { - stream: S, - sender: mpsc::Sender, - receiver: mpsc::Receiver, - }, - #[state_machine_future(transitions(WriteCopyDone))] - WriteCopyData { - stream: S, - pending_message: Option, - sender: mpsc::Sender, - receiver: mpsc::Receiver, - }, - #[state_machine_future(transitions(ReadCommandComplete))] - WriteCopyDone { - future: sink::Send>, - receiver: mpsc::Receiver, - }, - #[state_machine_future(transitions(Finished))] - ReadCommandComplete { receiver: mpsc::Receiver }, - #[state_machine_future(ready)] - Finished(u64), - #[state_machine_future(error)] - Failed(Error), -} - -impl PollCopyIn for CopyIn -where - S: Stream, - S::Item: AsRef<[u8]>, - S::Error: Into>, -{ - fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll, Error> { - let state = state.take(); - let receiver = state.client.send(state.request)?; - - // the statement can drop after this point, since its close will queue up after the copy - transition!(ReadCopyInResponse { - stream: state.stream, - sender: state.sender, - receiver - }) - } - - fn poll_read_copy_in_response<'a>( - state: &'a mut RentToOwn<'a, ReadCopyInResponse>, - ) -> Poll, Error> { - loop { - let message = try_ready_receive!(state.receiver.poll()); - - match message { - Some(Message::BindComplete) => {} - Some(Message::CopyInResponse(_)) => { - let state = state.take(); - transition!(WriteCopyData { - stream: state.stream, - pending_message: None, - sender: state.sender, - receiver: state.receiver - }) - } - Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), - None => return Err(Error::closed()), - } - } - } - - fn poll_write_copy_data<'a>( - state: &'a mut RentToOwn<'a, WriteCopyData>, - ) -> Poll { - loop { - let message = match state.pending_message.take() { - Some(message) => message, - None => match try_ready!(state.stream.poll().map_err(Error::copy_in_stream)) { - Some(data) => { - let mut buf = vec![]; - frontend::copy_data(data.as_ref(), &mut buf).map_err(Error::encode)?; - CopyMessage::Data(buf) - } - None => { - let state = state.take(); - transition!(WriteCopyDone { - future: state.sender.send(CopyMessage::Done), - receiver: state.receiver - }) - } - }, - }; - - match state.sender.start_send(message) { - Ok(AsyncSink::Ready) => {} - Ok(AsyncSink::NotReady(message)) => { - state.pending_message = Some(message); - return Ok(Async::NotReady); - } - Err(_) => return Err(Error::closed()), - } - } - } - - fn poll_write_copy_done<'a>( - state: &'a mut RentToOwn<'a, WriteCopyDone>, - ) -> Poll { - try_ready!(state.future.poll().map_err(|_| Error::closed())); - let state = state.take(); - - transition!(ReadCommandComplete { - receiver: state.receiver - }) - } - - fn poll_read_command_complete<'a>( - state: &'a mut RentToOwn<'a, ReadCommandComplete>, - ) -> Poll { - let message = try_ready_receive!(state.receiver.poll()); - - match message { - Some(Message::CommandComplete(body)) => { - let rows = body - .tag() - .map_err(Error::parse)? - .rsplit(' ') - .next() - .unwrap() - .parse() - .unwrap_or(0); - transition!(Finished(rows)) - } - Some(Message::ErrorResponse(body)) => Err(Error::db(body)), - Some(_) => Err(Error::unexpected_message()), - None => Err(Error::closed()), - } - } -} - -impl CopyInFuture -where - S: Stream, - S::Item: AsRef<[u8]>, - S::Error: Into>, -{ - pub fn new( - client: Client, - request: PendingRequest, - statement: Statement, - stream: S, - sender: mpsc::Sender, - ) -> CopyInFuture { - CopyIn::start(client, request, statement, stream, sender) - } -} diff --git a/tokio-postgres/src/proto/copy_out.rs b/tokio-postgres/src/proto/copy_out.rs deleted file mode 100644 index 2fdf1dbf1..000000000 --- a/tokio-postgres/src/proto/copy_out.rs +++ /dev/null @@ -1,105 +0,0 @@ -use bytes::Bytes; -use futures::sync::mpsc; -use futures::{Async, Poll, Stream}; -use postgres_protocol::message::backend::Message; -use std::mem; - -use proto::client::{Client, PendingRequest}; -use proto::statement::Statement; -use Error; - -enum State { - Start { - client: Client, - request: PendingRequest, - statement: Statement, - }, - ReadingCopyOutResponse { - receiver: mpsc::Receiver, - }, - ReadingCopyData { - receiver: mpsc::Receiver, - }, - Done, -} - -pub struct CopyOutStream(State); - -impl Stream for CopyOutStream { - type Item = Bytes; - type Error = Error; - - fn poll(&mut self) -> Poll, Error> { - loop { - match mem::replace(&mut self.0, State::Done) { - State::Start { - client, - request, - statement, - } => { - let receiver = client.send(request)?; - // it's ok for the statement to close now that we've queued the query - drop(statement); - self.0 = State::ReadingCopyOutResponse { receiver }; - } - State::ReadingCopyOutResponse { mut receiver } => { - let message = match receiver.poll() { - Ok(Async::Ready(message)) => message, - Ok(Async::NotReady) => { - self.0 = State::ReadingCopyOutResponse { receiver }; - break Ok(Async::NotReady); - } - Err(()) => unreachable!("mpsc::Receiver doesn't return errors"), - }; - - match message { - Some(Message::BindComplete) => { - self.0 = State::ReadingCopyOutResponse { receiver }; - } - Some(Message::CopyOutResponse(_)) => { - self.0 = State::ReadingCopyData { receiver }; - } - Some(Message::ErrorResponse(body)) => break Err(Error::db(body)), - Some(_) => break Err(Error::unexpected_message()), - None => break Err(Error::closed()), - } - } - State::ReadingCopyData { mut receiver } => { - let message = match receiver.poll() { - Ok(Async::Ready(message)) => message, - Ok(Async::NotReady) => { - self.0 = State::ReadingCopyData { receiver }; - break Ok(Async::NotReady); - } - Err(()) => unreachable!("mpsc::Reciever doesn't return errors"), - }; - - match message { - Some(Message::CopyData(body)) => { - self.0 = State::ReadingCopyData { receiver }; - break Ok(Async::Ready(Some(body.into_bytes()))); - } - Some(Message::CopyDone) | Some(Message::CommandComplete(_)) => { - self.0 = State::ReadingCopyData { receiver }; - } - Some(Message::ReadyForQuery(_)) => break Ok(Async::Ready(None)), - Some(Message::ErrorResponse(body)) => break Err(Error::db(body)), - Some(_) => break Err(Error::unexpected_message()), - None => break Err(Error::closed()), - } - } - State::Done => break Ok(Async::Ready(None)), - } - } - } -} - -impl CopyOutStream { - pub fn new(client: Client, request: PendingRequest, statement: Statement) -> CopyOutStream { - CopyOutStream(State::Start { - client, - request, - statement, - }) - } -} diff --git a/tokio-postgres/src/proto/execute.rs b/tokio-postgres/src/proto/execute.rs deleted file mode 100644 index 6ad3234ae..000000000 --- a/tokio-postgres/src/proto/execute.rs +++ /dev/null @@ -1,68 +0,0 @@ -use futures::sync::mpsc; -use futures::{Poll, Stream}; -use postgres_protocol::message::backend::Message; -use state_machine_future::RentToOwn; - -use proto::client::{Client, PendingRequest}; -use proto::statement::Statement; -use Error; - -#[derive(StateMachineFuture)] -pub enum Execute { - #[state_machine_future(start, transitions(ReadResponse))] - Start { - client: Client, - request: PendingRequest, - statement: Statement, - }, - #[state_machine_future(transitions(Finished))] - ReadResponse { receiver: mpsc::Receiver }, - #[state_machine_future(ready)] - Finished(u64), - #[state_machine_future(error)] - Failed(Error), -} - -impl PollExecute for Execute { - fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let state = state.take(); - let receiver = state.client.send(state.request)?; - - // the statement can drop after this point, since its close will queue up after the execution - transition!(ReadResponse { receiver }) - } - - fn poll_read_response<'a>( - state: &'a mut RentToOwn<'a, ReadResponse>, - ) -> Poll { - loop { - let message = try_ready_receive!(state.receiver.poll()); - - match message { - Some(Message::BindComplete) => {} - Some(Message::DataRow(_)) => {} - Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(Message::CommandComplete(body)) => { - let rows = body - .tag() - .map_err(Error::parse)? - .rsplit(' ') - .next() - .unwrap() - .parse() - .unwrap_or(0); - transition!(Finished(rows)) - } - Some(Message::EmptyQueryResponse) => transition!(Finished(0)), - Some(_) => return Err(Error::unexpected_message()), - None => return Err(Error::closed()), - } - } - } -} - -impl ExecuteFuture { - pub fn new(client: Client, request: PendingRequest, statement: Statement) -> ExecuteFuture { - Execute::start(client, request, statement) - } -} diff --git a/tokio-postgres/src/proto/handshake.rs b/tokio-postgres/src/proto/handshake.rs deleted file mode 100644 index 119669f75..000000000 --- a/tokio-postgres/src/proto/handshake.rs +++ /dev/null @@ -1,328 +0,0 @@ -use fallible_iterator::FallibleIterator; -use futures::sink; -use futures::sync::mpsc; -use futures::{Future, Poll, Sink, Stream}; -use postgres_protocol::authentication; -use postgres_protocol::authentication::sasl::{self, ChannelBinding, ScramSha256}; -use postgres_protocol::message::backend::Message; -use postgres_protocol::message::frontend; -use state_machine_future::RentToOwn; -use std::collections::HashMap; -use std::io; -use tokio_codec::Framed; - -use params::{ConnectParams, User}; -use proto::client::Client; -use proto::codec::PostgresCodec; -use proto::connect::ConnectFuture; -use proto::connection::Connection; -use tls::TlsStream; -use {CancelData, Error, TlsMode}; - -#[derive(StateMachineFuture)] -pub enum Handshake { - #[state_machine_future(start, transitions(SendingStartup))] - Start { - future: ConnectFuture, - params: ConnectParams, - }, - #[state_machine_future(transitions(ReadingAuth))] - SendingStartup { - future: sink::Send, PostgresCodec>>, - user: User, - }, - #[state_machine_future(transitions(ReadingInfo, SendingPassword, SendingSasl))] - ReadingAuth { - stream: Framed, PostgresCodec>, - user: User, - }, - #[state_machine_future(transitions(ReadingAuthCompletion))] - SendingPassword { - future: sink::Send, PostgresCodec>>, - }, - #[state_machine_future(transitions(ReadingSasl))] - SendingSasl { - future: sink::Send, PostgresCodec>>, - scram: ScramSha256, - }, - #[state_machine_future(transitions(SendingSasl, ReadingAuthCompletion))] - ReadingSasl { - stream: Framed, PostgresCodec>, - scram: ScramSha256, - }, - #[state_machine_future(transitions(ReadingInfo))] - ReadingAuthCompletion { - stream: Framed, PostgresCodec>, - }, - #[state_machine_future(transitions(Finished))] - ReadingInfo { - stream: Framed, PostgresCodec>, - cancel_data: Option, - parameters: HashMap, - }, - #[state_machine_future(ready)] - Finished((Client, Connection)), - #[state_machine_future(error)] - Failed(Error), -} - -impl PollHandshake for Handshake { - fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let stream = try_ready!(state.future.poll()); - let state = state.take(); - - let user = match state.params.user() { - Some(user) => user.clone(), - None => return Err(Error::missing_user()), - }; - - let mut buf = vec![]; - { - let options = state - .params - .options() - .iter() - .map(|&(ref key, ref value)| (&**key, &**value)); - let client_encoding = Some(("client_encoding", "UTF8")); - let timezone = Some(("timezone", "GMT")); - let user = Some(("user", user.name())); - let database = state.params.database().map(|s| ("database", s)); - - frontend::startup_message( - options - .chain(client_encoding) - .chain(timezone) - .chain(user) - .chain(database), - &mut buf, - ).map_err(Error::encode)?; - } - - let stream = Framed::new(stream, PostgresCodec); - transition!(SendingStartup { - future: stream.send(buf), - user, - }) - } - - fn poll_sending_startup<'a>( - state: &'a mut RentToOwn<'a, SendingStartup>, - ) -> Poll { - let stream = try_ready!(state.future.poll().map_err(Error::io)); - let state = state.take(); - transition!(ReadingAuth { - stream, - user: state.user, - }) - } - - fn poll_reading_auth<'a>( - state: &'a mut RentToOwn<'a, ReadingAuth>, - ) -> Poll { - let message = try_ready!(state.stream.poll().map_err(Error::io)); - let state = state.take(); - - match message { - Some(Message::AuthenticationOk) => transition!(ReadingInfo { - stream: state.stream, - cancel_data: None, - parameters: HashMap::new(), - }), - Some(Message::AuthenticationCleartextPassword) => { - let pass = state.user.password().ok_or_else(Error::missing_password)?; - let mut buf = vec![]; - frontend::password_message(pass, &mut buf).map_err(Error::encode)?; - transition!(SendingPassword { - future: state.stream.send(buf) - }) - } - Some(Message::AuthenticationMd5Password(body)) => { - let pass = state.user.password().ok_or_else(Error::missing_password)?; - let output = authentication::md5_hash( - state.user.name().as_bytes(), - pass.as_bytes(), - body.salt(), - ); - let mut buf = vec![]; - frontend::password_message(&output, &mut buf).map_err(Error::encode)?; - transition!(SendingPassword { - future: state.stream.send(buf) - }) - } - Some(Message::AuthenticationSasl(body)) => { - let pass = state.user.password().ok_or_else(Error::missing_password)?; - - let mut has_scram = false; - let mut has_scram_plus = false; - let mut mechanisms = body.mechanisms(); - while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? { - match mechanism { - sasl::SCRAM_SHA_256 => has_scram = true, - sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true, - _ => {} - } - } - let channel_binding = state - .stream - .get_ref() - .tls_unique() - .map(ChannelBinding::tls_unique) - .or_else(|| { - state - .stream - .get_ref() - .tls_server_end_point() - .map(ChannelBinding::tls_server_end_point) - }); - - let (channel_binding, mechanism) = if has_scram_plus { - match channel_binding { - Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS), - None => (ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), - } - } else if has_scram { - match channel_binding { - Some(_) => (ChannelBinding::unrequested(), sasl::SCRAM_SHA_256), - None => (ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), - } - } else { - return Err(Error::unsupported_authentication()); - }; - - let mut scram = ScramSha256::new(pass.as_bytes(), channel_binding); - - let mut buf = vec![]; - frontend::sasl_initial_response(mechanism, scram.message(), &mut buf) - .map_err(Error::encode)?; - - transition!(SendingSasl { - future: state.stream.send(buf), - scram, - }) - } - Some(Message::AuthenticationKerberosV5) - | Some(Message::AuthenticationScmCredential) - | Some(Message::AuthenticationGss) - | Some(Message::AuthenticationSspi) => Err(Error::unsupported_authentication()), - Some(Message::ErrorResponse(body)) => Err(Error::db(body)), - Some(_) => Err(Error::unexpected_message()), - None => Err(Error::closed()), - } - } - - fn poll_sending_password<'a>( - state: &'a mut RentToOwn<'a, SendingPassword>, - ) -> Poll { - let stream = try_ready!(state.future.poll().map_err(Error::io)); - transition!(ReadingAuthCompletion { stream }) - } - - fn poll_sending_sasl<'a>( - state: &'a mut RentToOwn<'a, SendingSasl>, - ) -> Poll { - let stream = try_ready!(state.future.poll().map_err(Error::io)); - let state = state.take(); - transition!(ReadingSasl { - stream, - scram: state.scram - }) - } - - fn poll_reading_sasl<'a>( - state: &'a mut RentToOwn<'a, ReadingSasl>, - ) -> Poll { - let message = try_ready!(state.stream.poll().map_err(Error::io)); - let mut state = state.take(); - - match message { - Some(Message::AuthenticationSaslContinue(body)) => { - state - .scram - .update(body.data()) - .map_err(Error::authentication)?; - let mut buf = vec![]; - frontend::sasl_response(state.scram.message(), &mut buf).map_err(Error::encode)?; - transition!(SendingSasl { - future: state.stream.send(buf), - scram: state.scram, - }) - } - Some(Message::AuthenticationSaslFinal(body)) => { - state - .scram - .finish(body.data()) - .map_err(Error::authentication)?; - transition!(ReadingAuthCompletion { - stream: state.stream, - }) - } - Some(Message::ErrorResponse(body)) => Err(Error::db(body)), - Some(_) => Err(Error::unexpected_message()), - None => Err(Error::closed()), - } - } - - fn poll_reading_auth_completion<'a>( - state: &'a mut RentToOwn<'a, ReadingAuthCompletion>, - ) -> Poll { - let message = try_ready!(state.stream.poll().map_err(Error::io)); - let state = state.take(); - - match message { - Some(Message::AuthenticationOk) => transition!(ReadingInfo { - stream: state.stream, - cancel_data: None, - parameters: HashMap::new(), - }), - Some(Message::ErrorResponse(body)) => Err(Error::db(body)), - Some(_) => Err(Error::unexpected_message()), - None => Err(Error::closed()), - } - } - - fn poll_reading_info<'a>( - state: &'a mut RentToOwn<'a, ReadingInfo>, - ) -> Poll { - loop { - let message = try_ready!(state.stream.poll().map_err(Error::io)); - match message { - Some(Message::BackendKeyData(body)) => { - state.cancel_data = Some(CancelData { - process_id: body.process_id(), - secret_key: body.secret_key(), - }); - } - Some(Message::ParameterStatus(body)) => { - state.parameters.insert( - body.name().map_err(Error::parse)?.to_string(), - body.value().map_err(Error::parse)?.to_string(), - ); - } - Some(Message::ReadyForQuery(_)) => { - let state = state.take(); - let cancel_data = state.cancel_data.ok_or_else(|| { - Error::parse(io::Error::new( - io::ErrorKind::InvalidData, - "BackendKeyData message missing", - )) - })?; - let (sender, receiver) = mpsc::unbounded(); - let client = Client::new(sender); - let connection = - Connection::new(state.stream, cancel_data, state.parameters, receiver); - transition!(Finished((client, connection))) - } - Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(Message::NoticeResponse(_)) => {} - Some(_) => return Err(Error::unexpected_message()), - None => return Err(Error::closed()), - } - } - } -} - -impl HandshakeFuture { - pub fn new(params: ConnectParams, tls: TlsMode) -> HandshakeFuture { - Handshake::start(ConnectFuture::new(params.clone(), tls), params) - } -} diff --git a/tokio-postgres/src/proto/mod.rs b/tokio-postgres/src/proto/mod.rs deleted file mode 100644 index 6471badfa..000000000 --- a/tokio-postgres/src/proto/mod.rs +++ /dev/null @@ -1,59 +0,0 @@ -macro_rules! try_ready_receive { - ($e:expr) => { - match $e { - Ok(::futures::Async::Ready(v)) => v, - Ok(::futures::Async::NotReady) => return Ok(::futures::Async::NotReady), - Err(()) => unreachable!("mpsc::Receiver doesn't return errors"), - } - }; -} - -macro_rules! try_ready_closed { - ($e:expr) => { - match $e { - Ok(::futures::Async::Ready(v)) => v, - Ok(::futures::Async::NotReady) => return Ok(::futures::Async::NotReady), - Err(_) => return Err(::Error::closed()), - } - }; -} - -mod bind; -mod cancel; -mod client; -mod codec; -mod connect; -mod connection; -mod copy_in; -mod copy_out; -mod execute; -mod handshake; -mod portal; -mod prepare; -mod query; -mod row; -mod simple_query; -mod socket; -mod statement; -mod transaction; -mod typeinfo; -mod typeinfo_composite; -mod typeinfo_enum; - -pub use proto::bind::BindFuture; -pub use proto::cancel::CancelFuture; -pub use proto::client::Client; -pub use proto::codec::PostgresCodec; -pub use proto::connection::Connection; -pub use proto::copy_in::CopyInFuture; -pub use proto::copy_out::CopyOutStream; -pub use proto::execute::ExecuteFuture; -pub use proto::handshake::HandshakeFuture; -pub use proto::portal::Portal; -pub use proto::prepare::PrepareFuture; -pub use proto::query::QueryStream; -pub use proto::row::Row; -pub use proto::simple_query::SimpleQueryFuture; -pub use proto::socket::Socket; -pub use proto::statement::Statement; -pub use proto::transaction::TransactionFuture; diff --git a/tokio-postgres/src/proto/portal.rs b/tokio-postgres/src/proto/portal.rs deleted file mode 100644 index ef982fc5e..000000000 --- a/tokio-postgres/src/proto/portal.rs +++ /dev/null @@ -1,39 +0,0 @@ -use std::sync::Arc; - -use proto::client::WeakClient; -use proto::statement::Statement; - -struct Inner { - client: WeakClient, - name: String, - statement: Statement, -} - -impl Drop for Inner { - fn drop(&mut self) { - if let Some(client) = self.client.upgrade() { - client.close_portal(&self.name); - } - } -} - -#[derive(Clone)] -pub struct Portal(Arc); - -impl Portal { - pub fn new(client: WeakClient, name: String, statement: Statement) -> Portal { - Portal(Arc::new(Inner { - client, - name, - statement, - })) - } - - pub fn name(&self) -> &str { - &self.0.name - } - - pub fn statement(&self) -> &Statement { - &self.0.statement - } -} diff --git a/tokio-postgres/src/proto/prepare.rs b/tokio-postgres/src/proto/prepare.rs deleted file mode 100644 index b4bfaaf5a..000000000 --- a/tokio-postgres/src/proto/prepare.rs +++ /dev/null @@ -1,226 +0,0 @@ -use fallible_iterator::FallibleIterator; -use futures::sync::mpsc; -use futures::{Future, Poll, Stream}; -use postgres_protocol::message::backend::Message; -use state_machine_future::RentToOwn; -use std::mem; -use std::vec; - -use proto::client::{Client, PendingRequest}; -use proto::statement::Statement; -use proto::typeinfo::TypeinfoFuture; -use types::{Oid, Type}; -use {Column, Error}; - -#[derive(StateMachineFuture)] -pub enum Prepare { - #[state_machine_future(start, transitions(ReadParseComplete))] - Start { - client: Client, - request: PendingRequest, - name: String, - }, - #[state_machine_future(transitions(ReadParameterDescription))] - ReadParseComplete { - client: Client, - receiver: mpsc::Receiver, - name: String, - }, - #[state_machine_future(transitions(ReadRowDescription))] - ReadParameterDescription { - client: Client, - receiver: mpsc::Receiver, - name: String, - }, - #[state_machine_future(transitions(GetParameterTypes, GetColumnTypes, Finished))] - ReadRowDescription { - client: Client, - receiver: mpsc::Receiver, - name: String, - parameters: Vec, - }, - #[state_machine_future(transitions(GetColumnTypes, Finished))] - GetParameterTypes { - future: TypeinfoFuture, - remaining_parameters: vec::IntoIter, - name: String, - parameters: Vec, - columns: Vec<(String, Oid)>, - }, - #[state_machine_future(transitions(Finished))] - GetColumnTypes { - future: TypeinfoFuture, - cur_column_name: String, - remaining_columns: vec::IntoIter<(String, Oid)>, - name: String, - parameters: Vec, - columns: Vec, - }, - #[state_machine_future(ready)] - Finished(Statement), - #[state_machine_future(error)] - Failed(Error), -} - -impl PollPrepare for Prepare { - fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let state = state.take(); - let receiver = state.client.send(state.request)?; - - transition!(ReadParseComplete { - receiver, - name: state.name, - client: state.client, - }) - } - - fn poll_read_parse_complete<'a>( - state: &'a mut RentToOwn<'a, ReadParseComplete>, - ) -> Poll { - let message = try_ready_receive!(state.receiver.poll()); - let state = state.take(); - - match message { - Some(Message::ParseComplete) => transition!(ReadParameterDescription { - receiver: state.receiver, - name: state.name, - client: state.client, - }), - Some(Message::ErrorResponse(body)) => Err(Error::db(body)), - Some(_) => Err(Error::unexpected_message()), - None => Err(Error::closed()), - } - } - - fn poll_read_parameter_description<'a>( - state: &'a mut RentToOwn<'a, ReadParameterDescription>, - ) -> Poll { - let message = try_ready_receive!(state.receiver.poll()); - let state = state.take(); - - match message { - Some(Message::ParameterDescription(body)) => transition!(ReadRowDescription { - receiver: state.receiver, - name: state.name, - parameters: body.parameters().collect().map_err(Error::parse)?, - client: state.client, - }), - Some(_) => Err(Error::unexpected_message()), - None => Err(Error::closed()), - } - } - - fn poll_read_row_description<'a>( - state: &'a mut RentToOwn<'a, ReadRowDescription>, - ) -> Poll { - let message = try_ready_receive!(state.receiver.poll()); - let state = state.take(); - - let columns = match message { - Some(Message::RowDescription(body)) => body - .fields() - .map(|f| (f.name().to_string(), f.type_oid())) - .collect() - .map_err(Error::parse)?, - Some(Message::NoData) => vec![], - Some(_) => return Err(Error::unexpected_message()), - None => return Err(Error::closed()), - }; - - let mut parameters = state.parameters.into_iter(); - if let Some(oid) = parameters.next() { - transition!(GetParameterTypes { - future: TypeinfoFuture::new(oid, state.client), - remaining_parameters: parameters, - name: state.name, - parameters: vec![], - columns: columns, - }); - } - - let mut columns = columns.into_iter(); - if let Some((name, oid)) = columns.next() { - transition!(GetColumnTypes { - future: TypeinfoFuture::new(oid, state.client), - cur_column_name: name, - remaining_columns: columns, - name: state.name, - parameters: vec![], - columns: vec![], - }); - } - - transition!(Finished(Statement::new( - state.client.downgrade(), - state.name, - vec![], - vec![] - ))) - } - - fn poll_get_parameter_types<'a>( - state: &'a mut RentToOwn<'a, GetParameterTypes>, - ) -> Poll { - let client = loop { - let (ty, client) = try_ready!(state.future.poll()); - state.parameters.push(ty); - - match state.remaining_parameters.next() { - Some(oid) => state.future = TypeinfoFuture::new(oid, client), - None => break client, - } - }; - let state = state.take(); - - let mut columns = state.columns.into_iter(); - if let Some((name, oid)) = columns.next() { - transition!(GetColumnTypes { - future: TypeinfoFuture::new(oid, client), - cur_column_name: name, - remaining_columns: columns, - name: state.name, - parameters: state.parameters, - columns: vec![], - }) - } - - transition!(Finished(Statement::new( - client.downgrade(), - state.name, - state.parameters, - vec![], - ))) - } - - fn poll_get_column_types<'a>( - state: &'a mut RentToOwn<'a, GetColumnTypes>, - ) -> Poll { - let client = loop { - let (ty, client) = try_ready!(state.future.poll()); - let name = mem::replace(&mut state.cur_column_name, String::new()); - state.columns.push(Column::new(name, ty)); - - match state.remaining_columns.next() { - Some((name, oid)) => { - state.cur_column_name = name; - state.future = TypeinfoFuture::new(oid, client); - } - None => break client, - } - }; - let state = state.take(); - - transition!(Finished(Statement::new( - client.downgrade(), - state.name, - state.parameters, - state.columns, - ))) - } -} - -impl PrepareFuture { - pub fn new(client: Client, request: PendingRequest, name: String) -> PrepareFuture { - Prepare::start(client, request, name) - } -} diff --git a/tokio-postgres/src/proto/query.rs b/tokio-postgres/src/proto/query.rs deleted file mode 100644 index 3cb5d1372..000000000 --- a/tokio-postgres/src/proto/query.rs +++ /dev/null @@ -1,122 +0,0 @@ -use futures::sync::mpsc; -use futures::{Async, Poll, Stream}; -use postgres_protocol::message::backend::Message; -use std::mem; - -use proto::client::{Client, PendingRequest}; -use proto::portal::Portal; -use proto::row::Row; -use proto::statement::Statement; -use Error; - -pub trait StatementHolder { - fn statement(&self) -> &Statement; -} - -impl StatementHolder for Statement { - fn statement(&self) -> &Statement { - self - } -} - -impl StatementHolder for Portal { - fn statement(&self) -> &Statement { - self.statement() - } -} - -enum State { - Start { - client: Client, - request: PendingRequest, - statement: T, - }, - ReadingResponse { - receiver: mpsc::Receiver, - statement: T, - }, - Done, -} - -pub struct QueryStream(State); - -impl Stream for QueryStream -where - T: StatementHolder, -{ - type Item = Row; - type Error = Error; - - fn poll(&mut self) -> Poll, Error> { - loop { - match mem::replace(&mut self.0, State::Done) { - State::Start { - client, - request, - statement, - } => { - let receiver = client.send(request)?; - self.0 = State::ReadingResponse { - receiver, - statement, - }; - } - State::ReadingResponse { - mut receiver, - statement, - } => { - let message = match receiver.poll() { - Ok(Async::Ready(message)) => message, - Ok(Async::NotReady) => { - self.0 = State::ReadingResponse { - receiver, - statement, - }; - break Ok(Async::NotReady); - } - Err(()) => unreachable!("mpsc::Receiver doesn't return errors"), - }; - - match message { - Some(Message::BindComplete) => { - self.0 = State::ReadingResponse { - receiver, - statement, - }; - } - Some(Message::ErrorResponse(body)) => break Err(Error::db(body)), - Some(Message::DataRow(body)) => { - let row = Row::new(statement.statement().clone(), body)?; - self.0 = State::ReadingResponse { - receiver, - statement, - }; - break Ok(Async::Ready(Some(row))); - } - Some(Message::EmptyQueryResponse) - | Some(Message::PortalSuspended) - | Some(Message::CommandComplete(_)) => { - break Ok(Async::Ready(None)); - } - Some(_) => break Err(Error::unexpected_message()), - None => break Err(Error::closed()), - } - } - State::Done => break Ok(Async::Ready(None)), - } - } - } -} - -impl QueryStream -where - T: StatementHolder, -{ - pub fn new(client: Client, request: PendingRequest, statement: T) -> QueryStream { - QueryStream(State::Start { - client, - request, - statement, - }) - } -} diff --git a/tokio-postgres/src/proto/row.rs b/tokio-postgres/src/proto/row.rs deleted file mode 100644 index 38c270d45..000000000 --- a/tokio-postgres/src/proto/row.rs +++ /dev/null @@ -1,65 +0,0 @@ -use postgres_protocol::message::backend::DataRowBody; -use postgres_shared::rows::{RowData, RowIndex}; -use std::fmt; - -use proto::statement::Statement; -use types::{FromSql, WrongType}; -use {Column, Error}; - -pub struct Row { - statement: Statement, - data: RowData, -} - -impl Row { - pub fn new(statement: Statement, data: DataRowBody) -> Result { - let data = RowData::new(data).map_err(Error::parse)?; - Ok(Row { statement, data }) - } - - pub fn columns(&self) -> &[Column] { - self.statement.columns() - } - - pub fn len(&self) -> usize { - self.columns().len() - } - - pub fn get<'b, I, T>(&'b self, idx: I) -> T - where - I: RowIndex + fmt::Debug, - T: FromSql<'b>, - { - match self.get_inner(&idx) { - Ok(Some(ok)) => ok, - Err(err) => panic!("error retrieving column {:?}: {:?}", idx, err), - Ok(None) => panic!("no such column {:?}", idx), - } - } - - pub fn try_get<'b, I, T>(&'b self, idx: I) -> Result, Error> - where - I: RowIndex, - T: FromSql<'b>, - { - self.get_inner(&idx) - } - - fn get_inner<'b, I, T>(&'b self, idx: &I) -> Result, Error> - where - I: RowIndex, - T: FromSql<'b>, - { - let idx = match idx.__idx(&self.columns()) { - Some(idx) => idx, - None => return Ok(None), - }; - - let ty = self.statement.columns()[idx].type_(); - if !::accepts(ty) { - return Err(Error::from_sql(Box::new(WrongType::new(ty.clone())))); - } - let value = FromSql::from_sql_nullable(ty, self.data.get(idx)); - value.map(Some).map_err(Error::from_sql) - } -} diff --git a/tokio-postgres/src/proto/simple_query.rs b/tokio-postgres/src/proto/simple_query.rs deleted file mode 100644 index e39d1b4e1..000000000 --- a/tokio-postgres/src/proto/simple_query.rs +++ /dev/null @@ -1,56 +0,0 @@ -use futures::sync::mpsc; -use futures::{Poll, Stream}; -use postgres_protocol::message::backend::Message; -use state_machine_future::RentToOwn; - -use proto::client::{Client, PendingRequest}; -use Error; - -#[derive(StateMachineFuture)] -pub enum SimpleQuery { - #[state_machine_future(start, transitions(ReadResponse))] - Start { - client: Client, - request: PendingRequest, - }, - #[state_machine_future(transitions(Finished))] - ReadResponse { receiver: mpsc::Receiver }, - #[state_machine_future(ready)] - Finished(()), - #[state_machine_future(error)] - Failed(Error), -} - -impl PollSimpleQuery for SimpleQuery { - fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let state = state.take(); - let receiver = state.client.send(state.request)?; - - transition!(ReadResponse { receiver }) - } - - fn poll_read_response<'a>( - state: &'a mut RentToOwn<'a, ReadResponse>, - ) -> Poll { - loop { - let message = try_ready_receive!(state.receiver.poll()); - - match message { - Some(Message::CommandComplete(_)) - | Some(Message::RowDescription(_)) - | Some(Message::DataRow(_)) - | Some(Message::EmptyQueryResponse) => {} - Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(Message::ReadyForQuery(_)) => transition!(Finished(())), - Some(_) => return Err(Error::unexpected_message()), - None => return Err(Error::closed()), - } - } - } -} - -impl SimpleQueryFuture { - pub fn new(client: Client, request: PendingRequest) -> SimpleQueryFuture { - SimpleQuery::start(client, request) - } -} diff --git a/tokio-postgres/src/proto/socket.rs b/tokio-postgres/src/proto/socket.rs deleted file mode 100644 index f6de498ae..000000000 --- a/tokio-postgres/src/proto/socket.rs +++ /dev/null @@ -1,84 +0,0 @@ -use bytes::{Buf, BufMut}; -use futures::Poll; -use std::io::{self, Read, Write}; -use tokio_io::{AsyncRead, AsyncWrite}; -use tokio_tcp::TcpStream; - -#[cfg(unix)] -use tokio_uds::UnixStream; - -pub enum Socket { - Tcp(TcpStream), - #[cfg(unix)] - Unix(UnixStream), -} - -impl Read for Socket { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match self { - Socket::Tcp(stream) => stream.read(buf), - #[cfg(unix)] - Socket::Unix(stream) => stream.read(buf), - } - } -} - -impl AsyncRead for Socket { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - match self { - Socket::Tcp(stream) => stream.prepare_uninitialized_buffer(buf), - #[cfg(unix)] - Socket::Unix(stream) => stream.prepare_uninitialized_buffer(buf), - } - } - - fn read_buf(&mut self, buf: &mut B) -> Poll - where - B: BufMut, - { - match self { - Socket::Tcp(stream) => stream.read_buf(buf), - #[cfg(unix)] - Socket::Unix(stream) => stream.read_buf(buf), - } - } -} - -impl Write for Socket { - fn write(&mut self, buf: &[u8]) -> io::Result { - match self { - Socket::Tcp(stream) => stream.write(buf), - #[cfg(unix)] - Socket::Unix(stream) => stream.write(buf), - } - } - - fn flush(&mut self) -> io::Result<()> { - match self { - Socket::Tcp(stream) => stream.flush(), - #[cfg(unix)] - Socket::Unix(stream) => stream.flush(), - } - } -} - -impl AsyncWrite for Socket { - fn shutdown(&mut self) -> Poll<(), io::Error> { - match self { - Socket::Tcp(stream) => stream.shutdown(), - #[cfg(unix)] - Socket::Unix(stream) => stream.shutdown(), - } - } - - fn write_buf(&mut self, buf: &mut B) -> Poll - where - B: Buf, - { - match self { - Socket::Tcp(stream) => stream.write_buf(buf), - #[cfg(unix)] - Socket::Unix(stream) => stream.write_buf(buf), - } - } -} diff --git a/tokio-postgres/src/proto/statement.rs b/tokio-postgres/src/proto/statement.rs deleted file mode 100644 index 3460a76c0..000000000 --- a/tokio-postgres/src/proto/statement.rs +++ /dev/null @@ -1,51 +0,0 @@ -use postgres_shared::stmt::Column; -use std::sync::Arc; - -use proto::client::WeakClient; -use types::Type; - -pub struct StatementInner { - client: WeakClient, - name: String, - params: Vec, - columns: Vec, -} - -impl Drop for StatementInner { - fn drop(&mut self) { - if let Some(client) = self.client.upgrade() { - client.close_statement(&self.name); - } - } -} - -#[derive(Clone)] -pub struct Statement(Arc); - -impl Statement { - pub fn new( - client: WeakClient, - name: String, - params: Vec, - columns: Vec, - ) -> Statement { - Statement(Arc::new(StatementInner { - client, - name, - params, - columns, - })) - } - - pub fn name(&self) -> &str { - &self.0.name - } - - pub fn params(&self) -> &[Type] { - &self.0.params - } - - pub fn columns(&self) -> &[Column] { - &self.0.columns - } -} diff --git a/tokio-postgres/src/proto/transaction.rs b/tokio-postgres/src/proto/transaction.rs deleted file mode 100644 index 2a7b08265..000000000 --- a/tokio-postgres/src/proto/transaction.rs +++ /dev/null @@ -1,104 +0,0 @@ -use futures::{Async, Future, Poll}; -use proto::client::Client; -use proto::simple_query::SimpleQueryFuture; -use state_machine_future::RentToOwn; - -use Error; - -#[derive(StateMachineFuture)] -pub enum Transaction -where - F: Future, - E: From, -{ - #[state_machine_future(start, transitions(Beginning))] - Start { client: Client, future: F }, - #[state_machine_future(transitions(Running))] - Beginning { - client: Client, - begin: SimpleQueryFuture, - future: F, - }, - #[state_machine_future(transitions(Finishing))] - Running { client: Client, future: F }, - #[state_machine_future(transitions(Finished))] - Finishing { - future: SimpleQueryFuture, - result: Result, - }, - #[state_machine_future(ready)] - Finished(T), - #[state_machine_future(error)] - Failed(E), -} - -impl PollTransaction for Transaction -where - F: Future, - E: From, -{ - fn poll_start<'a>( - state: &'a mut RentToOwn<'a, Start>, - ) -> Poll, E> { - let state = state.take(); - transition!(Beginning { - begin: state.client.batch_execute("BEGIN"), - client: state.client, - future: state.future, - }) - } - - fn poll_beginning<'a>( - state: &'a mut RentToOwn<'a, Beginning>, - ) -> Poll, E> { - try_ready!(state.begin.poll()); - let state = state.take(); - transition!(Running { - client: state.client, - future: state.future, - }) - } - - fn poll_running<'a>( - state: &'a mut RentToOwn<'a, Running>, - ) -> Poll, E> { - match state.future.poll() { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(t)) => transition!(Finishing { - future: state.client.batch_execute("COMMIT"), - result: Ok(t), - }), - Err(e) => transition!(Finishing { - future: state.client.batch_execute("ROLLBACK"), - result: Err(e), - }), - } - } - - fn poll_finishing<'a>( - state: &'a mut RentToOwn<'a, Finishing>, - ) -> Poll, E> { - match state.future.poll() { - Ok(Async::NotReady) => return Ok(Async::NotReady), - Ok(Async::Ready(())) => { - let t = state.take().result?; - transition!(Finished(t)) - } - Err(e) => match state.take().result { - Ok(_) => Err(e.into()), - // prioritize the future's error over the rollback error - Err(e) => Err(e), - }, - } - } -} - -impl TransactionFuture -where - F: Future, - E: From, -{ - pub fn new(client: Client, future: F) -> TransactionFuture { - Transaction::start(client, future) - } -} diff --git a/tokio-postgres/src/proto/typeinfo.rs b/tokio-postgres/src/proto/typeinfo.rs deleted file mode 100644 index 81aedcd9f..000000000 --- a/tokio-postgres/src/proto/typeinfo.rs +++ /dev/null @@ -1,344 +0,0 @@ -use futures::stream::{self, Stream}; -use futures::{Async, Future, Poll}; -use state_machine_future::RentToOwn; - -use error::{Error, SqlState}; -use next_statement; -use proto::client::Client; -use proto::prepare::PrepareFuture; -use proto::query::QueryStream; -use proto::statement::Statement; -use proto::typeinfo_composite::TypeinfoCompositeFuture; -use proto::typeinfo_enum::TypeinfoEnumFuture; -use types::{Kind, Oid, Type}; - -const TYPEINFO_QUERY: &'static str = " -SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid -FROM pg_catalog.pg_type t -LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid -INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid -WHERE t.oid = $1 -"; - -// Range types weren't added until Postgres 9.2, so pg_range may not exist -const TYPEINFO_FALLBACK_QUERY: &'static str = " -SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid -FROM pg_catalog.pg_type t -INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid -WHERE t.oid = $1 -"; - -#[derive(StateMachineFuture)] -pub enum Typeinfo { - #[state_machine_future( - start, - transitions(PreparingTypeinfo, QueryingTypeinfo, Finished) - )] - Start { oid: Oid, client: Client }, - #[state_machine_future(transitions(PreparingTypeinfoFallback, QueryingTypeinfo))] - PreparingTypeinfo { - future: Box, - oid: Oid, - client: Client, - }, - #[state_machine_future(transitions(QueryingTypeinfo))] - PreparingTypeinfoFallback { - future: Box, - oid: Oid, - client: Client, - }, - #[state_machine_future(transitions( - CachingType, - QueryingEnumVariants, - QueryingDomainBasetype, - QueryingArrayElem, - QueryingCompositeFields, - QueryingRangeSubtype - ))] - QueryingTypeinfo { - future: stream::Collect>, - oid: Oid, - client: Client, - }, - #[state_machine_future(transitions(CachingType))] - QueryingEnumVariants { - future: TypeinfoEnumFuture, - name: String, - oid: Oid, - schema: String, - }, - #[state_machine_future(transitions(CachingType))] - QueryingDomainBasetype { - future: Box, - name: String, - oid: Oid, - schema: String, - }, - #[state_machine_future(transitions(CachingType))] - QueryingArrayElem { - future: Box, - name: String, - oid: Oid, - schema: String, - }, - #[state_machine_future(transitions(CachingType))] - QueryingCompositeFields { - future: TypeinfoCompositeFuture, - name: String, - oid: Oid, - schema: String, - }, - #[state_machine_future(transitions(CachingType))] - QueryingRangeSubtype { - future: Box, - name: String, - oid: Oid, - schema: String, - }, - #[state_machine_future(transitions(Finished))] - CachingType { ty: Type, oid: Oid, client: Client }, - #[state_machine_future(ready)] - Finished((Type, Client)), - #[state_machine_future(error)] - Failed(Error), -} - -impl PollTypeinfo for Typeinfo { - fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let state = state.take(); - - if let Some(ty) = Type::from_oid(state.oid) { - transition!(Finished((ty, state.client))); - } - - if let Some(ty) = state.client.cached_type(state.oid) { - transition!(Finished((ty, state.client))); - } - - match state.client.typeinfo_query() { - Some(statement) => transition!(QueryingTypeinfo { - future: state.client.query(&statement, &[&state.oid]).collect(), - oid: state.oid, - client: state.client, - }), - None => transition!(PreparingTypeinfo { - future: Box::new(state.client.prepare(next_statement(), TYPEINFO_QUERY, &[])), - oid: state.oid, - client: state.client, - }), - } - } - - fn poll_preparing_typeinfo<'a>( - state: &'a mut RentToOwn<'a, PreparingTypeinfo>, - ) -> Poll { - let statement = match state.future.poll() { - Ok(Async::Ready(statement)) => statement, - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => { - let mut state = state.take(); - - transition!(PreparingTypeinfoFallback { - future: Box::new(state.client.prepare( - next_statement(), - TYPEINFO_FALLBACK_QUERY, - &[] - )), - oid: state.oid, - client: state.client, - }) - } - Err(e) => return Err(e), - }; - let state = state.take(); - - let future = state.client.query(&statement, &[&state.oid]).collect(); - state.client.set_typeinfo_query(&statement); - transition!(QueryingTypeinfo { - future, - oid: state.oid, - client: state.client - }) - } - - fn poll_preparing_typeinfo_fallback<'a>( - state: &'a mut RentToOwn<'a, PreparingTypeinfoFallback>, - ) -> Poll { - let statement = try_ready!(state.future.poll()); - let state = state.take(); - - let future = state.client.query(&statement, &[&state.oid]).collect(); - state.client.set_typeinfo_query(&statement); - transition!(QueryingTypeinfo { - future, - oid: state.oid, - client: state.client - }) - } - - fn poll_querying_typeinfo<'a>( - state: &'a mut RentToOwn<'a, QueryingTypeinfo>, - ) -> Poll { - let rows = try_ready!(state.future.poll()); - let state = state.take(); - - let row = match rows.get(0) { - Some(row) => row, - None => return Err(Error::unexpected_message()), - }; - - let name = row - .try_get::<_, String>(0)? - .ok_or_else(Error::unexpected_message)?; - let type_ = row - .try_get::<_, i8>(1)? - .ok_or_else(Error::unexpected_message)?; - let elem_oid = row - .try_get::<_, Oid>(2)? - .ok_or_else(Error::unexpected_message)?; - let rngsubtype = row - .try_get::<_, Option>(3)? - .ok_or_else(Error::unexpected_message)?; - let basetype = row - .try_get::<_, Oid>(4)? - .ok_or_else(Error::unexpected_message)?; - let schema = row - .try_get::<_, String>(5)? - .ok_or_else(Error::unexpected_message)?; - let relid = row - .try_get::<_, Oid>(6)? - .ok_or_else(Error::unexpected_message)?; - - let kind = if type_ == b'e' as i8 { - transition!(QueryingEnumVariants { - future: TypeinfoEnumFuture::new(state.oid, state.client), - name, - oid: state.oid, - schema, - }) - } else if type_ == b'p' as i8 { - Kind::Pseudo - } else if basetype != 0 { - transition!(QueryingDomainBasetype { - future: Box::new(TypeinfoFuture::new(basetype, state.client)), - name, - oid: state.oid, - schema, - }) - } else if elem_oid != 0 { - transition!(QueryingArrayElem { - future: Box::new(TypeinfoFuture::new(elem_oid, state.client)), - name, - oid: state.oid, - schema, - }) - } else if relid != 0 { - transition!(QueryingCompositeFields { - future: TypeinfoCompositeFuture::new(relid, state.client), - name, - oid: state.oid, - schema, - }) - } else if let Some(rngsubtype) = rngsubtype { - transition!(QueryingRangeSubtype { - future: Box::new(TypeinfoFuture::new(rngsubtype, state.client)), - name, - oid: state.oid, - schema, - }) - } else { - Kind::Simple - }; - - let ty = Type::_new(name.to_string(), state.oid, kind, schema.to_string()); - transition!(CachingType { - ty, - oid: state.oid, - client: state.client, - }) - } - - fn poll_querying_enum_variants<'a>( - state: &'a mut RentToOwn<'a, QueryingEnumVariants>, - ) -> Poll { - let (variants, client) = try_ready!(state.future.poll()); - let state = state.take(); - - let ty = Type::_new(state.name, state.oid, Kind::Enum(variants), state.schema); - transition!(CachingType { - ty, - oid: state.oid, - client, - }) - } - - fn poll_querying_domain_basetype<'a>( - state: &'a mut RentToOwn<'a, QueryingDomainBasetype>, - ) -> Poll { - let (basetype, client) = try_ready!(state.future.poll()); - let state = state.take(); - - let ty = Type::_new(state.name, state.oid, Kind::Domain(basetype), state.schema); - transition!(CachingType { - ty, - oid: state.oid, - client, - }) - } - - fn poll_querying_array_elem<'a>( - state: &'a mut RentToOwn<'a, QueryingArrayElem>, - ) -> Poll { - let (elem, client) = try_ready!(state.future.poll()); - let state = state.take(); - - let ty = Type::_new(state.name, state.oid, Kind::Array(elem), state.schema); - transition!(CachingType { - ty, - oid: state.oid, - client, - }) - } - - fn poll_querying_composite_fields<'a>( - state: &'a mut RentToOwn<'a, QueryingCompositeFields>, - ) -> Poll { - let (fields, client) = try_ready!(state.future.poll()); - let state = state.take(); - - let ty = Type::_new(state.name, state.oid, Kind::Composite(fields), state.schema); - transition!(CachingType { - ty, - oid: state.oid, - client, - }) - } - - fn poll_querying_range_subtype<'a>( - state: &'a mut RentToOwn<'a, QueryingRangeSubtype>, - ) -> Poll { - let (subtype, client) = try_ready!(state.future.poll()); - let state = state.take(); - - let ty = Type::_new(state.name, state.oid, Kind::Range(subtype), state.schema); - transition!(CachingType { - ty, - oid: state.oid, - client, - }) - } - - fn poll_caching_type<'a>( - state: &'a mut RentToOwn<'a, CachingType>, - ) -> Poll { - let state = state.take(); - state.client.cache_type(&state.ty); - transition!(Finished((state.ty, state.client))) - } -} - -impl TypeinfoFuture { - pub fn new(oid: Oid, client: Client) -> TypeinfoFuture { - Typeinfo::start(oid, client) - } -} diff --git a/tokio-postgres/src/proto/typeinfo_composite.rs b/tokio-postgres/src/proto/typeinfo_composite.rs deleted file mode 100644 index ab2f16fbd..000000000 --- a/tokio-postgres/src/proto/typeinfo_composite.rs +++ /dev/null @@ -1,143 +0,0 @@ -use futures::stream::{self, Stream}; -use futures::{Future, Poll}; -use state_machine_future::RentToOwn; -use std::mem; -use std::vec; - -use error::Error; -use next_statement; -use proto::client::Client; -use proto::prepare::PrepareFuture; -use proto::query::QueryStream; -use proto::statement::Statement; -use proto::typeinfo::TypeinfoFuture; -use types::{Field, Oid}; - -const TYPEINFO_COMPOSITE_QUERY: &'static str = " -SELECT attname, atttypid -FROM pg_catalog.pg_attribute -WHERE attrelid = $1 -AND NOT attisdropped -AND attnum > 0 -ORDER BY attnum -"; - -#[derive(StateMachineFuture)] -pub enum TypeinfoComposite { - #[state_machine_future( - start, - transitions(PreparingTypeinfoComposite, QueryingCompositeFields) - )] - Start { oid: Oid, client: Client }, - #[state_machine_future(transitions(QueryingCompositeFields))] - PreparingTypeinfoComposite { - future: Box, - oid: Oid, - client: Client, - }, - #[state_machine_future(transitions(QueryingCompositeFieldTypes, Finished))] - QueryingCompositeFields { - future: stream::Collect>, - client: Client, - }, - #[state_machine_future(transitions(Finished))] - QueryingCompositeFieldTypes { - future: Box, - cur_field_name: String, - remaining_fields: vec::IntoIter<(String, Oid)>, - fields: Vec, - }, - #[state_machine_future(ready)] - Finished((Vec, Client)), - #[state_machine_future(error)] - Failed(Error), -} - -impl PollTypeinfoComposite for TypeinfoComposite { - fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let state = state.take(); - - match state.client.typeinfo_composite_query() { - Some(statement) => transition!(QueryingCompositeFields { - future: state.client.query(&statement, &[&state.oid]).collect(), - client: state.client, - }), - None => transition!(PreparingTypeinfoComposite { - future: Box::new(state.client.prepare( - next_statement(), - TYPEINFO_COMPOSITE_QUERY, - &[] - )), - oid: state.oid, - client: state.client, - }), - } - } - - fn poll_preparing_typeinfo_composite<'a>( - state: &'a mut RentToOwn<'a, PreparingTypeinfoComposite>, - ) -> Poll { - let statement = try_ready!(state.future.poll()); - let state = state.take(); - - state.client.set_typeinfo_composite_query(&statement); - transition!(QueryingCompositeFields { - future: state.client.query(&statement, &[&state.oid]).collect(), - client: state.client, - }) - } - - fn poll_querying_composite_fields<'a>( - state: &'a mut RentToOwn<'a, QueryingCompositeFields>, - ) -> Poll { - let rows = try_ready!(state.future.poll()); - let state = state.take(); - - let fields = rows - .iter() - .map(|row| { - let name = row.try_get(0)?.ok_or_else(Error::unexpected_message)?; - let oid = row.try_get(1)?.ok_or_else(Error::unexpected_message)?; - Ok((name, oid)) - }).collect::, Error>>()?; - - let mut remaining_fields = fields.into_iter(); - match remaining_fields.next() { - Some((cur_field_name, oid)) => transition!(QueryingCompositeFieldTypes { - future: Box::new(TypeinfoFuture::new(oid, state.client)), - cur_field_name, - fields: vec![], - remaining_fields, - }), - None => transition!(Finished((vec![], state.client))), - } - } - - fn poll_querying_composite_field_types<'a>( - state: &'a mut RentToOwn<'a, QueryingCompositeFieldTypes>, - ) -> Poll { - loop { - let (ty, client) = try_ready!(state.future.poll()); - - let name = mem::replace(&mut state.cur_field_name, String::new()); - state.fields.push(Field::new(name, ty)); - - match state.remaining_fields.next() { - Some((cur_field_name, oid)) => { - state.cur_field_name = cur_field_name; - state.future = Box::new(TypeinfoFuture::new(oid, client)); - } - None => { - let state = state.take(); - transition!(Finished((state.fields, client))); - } - } - } - } -} - -impl TypeinfoCompositeFuture { - pub fn new(oid: Oid, client: Client) -> TypeinfoCompositeFuture { - TypeinfoComposite::start(oid, client) - } -} diff --git a/tokio-postgres/src/proto/typeinfo_enum.rs b/tokio-postgres/src/proto/typeinfo_enum.rs deleted file mode 100644 index 82283dbe4..000000000 --- a/tokio-postgres/src/proto/typeinfo_enum.rs +++ /dev/null @@ -1,141 +0,0 @@ -use futures::stream::{self, Stream}; -use futures::{Async, Future, Poll}; -use state_machine_future::RentToOwn; - -use error::{Error, SqlState}; -use next_statement; -use proto::client::Client; -use proto::prepare::PrepareFuture; -use proto::query::QueryStream; -use proto::statement::Statement; -use types::Oid; - -const TYPEINFO_ENUM_QUERY: &'static str = " -SELECT enumlabel -FROM pg_catalog.pg_enum -WHERE enumtypid = $1 -ORDER BY enumsortorder -"; - -// Postgres 9.0 didn't have enumsortorder -const TYPEINFO_ENUM_FALLBACK_QUERY: &'static str = " -SELECT enumlabel -FROM pg_catalog.pg_enum -WHERE enumtypid = $1 -ORDER BY oid -"; - -#[derive(StateMachineFuture)] -pub enum TypeinfoEnum { - #[state_machine_future( - start, - transitions(PreparingTypeinfoEnum, QueryingEnumVariants) - )] - Start { oid: Oid, client: Client }, - #[state_machine_future(transitions(PreparingTypeinfoEnumFallback, QueryingEnumVariants))] - PreparingTypeinfoEnum { - future: Box, - oid: Oid, - client: Client, - }, - #[state_machine_future(transitions(QueryingEnumVariants))] - PreparingTypeinfoEnumFallback { - future: Box, - oid: Oid, - client: Client, - }, - #[state_machine_future(transitions(Finished))] - QueryingEnumVariants { - future: stream::Collect>, - client: Client, - }, - #[state_machine_future(ready)] - Finished((Vec, Client)), - #[state_machine_future(error)] - Failed(Error), -} - -impl PollTypeinfoEnum for TypeinfoEnum { - fn poll_start<'a>(state: &'a mut RentToOwn<'a, Start>) -> Poll { - let state = state.take(); - - match state.client.typeinfo_enum_query() { - Some(statement) => transition!(QueryingEnumVariants { - future: state.client.query(&statement, &[&state.oid]).collect(), - client: state.client, - }), - None => transition!(PreparingTypeinfoEnum { - future: Box::new( - state - .client - .prepare(next_statement(), TYPEINFO_ENUM_QUERY, &[]) - ), - oid: state.oid, - client: state.client, - }), - } - } - - fn poll_preparing_typeinfo_enum<'a>( - state: &'a mut RentToOwn<'a, PreparingTypeinfoEnum>, - ) -> Poll { - let statement = match state.future.poll() { - Ok(Async::Ready(statement)) => statement, - Ok(Async::NotReady) => return Ok(Async::NotReady), - Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => { - let mut state = state.take(); - - transition!(PreparingTypeinfoEnumFallback { - future: Box::new(state.client.prepare( - next_statement(), - TYPEINFO_ENUM_FALLBACK_QUERY, - &[] - )), - oid: state.oid, - client: state.client, - }) - } - Err(e) => return Err(e), - }; - let state = state.take(); - - state.client.set_typeinfo_enum_query(&statement); - transition!(QueryingEnumVariants { - future: state.client.query(&statement, &[&state.oid]).collect(), - client: state.client, - }) - } - - fn poll_preparing_typeinfo_enum_fallback<'a>( - state: &'a mut RentToOwn<'a, PreparingTypeinfoEnumFallback>, - ) -> Poll { - let statement = try_ready!(state.future.poll()); - let state = state.take(); - - state.client.set_typeinfo_enum_query(&statement); - transition!(QueryingEnumVariants { - future: state.client.query(&statement, &[&state.oid]).collect(), - client: state.client, - }) - } - - fn poll_querying_enum_variants<'a>( - state: &'a mut RentToOwn<'a, QueryingEnumVariants>, - ) -> Poll { - let rows = try_ready!(state.future.poll()); - let state = state.take(); - - let variants = rows - .iter() - .map(|row| row.try_get(0)?.ok_or_else(Error::unexpected_message)) - .collect::, _>>()?; - - transition!(Finished((variants, state.client))) - } -} - -impl TypeinfoEnumFuture { - pub fn new(oid: Oid, client: Client) -> TypeinfoEnumFuture { - TypeinfoEnum::start(oid, client) - } -} diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs new file mode 100644 index 000000000..2fcb22d57 --- /dev/null +++ b/tokio-postgres/src/query.rs @@ -0,0 +1,325 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::prepare::get_type; +use crate::types::{BorrowToSql, IsNull}; +use crate::{Column, Error, Portal, Row, Statement}; +use bytes::{Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Stream}; +use log::{debug, log_enabled, Level}; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::{CommandCompleteBody, Message}; +use postgres_protocol::message::frontend; +use postgres_types::Type; +use std::fmt; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +struct BorrowToSqlParamsDebug<'a, T>(&'a [T]); + +impl fmt::Debug for BorrowToSqlParamsDebug<'_, T> +where + T: BorrowToSql, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list() + .entries(self.0.iter().map(|x| x.borrow_to_sql())) + .finish() + } +} + +pub async fn query( + client: &InnerClient, + statement: Statement, + params: I, +) -> Result +where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let buf = if log_enabled!(Level::Debug) { + let params = params.into_iter().collect::>(); + debug!( + "executing statement {} with parameters: {:?}", + statement.name(), + BorrowToSqlParamsDebug(params.as_slice()), + ); + encode(client, &statement, params)? + } else { + encode(client, &statement, params)? + }; + let responses = start(client, buf).await?; + Ok(RowStream { + statement, + responses, + rows_affected: None, + _p: PhantomPinned, + }) +} + +pub async fn query_typed( + client: &Arc, + query: &str, + params: I, +) -> Result +where + P: BorrowToSql, + I: IntoIterator, +{ + let buf = { + let params = params.into_iter().collect::>(); + let param_oids = params.iter().map(|(_, t)| t.oid()).collect::>(); + + client.with_buf(|buf| { + frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?; + encode_bind_raw("", params, "", buf)?; + frontend::describe(b'S', "", buf).map_err(Error::encode)?; + frontend::execute("", 0, buf).map_err(Error::encode)?; + frontend::sync(buf); + + Ok(buf.split().freeze()) + })? + }; + + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + loop { + match responses.next().await? { + Message::ParseComplete | Message::BindComplete | Message::ParameterDescription(_) => {} + Message::NoData => { + return Ok(RowStream { + statement: Statement::unnamed(vec![], vec![]), + responses, + rows_affected: None, + _p: PhantomPinned, + }); + } + Message::RowDescription(row_description) => { + let mut columns: Vec = vec![]; + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, field.type_oid()).await?; + let column = Column { + name: field.name().to_string(), + table_oid: Some(field.table_oid()).filter(|n| *n != 0), + column_id: Some(field.column_id()).filter(|n| *n != 0), + r#type: type_, + }; + columns.push(column); + } + return Ok(RowStream { + statement: Statement::unnamed(vec![], columns), + responses, + rows_affected: None, + _p: PhantomPinned, + }); + } + _ => return Err(Error::unexpected_message()), + } + } +} + +pub async fn query_portal( + client: &InnerClient, + portal: &Portal, + max_rows: i32, +) -> Result { + let buf = client.with_buf(|buf| { + frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?; + frontend::sync(buf); + Ok(buf.split().freeze()) + })?; + + let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + Ok(RowStream { + statement: portal.statement().clone(), + responses, + rows_affected: None, + _p: PhantomPinned, + }) +} + +/// Extract the number of rows affected from [`CommandCompleteBody`]. +pub fn extract_row_affected(body: &CommandCompleteBody) -> Result { + let rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + Ok(rows) +} + +pub async fn execute( + client: &InnerClient, + statement: Statement, + params: I, +) -> Result +where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let buf = if log_enabled!(Level::Debug) { + let params = params.into_iter().collect::>(); + debug!( + "executing statement {} with parameters: {:?}", + statement.name(), + BorrowToSqlParamsDebug(params.as_slice()), + ); + encode(client, &statement, params)? + } else { + encode(client, &statement, params)? + }; + let mut responses = start(client, buf).await?; + + let mut rows = 0; + loop { + match responses.next().await? { + Message::DataRow(_) => {} + Message::CommandComplete(body) => { + rows = extract_row_affected(&body)?; + } + Message::EmptyQueryResponse => rows = 0, + Message::ReadyForQuery(_) => return Ok(rows), + _ => return Err(Error::unexpected_message()), + } + } +} + +async fn start(client: &InnerClient, buf: Bytes) -> Result { + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(responses) +} + +pub fn encode(client: &InnerClient, statement: &Statement, params: I) -> Result +where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + client.with_buf(|buf| { + encode_bind(statement, params, "", buf)?; + frontend::execute("", 0, buf).map_err(Error::encode)?; + frontend::sync(buf); + Ok(buf.split().freeze()) + }) +} + +pub fn encode_bind( + statement: &Statement, + params: I, + portal: &str, + buf: &mut BytesMut, +) -> Result<(), Error> +where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let params = params.into_iter(); + if params.len() != statement.params().len() { + return Err(Error::parameters(params.len(), statement.params().len())); + } + + encode_bind_raw( + statement.name(), + params.zip(statement.params().iter().cloned()), + portal, + buf, + ) +} + +fn encode_bind_raw( + statement_name: &str, + params: I, + portal: &str, + buf: &mut BytesMut, +) -> Result<(), Error> +where + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let (param_formats, params): (Vec<_>, Vec<_>) = params + .into_iter() + .map(|(p, ty)| (p.borrow_to_sql().encode_format(&ty) as i16, (p, ty))) + .unzip(); + + let mut error_idx = 0; + let r = frontend::bind( + portal, + statement_name, + param_formats, + params.into_iter().enumerate(), + |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(&ty, buf) { + Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No), + Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes), + Err(e) => { + error_idx = idx; + Err(e) + } + }, + Some(1), + buf, + ); + match r { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + } +} + +pin_project! { + /// A stream of table rows. + pub struct RowStream { + statement: Statement, + responses: Responses, + rows_affected: Option, + #[pin] + _p: PhantomPinned, + } +} + +impl Stream for RowStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + loop { + match ready!(this.responses.poll_next(cx)?) { + Message::DataRow(body) => { + return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) + } + Message::CommandComplete(body) => { + *this.rows_affected = Some(extract_row_affected(&body)?); + } + Message::EmptyQueryResponse | Message::PortalSuspended => {} + Message::ReadyForQuery(_) => return Poll::Ready(None), + _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } + } +} + +impl RowStream { + /// Returns the number of rows affected by the query. + /// + /// This function will return `None` until the stream has been exhausted. + pub fn rows_affected(&self) -> Option { + self.rows_affected + } +} diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs new file mode 100644 index 000000000..ccb8817d0 --- /dev/null +++ b/tokio-postgres/src/row.rs @@ -0,0 +1,275 @@ +//! Rows. + +use crate::row::sealed::{AsName, Sealed}; +use crate::simple_query::SimpleColumn; +use crate::statement::Column; +use crate::types::{FromSql, Type, WrongType}; +use crate::{Error, Statement}; +use fallible_iterator::FallibleIterator; +use postgres_protocol::message::backend::DataRowBody; +use std::fmt; +use std::ops::Range; +use std::str; +use std::sync::Arc; + +mod sealed { + pub trait Sealed {} + + pub trait AsName { + fn as_name(&self) -> &str; + } +} + +impl AsName for Column { + fn as_name(&self) -> &str { + self.name() + } +} + +impl AsName for String { + fn as_name(&self) -> &str { + self + } +} + +/// A trait implemented by types that can index into columns of a row. +/// +/// This cannot be implemented outside of this crate. +pub trait RowIndex: Sealed { + #[doc(hidden)] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName; +} + +impl Sealed for usize {} + +impl RowIndex for usize { + #[inline] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName, + { + if *self >= columns.len() { + None + } else { + Some(*self) + } + } +} + +impl Sealed for str {} + +impl RowIndex for str { + #[inline] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName, + { + if let Some(idx) = columns.iter().position(|d| d.as_name() == self) { + return Some(idx); + }; + + // FIXME ASCII-only case insensitivity isn't really the right thing to + // do. Postgres itself uses a dubious wrapper around tolower and JDBC + // uses the US locale. + columns + .iter() + .position(|d| d.as_name().eq_ignore_ascii_case(self)) + } +} + +impl Sealed for &T where T: ?Sized + Sealed {} + +impl RowIndex for &T +where + T: ?Sized + RowIndex, +{ + #[inline] + fn __idx(&self, columns: &[U]) -> Option + where + U: AsName, + { + T::__idx(*self, columns) + } +} + +/// A row of data returned from the database by a query. +#[derive(Clone)] +pub struct Row { + statement: Statement, + body: DataRowBody, + ranges: Vec>>, +} + +impl fmt::Debug for Row { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Row") + .field("columns", &self.columns()) + .finish() + } +} + +impl Row { + pub(crate) fn new(statement: Statement, body: DataRowBody) -> Result { + let ranges = body.ranges().collect().map_err(Error::parse)?; + Ok(Row { + statement, + body, + ranges, + }) + } + + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[Column] { + self.statement.columns() + } + + /// Determines if the row contains no values. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the number of values in the row. + pub fn len(&self) -> usize { + self.columns().len() + } + + /// Deserializes a value from the row. + /// + /// The value can be specified either by its numeric index in the row, or by its column name. + /// + /// # Panics + /// + /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + #[track_caller] + pub fn get<'a, I, T>(&'a self, idx: I) -> T + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + match self.get_inner(&idx) { + Ok(ok) => ok, + Err(err) => panic!("error retrieving column {}: {}", idx, err), + } + } + + /// Like `Row::get`, but returns a `Result` rather than panicking. + pub fn try_get<'a, I, T>(&'a self, idx: I) -> Result + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + self.get_inner(&idx) + } + + fn get_inner<'a, I, T>(&'a self, idx: &I) -> Result + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + let idx = match idx.__idx(self.columns()) { + Some(idx) => idx, + None => return Err(Error::column(idx.to_string())), + }; + + let ty = self.columns()[idx].type_(); + if !T::accepts(ty) { + return Err(Error::from_sql( + Box::new(WrongType::new::(ty.clone())), + idx, + )); + } + + FromSql::from_sql_nullable(ty, self.col_buffer(idx)).map_err(|e| Error::from_sql(e, idx)) + } + + /// Get the raw bytes for the column at the given index. + fn col_buffer(&self, idx: usize) -> Option<&[u8]> { + let range = self.ranges[idx].to_owned()?; + Some(&self.body.buffer()[range]) + } +} + +impl AsName for SimpleColumn { + fn as_name(&self) -> &str { + self.name() + } +} + +/// A row of data returned from the database by a simple query. +#[derive(Debug)] +pub struct SimpleQueryRow { + columns: Arc<[SimpleColumn]>, + body: DataRowBody, + ranges: Vec>>, +} + +impl SimpleQueryRow { + #[allow(clippy::new_ret_no_self)] + pub(crate) fn new( + columns: Arc<[SimpleColumn]>, + body: DataRowBody, + ) -> Result { + let ranges = body.ranges().collect().map_err(Error::parse)?; + Ok(SimpleQueryRow { + columns, + body, + ranges, + }) + } + + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[SimpleColumn] { + &self.columns + } + + /// Determines if the row contains no values. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the number of values in the row. + pub fn len(&self) -> usize { + self.columns.len() + } + + /// Returns a value from the row. + /// + /// The value can be specified either by its numeric index in the row, or by its column name. + /// + /// # Panics + /// + /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + #[track_caller] + pub fn get(&self, idx: I) -> Option<&str> + where + I: RowIndex + fmt::Display, + { + match self.get_inner(&idx) { + Ok(ok) => ok, + Err(err) => panic!("error retrieving column {}: {}", idx, err), + } + } + + /// Like `SimpleQueryRow::get`, but returns a `Result` rather than panicking. + pub fn try_get(&self, idx: I) -> Result, Error> + where + I: RowIndex + fmt::Display, + { + self.get_inner(&idx) + } + + fn get_inner(&self, idx: &I) -> Result, Error> + where + I: RowIndex + fmt::Display, + { + let idx = match idx.__idx(&self.columns) { + Some(idx) => idx, + None => return Err(Error::column(idx.to_string())), + }; + + let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]); + FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx)) + } +} diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs new file mode 100644 index 000000000..24473b896 --- /dev/null +++ b/tokio-postgres/src/simple_query.rs @@ -0,0 +1,118 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::query::extract_row_affected; +use crate::{Error, SimpleQueryMessage, SimpleQueryRow}; +use bytes::Bytes; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Stream}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol::message::backend::Message; +use postgres_protocol::message::frontend; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +/// Information about a column of a single query row. +#[derive(Debug)] +pub struct SimpleColumn { + name: String, +} + +impl SimpleColumn { + pub(crate) fn new(name: String) -> SimpleColumn { + SimpleColumn { name } + } + + /// Returns the name of the column. + pub fn name(&self) -> &str { + &self.name + } +} + +pub async fn simple_query(client: &InnerClient, query: &str) -> Result { + debug!("executing simple query: {}", query); + + let buf = encode(client, query)?; + let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + Ok(SimpleQueryStream { + responses, + columns: None, + _p: PhantomPinned, + }) +} + +pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Error> { + debug!("executing statement batch: {}", query); + + let buf = encode(client, query)?; + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + loop { + match responses.next().await? { + Message::ReadyForQuery(_) => return Ok(()), + Message::CommandComplete(_) + | Message::EmptyQueryResponse + | Message::RowDescription(_) + | Message::DataRow(_) => {} + _ => return Err(Error::unexpected_message()), + } + } +} + +fn encode(client: &InnerClient, query: &str) -> Result { + client.with_buf(|buf| { + frontend::query(query, buf).map_err(Error::encode)?; + Ok(buf.split().freeze()) + }) +} + +pin_project! { + /// A stream of simple query results. + pub struct SimpleQueryStream { + responses: Responses, + columns: Option>, + #[pin] + _p: PhantomPinned, + } +} + +impl Stream for SimpleQueryStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + match ready!(this.responses.poll_next(cx)?) { + Message::CommandComplete(body) => { + let rows = extract_row_affected(&body)?; + Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows)))) + } + Message::EmptyQueryResponse => { + Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0)))) + } + Message::RowDescription(body) => { + let columns: Arc<[SimpleColumn]> = body + .fields() + .map(|f| Ok(SimpleColumn::new(f.name().to_string()))) + .collect::>() + .map_err(Error::parse)? + .into(); + + *this.columns = Some(columns.clone()); + Poll::Ready(Some(Ok(SimpleQueryMessage::RowDescription(columns)))) + } + Message::DataRow(body) => { + let row = match &this.columns { + Some(columns) => SimpleQueryRow::new(columns.clone(), body)?, + None => return Poll::Ready(Some(Err(Error::unexpected_message()))), + }; + Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))) + } + Message::ReadyForQuery(_) => Poll::Ready(None), + _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } +} diff --git a/tokio-postgres/src/socket.rs b/tokio-postgres/src/socket.rs new file mode 100644 index 000000000..966510d56 --- /dev/null +++ b/tokio-postgres/src/socket.rs @@ -0,0 +1,75 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; + +#[derive(Debug)] +enum Inner { + Tcp(TcpStream), + #[cfg(unix)] + Unix(UnixStream), +} + +/// The standard stream type used by the crate. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +#[derive(Debug)] +pub struct Socket(Inner); + +impl Socket { + pub(crate) fn new_tcp(stream: TcpStream) -> Socket { + Socket(Inner::Tcp(stream)) + } + + #[cfg(unix)] + pub(crate) fn new_unix(stream: UnixStream) -> Socket { + Socket(Inner::Unix(stream)) + } +} + +impl AsyncRead for Socket { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut self.0 { + Inner::Tcp(s) => Pin::new(s).poll_read(cx, buf), + #[cfg(unix)] + Inner::Unix(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for Socket { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut self.0 { + Inner::Tcp(s) => Pin::new(s).poll_write(cx, buf), + #[cfg(unix)] + Inner::Unix(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut self.0 { + Inner::Tcp(s) => Pin::new(s).poll_flush(cx), + #[cfg(unix)] + Inner::Unix(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut self.0 { + Inner::Tcp(s) => Pin::new(s).poll_shutdown(cx), + #[cfg(unix)] + Inner::Unix(s) => Pin::new(s).poll_shutdown(cx), + } + } +} diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs new file mode 100644 index 000000000..4f7ddaec6 --- /dev/null +++ b/tokio-postgres/src/statement.rs @@ -0,0 +1,116 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::types::Type; +use postgres_protocol::message::frontend; +use std::sync::{Arc, Weak}; + +struct StatementInner { + client: Weak, + name: String, + params: Vec, + columns: Vec, +} + +impl Drop for StatementInner { + fn drop(&mut self) { + if self.name.is_empty() { + // Unnamed statements don't need to be closed + return; + } + if let Some(client) = self.client.upgrade() { + let buf = client.with_buf(|buf| { + frontend::close(b'S', &self.name, buf).unwrap(); + frontend::sync(buf); + buf.split().freeze() + }); + let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } + } +} + +/// A prepared statement. +/// +/// Prepared statements can only be used with the connection that created them. +#[derive(Clone)] +pub struct Statement(Arc); + +impl Statement { + pub(crate) fn new( + inner: &Arc, + name: String, + params: Vec, + columns: Vec, + ) -> Statement { + Statement(Arc::new(StatementInner { + client: Arc::downgrade(inner), + name, + params, + columns, + })) + } + + pub(crate) fn unnamed(params: Vec, columns: Vec) -> Statement { + Statement(Arc::new(StatementInner { + client: Weak::new(), + name: String::new(), + params, + columns, + })) + } + + pub(crate) fn name(&self) -> &str { + &self.0.name + } + + /// Returns the expected types of the statement's parameters. + pub fn params(&self) -> &[Type] { + &self.0.params + } + + /// Returns information about the columns returned when the statement is queried. + pub fn columns(&self) -> &[Column] { + &self.0.columns + } +} + +impl std::fmt::Debug for Statement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { + f.debug_struct("Statement") + .field("name", &self.0.name) + .field("params", &self.0.params) + .field("columns", &self.0.columns) + .finish_non_exhaustive() + } +} + +/// Information about a column of a query. +#[derive(Debug)] +pub struct Column { + pub(crate) name: String, + pub(crate) table_oid: Option, + pub(crate) column_id: Option, + pub(crate) r#type: Type, +} + +impl Column { + /// Returns the name of the column. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the OID of the underlying database table. + pub fn table_oid(&self) -> Option { + self.table_oid + } + + /// Return the column ID within the underlying database table. + pub fn column_id(&self) -> Option { + self.column_id + } + + /// Returns the type of the column. + pub fn type_(&self) -> &Type { + &self.r#type + } +} diff --git a/tokio-postgres/src/tls.rs b/tokio-postgres/src/tls.rs index bcb497787..963daed18 100644 --- a/tokio-postgres/src/tls.rs +++ b/tokio-postgres/src/tls.rs @@ -1,83 +1,164 @@ -use bytes::{Buf, BufMut}; -use futures::{Future, Poll}; +//! TLS support. + use std::error::Error; -use std::io::{self, Read, Write}; -use tokio_io::{AsyncRead, AsyncWrite}; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, io}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub(crate) mod private { + pub struct ForcePrivateApi; +} -use proto; +/// Channel binding information returned from a TLS handshake. +pub struct ChannelBinding { + pub(crate) tls_server_end_point: Option>, +} -pub struct Socket(pub(crate) proto::Socket); +impl ChannelBinding { + /// Creates a `ChannelBinding` containing no information. + pub fn none() -> ChannelBinding { + ChannelBinding { + tls_server_end_point: None, + } + } -impl Read for Socket { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - self.0.read(buf) + /// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information. + pub fn tls_server_end_point(tls_server_end_point: Vec) -> ChannelBinding { + ChannelBinding { + tls_server_end_point: Some(tls_server_end_point), + } } } -impl AsyncRead for Socket { - unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { - self.0.prepare_uninitialized_buffer(buf) +/// A constructor of `TlsConnect`ors. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +#[cfg(feature = "runtime")] +pub trait MakeTlsConnect { + /// The stream type created by the `TlsConnect` implementation. + type Stream: TlsStream + Unpin; + /// The `TlsConnect` implementation created by this type. + type TlsConnect: TlsConnect; + /// The error type returned by the `TlsConnect` implementation. + type Error: Into>; + + /// Creates a new `TlsConnect`or. + /// + /// The domain name is provided for certificate verification and SNI. + fn make_tls_connect(&mut self, domain: &str) -> Result; +} + +/// An asynchronous function wrapping a stream in a TLS session. +pub trait TlsConnect { + /// The stream returned by the future. + type Stream: TlsStream + Unpin; + /// The error returned by the future. + type Error: Into>; + /// The future returned by the connector. + type Future: Future>; + + /// Returns a future performing a TLS handshake over the stream. + fn connect(self, stream: S) -> Self::Future; + + #[doc(hidden)] + fn can_connect(&self, _: private::ForcePrivateApi) -> bool { + true } +} + +/// A TLS-wrapped connection to a PostgreSQL database. +pub trait TlsStream: AsyncRead + AsyncWrite { + /// Returns channel binding information for the session. + fn channel_binding(&self) -> ChannelBinding; +} + +/// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error. +/// +/// This can be used when `sslmode` is `none` or `prefer`. +#[derive(Debug, Copy, Clone)] +pub struct NoTls; + +#[cfg(feature = "runtime")] +impl MakeTlsConnect for NoTls { + type Stream = NoTlsStream; + type TlsConnect = NoTls; + type Error = NoTlsError; - fn read_buf(&mut self, buf: &mut B) -> Poll - where - B: BufMut, - { - self.0.read_buf(buf) + fn make_tls_connect(&mut self, _: &str) -> Result { + Ok(NoTls) } } -impl Write for Socket { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.write(buf) +impl TlsConnect for NoTls { + type Stream = NoTlsStream; + type Error = NoTlsError; + type Future = NoTlsFuture; + + fn connect(self, _: S) -> NoTlsFuture { + NoTlsFuture(()) } - fn flush(&mut self) -> io::Result<()> { - self.0.flush() + fn can_connect(&self, _: private::ForcePrivateApi) -> bool { + false } } -impl AsyncWrite for Socket { - fn shutdown(&mut self) -> Poll<(), io::Error> { - self.0.shutdown() +/// The future returned by `NoTls`. +pub struct NoTlsFuture(()); + +impl Future for NoTlsFuture { + type Output = Result; + + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + Poll::Ready(Err(NoTlsError(()))) } +} - fn write_buf(&mut self, buf: &mut B) -> Poll - where - B: Buf, - { - self.0.write_buf(buf) +/// The TLS "stream" type produced by the `NoTls` connector. +/// +/// Since `NoTls` doesn't support TLS, this type is uninhabited. +pub enum NoTlsStream {} + +impl AsyncRead for NoTlsStream { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut ReadBuf<'_>, + ) -> Poll> { + match *self {} } } -pub trait TlsConnect { - fn connect( - &self, - domain: &str, - socket: Socket, - ) -> Box, Error = Box> + Sync + Send>; +impl AsyncWrite for NoTlsStream { + fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll> { + match *self {} + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match *self {} + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match *self {} + } } -pub trait TlsStream: 'static + Sync + Send + AsyncRead + AsyncWrite { - /// Returns the data associated with the `tls-unique` channel binding type as described in - /// [RFC 5929], if supported. - /// - /// An implementation only needs to support one of this or `tls_server_end_point`. - /// - /// [RFC 5929]: https://tools.ietf.org/html/rfc5929 - fn tls_unique(&self) -> Option> { - None +impl TlsStream for NoTlsStream { + fn channel_binding(&self) -> ChannelBinding { + match *self {} } +} - /// Returns the data associated with the `tls-server-end-point` channel binding type as - /// described in [RFC 5929], if supported. - /// - /// An implementation only needs to support one of this or `tls_unique`. - /// - /// [RFC 5929]: https://tools.ietf.org/html/rfc5929 - fn tls_server_end_point(&self) -> Option> { - None +/// The error returned by `NoTls`. +#[derive(Debug)] +pub struct NoTlsError(()); + +impl fmt::Display for NoTlsError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("no TLS implementation configured") } } -impl TlsStream for proto::Socket {} +impl Error for NoTlsError {} diff --git a/tokio-postgres/src/to_statement.rs b/tokio-postgres/src/to_statement.rs new file mode 100644 index 000000000..7e1299272 --- /dev/null +++ b/tokio-postgres/src/to_statement.rs @@ -0,0 +1,57 @@ +use crate::to_statement::private::{Sealed, ToStatementType}; +use crate::Statement; + +mod private { + use crate::{Client, Error, Statement}; + + pub trait Sealed {} + + pub enum ToStatementType<'a> { + Statement(&'a Statement), + Query(&'a str), + } + + impl ToStatementType<'_> { + pub async fn into_statement(self, client: &Client) -> Result { + match self { + ToStatementType::Statement(s) => Ok(s.clone()), + ToStatementType::Query(s) => client.prepare(s).await, + } + } + } +} + +/// A trait abstracting over prepared and unprepared statements. +/// +/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which +/// was prepared previously. +/// +/// This trait is "sealed" and cannot be implemented by anything outside this crate. +pub trait ToStatement: Sealed { + #[doc(hidden)] + fn __convert(&self) -> ToStatementType<'_>; +} + +impl ToStatement for Statement { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Statement(self) + } +} + +impl Sealed for Statement {} + +impl ToStatement for str { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Query(self) + } +} + +impl Sealed for str {} + +impl ToStatement for String { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Query(self) + } +} + +impl Sealed for String {} diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs new file mode 100644 index 000000000..782c476c4 --- /dev/null +++ b/tokio-postgres/src/transaction.rs @@ -0,0 +1,332 @@ +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::copy_out::CopyOutStream; +use crate::query::RowStream; +#[cfg(feature = "runtime")] +use crate::tls::MakeTlsConnect; +use crate::tls::TlsConnect; +use crate::types::{BorrowToSql, ToSql, Type}; +#[cfg(feature = "runtime")] +use crate::Socket; +use crate::{ + bind, query, slice_iter, CancelToken, Client, CopyInSink, Error, Portal, Row, + SimpleQueryMessage, Statement, ToStatement, +}; +use bytes::Buf; +use futures_util::TryStreamExt; +use postgres_protocol::message::frontend; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// A representation of a PostgreSQL database transaction. +/// +/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the +/// transaction. Transactions can be nested, with inner transactions implemented via safepoints. +pub struct Transaction<'a> { + client: &'a mut Client, + savepoint: Option, + done: bool, +} + +/// A representation of a PostgreSQL database savepoint. +struct Savepoint { + name: String, + depth: u32, +} + +impl Drop for Transaction<'_> { + fn drop(&mut self) { + if self.done { + return; + } + + let query = if let Some(sp) = self.savepoint.as_ref() { + format!("ROLLBACK TO {}", sp.name) + } else { + "ROLLBACK".to_string() + }; + let buf = self.client.inner().with_buf(|buf| { + frontend::query(&query, buf).unwrap(); + buf.split().freeze() + }); + let _ = self + .client + .inner() + .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } +} + +impl<'a> Transaction<'a> { + pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> { + Transaction { + client, + savepoint: None, + done: false, + } + } + + /// Consumes the transaction, committing all changes made within it. + pub async fn commit(mut self) -> Result<(), Error> { + self.done = true; + let query = if let Some(sp) = self.savepoint.as_ref() { + format!("RELEASE {}", sp.name) + } else { + "COMMIT".to_string() + }; + self.client.batch_execute(&query).await + } + + /// Rolls the transaction back, discarding all changes made within it. + /// + /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. + pub async fn rollback(mut self) -> Result<(), Error> { + self.done = true; + let query = if let Some(sp) = self.savepoint.as_ref() { + format!("ROLLBACK TO {}", sp.name) + } else { + "ROLLBACK".to_string() + }; + self.client.batch_execute(&query).await + } + + /// Like `Client::prepare`. + pub async fn prepare(&self, query: &str) -> Result { + self.client.prepare(query).await + } + + /// Like `Client::prepare_typed`. + pub async fn prepare_typed( + &self, + query: &str, + parameter_types: &[Type], + ) -> Result { + self.client.prepare_typed(query, parameter_types).await + } + + /// Like `Client::query`. + pub async fn query( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.client.query(statement, params).await + } + + /// Like `Client::query_one`. + pub async fn query_one( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + ToStatement, + { + self.client.query_one(statement, params).await + } + + /// Like `Client::query_opt`. + pub async fn query_opt( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.client.query_opt(statement, params).await + } + + /// Like `Client::query_raw`. + pub async fn query_raw(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + self.client.query_raw(statement, params).await + } + + /// Like `Client::query_typed`. + pub async fn query_typed( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> Result, Error> { + self.client.query_typed(statement, params).await + } + + /// Like `Client::query_typed_raw`. + pub async fn query_typed_raw(&self, query: &str, params: I) -> Result + where + P: BorrowToSql, + I: IntoIterator, + { + self.client.query_typed_raw(query, params).await + } + + /// Like `Client::execute`. + pub async fn execute( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + ToStatement, + { + self.client.execute(statement, params).await + } + + /// Like `Client::execute_iter`. + pub async fn execute_raw(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + self.client.execute_raw(statement, params).await + } + + /// Binds a statement to a set of parameters, creating a `Portal` which can be incrementally queried. + /// + /// Portals only last for the duration of the transaction in which they are created, and can only be used on the + /// connection that created them. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + pub async fn bind( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + ToStatement, + { + self.bind_raw(statement, slice_iter(params)).await + } + + /// A maximally flexible version of [`bind`]. + /// + /// [`bind`]: #method.bind + pub async fn bind_raw(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + P: BorrowToSql, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let statement = statement.__convert().into_statement(self.client).await?; + bind::bind(self.client.inner(), statement, params).await + } + + /// Continues execution of a portal, returning a stream of the resulting rows. + /// + /// Unlike `query`, portals can be incrementally evaluated by limiting the number of rows returned in each call to + /// `query_portal`. If the requested number is negative or 0, all rows will be returned. + pub async fn query_portal(&self, portal: &Portal, max_rows: i32) -> Result, Error> { + self.query_portal_raw(portal, max_rows) + .await? + .try_collect() + .await + } + + /// The maximally flexible version of [`query_portal`]. + /// + /// [`query_portal`]: #method.query_portal + pub async fn query_portal_raw( + &self, + portal: &Portal, + max_rows: i32, + ) -> Result { + query::query_portal(self.client.inner(), portal, max_rows).await + } + + /// Like `Client::copy_in`. + pub async fn copy_in(&self, statement: &T) -> Result, Error> + where + T: ?Sized + ToStatement, + U: Buf + 'static + Send, + { + self.client.copy_in(statement).await + } + + /// Like `Client::copy_out`. + pub async fn copy_out(&self, statement: &T) -> Result + where + T: ?Sized + ToStatement, + { + self.client.copy_out(statement).await + } + + /// Like `Client::simple_query`. + pub async fn simple_query(&self, query: &str) -> Result, Error> { + self.client.simple_query(query).await + } + + /// Like `Client::batch_execute`. + pub async fn batch_execute(&self, query: &str) -> Result<(), Error> { + self.client.batch_execute(query).await + } + + /// Like `Client::cancel_token`. + pub fn cancel_token(&self) -> CancelToken { + self.client.cancel_token() + } + + /// Like `Client::cancel_query`. + #[cfg(feature = "runtime")] + #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")] + pub async fn cancel_query(&self, tls: T) -> Result<(), Error> + where + T: MakeTlsConnect, + { + #[allow(deprecated)] + self.client.cancel_query(tls).await + } + + /// Like `Client::cancel_query_raw`. + #[deprecated(since = "0.6.0", note = "use Transaction::cancel_token() instead")] + pub async fn cancel_query_raw(&self, stream: S, tls: T) -> Result<(), Error> + where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, + { + #[allow(deprecated)] + self.client.cancel_query_raw(stream, tls).await + } + + /// Like `Client::transaction`, but creates a nested transaction via a savepoint. + pub async fn transaction(&mut self) -> Result, Error> { + self._savepoint(None).await + } + + /// Like `Client::transaction`, but creates a nested transaction via a savepoint with the specified name. + pub async fn savepoint(&mut self, name: I) -> Result, Error> + where + I: Into, + { + self._savepoint(Some(name.into())).await + } + + async fn _savepoint(&mut self, name: Option) -> Result, Error> { + let depth = self.savepoint.as_ref().map_or(0, |sp| sp.depth) + 1; + let name = name.unwrap_or_else(|| format!("sp_{}", depth)); + let query = format!("SAVEPOINT {}", name); + self.batch_execute(&query).await?; + + Ok(Transaction { + client: self.client, + savepoint: Some(Savepoint { name, depth }), + done: false, + }) + } + + /// Returns a reference to the underlying `Client`. + pub fn client(&self) -> &Client { + self.client + } +} diff --git a/tokio-postgres/src/transaction_builder.rs b/tokio-postgres/src/transaction_builder.rs new file mode 100644 index 000000000..88c883176 --- /dev/null +++ b/tokio-postgres/src/transaction_builder.rs @@ -0,0 +1,149 @@ +use postgres_protocol::message::frontend; + +use crate::{codec::FrontendMessage, connection::RequestMessages, Client, Error, Transaction}; + +/// The isolation level of a database transaction. +#[derive(Debug, Copy, Clone)] +#[non_exhaustive] +pub enum IsolationLevel { + /// Equivalent to `ReadCommitted`. + ReadUncommitted, + + /// An individual statement in the transaction will see rows committed before it began. + ReadCommitted, + + /// All statements in the transaction will see the same view of rows committed before the first query in the + /// transaction. + RepeatableRead, + + /// The reads and writes in this transaction must be able to be committed as an atomic "unit" with respect to reads + /// and writes of all other concurrent serializable transactions without interleaving. + Serializable, +} + +/// A builder for database transactions. +pub struct TransactionBuilder<'a> { + client: &'a mut Client, + isolation_level: Option, + read_only: Option, + deferrable: Option, +} + +impl<'a> TransactionBuilder<'a> { + pub(crate) fn new(client: &'a mut Client) -> TransactionBuilder<'a> { + TransactionBuilder { + client, + isolation_level: None, + read_only: None, + deferrable: None, + } + } + + /// Sets the isolation level of the transaction. + pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self { + self.isolation_level = Some(isolation_level); + self + } + + /// Sets the access mode of the transaction. + pub fn read_only(mut self, read_only: bool) -> Self { + self.read_only = Some(read_only); + self + } + + /// Sets the deferrability of the transaction. + /// + /// If the transaction is also serializable and read only, creation of the transaction may block, but when it + /// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to + /// serialization failure. + pub fn deferrable(mut self, deferrable: bool) -> Self { + self.deferrable = Some(deferrable); + self + } + + /// Begins the transaction. + /// + /// The transaction will roll back by default - use the `commit` method to commit it. + pub async fn start(self) -> Result, Error> { + let mut query = "START TRANSACTION".to_string(); + let mut first = true; + + if let Some(level) = self.isolation_level { + first = false; + + query.push_str(" ISOLATION LEVEL "); + let level = match level { + IsolationLevel::ReadUncommitted => "READ UNCOMMITTED", + IsolationLevel::ReadCommitted => "READ COMMITTED", + IsolationLevel::RepeatableRead => "REPEATABLE READ", + IsolationLevel::Serializable => "SERIALIZABLE", + }; + query.push_str(level); + } + + if let Some(read_only) = self.read_only { + if !first { + query.push(','); + } + first = false; + + let s = if read_only { + " READ ONLY" + } else { + " READ WRITE" + }; + query.push_str(s); + } + + if let Some(deferrable) = self.deferrable { + if !first { + query.push(','); + } + + let s = if deferrable { + " DEFERRABLE" + } else { + " NOT DEFERRABLE" + }; + query.push_str(s); + } + + struct RollbackIfNotDone<'me> { + client: &'me Client, + done: bool, + } + + impl Drop for RollbackIfNotDone<'_> { + fn drop(&mut self) { + if self.done { + return; + } + + let buf = self.client.inner().with_buf(|buf| { + frontend::query("ROLLBACK", buf).unwrap(); + buf.split().freeze() + }); + let _ = self + .client + .inner() + .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } + } + + // This is done as `Future` created by this method can be dropped after + // `RequestMessages` is synchronously send to the `Connection` by + // `batch_execute()`, but before `Responses` is asynchronously polled to + // completion. In that case `Transaction` won't be created and thus + // won't be rolled back. + { + let mut cleaner = RollbackIfNotDone { + client: self.client, + done: false, + }; + self.client.batch_execute(&query).await?; + cleaner.done = true; + } + + Ok(Transaction::new(self.client)) + } +} diff --git a/tokio-postgres/src/types.rs b/tokio-postgres/src/types.rs new file mode 100644 index 000000000..b2e15d059 --- /dev/null +++ b/tokio-postgres/src/types.rs @@ -0,0 +1,6 @@ +//! Types. +//! +//! This module is a reexport of the `postgres_types` crate. + +#[doc(inline)] +pub use postgres_types::*; diff --git a/tokio-postgres/tests/test.rs b/tokio-postgres/tests/test.rs deleted file mode 100644 index 015a6b8e9..000000000 --- a/tokio-postgres/tests/test.rs +++ /dev/null @@ -1,694 +0,0 @@ -extern crate env_logger; -extern crate tokio; -extern crate tokio_postgres; - -#[macro_use] -extern crate futures; -#[macro_use] -extern crate log; - -use futures::future; -use futures::stream; -use futures::sync::mpsc; -use std::error::Error; -use std::time::{Duration, Instant}; -use tokio::prelude::*; -use tokio::runtime::current_thread::Runtime; -use tokio::timer::Delay; -use tokio_postgres::error::SqlState; -use tokio_postgres::types::{Kind, Type}; -use tokio_postgres::{AsyncMessage, TlsMode}; - -fn smoke_test(url: &str) { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect(url.parse().unwrap(), TlsMode::None); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - let prepare = client.prepare("SELECT 1::INT4"); - let statement = runtime.block_on(prepare).unwrap(); - let select = client.query(&statement, &[]).collect().map(|rows| { - assert_eq!(rows.len(), 1); - assert_eq!(rows[0].get::<_, i32>(0), 1); - }); - runtime.block_on(select).unwrap(); - - drop(statement); - drop(client); - runtime.run().unwrap(); -} - -#[test] -fn plain_password_missing() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://pass_user@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - runtime.block_on(handshake).err().unwrap(); -} - -#[test] -fn plain_password_wrong() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://pass_user:foo@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - match runtime.block_on(handshake) { - Ok(_) => panic!("unexpected success"), - Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {} - Err(e) => panic!("{}", e), - } -} - -#[test] -fn plain_password_ok() { - smoke_test("postgres://pass_user:password@localhost:5433/postgres"); -} - -#[test] -fn md5_password_missing() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://md5_user@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - runtime.block_on(handshake).err().unwrap(); -} - -#[test] -fn md5_password_wrong() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://md5_user:foo@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - match runtime.block_on(handshake) { - Ok(_) => panic!("unexpected success"), - Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {} - Err(e) => panic!("{}", e), - } -} - -#[test] -fn md5_password_ok() { - smoke_test("postgres://md5_user:password@localhost:5433/postgres"); -} - -#[test] -fn scram_password_missing() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://scram_user@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - runtime.block_on(handshake).err().unwrap(); -} - -#[test] -fn scram_password_wrong() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://scram_user:foo@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - match runtime.block_on(handshake) { - Ok(_) => panic!("unexpected success"), - Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {} - Err(e) => panic!("{}", e), - } -} - -#[test] -fn scram_password_ok() { - smoke_test("postgres://scram_user:password@localhost:5433/postgres"); -} - -#[test] -fn pipelined_prepare() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - let prepare1 = client.prepare("SELECT $1::HSTORE[]"); - let prepare2 = client.prepare("SELECT $1::HSTORE[]"); - let prepare = prepare1.join(prepare2); - runtime.block_on(prepare).unwrap(); - - drop(client); - runtime.run().unwrap(); -} - -#[test] -fn insert_select() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)")) - .unwrap(); - - let insert = client.prepare("INSERT INTO foo (name) VALUES ($1), ($2)"); - let select = client.prepare("SELECT id, name FROM foo ORDER BY id"); - let prepare = insert.join(select); - let (insert, select) = runtime.block_on(prepare).unwrap(); - - let insert = client - .execute(&insert, &[&"alice", &"bob"]) - .map(|n| assert_eq!(n, 2)); - let select = client.query(&select, &[]).collect().map(|rows| { - assert_eq!(rows.len(), 2); - assert_eq!(rows[0].get::<_, i32>(0), 1); - assert_eq!(rows[0].get::<_, &str>(1), "alice"); - assert_eq!(rows[1].get::<_, i32>(0), 2); - assert_eq!(rows[1].get::<_, &str>(1), "bob"); - }); - let tests = insert.join(select); - runtime.block_on(tests).unwrap(); -} - -#[test] -fn query_portal() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute( - "CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT); - INSERT INTO foo (name) VALUES ('alice'), ('bob'), ('charlie'); - BEGIN;", - )).unwrap(); - - let statement = runtime - .block_on(client.prepare("SELECT id, name FROM foo ORDER BY id")) - .unwrap(); - let portal = runtime.block_on(client.bind(&statement, &[])).unwrap(); - - let f1 = client.query_portal(&portal, 2).collect(); - let f2 = client.query_portal(&portal, 2).collect(); - let f3 = client.query_portal(&portal, 2).collect(); - let (r1, r2, r3) = runtime.block_on(f1.join3(f2, f3)).unwrap(); - - assert_eq!(r1.len(), 2); - assert_eq!(r1[0].get::<_, i32>(0), 1); - assert_eq!(r1[0].get::<_, &str>(1), "alice"); - assert_eq!(r1[1].get::<_, i32>(0), 2); - assert_eq!(r1[1].get::<_, &str>(1), "bob"); - - assert_eq!(r2.len(), 1); - assert_eq!(r2[0].get::<_, i32>(0), 3); - assert_eq!(r2[0].get::<_, &str>(1), "charlie"); - - assert_eq!(r3.len(), 0); -} - -#[test] -fn cancel_query() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let cancel_data = connection.cancel_data(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - let sleep = client - .batch_execute("SELECT pg_sleep(100)") - .then(|r| match r { - Ok(_) => panic!("unexpected success"), - Err(ref e) if e.code() == Some(&SqlState::QUERY_CANCELED) => Ok::<(), ()>(()), - Err(e) => panic!("unexpected error {}", e), - }); - let cancel = Delay::new(Instant::now() + Duration::from_millis(100)) - .then(|r| { - r.unwrap(); - tokio_postgres::cancel_query( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - cancel_data, - ) - }).then(|r| { - r.unwrap(); - Ok::<(), ()>(()) - }); - - let ((), ()) = runtime.block_on(sleep.join(cancel)).unwrap(); -} - -#[test] -fn custom_enum() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute( - "CREATE TYPE pg_temp.mood AS ENUM ( - 'sad', - 'ok', - 'happy' - )", - )).unwrap(); - - let select = client.prepare("SELECT $1::mood"); - let select = runtime.block_on(select).unwrap(); - - let ty = &select.params()[0]; - assert_eq!("mood", ty.name()); - assert_eq!( - &Kind::Enum(vec![ - "sad".to_string(), - "ok".to_string(), - "happy".to_string(), - ]), - ty.kind() - ); -} - -#[test] -fn custom_domain() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute( - "CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)", - )).unwrap(); - - let select = client.prepare("SELECT $1::session_id"); - let select = runtime.block_on(select).unwrap(); - - let ty = &select.params()[0]; - assert_eq!("session_id", ty.name()); - assert_eq!(&Kind::Domain(Type::BYTEA), ty.kind()); -} - -#[test] -fn custom_array() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - let select = client.prepare("SELECT $1::HSTORE[]"); - let select = runtime.block_on(select).unwrap(); - - let ty = &select.params()[0]; - assert_eq!("_hstore", ty.name()); - match *ty.kind() { - Kind::Array(ref ty) => { - assert_eq!("hstore", ty.name()); - assert_eq!(&Kind::Simple, ty.kind()); - } - _ => panic!("unexpected kind"), - } -} - -#[test] -fn custom_composite() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute( - "CREATE TYPE pg_temp.inventory_item AS ( - name TEXT, - supplier INTEGER, - price NUMERIC - )", - )).unwrap(); - - let select = client.prepare("SELECT $1::inventory_item"); - let select = runtime.block_on(select).unwrap(); - - let ty = &select.params()[0]; - assert_eq!(ty.name(), "inventory_item"); - match *ty.kind() { - Kind::Composite(ref fields) => { - assert_eq!(fields[0].name(), "name"); - assert_eq!(fields[0].type_(), &Type::TEXT); - assert_eq!(fields[1].name(), "supplier"); - assert_eq!(fields[1].type_(), &Type::INT4); - assert_eq!(fields[2].name(), "price"); - assert_eq!(fields[2].type_(), &Type::NUMERIC); - } - ref t => panic!("bad type {:?}", t), - } -} - -#[test] -fn custom_range() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute( - "CREATE TYPE pg_temp.floatrange AS RANGE ( - subtype = float8, - subtype_diff = float8mi - )", - )).unwrap(); - - let select = client.prepare("SELECT $1::floatrange"); - let select = runtime.block_on(select).unwrap(); - - let ty = &select.params()[0]; - assert_eq!("floatrange", ty.name()); - assert_eq!(&Kind::Range(Type::FLOAT8), ty.kind()); -} - -#[test] -fn custom_simple() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - let (mut client, connection) = runtime.block_on(handshake).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - let select = client.prepare("SELECT $1::HSTORE"); - let select = runtime.block_on(select).unwrap(); - - let ty = &select.params()[0]; - assert_eq!("hstore", ty.name()); - assert_eq!(&Kind::Simple, ty.kind()); -} - -#[test] -fn notifications() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let handshake = tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - ); - let (mut client, mut connection) = runtime.block_on(handshake).unwrap(); - - let (tx, rx) = mpsc::unbounded(); - let connection = future::poll_fn(move || { - while let Some(message) = try_ready!(connection.poll_message().map_err(|e| panic!("{}", e))) - { - if let AsyncMessage::Notification(notification) = message { - debug!("received {}", notification.payload); - tx.unbounded_send(notification).unwrap(); - } - } - - Ok(Async::Ready(())) - }); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute("LISTEN test_notifications")) - .unwrap(); - - runtime - .block_on(client.batch_execute("NOTIFY test_notifications, 'hello'")) - .unwrap(); - - runtime - .block_on(client.batch_execute("NOTIFY test_notifications, 'world'")) - .unwrap(); - - drop(client); - runtime.run().unwrap(); - - let notifications = rx.collect().wait().unwrap(); - assert_eq!(notifications.len(), 2); - assert_eq!(notifications[0].channel, "test_notifications"); - assert_eq!(notifications[0].payload, "hello"); - assert_eq!(notifications[1].channel, "test_notifications"); - assert_eq!(notifications[1].payload, "world"); -} - -#[test] -fn transaction_commit() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let (mut client, connection) = runtime - .block_on(tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - )).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute( - "CREATE TEMPORARY TABLE foo ( - id SERIAL, - name TEXT - )", - )).unwrap(); - - let f = client.batch_execute("INSERT INTO foo (name) VALUES ('steven')"); - runtime.block_on(client.transaction(f)).unwrap(); - - let rows = runtime - .block_on( - client - .prepare("SELECT name FROM foo") - .and_then(|s| client.query(&s, &[]).collect()), - ).unwrap(); - - assert_eq!(rows.len(), 1); - assert_eq!(rows[0].get::<_, &str>(0), "steven"); -} - -#[test] -fn transaction_abort() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let (mut client, connection) = runtime - .block_on(tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - )).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute( - "CREATE TEMPORARY TABLE foo ( - id SERIAL, - name TEXT - )", - )).unwrap(); - - let f = client - .batch_execute("INSERT INTO foo (name) VALUES ('steven')") - .map_err(|e| Box::new(e) as Box) - .and_then(|_| Err::<(), _>(Box::::from(""))); - runtime.block_on(client.transaction(f)).unwrap_err(); - - let rows = runtime - .block_on( - client - .prepare("SELECT name FROM foo") - .and_then(|s| client.query(&s, &[]).collect()), - ).unwrap(); - - assert_eq!(rows.len(), 0); -} - -#[test] -fn copy_in() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let (mut client, connection) = runtime - .block_on(tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - )).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute( - "CREATE TEMPORARY TABLE foo ( - id INTEGER, - name TEXT - )", - )).unwrap(); - - let stream = stream::iter_ok::<_, String>(vec![b"1\tjim\n".to_vec(), b"2\tjoe\n".to_vec()]); - let rows = runtime - .block_on( - client - .prepare("COPY foo FROM STDIN") - .and_then(|s| client.copy_in(&s, &[], stream)), - ).unwrap(); - assert_eq!(rows, 2); - - let rows = runtime - .block_on( - client - .prepare("SELECT id, name FROM foo ORDER BY id") - .and_then(|s| client.query(&s, &[]).collect()), - ).unwrap(); - - assert_eq!(rows.len(), 2); - assert_eq!(rows[0].get::<_, i32>(0), 1); - assert_eq!(rows[0].get::<_, &str>(1), "jim"); - assert_eq!(rows[1].get::<_, i32>(0), 2); - assert_eq!(rows[1].get::<_, &str>(1), "joe"); -} - -#[test] -fn copy_in_error() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let (mut client, connection) = runtime - .block_on(tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - )).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute( - "CREATE TEMPORARY TABLE foo ( - id INTEGER, - name TEXT - )", - )).unwrap(); - - let stream = stream::iter_result(vec![Ok(b"1\tjim\n".to_vec()), Err("asdf")]); - let error = runtime - .block_on( - client - .prepare("COPY foo FROM STDIN") - .and_then(|s| client.copy_in(&s, &[], stream)), - ).unwrap_err(); - assert!(error.to_string().contains("asdf")); - - let rows = runtime - .block_on( - client - .prepare("SELECT id, name FROM foo ORDER BY id") - .and_then(|s| client.query(&s, &[]).collect()), - ).unwrap(); - - assert_eq!(rows.len(), 0); -} - -#[test] -fn copy_out() { - let _ = env_logger::try_init(); - let mut runtime = Runtime::new().unwrap(); - - let (mut client, connection) = runtime - .block_on(tokio_postgres::connect( - "postgres://postgres@localhost:5433".parse().unwrap(), - TlsMode::None, - )).unwrap(); - let connection = connection.map_err(|e| panic!("{}", e)); - runtime.handle().spawn(connection).unwrap(); - - runtime - .block_on(client.batch_execute( - "CREATE TEMPORARY TABLE foo ( - id SERIAL, - name TEXT - ); - INSERT INTO foo (name) VALUES ('jim'), ('joe');", - )).unwrap(); - - let data = runtime - .block_on( - client - .prepare("COPY foo TO STDOUT") - .and_then(|s| client.copy_out(&s, &[]).concat2()), - ).unwrap(); - assert_eq!(&data[..], b"1\tjim\n2\tjoe\n"); -} diff --git a/tokio-postgres/tests/test/binary_copy.rs b/tokio-postgres/tests/test/binary_copy.rs new file mode 100644 index 000000000..94b96ab85 --- /dev/null +++ b/tokio-postgres/tests/test/binary_copy.rs @@ -0,0 +1,203 @@ +use crate::connect; +use futures_util::{pin_mut, TryStreamExt}; +use tokio_postgres::binary_copy::{BinaryCopyInWriter, BinaryCopyOutStream}; +use tokio_postgres::types::Type; + +#[tokio::test] +async fn write_basic() { + let client = connect("user=postgres").await; + + client + .batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)") + .await + .unwrap(); + + let sink = client + .copy_in("COPY foo (id, bar) FROM STDIN BINARY") + .await + .unwrap(); + let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::TEXT]); + pin_mut!(writer); + writer.as_mut().write(&[&1i32, &"foobar"]).await.unwrap(); + writer + .as_mut() + .write(&[&2i32, &None::<&str>]) + .await + .unwrap(); + writer.finish().await.unwrap(); + + let rows = client + .query("SELECT id, bar FROM foo ORDER BY id", &[]) + .await + .unwrap(); + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::<_, i32>(0), 1); + assert_eq!(rows[0].get::<_, Option<&str>>(1), Some("foobar")); + assert_eq!(rows[1].get::<_, i32>(0), 2); + assert_eq!(rows[1].get::<_, Option<&str>>(1), None); +} + +#[tokio::test] +async fn write_many_rows() { + let client = connect("user=postgres").await; + + client + .batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar TEXT)") + .await + .unwrap(); + + let sink = client + .copy_in("COPY foo (id, bar) FROM STDIN BINARY") + .await + .unwrap(); + let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::TEXT]); + pin_mut!(writer); + + for i in 0..10_000i32 { + writer + .as_mut() + .write(&[&i, &format!("the value for {}", i)]) + .await + .unwrap(); + } + + writer.finish().await.unwrap(); + + let rows = client + .query("SELECT id, bar FROM foo ORDER BY id", &[]) + .await + .unwrap(); + for (i, row) in rows.iter().enumerate() { + assert_eq!(row.get::<_, i32>(0), i as i32); + assert_eq!(row.get::<_, &str>(1), format!("the value for {}", i)); + } +} + +#[tokio::test] +async fn write_big_rows() { + let client = connect("user=postgres").await; + + client + .batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)") + .await + .unwrap(); + + let sink = client + .copy_in("COPY foo (id, bar) FROM STDIN BINARY") + .await + .unwrap(); + let writer = BinaryCopyInWriter::new(sink, &[Type::INT4, Type::BYTEA]); + pin_mut!(writer); + + for i in 0..2i32 { + writer + .as_mut() + .write(&[&i, &vec![i as u8; 128 * 1024]]) + .await + .unwrap(); + } + + writer.finish().await.unwrap(); + + let rows = client + .query("SELECT id, bar FROM foo ORDER BY id", &[]) + .await + .unwrap(); + for (i, row) in rows.iter().enumerate() { + assert_eq!(row.get::<_, i32>(0), i as i32); + assert_eq!(row.get::<_, &[u8]>(1), &*vec![i as u8; 128 * 1024]); + } +} + +#[tokio::test] +async fn read_basic() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo (id INT, bar TEXT); + INSERT INTO foo (id, bar) VALUES (1, 'foobar'), (2, NULL); + ", + ) + .await + .unwrap(); + + let stream = client + .copy_out("COPY foo (id, bar) TO STDIN BINARY") + .await + .unwrap(); + let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::TEXT]) + .try_collect::>() + .await + .unwrap(); + assert_eq!(rows.len(), 2); + + assert_eq!(rows[0].get::(0), 1); + assert_eq!(rows[0].get::>(1), Some("foobar")); + assert_eq!(rows[1].get::(0), 2); + assert_eq!(rows[1].get::>(1), None); +} + +#[tokio::test] +async fn read_many_rows() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo (id INT, bar TEXT); + INSERT INTO foo (id, bar) SELECT i, 'the value for ' || i FROM generate_series(0, 9999) i;" + ) + .await + .unwrap(); + + let stream = client + .copy_out("COPY foo (id, bar) TO STDIN BINARY") + .await + .unwrap(); + let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::TEXT]) + .try_collect::>() + .await + .unwrap(); + assert_eq!(rows.len(), 10_000); + + for (i, row) in rows.iter().enumerate() { + assert_eq!(row.get::(0), i as i32); + assert_eq!(row.get::<&str>(1), format!("the value for {}", i)); + } +} + +#[tokio::test] +async fn read_big_rows() { + let client = connect("user=postgres").await; + + client + .batch_execute("CREATE TEMPORARY TABLE foo (id INT, bar BYTEA)") + .await + .unwrap(); + for i in 0..2i32 { + client + .execute( + "INSERT INTO foo (id, bar) VALUES ($1, $2)", + &[&i, &vec![i as u8; 128 * 1024]], + ) + .await + .unwrap(); + } + + let stream = client + .copy_out("COPY foo (id, bar) TO STDIN BINARY") + .await + .unwrap(); + let rows = BinaryCopyOutStream::new(stream, &[Type::INT4, Type::BYTEA]) + .try_collect::>() + .await + .unwrap(); + assert_eq!(rows.len(), 2); + + for (i, row) in rows.iter().enumerate() { + assert_eq!(row.get::(0), i as i32); + assert_eq!(row.get::<&[u8]>(1), &vec![i as u8; 128 * 1024][..]); + } +} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs new file mode 100644 index 000000000..9a6aa26fe --- /dev/null +++ b/tokio-postgres/tests/test/main.rs @@ -0,0 +1,1081 @@ +#![warn(rust_2018_idioms)] + +use bytes::{Bytes, BytesMut}; +use futures_channel::mpsc; +use futures_util::{ + future, join, pin_mut, stream, try_join, Future, FutureExt, SinkExt, StreamExt, TryStreamExt, +}; +use pin_project_lite::pin_project; +use std::fmt::Write; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::time::Duration; +use tokio::net::TcpStream; +use tokio::time; +use tokio_postgres::error::SqlState; +use tokio_postgres::tls::{NoTls, NoTlsStream}; +use tokio_postgres::types::{Kind, Type}; +use tokio_postgres::{ + AsyncMessage, Client, Config, Connection, Error, IsolationLevel, SimpleQueryMessage, +}; + +mod binary_copy; +mod parse; +#[cfg(feature = "runtime")] +mod runtime; +mod types; + +pin_project! { + /// Polls `F` at most `polls_left` times returning `Some(F::Output)` if + /// [`Future`] returned [`Poll::Ready`] or [`None`] otherwise. + struct Cancellable { + #[pin] + fut: F, + polls_left: usize, + } +} + +impl Future for Cancellable { + type Output = Option; + + fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + let this = self.project(); + match this.fut.poll(ctx) { + Poll::Ready(r) => Poll::Ready(Some(r)), + Poll::Pending => { + *this.polls_left = this.polls_left.saturating_sub(1); + if *this.polls_left == 0 { + Poll::Ready(None) + } else { + Poll::Pending + } + } + } + } +} + +async fn connect_raw(s: &str) -> Result<(Client, Connection), Error> { + let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap(); + let config = s.parse::().unwrap(); + config.connect_raw(socket, NoTls).await +} + +async fn connect(s: &str) -> Client { + let (client, connection) = connect_raw(s).await.unwrap(); + let connection = connection.map(|r| r.unwrap()); + tokio::spawn(connection); + client +} + +async fn current_transaction_id(client: &Client) -> i64 { + client + .query("SELECT txid_current()", &[]) + .await + .unwrap() + .pop() + .unwrap() + .get::<_, i64>("txid_current") +} + +async fn in_transaction(client: &Client) -> bool { + current_transaction_id(client).await == current_transaction_id(client).await +} + +#[tokio::test] +async fn plain_password_missing() { + connect_raw("user=pass_user dbname=postgres") + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn plain_password_wrong() { + match connect_raw("user=pass_user password=foo dbname=postgres").await { + Ok(_) => panic!("unexpected success"), + Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {} + Err(e) => panic!("{}", e), + } +} + +#[tokio::test] +async fn plain_password_ok() { + connect("user=pass_user password=password dbname=postgres").await; +} + +#[tokio::test] +async fn md5_password_missing() { + connect_raw("user=md5_user dbname=postgres") + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn md5_password_wrong() { + match connect_raw("user=md5_user password=foo dbname=postgres").await { + Ok(_) => panic!("unexpected success"), + Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {} + Err(e) => panic!("{}", e), + } +} + +#[tokio::test] +async fn md5_password_ok() { + connect("user=md5_user password=password dbname=postgres").await; +} + +#[tokio::test] +async fn scram_password_missing() { + connect_raw("user=scram_user dbname=postgres") + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn scram_password_wrong() { + match connect_raw("user=scram_user password=foo dbname=postgres").await { + Ok(_) => panic!("unexpected success"), + Err(ref e) if e.code() == Some(&SqlState::INVALID_PASSWORD) => {} + Err(e) => panic!("{}", e), + } +} + +#[tokio::test] +async fn scram_password_ok() { + connect("user=scram_user password=password dbname=postgres").await; +} + +#[tokio::test] +async fn pipelined_prepare() { + let client = connect("user=postgres").await; + + let prepare1 = client.prepare("SELECT $1::HSTORE[]"); + let prepare2 = client.prepare("SELECT $1::BIGINT"); + + let (statement1, statement2) = try_join!(prepare1, prepare2).unwrap(); + + assert_eq!(statement1.params()[0].name(), "_hstore"); + assert_eq!(statement1.columns()[0].type_().name(), "_hstore"); + + assert_eq!(statement2.params()[0], Type::INT8); + assert_eq!(statement2.columns()[0].type_(), &Type::INT8); +} + +#[tokio::test] +async fn insert_select() { + let client = connect("user=postgres").await; + + client + .batch_execute("CREATE TEMPORARY TABLE foo (id SERIAL, name TEXT)") + .await + .unwrap(); + + let insert = client.prepare("INSERT INTO foo (name) VALUES ($1), ($2)"); + let select = client.prepare("SELECT id, name FROM foo ORDER BY id"); + let (insert, select) = try_join!(insert, select).unwrap(); + + let insert = client.execute(&insert, &[&"alice", &"bob"]); + let select = client.query(&select, &[]); + let (_, rows) = try_join!(insert, select).unwrap(); + + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::<_, i32>(0), 1); + assert_eq!(rows[0].get::<_, &str>(1), "alice"); + assert_eq!(rows[1].get::<_, i32>(0), 2); + assert_eq!(rows[1].get::<_, &str>(1), "bob"); +} + +#[tokio::test] +async fn custom_enum() { + let client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TYPE pg_temp.mood AS ENUM ( + 'sad', + 'ok', + 'happy' + )", + ) + .await + .unwrap(); + + let select = client.prepare("SELECT $1::mood").await.unwrap(); + + let ty = &select.params()[0]; + assert_eq!("mood", ty.name()); + assert_eq!( + &Kind::Enum(vec![ + "sad".to_string(), + "ok".to_string(), + "happy".to_string(), + ]), + ty.kind(), + ); +} + +#[tokio::test] +async fn custom_domain() { + let client = connect("user=postgres").await; + + client + .batch_execute("CREATE DOMAIN pg_temp.session_id AS bytea CHECK(octet_length(VALUE) = 16)") + .await + .unwrap(); + + let select = client.prepare("SELECT $1::session_id").await.unwrap(); + + let ty = &select.params()[0]; + assert_eq!("session_id", ty.name()); + assert_eq!(&Kind::Domain(Type::BYTEA), ty.kind()); +} + +#[tokio::test] +async fn custom_array() { + let client = connect("user=postgres").await; + + let select = client.prepare("SELECT $1::HSTORE[]").await.unwrap(); + + let ty = &select.params()[0]; + assert_eq!("_hstore", ty.name()); + match ty.kind() { + Kind::Array(ty) => { + assert_eq!("hstore", ty.name()); + assert_eq!(&Kind::Simple, ty.kind()); + } + _ => panic!("unexpected kind"), + } +} + +#[tokio::test] +async fn custom_composite() { + let client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + name TEXT, + supplier INTEGER, + price NUMERIC + )", + ) + .await + .unwrap(); + + let select = client.prepare("SELECT $1::inventory_item").await.unwrap(); + + let ty = &select.params()[0]; + assert_eq!(ty.name(), "inventory_item"); + match ty.kind() { + Kind::Composite(fields) => { + assert_eq!(fields[0].name(), "name"); + assert_eq!(fields[0].type_(), &Type::TEXT); + assert_eq!(fields[1].name(), "supplier"); + assert_eq!(fields[1].type_(), &Type::INT4); + assert_eq!(fields[2].name(), "price"); + assert_eq!(fields[2].type_(), &Type::NUMERIC); + } + _ => panic!("unexpected kind"), + } +} + +#[tokio::test] +async fn custom_range() { + let client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TYPE pg_temp.floatrange AS RANGE ( + subtype = float8, + subtype_diff = float8mi + )", + ) + .await + .unwrap(); + + let select = client.prepare("SELECT $1::floatrange").await.unwrap(); + + let ty = &select.params()[0]; + assert_eq!("floatrange", ty.name()); + assert_eq!(&Kind::Range(Type::FLOAT8), ty.kind()); +} + +#[tokio::test] +#[allow(clippy::get_first)] +async fn simple_query() { + let client = connect("user=postgres").await; + + let messages = client + .simple_query( + "CREATE TEMPORARY TABLE foo ( + id SERIAL, + name TEXT + ); + INSERT INTO foo (name) VALUES ('steven'), ('joe'); + SELECT * FROM foo ORDER BY id;", + ) + .await + .unwrap(); + + match messages[0] { + SimpleQueryMessage::CommandComplete(0) => {} + _ => panic!("unexpected message"), + } + match messages[1] { + SimpleQueryMessage::CommandComplete(2) => {} + _ => panic!("unexpected message"), + } + match &messages[2] { + SimpleQueryMessage::RowDescription(columns) => { + assert_eq!(columns.get(0).map(|c| c.name()), Some("id")); + assert_eq!(columns.get(1).map(|c| c.name()), Some("name")); + } + _ => panic!("unexpected message"), + } + match &messages[3] { + SimpleQueryMessage::Row(row) => { + assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id")); + assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); + assert_eq!(row.get(0), Some("1")); + assert_eq!(row.get(1), Some("steven")); + } + _ => panic!("unexpected message"), + } + match &messages[4] { + SimpleQueryMessage::Row(row) => { + assert_eq!(row.columns().get(0).map(|c| c.name()), Some("id")); + assert_eq!(row.columns().get(1).map(|c| c.name()), Some("name")); + assert_eq!(row.get(0), Some("2")); + assert_eq!(row.get(1), Some("joe")); + } + _ => panic!("unexpected message"), + } + match messages[5] { + SimpleQueryMessage::CommandComplete(2) => {} + _ => panic!("unexpected message"), + } + assert_eq!(messages.len(), 6); +} + +#[tokio::test] +async fn cancel_query_raw() { + let client = connect("user=postgres").await; + + let socket = TcpStream::connect("127.0.0.1:5433").await.unwrap(); + let cancel_token = client.cancel_token(); + let cancel = cancel_token.cancel_query_raw(socket, NoTls); + let cancel = time::sleep(Duration::from_millis(100)).then(|()| cancel); + + let sleep = client.batch_execute("SELECT pg_sleep(100)"); + + match join!(sleep, cancel) { + (Err(ref e), Ok(())) if e.code() == Some(&SqlState::QUERY_CANCELED) => {} + t => panic!("unexpected return: {:?}", t), + } +} + +#[tokio::test] +async fn transaction_commit() { + let mut client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TEMPORARY TABLE foo( + id SERIAL, + name TEXT + )", + ) + .await + .unwrap(); + + let transaction = client.transaction().await.unwrap(); + transaction + .batch_execute("INSERT INTO foo (name) VALUES ('steven')") + .await + .unwrap(); + transaction.commit().await.unwrap(); + + let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "steven"); +} + +#[tokio::test] +async fn transaction_rollback() { + let mut client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TEMPORARY TABLE foo( + id SERIAL, + name TEXT + )", + ) + .await + .unwrap(); + + let transaction = client.transaction().await.unwrap(); + transaction + .batch_execute("INSERT INTO foo (name) VALUES ('steven')") + .await + .unwrap(); + transaction.rollback().await.unwrap(); + + let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); + + assert_eq!(rows.len(), 0); +} + +#[tokio::test] +async fn transaction_future_cancellation() { + let mut client = connect("user=postgres").await; + + for i in 0.. { + let done = { + let txn = client.transaction(); + let fut = Cancellable { + fut: txn, + polls_left: i, + }; + fut.await + .map(|res| res.expect("transaction failed")) + .is_some() + }; + + assert!(!in_transaction(&client).await); + + if done { + break; + } + } +} + +#[tokio::test] +async fn transaction_commit_future_cancellation() { + let mut client = connect("user=postgres").await; + + for i in 0.. { + let done = { + let txn = client.transaction().await.unwrap(); + let commit = txn.commit(); + let fut = Cancellable { + fut: commit, + polls_left: i, + }; + fut.await + .map(|res| res.expect("transaction failed")) + .is_some() + }; + + assert!(!in_transaction(&client).await); + + if done { + break; + } + } +} + +#[tokio::test] +async fn transaction_rollback_future_cancellation() { + let mut client = connect("user=postgres").await; + + for i in 0.. { + let done = { + let txn = client.transaction().await.unwrap(); + let rollback = txn.rollback(); + let fut = Cancellable { + fut: rollback, + polls_left: i, + }; + fut.await + .map(|res| res.expect("transaction failed")) + .is_some() + }; + + assert!(!in_transaction(&client).await); + + if done { + break; + } + } +} + +#[tokio::test] +async fn transaction_rollback_drop() { + let mut client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TEMPORARY TABLE foo( + id SERIAL, + name TEXT + )", + ) + .await + .unwrap(); + + let transaction = client.transaction().await.unwrap(); + transaction + .batch_execute("INSERT INTO foo (name) VALUES ('steven')") + .await + .unwrap(); + drop(transaction); + + let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); + + assert_eq!(rows.len(), 0); +} + +#[tokio::test] +async fn transaction_builder() { + let mut client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TEMPORARY TABLE foo( + id SERIAL, + name TEXT + )", + ) + .await + .unwrap(); + + let transaction = client + .build_transaction() + .isolation_level(IsolationLevel::Serializable) + .read_only(true) + .deferrable(true) + .start() + .await + .unwrap(); + transaction + .batch_execute("INSERT INTO foo (name) VALUES ('steven')") + .await + .unwrap(); + transaction.commit().await.unwrap(); + + let stmt = client.prepare("SELECT name FROM foo").await.unwrap(); + let rows = client.query(&stmt, &[]).await.unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "steven"); +} + +#[tokio::test] +async fn copy_in() { + let client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TEMPORARY TABLE foo ( + id INTEGER, + name TEXT + )", + ) + .await + .unwrap(); + + let mut stream = stream::iter( + vec![ + Bytes::from_static(b"1\tjim\n"), + Bytes::from_static(b"2\tjoe\n"), + ] + .into_iter() + .map(Ok::<_, Error>), + ); + let sink = client.copy_in("COPY foo FROM STDIN").await.unwrap(); + pin_mut!(sink); + sink.send_all(&mut stream).await.unwrap(); + let rows = sink.finish().await.unwrap(); + assert_eq!(rows, 2); + + let rows = client + .query("SELECT id, name FROM foo ORDER BY id", &[]) + .await + .unwrap(); + + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::<_, i32>(0), 1); + assert_eq!(rows[0].get::<_, &str>(1), "jim"); + assert_eq!(rows[1].get::<_, i32>(0), 2); + assert_eq!(rows[1].get::<_, &str>(1), "joe"); +} + +#[tokio::test] +async fn copy_in_large() { + let client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TEMPORARY TABLE foo ( + id INTEGER, + name TEXT + )", + ) + .await + .unwrap(); + + let a = Bytes::from_static(b"0\tname0\n"); + let mut b = BytesMut::new(); + for i in 1..5_000 { + writeln!(b, "{0}\tname{0}", i).unwrap(); + } + let mut c = BytesMut::new(); + for i in 5_000..10_000 { + writeln!(c, "{0}\tname{0}", i).unwrap(); + } + let mut stream = stream::iter( + vec![a, b.freeze(), c.freeze()] + .into_iter() + .map(Ok::<_, Error>), + ); + + let sink = client.copy_in("COPY foo FROM STDIN").await.unwrap(); + pin_mut!(sink); + sink.send_all(&mut stream).await.unwrap(); + let rows = sink.finish().await.unwrap(); + assert_eq!(rows, 10_000); +} + +#[tokio::test] +async fn copy_in_error() { + let client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TEMPORARY TABLE foo ( + id INTEGER, + name TEXT + )", + ) + .await + .unwrap(); + + { + let sink = client.copy_in("COPY foo FROM STDIN").await.unwrap(); + pin_mut!(sink); + sink.send(Bytes::from_static(b"1\tsteven")).await.unwrap(); + } + + let rows = client + .query("SELECT id, name FROM foo ORDER BY id", &[]) + .await + .unwrap(); + assert_eq!(rows.len(), 0); +} + +#[tokio::test] +async fn copy_out() { + let client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TEMPORARY TABLE foo ( + id SERIAL, + name TEXT + ); + + INSERT INTO foo (name) VALUES ('jim'), ('joe');", + ) + .await + .unwrap(); + + let stmt = client.prepare("COPY foo TO STDOUT").await.unwrap(); + let data = client + .copy_out(&stmt) + .await + .unwrap() + .try_fold(BytesMut::new(), |mut buf, chunk| async move { + buf.extend_from_slice(&chunk); + Ok(buf) + }) + .await + .unwrap(); + assert_eq!(&data[..], b"1\tjim\n2\tjoe\n"); +} + +#[tokio::test] +async fn notices() { + let long_name = "x".repeat(65); + let (client, mut connection) = + connect_raw(&format!("user=postgres application_name={}", long_name,)) + .await + .unwrap(); + + let (tx, rx) = mpsc::unbounded(); + let stream = + stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e)); + let connection = stream.forward(tx).map(|r| r.unwrap()); + tokio::spawn(connection); + + client + .batch_execute("DROP DATABASE IF EXISTS noexistdb") + .await + .unwrap(); + + drop(client); + + let notices = rx + .filter_map(|m| match m { + AsyncMessage::Notice(n) => future::ready(Some(n)), + _ => future::ready(None), + }) + .collect::>() + .await; + assert_eq!(notices.len(), 2); + assert_eq!( + notices[0].message(), + "identifier \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" \ + will be truncated to \"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\"" + ); + assert_eq!( + notices[1].message(), + "database \"noexistdb\" does not exist, skipping" + ); +} + +#[tokio::test] +async fn notifications() { + let (client, mut connection) = connect_raw("user=postgres").await.unwrap(); + + let (tx, rx) = mpsc::unbounded(); + let stream = + stream::poll_fn(move |cx| connection.poll_message(cx)).map_err(|e| panic!("{}", e)); + let connection = stream.forward(tx).map(|r| r.unwrap()); + tokio::spawn(connection); + + client + .batch_execute( + "LISTEN test_notifications; + NOTIFY test_notifications, 'hello'; + NOTIFY test_notifications, 'world';", + ) + .await + .unwrap(); + + drop(client); + + let notifications = rx + .filter_map(|m| match m { + AsyncMessage::Notification(n) => future::ready(Some(n)), + _ => future::ready(None), + }) + .collect::>() + .await; + assert_eq!(notifications.len(), 2); + assert_eq!(notifications[0].channel(), "test_notifications"); + assert_eq!(notifications[0].payload(), "hello"); + assert_eq!(notifications[1].channel(), "test_notifications"); + assert_eq!(notifications[1].payload(), "world"); +} + +#[tokio::test] +async fn query_portal() { + let mut client = connect("user=postgres").await; + + client + .batch_execute( + "CREATE TEMPORARY TABLE foo ( + id SERIAL, + name TEXT + ); + + INSERT INTO foo (name) VALUES ('alice'), ('bob'), ('charlie');", + ) + .await + .unwrap(); + + let stmt = client + .prepare("SELECT id, name FROM foo ORDER BY id") + .await + .unwrap(); + + let transaction = client.transaction().await.unwrap(); + + let portal = transaction.bind(&stmt, &[]).await.unwrap(); + let f1 = transaction.query_portal(&portal, 2); + let f2 = transaction.query_portal(&portal, 2); + let f3 = transaction.query_portal(&portal, 2); + + let (r1, r2, r3) = try_join!(f1, f2, f3).unwrap(); + + assert_eq!(r1.len(), 2); + assert_eq!(r1[0].get::<_, i32>(0), 1); + assert_eq!(r1[0].get::<_, &str>(1), "alice"); + assert_eq!(r1[1].get::<_, i32>(0), 2); + assert_eq!(r1[1].get::<_, &str>(1), "bob"); + + assert_eq!(r2.len(), 1); + assert_eq!(r2[0].get::<_, i32>(0), 3); + assert_eq!(r2[0].get::<_, &str>(1), "charlie"); + + assert_eq!(r3.len(), 0); +} + +#[tokio::test] +async fn require_channel_binding() { + connect_raw("user=postgres channel_binding=require") + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn prefer_channel_binding() { + connect("user=postgres channel_binding=prefer").await; +} + +#[tokio::test] +async fn disable_channel_binding() { + connect("user=postgres channel_binding=disable").await; +} + +#[tokio::test] +async fn check_send() { + fn is_send(_: &T) {} + + let f = connect("user=postgres"); + is_send(&f); + let mut client = f.await; + + let f = client.prepare("SELECT $1::TEXT"); + is_send(&f); + let stmt = f.await.unwrap(); + + let f = client.query(&stmt, &[&"hello"]); + is_send(&f); + drop(f); + + let f = client.execute(&stmt, &[&"hello"]); + is_send(&f); + drop(f); + + let f = client.transaction(); + is_send(&f); + let trans = f.await.unwrap(); + + let f = trans.query(&stmt, &[&"hello"]); + is_send(&f); + drop(f); + + let f = trans.execute(&stmt, &[&"hello"]); + is_send(&f); + drop(f); +} + +#[tokio::test] +async fn query_one() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo ( + name TEXT + ); + INSERT INTO foo (name) VALUES ('alice'), ('bob'), ('carol'); + ", + ) + .await + .unwrap(); + + client + .query_one("SELECT * FROM foo WHERE name = 'dave'", &[]) + .await + .err() + .unwrap(); + client + .query_one("SELECT * FROM foo WHERE name = 'alice'", &[]) + .await + .unwrap(); + client + .query_one("SELECT * FROM foo", &[]) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn query_opt() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo ( + name TEXT + ); + INSERT INTO foo (name) VALUES ('alice'), ('bob'), ('carol'); + ", + ) + .await + .unwrap(); + + assert!(client + .query_opt("SELECT * FROM foo WHERE name = 'dave'", &[]) + .await + .unwrap() + .is_none()); + client + .query_opt("SELECT * FROM foo WHERE name = 'alice'", &[]) + .await + .unwrap() + .unwrap(); + client + .query_one("SELECT * FROM foo", &[]) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn deferred_constraint() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE t ( + i INT, + UNIQUE (i) DEFERRABLE INITIALLY DEFERRED + ); + ", + ) + .await + .unwrap(); + + client + .execute("INSERT INTO t (i) VALUES (1)", &[]) + .await + .unwrap(); + client + .execute("INSERT INTO t (i) VALUES (1)", &[]) + .await + .unwrap_err(); +} + +#[tokio::test] +async fn query_typed_no_transaction() { + let client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo ( + name TEXT, + age INT + ); + INSERT INTO foo (name, age) VALUES ('alice', 20), ('bob', 30), ('carol', 40); + ", + ) + .await + .unwrap(); + + let rows: Vec = client + .query_typed( + "SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age", + &[(&"alice", Type::TEXT), (&50i32, Type::INT4)], + ) + .await + .unwrap(); + + assert_eq!(rows.len(), 2); + let first_row = &rows[0]; + assert_eq!(first_row.get::<_, &str>(0), "bob"); + assert_eq!(first_row.get::<_, i32>(1), 30); + assert_eq!(first_row.get::<_, &str>(2), "literal"); + assert_eq!(first_row.get::<_, i32>(3), 5); + + let second_row = &rows[1]; + assert_eq!(second_row.get::<_, &str>(0), "carol"); + assert_eq!(second_row.get::<_, i32>(1), 40); + assert_eq!(second_row.get::<_, &str>(2), "literal"); + assert_eq!(second_row.get::<_, i32>(3), 5); + + // Test for UPDATE that returns no data + let updated_rows = client + .query_typed("UPDATE foo set age = 33", &[]) + .await + .unwrap(); + assert_eq!(updated_rows.len(), 0); +} + +#[tokio::test] +async fn query_typed_with_transaction() { + let mut client = connect("user=postgres").await; + + client + .batch_execute( + " + CREATE TEMPORARY TABLE foo ( + name TEXT, + age INT + ); + ", + ) + .await + .unwrap(); + + let transaction = client.transaction().await.unwrap(); + + let rows: Vec = transaction + .query_typed( + "INSERT INTO foo (name, age) VALUES ($1, $2), ($3, $4), ($5, $6) returning name, age", + &[ + (&"alice", Type::TEXT), + (&20i32, Type::INT4), + (&"bob", Type::TEXT), + (&30i32, Type::INT4), + (&"carol", Type::TEXT), + (&40i32, Type::INT4), + ], + ) + .await + .unwrap(); + let inserted_values: Vec<(String, i32)> = rows + .iter() + .map(|row| (row.get::<_, String>(0), row.get::<_, i32>(1))) + .collect(); + assert_eq!( + inserted_values, + [ + ("alice".to_string(), 20), + ("bob".to_string(), 30), + ("carol".to_string(), 40) + ] + ); + + let rows: Vec = transaction + .query_typed( + "SELECT name, age, 'literal', 5 FROM foo WHERE name <> $1 AND age < $2 ORDER BY age", + &[(&"alice", Type::TEXT), (&50i32, Type::INT4)], + ) + .await + .unwrap(); + + assert_eq!(rows.len(), 2); + let first_row = &rows[0]; + assert_eq!(first_row.get::<_, &str>(0), "bob"); + assert_eq!(first_row.get::<_, i32>(1), 30); + assert_eq!(first_row.get::<_, &str>(2), "literal"); + assert_eq!(first_row.get::<_, i32>(3), 5); + + let second_row = &rows[1]; + assert_eq!(second_row.get::<_, &str>(0), "carol"); + assert_eq!(second_row.get::<_, i32>(1), 40); + assert_eq!(second_row.get::<_, &str>(2), "literal"); + assert_eq!(second_row.get::<_, i32>(3), 5); + + // Test for UPDATE that returns no data + let updated_rows = transaction + .query_typed("UPDATE foo set age = 33", &[]) + .await + .unwrap(); + assert_eq!(updated_rows.len(), 0); +} diff --git a/tokio-postgres/tests/test/parse.rs b/tokio-postgres/tests/test/parse.rs new file mode 100644 index 000000000..35eeca72b --- /dev/null +++ b/tokio-postgres/tests/test/parse.rs @@ -0,0 +1,144 @@ +use std::time::Duration; +use tokio_postgres::config::{Config, SslNegotiation, TargetSessionAttrs}; + +fn check(s: &str, config: &Config) { + assert_eq!(s.parse::().expect(s), *config, "`{}`", s); +} + +#[test] +fn pairs_ok() { + check( + r"user=foo password=' fizz \'buzz\\ ' application_name = ''", + Config::new() + .user("foo") + .password(r" fizz 'buzz\ ") + .application_name(""), + ); +} + +#[test] +fn pairs_ws() { + check( + " user\t=\r\n\x0bfoo \t password = hunter2 ", + Config::new().user("foo").password("hunter2"), + ); +} + +#[test] +fn settings() { + check( + "connect_timeout=3 keepalives=0 keepalives_idle=30 target_session_attrs=read-write", + Config::new() + .connect_timeout(Duration::from_secs(3)) + .keepalives(false) + .keepalives_idle(Duration::from_secs(30)) + .target_session_attrs(TargetSessionAttrs::ReadWrite), + ); + check( + "connect_timeout=3 keepalives=0 keepalives_idle=30 target_session_attrs=read-only", + Config::new() + .connect_timeout(Duration::from_secs(3)) + .keepalives(false) + .keepalives_idle(Duration::from_secs(30)) + .target_session_attrs(TargetSessionAttrs::ReadOnly), + ); + check( + "sslnegotiation=direct", + Config::new().ssl_negotiation(SslNegotiation::Direct), + ); +} + +#[test] +fn keepalive_settings() { + check( + "keepalives=1 keepalives_idle=15 keepalives_interval=5 keepalives_retries=9", + Config::new() + .keepalives(true) + .keepalives_idle(Duration::from_secs(15)) + .keepalives_interval(Duration::from_secs(5)) + .keepalives_retries(9), + ); +} + +#[test] +fn url() { + check("postgresql://", &Config::new()); + check( + "postgresql://localhost", + Config::new().host("localhost").port(5432), + ); + check( + "postgresql://localhost:5433", + Config::new().host("localhost").port(5433), + ); + check( + "postgresql://localhost/mydb", + Config::new().host("localhost").port(5432).dbname("mydb"), + ); + check( + "postgresql://user@localhost", + Config::new().user("user").host("localhost").port(5432), + ); + check( + "postgresql://user:secret@localhost", + Config::new() + .user("user") + .password("secret") + .host("localhost") + .port(5432), + ); + check( + "postgresql://other@localhost/otherdb?connect_timeout=10&application_name=myapp", + Config::new() + .user("other") + .host("localhost") + .port(5432) + .dbname("otherdb") + .connect_timeout(Duration::from_secs(10)) + .application_name("myapp"), + ); + check( + "postgresql://host1:123,host2:456/somedb?target_session_attrs=any&application_name=myapp", + Config::new() + .host("host1") + .port(123) + .host("host2") + .port(456) + .dbname("somedb") + .target_session_attrs(TargetSessionAttrs::Any) + .application_name("myapp"), + ); + check( + "postgresql:///mydb?host=localhost&port=5433", + Config::new().dbname("mydb").host("localhost").port(5433), + ); + check( + "postgresql://[2001:db8::1234]/database", + Config::new() + .host("2001:db8::1234") + .port(5432) + .dbname("database"), + ); + check( + "postgresql://[2001:db8::1234]:5433/database", + Config::new() + .host("2001:db8::1234") + .port(5433) + .dbname("database"), + ); + #[cfg(unix)] + check( + "postgresql:///dbname?host=/var/lib/postgresql", + Config::new() + .dbname("dbname") + .host_path("/var/lib/postgresql"), + ); + #[cfg(unix)] + check( + "postgresql://%2Fvar%2Flib%2Fpostgresql/dbname", + Config::new() + .host_path("/var/lib/postgresql") + .port(5432) + .dbname("dbname"), + ) +} diff --git a/tokio-postgres/tests/test/runtime.rs b/tokio-postgres/tests/test/runtime.rs new file mode 100644 index 000000000..86c1f0701 --- /dev/null +++ b/tokio-postgres/tests/test/runtime.rs @@ -0,0 +1,135 @@ +use futures_util::{join, FutureExt}; +use std::time::Duration; +use tokio::time; +use tokio_postgres::error::SqlState; +use tokio_postgres::{Client, NoTls}; + +async fn connect(s: &str) -> Client { + let (client, connection) = tokio_postgres::connect(s, NoTls).await.unwrap(); + let connection = connection.map(|e| e.unwrap()); + tokio::spawn(connection); + + client +} + +async fn smoke_test(s: &str) { + let client = connect(s).await; + + let stmt = client.prepare("SELECT $1::INT").await.unwrap(); + let rows = client.query(&stmt, &[&1i32]).await.unwrap(); + assert_eq!(rows[0].get::<_, i32>(0), 1i32); +} + +#[tokio::test] +#[ignore] // FIXME doesn't work with our docker-based tests :( +async fn unix_socket() { + smoke_test("host=/var/run/postgresql port=5433 user=postgres").await; +} + +#[tokio::test] +async fn tcp() { + smoke_test("host=localhost port=5433 user=postgres").await; +} + +#[tokio::test] +async fn multiple_hosts_one_port() { + smoke_test("host=foobar.invalid,localhost port=5433 user=postgres").await; +} + +#[tokio::test] +async fn multiple_hosts_multiple_ports() { + smoke_test("host=foobar.invalid,localhost port=5432,5433 user=postgres").await; +} + +#[tokio::test] +async fn wrong_port_count() { + tokio_postgres::connect("host=localhost port=5433,5433 user=postgres", NoTls) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn target_session_attrs_ok() { + smoke_test("host=localhost port=5433 user=postgres target_session_attrs=read-write").await; +} + +#[tokio::test] +async fn target_session_attrs_err() { + tokio_postgres::connect( + "host=localhost port=5433 user=postgres target_session_attrs=read-write + options='-c default_transaction_read_only=on'", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn host_only_ok() { + let _ = tokio_postgres::connect( + "host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_only_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_and_host_ok() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_mismatch() { + let _ = tokio_postgres::connect( + "hostaddr=127.0.0.1,127.0.0.2 host=localhost port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn hostaddr_host_both_missing() { + let _ = tokio_postgres::connect( + "port=5433 user=pass_user dbname=postgres password=password", + NoTls, + ) + .await + .err() + .unwrap(); +} + +#[tokio::test] +async fn cancel_query() { + let client = connect("host=localhost port=5433 user=postgres").await; + + let cancel_token = client.cancel_token(); + let cancel = cancel_token.cancel_query(NoTls); + let cancel = time::sleep(Duration::from_millis(100)).then(|()| cancel); + + let sleep = client.batch_execute("SELECT pg_sleep(100)"); + + match join!(sleep, cancel) { + (Err(ref e), Ok(())) if e.code() == Some(&SqlState::QUERY_CANCELED) => {} + t => panic!("unexpected return: {:?}", t), + } +} diff --git a/postgres/tests/types/bit_vec.rs b/tokio-postgres/tests/test/types/bit_vec_06.rs similarity index 74% rename from postgres/tests/types/bit_vec.rs rename to tokio-postgres/tests/test/types/bit_vec_06.rs index 2e0ac53de..4d01dc2f2 100644 --- a/postgres/tests/types/bit_vec.rs +++ b/tokio-postgres/tests/test/types/bit_vec_06.rs @@ -1,10 +1,9 @@ -extern crate bit_vec; +use bit_vec_06::BitVec; -use self::bit_vec::BitVec; -use types::test_type; +use crate::types::test_type; -#[test] -fn test_bit_params() { +#[tokio::test] +async fn test_bit_params() { let mut bv = BitVec::from_bytes(&[0b0110_1001, 0b0000_0111]); bv.pop(); bv.pop(); @@ -12,10 +11,11 @@ fn test_bit_params() { "BIT(14)", &[(Some(bv), "B'01101001000001'"), (None, "NULL")], ) + .await } -#[test] -fn test_varbit_params() { +#[tokio::test] +async fn test_varbit_params() { let mut bv = BitVec::from_bytes(&[0b0110_1001, 0b0000_0111]); bv.pop(); bv.pop(); @@ -27,4 +27,5 @@ fn test_varbit_params() { (None, "NULL"), ], ) + .await } diff --git a/tokio-postgres/tests/test/types/chrono_04.rs b/tokio-postgres/tests/test/types/chrono_04.rs new file mode 100644 index 000000000..eda8151a6 --- /dev/null +++ b/tokio-postgres/tests/test/types/chrono_04.rs @@ -0,0 +1,190 @@ +use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use std::fmt; +use tokio_postgres::types::{Date, FromSqlOwned, Timestamp}; +use tokio_postgres::Client; + +use crate::connect; +use crate::types::test_type; + +#[tokio::test] +async fn test_naive_date_time_params() { + fn make_check(time: &str) -> (Option, &str) { + ( + Some(NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap()), + time, + ) + } + test_type( + "TIMESTAMP", + &[ + make_check("'1970-01-01 00:00:00.010000000'"), + make_check("'1965-09-25 11:19:33.100314000'"), + make_check("'2010-02-09 23:11:45.120200000'"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_with_special_naive_date_time_params() { + fn make_check(time: &str) -> (Timestamp, &str) { + ( + Timestamp::Value( + NaiveDateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap(), + ), + time, + ) + } + test_type( + "TIMESTAMP", + &[ + make_check("'1970-01-01 00:00:00.010000000'"), + make_check("'1965-09-25 11:19:33.100314000'"), + make_check("'2010-02-09 23:11:45.120200000'"), + (Timestamp::PosInfinity, "'infinity'"), + (Timestamp::NegInfinity, "'-infinity'"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_date_time_params() { + fn make_check(time: &str) -> (Option>, &str) { + ( + Some( + DateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f%#z'") + .unwrap() + .to_utc(), + ), + time, + ) + } + test_type( + "TIMESTAMP WITH TIME ZONE", + &[ + make_check("'1970-01-01 00:00:00.010000000Z'"), + make_check("'1965-09-25 11:19:33.100314000Z'"), + make_check("'2010-02-09 23:11:45.120200000Z'"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_with_special_date_time_params() { + fn make_check(time: &str) -> (Timestamp>, &str) { + ( + Timestamp::Value( + DateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f%#z'") + .unwrap() + .to_utc(), + ), + time, + ) + } + test_type( + "TIMESTAMP WITH TIME ZONE", + &[ + make_check("'1970-01-01 00:00:00.010000000Z'"), + make_check("'1965-09-25 11:19:33.100314000Z'"), + make_check("'2010-02-09 23:11:45.120200000Z'"), + (Timestamp::PosInfinity, "'infinity'"), + (Timestamp::NegInfinity, "'-infinity'"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_date_params() { + fn make_check(time: &str) -> (Option, &str) { + ( + Some(NaiveDate::parse_from_str(time, "'%Y-%m-%d'").unwrap()), + time, + ) + } + test_type( + "DATE", + &[ + make_check("'1970-01-01'"), + make_check("'1965-09-25'"), + make_check("'2010-02-09'"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_with_special_date_params() { + fn make_check(date: &str) -> (Date, &str) { + ( + Date::Value(NaiveDate::parse_from_str(date, "'%Y-%m-%d'").unwrap()), + date, + ) + } + test_type( + "DATE", + &[ + make_check("'1970-01-01'"), + make_check("'1965-09-25'"), + make_check("'2010-02-09'"), + (Date::PosInfinity, "'infinity'"), + (Date::NegInfinity, "'-infinity'"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_time_params() { + fn make_check(time: &str) -> (Option, &str) { + ( + Some(NaiveTime::parse_from_str(time, "'%H:%M:%S.%f'").unwrap()), + time, + ) + } + test_type( + "TIME", + &[ + make_check("'00:00:00.010000000'"), + make_check("'11:19:33.100314000'"), + make_check("'23:11:45.120200000'"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_special_params_without_wrapper() { + async fn assert_overflows(client: &mut Client, val: &str, sql_type: &str) + where + T: FromSqlOwned + fmt::Debug, + { + let err = client + .query_one(&*format!("SELECT {}::{}", val, sql_type), &[]) + .await + .unwrap() + .try_get::<_, T>(0) + .unwrap_err(); + assert_eq!( + err.to_string(), + "error deserializing column 0: value too large to decode" + ); + } + + let mut client = connect("user=postgres").await; + + assert_overflows::>(&mut client, "'-infinity'", "timestamptz").await; + assert_overflows::>(&mut client, "'infinity'", "timestamptz").await; + + assert_overflows::(&mut client, "'-infinity'", "timestamp").await; + assert_overflows::(&mut client, "'infinity'", "timestamp").await; + + assert_overflows::(&mut client, "'-infinity'", "date").await; + assert_overflows::(&mut client, "'infinity'", "date").await; +} diff --git a/tokio-postgres/tests/test/types/eui48_1.rs b/tokio-postgres/tests/test/types/eui48_1.rs new file mode 100644 index 000000000..0c22e9e87 --- /dev/null +++ b/tokio-postgres/tests/test/types/eui48_1.rs @@ -0,0 +1,18 @@ +use eui48_1::MacAddress; + +use crate::types::test_type; + +#[tokio::test] +async fn test_eui48_params() { + test_type( + "MACADDR", + &[ + ( + Some(MacAddress::parse_str("12-34-56-AB-CD-EF").unwrap()), + "'12-34-56-ab-cd-ef'", + ), + (None, "NULL"), + ], + ) + .await +} diff --git a/tokio-postgres/tests/test/types/geo_types_06.rs b/tokio-postgres/tests/test/types/geo_types_06.rs new file mode 100644 index 000000000..7195abc06 --- /dev/null +++ b/tokio-postgres/tests/test/types/geo_types_06.rs @@ -0,0 +1,60 @@ +use geo_types_06::{Coordinate, LineString, Point, Rect}; + +use crate::types::test_type; + +#[tokio::test] +async fn test_point_params() { + test_type( + "POINT", + &[ + (Some(Point::new(0.0, 0.0)), "POINT(0, 0)"), + (Some(Point::new(-3.2, 1.618)), "POINT(-3.2, 1.618)"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_box_params() { + test_type( + "BOX", + &[ + ( + Some(Rect::new( + Coordinate { x: -3.2, y: 1.618 }, + Coordinate { + x: 160.0, + y: 69701.5615, + }, + )), + "BOX(POINT(160.0, 69701.5615), POINT(-3.2, 1.618))", + ), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_path_params() { + let points = vec![ + Coordinate { x: 0., y: 0. }, + Coordinate { x: -3.2, y: 1.618 }, + Coordinate { + x: 160.0, + y: 69701.5615, + }, + ]; + test_type( + "PATH", + &[ + ( + Some(LineString(points)), + "path '((0, 0), (-3.2, 1.618), (160.0, 69701.5615))'", + ), + (None, "NULL"), + ], + ) + .await; +} diff --git a/tokio-postgres/tests/test/types/geo_types_07.rs b/tokio-postgres/tests/test/types/geo_types_07.rs new file mode 100644 index 000000000..43a13f451 --- /dev/null +++ b/tokio-postgres/tests/test/types/geo_types_07.rs @@ -0,0 +1,61 @@ +#[cfg(feature = "with-geo-types-0_7")] +use geo_types_07::{Coord, LineString, Point, Rect}; + +use crate::types::test_type; + +#[tokio::test] +async fn test_point_params() { + test_type( + "POINT", + &[ + (Some(Point::new(0.0, 0.0)), "POINT(0, 0)"), + (Some(Point::new(-3.2, 1.618)), "POINT(-3.2, 1.618)"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_box_params() { + test_type( + "BOX", + &[ + ( + Some(Rect::new( + Coord { x: -3.2, y: 1.618 }, + Coord { + x: 160.0, + y: 69701.5615, + }, + )), + "BOX(POINT(160.0, 69701.5615), POINT(-3.2, 1.618))", + ), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_path_params() { + let points = vec![ + Coord { x: 0., y: 0. }, + Coord { x: -3.2, y: 1.618 }, + Coord { + x: 160.0, + y: 69701.5615, + }, + ]; + test_type( + "PATH", + &[ + ( + Some(LineString(points)), + "path '((0, 0), (-3.2, 1.618), (160.0, 69701.5615))'", + ), + (None, "NULL"), + ], + ) + .await; +} diff --git a/tokio-postgres/tests/test/types/jiff_01.rs b/tokio-postgres/tests/test/types/jiff_01.rs new file mode 100644 index 000000000..7c9052676 --- /dev/null +++ b/tokio-postgres/tests/test/types/jiff_01.rs @@ -0,0 +1,175 @@ +use jiff_01::{ + civil::{Date as JiffDate, DateTime, Time}, + Timestamp as JiffTimestamp, +}; +use std::fmt; +use tokio_postgres::{ + types::{Date, FromSqlOwned, Timestamp}, + Client, +}; + +use crate::connect; +use crate::types::test_type; + +#[tokio::test] +async fn test_datetime_params() { + fn make_check(s: &str) -> (Option, &str) { + (Some(s.trim_matches('\'').parse().unwrap()), s) + } + test_type( + "TIMESTAMP", + &[ + make_check("'1970-01-01 00:00:00.010000000'"), + make_check("'1965-09-25 11:19:33.100314000'"), + make_check("'2010-02-09 23:11:45.120200000'"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_with_special_datetime_params() { + fn make_check(s: &str) -> (Timestamp, &str) { + (Timestamp::Value(s.trim_matches('\'').parse().unwrap()), s) + } + test_type( + "TIMESTAMP", + &[ + make_check("'1970-01-01 00:00:00.010000000'"), + make_check("'1965-09-25 11:19:33.100314000'"), + make_check("'2010-02-09 23:11:45.120200000'"), + (Timestamp::PosInfinity, "'infinity'"), + (Timestamp::NegInfinity, "'-infinity'"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_timestamp_params() { + fn make_check(s: &str) -> (Option, &str) { + (Some(s.trim_matches('\'').parse().unwrap()), s) + } + test_type( + "TIMESTAMP WITH TIME ZONE", + &[ + make_check("'1970-01-01 00:00:00.010000000Z'"), + make_check("'1965-09-25 11:19:33.100314000Z'"), + make_check("'2010-02-09 23:11:45.120200000Z'"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_with_special_timestamp_params() { + fn make_check(s: &str) -> (Timestamp, &str) { + (Timestamp::Value(s.trim_matches('\'').parse().unwrap()), s) + } + test_type( + "TIMESTAMP WITH TIME ZONE", + &[ + make_check("'1970-01-01 00:00:00.010000000Z'"), + make_check("'1965-09-25 11:19:33.100314000Z'"), + make_check("'2010-02-09 23:11:45.120200000Z'"), + (Timestamp::PosInfinity, "'infinity'"), + (Timestamp::NegInfinity, "'-infinity'"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_date_params() { + fn make_check(s: &str) -> (Option, &str) { + (Some(s.trim_matches('\'').parse().unwrap()), s) + } + test_type( + "DATE", + &[ + make_check("'1970-01-01'"), + make_check("'1965-09-25'"), + make_check("'2010-02-09'"), + (None, "NULL"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_with_special_date_params() { + fn make_check(s: &str) -> (Date, &str) { + (Date::Value(s.trim_matches('\'').parse().unwrap()), s) + } + test_type( + "DATE", + &[ + make_check("'1970-01-01'"), + make_check("'1965-09-25'"), + make_check("'2010-02-09'"), + (Date::PosInfinity, "'infinity'"), + (Date::NegInfinity, "'-infinity'"), + ], + ) + .await; +} + +#[tokio::test] +async fn test_time_params() { + fn make_check(s: &str) -> (Option