diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..739c6fe9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,30 @@ +--- +name: Bug report +about: Create a bug report for ndarray-stats +title: '' +labels: '' +assignees: '' + +--- + +**Description** +Description of the bug. + +**Version Information** +- `ndarray`: ??? +- `ndarray-stats`: ??? +- Rust: ??? + +Please make sure that: +- the version of `ndarray-stats` you're using corresponds to the version of `ndarray` you're using +- the version of the Rust compiler you're using is supported by the version of `ndarray-stats` you're using +(See the "Releases" section of the README for correct version information.) + +**To Reproduce** +Example code which doesn't work. + +**Expected behavior** +Description of what you expected to happen. + +**Additional context** +Add any other context about the problem here. diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..7555a6ce --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,97 @@ +name: Continuous integration + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +env: + CARGO_TERM_COLOR: always + RUSTFLAGS: "-D warnings" + +jobs: + + test: + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - stable + - beta + - nightly + - 1.64.0 # MSRV + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --verbose + + cross_test: + runs-on: ubuntu-latest + strategy: + matrix: + include: + # 64-bit, big-endian + - rust: stable + target: s390x-unknown-linux-gnu + # 32-bit, little-endian + - rust: stable + target: i686-unknown-linux-gnu + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + target: ${{ matrix.target }} + - name: Install cross + run: cargo install cross -f + - name: Build + run: cross build --verbose --target=${{ matrix.target }} + - name: Run tests + run: cross test --verbose --target=${{ matrix.target }} + + format: + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - stable + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + components: rustfmt + - name: Rustfmt + run: cargo fmt -- --check + + coverage: + runs-on: ubuntu-latest + strategy: + matrix: + rust: + - nightly + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + - name: Install tarpaulin + uses: taiki-e/cache-cargo-install-action@v2 + with: + tool: cargo-tarpaulin + - name: Generate code coverage + run: cargo tarpaulin --verbose --all-features --workspace --timeout 120 --out Xml + - name: Upload to codecov.io + uses: codecov/codecov-action@v4 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + fail_ci_if_error: true diff --git a/.github/workflows/latest-deps.yml b/.github/workflows/latest-deps.yml new file mode 100644 index 00000000..79fc81b1 --- /dev/null +++ b/.github/workflows/latest-deps.yml @@ -0,0 +1,67 @@ +name: Check Latest Dependencies +on: + schedule: + # Chosen so that it runs right before the international date line experiences the weekend. + # Since we're open source, that means globally we should be aware of it right when we have the most + # time to fix it. + # + # Sorry if this ruins your weekend, future maintainer... + - cron: '0 12 * * FRI' + workflow_dispatch: # For running manually + pull_request: + paths: + - '.github/workflows/latest-deps.yaml' + +env: + CARGO_TERM_COLOR: always + HOST: x86_64-unknown-linux-gnu + RUSTFLAGS: "-D warnings" + MSRV: 1.64.0 + +jobs: + latest_deps_stable: + runs-on: ubuntu-latest + name: Check Latest Dependencies on Stable + steps: + - name: Check Out Repo + uses: actions/checkout@v4 + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + - name: Setup Mold Linker + uses: rui314/setup-mold@v1 + - name: Setup Rust Cache + uses: Swatinem/rust-cache@v2 + - name: Install nextest + uses: taiki-e/install-action@nextest + - name: Ensure latest dependencies + run: cargo update + - name: Run Tests + run: cargo nextest run + + latest_deps_msrv: + runs-on: ubuntu-latest + name: Check Latest Dependencies on MSRV + steps: + - name: Check Out Repo + uses: actions/checkout@v4 + - name: Install Stable Rust for Update + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + - name: Setup Mold Linker + uses: rui314/setup-mold@v1 + - name: Setup Rust Cache + uses: Swatinem/rust-cache@v2 + - name: Install nextest + uses: taiki-e/install-action@nextest + - name: Ensure latest dependencies + # The difference is here between this and `latest_deps_stable` + run: CARGO_RESOLVER_INCOMPATIBLE_RUST_VERSIONS="fallback" cargo update + - name: Install MSRV Rust for Test + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ env.MSRV }} + - name: Run Tests + run: cargo nextest run \ No newline at end of file diff --git a/.gitignore b/.gitignore index 91a0d835..e31f4a3d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,5 @@ /target **/*.rs.bk -Cargo.lock # IDE-related tags diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 6e1a620a..00000000 --- a/.travis.yml +++ /dev/null @@ -1,35 +0,0 @@ -language: rust -sudo: required -dist: trusty -addons: - apt: - packages: - - libssl-dev -cache: cargo -rust: - - 1.34.0 - - stable - - beta - - nightly -matrix: - allow_failures: - - rust: nightly -before_cache: | - if [[ "$TRAVIS_RUST_VERSION" == nightly ]]; then - RUSTFLAGS="--cfg procmacro2_semver_exempt" cargo install cargo-tarpaulin - fi -before_script: -- rustup component add rustfmt -# As a result of https://github.com/travis-ci/travis-ci/issues/1066, we run -# everything in one large command instead of multiple commands. -# In this way, the build stops immediately if one of the commands fails. -script: | - cargo clean && - cargo fmt --all -- --check && - cargo build && - cargo test -after_success: | - if [[ "$TRAVIS_RUST_VERSION" == nightly ]]; then - cargo tarpaulin --out Xml - bash <(curl -s https://codecov.io/bash) - fi diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 00000000..8b99b1e2 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,923 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15c4c2c83f81532e5845a733998b6971faca23490340a418e9b72a3ec9de12ea" + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bumpalo" +version = "3.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb690e81c7840c0d7aade59f242ea3b41b9bc27bcd5997890e7702ae4b32e487" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ed2e96bc16d8d740f6f48d663eddf4b8a0983e79210fd55479b7bcd0a69860e" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.11.1+wasi-snapshot-preview1", +] + +[[package]] +name = "half" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0" +dependencies = [ + "crunchy", +] + +[[package]] +name = "hashbrown" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e087f84d4f86bf4b218b927129862374b72199ae7d8657835f1e89000eea4fb" + +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + +[[package]] +name = "indexmap" +version = "2.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b0f83760fb341a774ed326568e19f5a863af4a952def8c39f9ab92fd95b88e5" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "is-terminal" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "js-sys" +version = "0.3.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b011eec8cc36da2aab2d5cff675ec18454fad408585853910a202391cf9f8e65" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.177" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" + +[[package]] +name = "libm" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "approx", + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "ndarray-rand" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f093b3db6fd194718dcdeea6bd8c829417deae904e3fcc7732dabcd4416d25d8" +dependencies = [ + "ndarray", + "rand 0.8.5", + "rand_distr", +] + +[[package]] +name = "ndarray-stats" +version = "0.6.0" +dependencies = [ + "approx", + "criterion", + "indexmap", + "itertools 0.13.0", + "ndarray", + "ndarray-rand", + "noisy_float", + "num-bigint", + "num-integer", + "num-traits", + "quickcheck", + "quickcheck_macros", + "rand 0.8.5", +] + +[[package]] +name = "noisy_float" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978fe6e6ebc0bf53de533cd456ca2d9de13de13856eda1518a285d7705a213af" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "once_cell" +version = "1.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" + +[[package]] +name = "oorandom" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "plotters" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a" + +[[package]] +name = "plotters-svg" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "portable-atomic" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quickcheck" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44883e74aa97ad63db83c4bf8ca490f02b2fc02f92575e720c8551e843c945f" +dependencies = [ + "rand 0.7.3", + "rand_core 0.5.1", +] + +[[package]] +name = "quickcheck_macros" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f71ee38b42f8459a88d3362be6f9b841ad2d5421844f61eb1c59c11bff3ac14a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce25767e7b499d1b604768e7cde645d14cc8584231ea6b295e9c9eb22c02e1d1" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +dependencies = [ + "getrandom 0.1.16", + "libc", + "rand_chacha 0.2.2", + "rand_core 0.5.1", + "rand_hc", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +dependencies = [ + "ppv-lite86", + "rand_core 0.5.1", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom 0.1.16", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.16", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand 0.8.5", +] + +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +dependencies = [ + "rand_core 0.5.1", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "regex" +version = "1.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebee201405406dbf528b8b672104ae6d6d63e6d118cb10e4d51abbc7b58044ff" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.145" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", + "serde_core", +] + +[[package]] +name = "syn" +version = "2.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasm-bindgen" +version = "0.2.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da95793dfc411fbbd93f5be7715b0578ec61fe87cb1a42b12eb625caa5c5ea60" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04264334509e04a7bf8690f2384ef5265f05143a4bff3889ab7a3269adab59c2" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "420bc339d9f322e562942d52e115d57e950d12d88983a14c79b86859ee6c7ebc" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76f218a38c84bcb33c25ec7059b07847d465ce0e0a76b995e134a45adcb6af76" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "web-sys" +version = "0.3.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a1f95c0d03a47f4ae1f7a64643a6bb97465d9b740f0fa8f90ea33915c99a9a1" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "zerocopy" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 7e135862..9377d81e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,16 @@ [package] name = "ndarray-stats" -version = "0.2.0" -authors = ["Jim Turner ", "LukeMathWalker "] +version = "0.6.0" +authors = [ + "Jim Turner ", + "LukeMathWalker ", +] edition = "2018" +rust-version = "1.64.0" license = "MIT/Apache-2.0" -repository = "https://github.com/jturner314/ndarray-stats" +repository = "https://github.com/rust-ndarray/ndarray-stats" documentation = "https://docs.rs/ndarray-stats/" readme = "README.md" @@ -16,21 +20,31 @@ keywords = ["array", "multidimensional", "statistics", "matrix", "ndarray"] categories = ["data-structures", "science"] [dependencies] -ndarray = "0.12.1" -noisy_float = "0.1.8" +ndarray = "0.16.0" +noisy_float = "0.2.0" num-integer = "0.1" num-traits = "0.2" -rand = "0.6" -itertools = { version = "0.8.0", default-features = false } -indexmap = "1.0" +rand = "0.8.3" +itertools = { version = "0.13", default-features = false } +indexmap = "2.4" [dev-dependencies] -criterion = "0.2" -quickcheck = { version = "0.8.1", default-features = false } -ndarray-rand = "0.9" -approx = "0.3" -quickcheck_macros = "0.8" +ndarray = { version = "0.16.1", features = ["approx"] } +criterion = "0.5.1" +quickcheck = { version = "0.9.2", default-features = false } +ndarray-rand = "0.15.0" +approx = "0.5" +quickcheck_macros = "1.0.0" +num-bigint = "0.4.0" [[bench]] name = "sort" harness = false + +[[bench]] +name = "summary_statistics" +harness = false + +[[bench]] +name = "deviation" +harness = false diff --git a/LICENSE-MIT b/LICENSE-MIT index 02f8c44f..c750fecc 100644 --- a/LICENSE-MIT +++ b/LICENSE-MIT @@ -1,4 +1,4 @@ -Copyright 2018 ndarray-stats developers +Copyright 2018–2024 ndarray-stats developers Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/README.md b/README.md index 54fa5b3d..3a565da8 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,7 @@ # ndarray-stats -[![Build status](https://travis-ci.org/jturner314/ndarray-stats.svg?branch=master)](https://travis-ci.org/jturner314/ndarray-stats) -[![Coverage](https://codecov.io/gh/jturner314/ndarray-stats/branch/master/graph/badge.svg)](https://codecov.io/gh/jturner314/ndarray-stats) -[![Dependencies status](https://deps.rs/repo/github/jturner314/ndarray-stats/status.svg)](https://deps.rs/repo/github/jturner314/ndarray-stats) +[![Coverage](https://codecov.io/gh/rust-ndarray/ndarray-stats/branch/master/graph/badge.svg)](https://codecov.io/gh/rust-ndarray/ndarray-stats) +[![Dependencies status](https://deps.rs/repo/github/rust-ndarray/ndarray-stats/status.svg)](https://deps.rs/repo/github/rust-ndarray/ndarray-stats) [![Crate](https://img.shields.io/crates/v/ndarray-stats.svg)](https://crates.io/crates/ndarray-stats) [![Documentation](https://docs.rs/ndarray-stats/badge.svg)](https://docs.rs/ndarray-stats) @@ -14,11 +13,12 @@ Currently available routines include: - partitioning; - correlation analysis (covariance, pearson correlation); - measures from information theory (entropy, KL divergence, etc.); +- deviation functions (distances, counts, errors, etc.); - histogram computation. See the [documentation](https://docs.rs/ndarray-stats) for more information. -Please feel free to contribute new functionality! A roadmap can be found [here](https://github.com/jturner314/ndarray-stats/issues/1). +Please feel free to contribute new functionality! A roadmap can be found [here](https://github.com/rust-ndarray/ndarray-stats/issues/1). [`ndarray`]: https://github.com/rust-ndarray/ndarray @@ -26,14 +26,79 @@ Please feel free to contribute new functionality! A roadmap can be found [here]( ```toml [dependencies] -ndarray = "0.12.1" -ndarray-stats = "0.2" +ndarray = "0.16" +ndarray-stats = "0.6.0" ``` ## Releases +* **0.6.0** + + * Breaking changes + * Minimum supported Rust version: `1.64.0` + * Updated to `ndarray:v0.16.0` + * Updated to `approx:v0.5.0` + + * Updated to `ndarray-rand:v0.15.0` + * Updated to `indexmap:v2.4` + * Updated to `itertools:v0.13` + + *Contributors*: [@bluss](https://github.com/bluss) + +* **0.5.1** + * Fixed bug in implementation of `MaybeNaN::remove_nan_mut` for `f32` and + `f64` for views with non-standard layouts. Before this fix, the bug could + cause incorrect results, buffer overflows, etc., in this method and others + which use it. Thanks to [@JacekCzupyt](https://github.com/JacekCzupyt) for + reporting the issue (#89). + * Minor docs improvements. + + *Contributors*: [@jturner314](https://github.com/jturner314), [@BenMoon](https://github.com/BenMoon) + +* **0.5.0** + * Breaking changes + * Minimum supported Rust version: `1.49.0` + * Updated to `ndarray:v0.15.0` + + *Contributors*: [@Armavica](https://github.com/armavica), [@cassiersg](https://github.com/cassiersg) + +* **0.4.0** + * Breaking changes + * Minimum supported Rust version: `1.42.0` + * New functionality: + * Summary statistics: + * Weighted variance + * Weighted standard deviation + * Improvements / breaking changes: + * Documentation improvements for Histograms + * Updated to `ndarray:v0.14.0` + + *Contributors*: [@munckymagik](https://github.com/munckymagik), [@nilgoyette](https://github.com/nilgoyette), [@LukeMathWalker](https://github.com/LukeMathWalker), [@lebensterben](https://github.com/lebensterben), [@xd009642](https://github.com/xd009642) + +* **0.3.0** + + * Breaking changes + * Minimum supported Rust version: `1.37` + * New functionality: + * Deviation functions: + * Counts equal/unequal + * `l1`, `l2`, `linf` distances + * (Root) mean squared error + * Peak signal-to-noise ratio + * Summary statistics: + * Weighted sum + * Weighted mean + * Improvements / breaking changes: + * Updated to `ndarray:v0.13.0` + + *Contributors*: [@munckymagik](https://github.com/munckymagik), [@nilgoyette](https://github.com/nilgoyette), [@jturner314](https://github.com/jturner314), [@LukeMathWalker](https://github.com/LukeMathWalker) + * **0.2.0** + * Breaking changes + * All `ndarray-stats`' extension traits are now impossible to implement by + users of the library (see [#34]) + * Redesigned error handling across the whole crate, standardising on `Result` * New functionality: * Summary statistics: * Harmonic mean @@ -51,14 +116,10 @@ ndarray-stats = "0.2" * Optimized bulk quantile computation (`quantiles_mut`, `quantiles_axis_mut`) * Fixes: * Reduced occurrences of overflow for `interpolate::midpoint` - * Improvements / breaking changes: - * Redesigned error handling across the whole crate, standardising on `Result` - * All `ndarray-stats`' extension traits are now impossible to implement by - users of the library (see [#34]) *Contributors*: [@jturner314](https://github.com/jturner314), [@LukeMathWalker](https://github.com/LukeMathWalker), [@phungleson](https://github.com/phungleson), [@munckymagik](https://github.com/munckymagik) - [#34]: https://github.com/jturner314/ndarray-stats/issues/34 + [#34]: https://github.com/rust-ndarray/ndarray-stats/issues/34 * **0.1.0** @@ -70,7 +131,7 @@ Please feel free to create issues and submit PRs. ## License -Copyright 2018 `ndarray-stats` developers +Copyright 2018–2024 `ndarray-stats` developers Licensed under the [Apache License, Version 2.0](LICENSE-APACHE), or the [MIT license](LICENSE-MIT), at your option. You may not use this project except in diff --git a/benches/deviation.rs b/benches/deviation.rs new file mode 100644 index 00000000..c0ceecb5 --- /dev/null +++ b/benches/deviation.rs @@ -0,0 +1,29 @@ +use criterion::{ + black_box, criterion_group, criterion_main, AxisScale, Criterion, PlotConfiguration, +}; +use ndarray::prelude::*; +use ndarray_rand::rand_distr::Uniform; +use ndarray_rand::RandomExt; +use ndarray_stats::DeviationExt; + +fn sq_l2_dist(c: &mut Criterion) { + let lens = vec![10, 100, 1000, 10000]; + let mut group = c.benchmark_group("sq_l2_dist"); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for len in &lens { + group.bench_with_input(format!("{}", len), len, |b, &len| { + let data = Array::random(len, Uniform::new(0.0, 1.0)); + let data2 = Array::random(len, Uniform::new(0.0, 1.0)); + + b.iter(|| black_box(data.sq_l2_dist(&data2).unwrap())) + }); + } + group.finish(); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = sq_l2_dist +} +criterion_main!(benches); diff --git a/benches/sort.rs b/benches/sort.rs index 03bce688..1a2f4429 100644 --- a/benches/sort.rs +++ b/benches/sort.rs @@ -1,6 +1,5 @@ use criterion::{ - black_box, criterion_group, criterion_main, AxisScale, BatchSize, Criterion, - ParameterizedBenchmark, PlotConfiguration, + black_box, criterion_group, criterion_main, AxisScale, BatchSize, Criterion, PlotConfiguration, }; use ndarray::prelude::*; use ndarray_stats::Sort1dExt; @@ -8,14 +7,15 @@ use rand::prelude::*; fn get_from_sorted_mut(c: &mut Criterion) { let lens = vec![10, 100, 1000, 10000]; - let benchmark = ParameterizedBenchmark::new( - "get_from_sorted_mut", - |bencher, &len| { + let mut group = c.benchmark_group("get_from_sorted_mut"); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for len in &lens { + group.bench_with_input(format!("{}", len), len, |b, &len| { let mut rng = StdRng::seed_from_u64(42); let mut data: Vec<_> = (0..len).collect(); data.shuffle(&mut rng); let indices: Vec<_> = (0..len).step_by(len / 10).collect(); - bencher.iter_batched( + b.iter_batched( || Array1::from(data.clone()), |mut arr| { for &i in &indices { @@ -24,34 +24,31 @@ fn get_from_sorted_mut(c: &mut Criterion) { }, BatchSize::SmallInput, ) - }, - lens, - ) - .plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); - c.bench("get_from_sorted_mut", benchmark); + }); + } + group.finish(); } fn get_many_from_sorted_mut(c: &mut Criterion) { let lens = vec![10, 100, 1000, 10000]; - let benchmark = ParameterizedBenchmark::new( - "get_many_from_sorted_mut", - |bencher, &len| { + let mut group = c.benchmark_group("get_many_from_sorted_mut"); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for len in &lens { + group.bench_with_input(format!("{}", len), len, |b, &len| { let mut rng = StdRng::seed_from_u64(42); let mut data: Vec<_> = (0..len).collect(); data.shuffle(&mut rng); let indices: Array1<_> = (0..len).step_by(len / 10).collect(); - bencher.iter_batched( + b.iter_batched( || Array1::from(data.clone()), |mut arr| { black_box(arr.get_many_from_sorted_mut(&indices)); }, BatchSize::SmallInput, ) - }, - lens, - ) - .plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); - c.bench("get_many_from_sorted_mut", benchmark); + }); + } + group.finish(); } criterion_group! { diff --git a/benches/summary_statistics.rs b/benches/summary_statistics.rs new file mode 100644 index 00000000..5796fc02 --- /dev/null +++ b/benches/summary_statistics.rs @@ -0,0 +1,35 @@ +use criterion::{ + black_box, criterion_group, criterion_main, AxisScale, BatchSize, Criterion, PlotConfiguration, +}; +use ndarray::prelude::*; +use ndarray_rand::rand_distr::Uniform; +use ndarray_rand::RandomExt; +use ndarray_stats::SummaryStatisticsExt; + +fn weighted_std(c: &mut Criterion) { + let lens = vec![10, 100, 1000, 10000]; + let mut group = c.benchmark_group("weighted_std"); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for len in &lens { + group.bench_with_input(format!("{}", len), len, |b, &len| { + let data = Array::random(len, Uniform::new(0.0, 1.0)); + let mut weights = Array::random(len, Uniform::new(0.0, 1.0)); + weights /= weights.sum(); + b.iter_batched( + || data.clone(), + |arr| { + black_box(arr.weighted_std(&weights, 0.0).unwrap()); + }, + BatchSize::SmallInput, + ) + }); + } + group.finish(); +} + +criterion_group! { + name = benches; + config = Criterion::default(); + targets = weighted_std +} +criterion_main!(benches); diff --git a/src/correlation.rs b/src/correlation.rs index 9985ad87..5ae194ba 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -1,3 +1,4 @@ +use crate::errors::EmptyInput; use ndarray::prelude::*; use ndarray::Data; use num_traits::{Float, FromPrimitive}; @@ -41,10 +42,10 @@ where /// ``` /// and similarly for ̅y. /// - /// **Panics** if `ddof` is greater than or equal to the number of - /// observations, if the number of observations is zero and division by - /// zero panics for type `A`, or if the type cast of `n_observations` from - /// `usize` to `A` fails. + /// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`. + /// + /// **Panics** if `ddof` is negative or greater than or equal to the number of + /// observations, or if the type cast of `n_observations` from `usize` to `A` fails. /// /// # Example /// @@ -54,13 +55,13 @@ where /// /// let a = arr2(&[[1., 3., 5.], /// [2., 4., 6.]]); - /// let covariance = a.cov(1.); + /// let covariance = a.cov(1.).unwrap(); /// assert_eq!( /// covariance, /// aview2(&[[4., 4.], [4., 4.]]) /// ); /// ``` - fn cov(&self, ddof: A) -> Array2 + fn cov(&self, ddof: A) -> Result, EmptyInput> where A: Float + FromPrimitive; @@ -89,30 +90,35 @@ where /// R_ij = rho(X_i, X_j) /// ``` /// - /// **Panics** if `M` is empty, if the type cast of `n_observations` - /// from `usize` to `A` fails or if the standard deviation of one of the random + /// If `M` is empty (either zero observations or zero random variables), it returns `Err(EmptyInput)`. + /// + /// **Panics** if the type cast of `n_observations` from `usize` to `A` fails or + /// if the standard deviation of one of the random variables is zero and + /// division by zero panics for type A. /// /// # Example /// - /// variables is zero and division by zero panics for type A. /// ``` + /// use approx; /// use ndarray::arr2; /// use ndarray_stats::CorrelationExt; + /// use approx::AbsDiffEq; /// /// let a = arr2(&[[1., 3., 5.], /// [2., 4., 6.]]); - /// let corr = a.pearson_correlation(); + /// let corr = a.pearson_correlation().unwrap(); + /// let epsilon = 1e-7; /// assert!( - /// corr.all_close( + /// corr.abs_diff_eq( /// &arr2(&[ /// [1., 1.], /// [1., 1.], /// ]), - /// 1e-7 + /// epsilon /// ) /// ); /// ``` - fn pearson_correlation(&self) -> Array2 + fn pearson_correlation(&self) -> Result, EmptyInput> where A: Float + FromPrimitive; @@ -123,7 +129,7 @@ impl CorrelationExt for ArrayBase where S: Data, { - fn cov(&self, ddof: A) -> Array2 + fn cov(&self, ddof: A) -> Result, EmptyInput> where A: Float + FromPrimitive, { @@ -139,28 +145,37 @@ where n_observations - ddof }; let mean = self.mean_axis(observation_axis); - let denoised = self - &mean.insert_axis(observation_axis); - let covariance = denoised.dot(&denoised.t()); - covariance.mapv_into(|x| x / dof) + match mean { + Some(mean) => { + let denoised = self - &mean.insert_axis(observation_axis); + let covariance = denoised.dot(&denoised.t()); + Ok(covariance.mapv_into(|x| x / dof)) + } + None => Err(EmptyInput), + } } - fn pearson_correlation(&self) -> Array2 + fn pearson_correlation(&self) -> Result, EmptyInput> where A: Float + FromPrimitive, { - let observation_axis = Axis(1); - // The ddof value doesn't matter, as long as we use the same one - // for computing covariance and standard deviation - // We choose -1 to avoid panicking when we only have one - // observation per random variable (or no observations at all) - let ddof = -A::one(); - let cov = self.cov(ddof); - let std = self - .std_axis(observation_axis, ddof) - .insert_axis(observation_axis); - let std_matrix = std.dot(&std.t()); - // element-wise division - cov / std_matrix + match self.dim() { + (n, m) if n > 0 && m > 0 => { + let observation_axis = Axis(1); + // The ddof value doesn't matter, as long as we use the same one + // for computing covariance and standard deviation + // We choose 0 as it is the smallest number admitted by std_axis + let ddof = A::zero(); + let cov = self.cov(ddof).unwrap(); + let std = self + .std_axis(observation_axis, ddof) + .insert_axis(observation_axis); + let std_matrix = std.dot(&std.t()); + // element-wise division + Ok(cov / std_matrix) + } + _ => Err(EmptyInput), + } } private_impl! {} @@ -170,19 +185,20 @@ where mod cov_tests { use super::*; use ndarray::array; + use ndarray_rand::rand; + use ndarray_rand::rand_distr::Uniform; use ndarray_rand::RandomExt; use quickcheck_macros::quickcheck; - use rand; - use rand::distributions::Uniform; #[quickcheck] fn constant_random_variables_have_zero_covariance_matrix(value: f64) -> bool { let n_random_variables = 3; let n_observations = 4; let a = Array::from_elem((n_random_variables, n_observations), value); - a.cov(1.).all_close( + abs_diff_eq!( + a.cov(1.).unwrap(), &Array::zeros((n_random_variables, n_random_variables)), - 1e-8, + epsilon = 1e-8, ) } @@ -194,8 +210,8 @@ mod cov_tests { (n_random_variables, n_observations), Uniform::new(-bound.abs(), bound.abs()), ); - let covariance = a.cov(1.); - covariance.all_close(&covariance.t(), 1e-8) + let covariance = a.cov(1.).unwrap(); + abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8) } #[test] @@ -205,14 +221,15 @@ mod cov_tests { let n_observations = 4; let a = Array::random((n_random_variables, n_observations), Uniform::new(0., 10.)); let invalid_ddof = (n_observations as f64) + rand::random::().abs(); - a.cov(invalid_ddof); + let _ = a.cov(invalid_ddof); } #[test] fn test_covariance_zero_variables() { let a = Array2::::zeros((0, 2)); let cov = a.cov(1.); - assert_eq!(cov.shape(), &[0, 0]); + assert!(cov.is_ok()); + assert_eq!(cov.unwrap().shape(), &[0, 0]); } #[test] @@ -220,8 +237,7 @@ mod cov_tests { let a = Array2::::zeros((2, 0)); // Negative ddof (-1 < 0) to avoid invalid-ddof panic let cov = a.cov(-1.); - assert_eq!(cov.shape(), &[2, 2]); - cov.mapv(|x| assert_eq!(x, 0.)); + assert_eq!(cov, Err(EmptyInput)); } #[test] @@ -229,7 +245,7 @@ mod cov_tests { let a = Array2::::zeros((0, 0)); // Negative ddof (-1 < 0) to avoid invalid-ddof panic let cov = a.cov(-1.); - assert_eq!(cov.shape(), &[0, 0]); + assert_eq!(cov, Err(EmptyInput)); } #[test] @@ -255,7 +271,7 @@ mod cov_tests { ] ]; assert_eq!(a.ndim(), 2); - assert!(a.cov(1.).all_close(&numpy_covariance, 1e-8)); + assert_abs_diff_eq!(a.cov(1.).unwrap(), &numpy_covariance, epsilon = 1e-8); } #[test] @@ -264,7 +280,7 @@ mod cov_tests { fn test_covariance_for_badly_conditioned_array() { let a: Array2 = array![[1e12 + 1., 1e12 - 1.], [1e-6 + 1e-12, 1e-6 - 1e-12],]; let expected_covariance = array![[2., 2e-12], [2e-12, 2e-24]]; - assert!(a.cov(1.).all_close(&expected_covariance, 1e-24)); + assert_abs_diff_eq!(a.cov(1.).unwrap(), &expected_covariance, epsilon = 1e-24); } } @@ -272,9 +288,10 @@ mod cov_tests { mod pearson_correlation_tests { use super::*; use ndarray::array; + use ndarray::Array; + use ndarray_rand::rand_distr::Uniform; use ndarray_rand::RandomExt; use quickcheck_macros::quickcheck; - use rand::distributions::Uniform; #[quickcheck] fn output_matrix_is_symmetric(bound: f64) -> bool { @@ -284,8 +301,12 @@ mod pearson_correlation_tests { (n_random_variables, n_observations), Uniform::new(-bound.abs(), bound.abs()), ); - let pearson_correlation = a.pearson_correlation(); - pearson_correlation.all_close(&pearson_correlation.t(), 1e-8) + let pearson_correlation = a.pearson_correlation().unwrap(); + abs_diff_eq!( + pearson_correlation.view(), + pearson_correlation.t(), + epsilon = 1e-8 + ) } #[quickcheck] @@ -295,6 +316,7 @@ mod pearson_correlation_tests { let a = Array::from_elem((n_random_variables, n_observations), value); let pearson_correlation = a.pearson_correlation(); pearson_correlation + .unwrap() .iter() .map(|x| x.is_nan()) .fold(true, |acc, flag| acc & flag) @@ -304,21 +326,21 @@ mod pearson_correlation_tests { fn test_zero_variables() { let a = Array2::::zeros((0, 2)); let pearson_correlation = a.pearson_correlation(); - assert_eq!(pearson_correlation.shape(), &[0, 0]); + assert_eq!(pearson_correlation, Err(EmptyInput)) } #[test] fn test_zero_observations() { let a = Array2::::zeros((2, 0)); let pearson = a.pearson_correlation(); - pearson.mapv(|x| x.is_nan()); + assert_eq!(pearson, Err(EmptyInput)); } #[test] fn test_zero_variables_zero_observations() { let a = Array2::::zeros((0, 0)); let pearson = a.pearson_correlation(); - assert_eq!(pearson.shape(), &[0, 0]); + assert_eq!(pearson, Err(EmptyInput)); } #[test] @@ -338,7 +360,10 @@ mod pearson_correlation_tests { [0.1365648, 0.38954398, -0.17324776, -0.8743213, 1.] ]; assert_eq!(a.ndim(), 2); - assert!(a.pearson_correlation().all_close(&numpy_corrcoeff, 1e-7)); + assert_abs_diff_eq!( + a.pearson_correlation().unwrap(), + numpy_corrcoeff, + epsilon = 1e-7 + ); } - } diff --git a/src/deviation.rs b/src/deviation.rs new file mode 100644 index 00000000..de85885f --- /dev/null +++ b/src/deviation.rs @@ -0,0 +1,377 @@ +use ndarray::{ArrayBase, Data, Dimension, Zip}; +use num_traits::{Signed, ToPrimitive}; +use std::convert::Into; +use std::ops::AddAssign; + +use crate::errors::MultiInputError; + +/// An extension trait for `ArrayBase` providing functions +/// to compute different deviation measures. +pub trait DeviationExt +where + S: Data, + D: Dimension, +{ + /// Counts the number of indices at which the elements of the arrays `self` + /// and `other` are equal. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + fn count_eq(&self, other: &ArrayBase) -> Result + where + A: PartialEq, + T: Data; + + /// Counts the number of indices at which the elements of the arrays `self` + /// and `other` are not equal. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + fn count_neq(&self, other: &ArrayBase) -> Result + where + A: PartialEq, + T: Data; + + /// Computes the [squared L2 distance] between `self` and `other`. + /// + /// ```text + /// n + /// ∑ |aᵢ - bᵢ|² + /// i=1 + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance + fn sq_l2_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed, + T: Data; + + /// Computes the [L2 distance] between `self` and `other`. + /// + /// ```text + /// n + /// √ ( ∑ |aᵢ - bᵢ|² ) + /// i=1 + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// **Panics** if the type cast from `A` to `f64` fails. + /// + /// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance + fn l2_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + T: Data; + + /// Computes the [L1 distance] between `self` and `other`. + /// + /// ```text + /// n + /// ∑ |aᵢ - bᵢ| + /// i=1 + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry + fn l1_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed, + T: Data; + + /// Computes the [L∞ distance] between `self` and `other`. + /// + /// ```text + /// max(|aᵢ - bᵢ|) + /// ᵢ + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance + fn linf_dist(&self, other: &ArrayBase) -> Result + where + A: Clone + PartialOrd + Signed, + T: Data; + + /// Computes the [mean absolute error] between `self` and `other`. + /// + /// ```text + /// n + /// 1/n * ∑ |aᵢ - bᵢ| + /// i=1 + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// **Panics** if the type cast from `A` to `f64` fails. + /// + /// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error + fn mean_abs_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + T: Data; + + /// Computes the [mean squared error] between `self` and `other`. + /// + /// ```text + /// n + /// 1/n * ∑ |aᵢ - bᵢ|² + /// i=1 + /// ``` + /// + /// where `self` is `a` and `other` is `b`. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// **Panics** if the type cast from `A` to `f64` fails. + /// + /// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error + fn mean_sq_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + T: Data; + + /// Computes the unnormalized [root-mean-square error] between `self` and `other`. + /// + /// ```text + /// √ mse(a, b) + /// ``` + /// + /// where `self` is `a`, `other` is `b` and `mse` is the mean-squared-error. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// **Panics** if the type cast from `A` to `f64` fails. + /// + /// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation + fn root_mean_sq_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + T: Data; + + /// Computes the [peak signal-to-noise ratio] between `self` and `other`. + /// + /// ```text + /// 10 * log10(maxv^2 / mse(a, b)) + /// ``` + /// + /// where `self` is `a`, `other` is `b`, `mse` is the mean-squared-error + /// and `maxv` is the maximum possible value either array can take. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape + /// + /// **Panics** if the type cast from `A` to `f64` fails. + /// + /// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + fn peak_signal_to_noise_ratio( + &self, + other: &ArrayBase, + maxv: A, + ) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + T: Data; + + private_decl! {} +} + +impl DeviationExt for ArrayBase +where + S: Data, + D: Dimension, +{ + fn count_eq(&self, other: &ArrayBase) -> Result + where + A: PartialEq, + T: Data, + { + return_err_if_empty!(self); + return_err_unless_same_shape!(self, other); + + let mut count = 0; + + Zip::from(self).and(other).for_each(|a, b| { + if a == b { + count += 1; + } + }); + + Ok(count) + } + + fn count_neq(&self, other: &ArrayBase) -> Result + where + A: PartialEq, + T: Data, + { + self.count_eq(other).map(|n_eq| self.len() - n_eq) + } + + fn sq_l2_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed, + T: Data, + { + return_err_if_empty!(self); + return_err_unless_same_shape!(self, other); + + let mut result = A::zero(); + + Zip::from(self).and(other).for_each(|self_i, other_i| { + let (a, b) = (self_i.clone(), other_i.clone()); + let diff = a - b; + result += diff.clone() * diff; + }); + + Ok(result) + } + + fn l2_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + T: Data, + { + let sq_l2_dist = self + .sq_l2_dist(other)? + .to_f64() + .expect("failed cast from type A to f64"); + + Ok(sq_l2_dist.sqrt()) + } + + fn l1_dist(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed, + T: Data, + { + return_err_if_empty!(self); + return_err_unless_same_shape!(self, other); + + let mut result = A::zero(); + + Zip::from(self).and(other).for_each(|self_i, other_i| { + let (a, b) = (self_i.clone(), other_i.clone()); + result += (a - b).abs(); + }); + + Ok(result) + } + + fn linf_dist(&self, other: &ArrayBase) -> Result + where + A: Clone + PartialOrd + Signed, + T: Data, + { + return_err_if_empty!(self); + return_err_unless_same_shape!(self, other); + + let mut max = A::zero(); + + Zip::from(self).and(other).for_each(|self_i, other_i| { + let (a, b) = (self_i.clone(), other_i.clone()); + let diff = (a - b).abs(); + if diff > max { + max = diff; + } + }); + + Ok(max) + } + + fn mean_abs_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + T: Data, + { + let l1_dist = self + .l1_dist(other)? + .to_f64() + .expect("failed cast from type A to f64"); + let n = self.len() as f64; + + Ok(l1_dist / n) + } + + fn mean_sq_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + T: Data, + { + let sq_l2_dist = self + .sq_l2_dist(other)? + .to_f64() + .expect("failed cast from type A to f64"); + let n = self.len() as f64; + + Ok(sq_l2_dist / n) + } + + fn root_mean_sq_err(&self, other: &ArrayBase) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + T: Data, + { + let msd = self.mean_sq_err(other)?; + Ok(msd.sqrt()) + } + + fn peak_signal_to_noise_ratio( + &self, + other: &ArrayBase, + maxv: A, + ) -> Result + where + A: AddAssign + Clone + Signed + ToPrimitive, + T: Data, + { + let maxv_f = maxv.to_f64().expect("failed cast from type A to f64"); + let msd = self.mean_sq_err(&other)?; + let psnr = 10. * f64::log10(maxv_f * maxv_f / msd); + + Ok(psnr) + } + + private_impl! {} +} diff --git a/src/entropy.rs b/src/entropy.rs index 77e9d78f..e029729b 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -133,7 +133,7 @@ where where A: Float, { - if self.len() == 0 { + if self.is_empty() { Err(EmptyInput) } else { let entropy = -self @@ -154,7 +154,7 @@ where A: Float, S2: Data, { - if self.len() == 0 { + if self.is_empty() { return Err(MultiInputError::EmptyInput); } if self.shape() != q.shape() { @@ -169,7 +169,7 @@ where Zip::from(&mut temp) .and(self) .and(q) - .apply(|result, &p, &q| { + .for_each(|result, &p, &q| { *result = { if p == A::zero() { A::zero() @@ -187,7 +187,7 @@ where S2: Data, A: Float, { - if self.len() == 0 { + if self.is_empty() { return Err(MultiInputError::EmptyInput); } if self.shape() != q.shape() { @@ -202,7 +202,7 @@ where Zip::from(&mut temp) .and(self) .and(q) - .apply(|result, &p, &q| { + .for_each(|result, &p, &q| { *result = { if p == A::zero() { A::zero() diff --git a/src/errors.rs b/src/errors.rs index 2386a301..e2617f39 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -46,7 +46,7 @@ impl From for MinMaxError { /// An error used by methods and functions that take two arrays as argument and /// expect them to have exactly the same shape /// (e.g. `ShapeMismatch` is raised when `a.shape() == b.shape()` evaluates to `False`). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub struct ShapeMismatch { pub first_shape: Vec, pub second_shape: Vec, @@ -65,7 +65,7 @@ impl fmt::Display for ShapeMismatch { impl Error for ShapeMismatch {} /// An error for methods that take multiple non-empty array inputs. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum MultiInputError { /// One or more of the arrays were empty. EmptyInput, diff --git a/src/histogram/bins.rs b/src/histogram/bins.rs index 887f3e04..f6ff818e 100644 --- a/src/histogram/bins.rs +++ b/src/histogram/bins.rs @@ -1,42 +1,53 @@ +#![warn(missing_docs, clippy::all, clippy::pedantic)] + use ndarray::prelude::*; use std::ops::{Index, Range}; -/// `Edges` is a sorted collection of `A` elements used -/// to represent the boundaries of intervals ([`Bins`]) on -/// a 1-dimensional axis. +/// A sorted collection of type `A` elements used to represent the boundaries of intervals, i.e. +/// [`Bins`] on a 1-dimensional axis. /// -/// [`Bins`]: struct.Bins.html -/// # Example: +/// **Note** that all intervals are left-closed and right-open. See examples below. +/// +/// # Examples /// /// ``` -/// use ndarray_stats::histogram::{Edges, Bins}; +/// use ndarray_stats::histogram::{Bins, Edges}; /// use noisy_float::types::n64; /// /// let unit_edges = Edges::from(vec![n64(0.), n64(1.)]); /// let unit_interval = Bins::new(unit_edges); -/// // left inclusive +/// // left-closed /// assert_eq!( /// unit_interval.range_of(&n64(0.)).unwrap(), /// n64(0.)..n64(1.), /// ); -/// // right exclusive +/// // right-open /// assert_eq!( /// unit_interval.range_of(&n64(1.)), /// None /// ); /// ``` +/// +/// [`Bins`]: struct.Bins.html #[derive(Clone, Debug, Eq, PartialEq)] pub struct Edges { edges: Vec, } impl From> for Edges { - /// Get an `Edges` instance from a `Vec`: - /// the vector will be sorted in increasing order - /// using an unstable sorting algorithm and duplicates - /// will be removed. + /// Converts a `Vec` into an `Edges`, consuming the edges. + /// The vector will be sorted in increasing order using an unstable sorting algorithm, with + /// duplicates removed. + /// + /// # Current implementation /// - /// # Example: + /// The current sorting algorithm is the same as [`std::slice::sort_unstable()`][sort], + /// which is based on [pattern-defeating quicksort][pdqsort]. + /// + /// This sort is unstable (i.e., may reorder equal elements), in-place (i.e., does not allocate) + /// , and O(n log n) worst-case. + /// + /// # Examples /// /// ``` /// use ndarray::array; @@ -49,6 +60,9 @@ impl From> for Edges { /// 15 /// ); /// ``` + /// + /// [sort]: https://doc.rust-lang.org/stable/std/primitive.slice.html#method.sort_unstable + /// [pdqsort]: https://github.com/orlp/pdqsort fn from(mut edges: Vec) -> Self { // sort the array in-place edges.sort_unstable(); @@ -59,11 +73,19 @@ impl From> for Edges { } impl From> for Edges { - /// Get an `Edges` instance from a `Array1`: - /// the array elements will be sorted in increasing order - /// using an unstable sorting algorithm and duplicates will be removed. + /// Converts an `Array1` into an `Edges`, consuming the 1-dimensional array. + /// The array will be sorted in increasing order using an unstable sorting algorithm, with + /// duplicates removed. + /// + /// # Current implementation + /// + /// The current sorting algorithm is the same as [`std::slice::sort_unstable()`][sort], + /// which is based on [pattern-defeating quicksort][pdqsort]. + /// + /// This sort is unstable (i.e., may reorder equal elements), in-place (i.e., does not allocate) + /// , and O(n log n) worst-case. /// - /// # Example: + /// # Examples /// /// ``` /// use ndarray_stats::histogram::Edges; @@ -75,6 +97,9 @@ impl From> for Edges { /// 10 /// ); /// ``` + /// + /// [sort]: https://doc.rust-lang.org/stable/std/primitive.slice.html#method.sort_unstable + /// [pdqsort]: https://github.com/orlp/pdqsort fn from(edges: Array1) -> Self { let edges = edges.to_vec(); Self::from(edges) @@ -84,11 +109,13 @@ impl From> for Edges { impl Index for Edges { type Output = A; - /// Get the `i`-th edge. + /// Returns a reference to the `i`-th edge in `self`. + /// + /// # Panics /// - /// **Panics** if the index `i` is out of bounds. + /// Panics if the index `i` is out of bounds. /// - /// # Example: + /// # Examples /// /// ``` /// use ndarray_stats::histogram::Edges; @@ -105,9 +132,9 @@ impl Index for Edges { } impl Edges { - /// Number of edges in `self`. + /// Returns the number of edges in `self`. /// - /// # Example: + /// # Examples /// /// ``` /// use ndarray_stats::histogram::Edges; @@ -119,14 +146,33 @@ impl Edges { /// 3 /// ); /// ``` + #[must_use] pub fn len(&self) -> usize { self.edges.len() } - /// Borrow an immutable reference to the edges as a 1-dimensional - /// array view. + /// Returns `true` if `self` contains no edges. + /// + /// # Examples + /// + /// ``` + /// use ndarray_stats::histogram::Edges; + /// use noisy_float::types::{N64, n64}; + /// + /// let edges = Edges::::from(vec![]); + /// assert_eq!(edges.is_empty(), true); + /// + /// let edges = Edges::from(vec![n64(0.), n64(2.), n64(5.)]); + /// assert_eq!(edges.is_empty(), false); + /// ``` + #[must_use] + pub fn is_empty(&self) -> bool { + self.edges.is_empty() + } + + /// Returns an immutable 1-dimensional array view of edges. /// - /// # Example: + /// # Examples /// /// ``` /// use ndarray::array; @@ -138,25 +184,31 @@ impl Edges { /// array![0, 3, 5].view() /// ); /// ``` + #[must_use] pub fn as_array_view(&self) -> ArrayView1<'_, A> { ArrayView1::from(&self.edges) } - /// Given `value`, it returns an option: - /// - `Some((left, right))`, where `right=left+1`, if there are two consecutive edges in - /// `self` such that `self[left] <= value < self[right]`; + /// Returns indices of two consecutive `edges` in `self`, if the interval they represent + /// contains the given `value`, or returns `None` otherwise. + /// + /// That is to say, it returns + /// - `Some((left, right))`, where `left` and `right` are the indices of two consecutive edges + /// in `self` and `right == left + 1`, if `self[left] <= value < self[right]`; /// - `None`, otherwise. /// - /// # Example: + /// # Examples /// /// ``` /// use ndarray_stats::histogram::Edges; /// /// let edges = Edges::from(vec![0, 2, 3]); + /// // `1` is in the interval [0, 2), whose indices are (0, 1) /// assert_eq!( /// edges.indices_of(&1), /// Some((0, 1)) /// ); + /// // `5` is not in any of intervals /// assert_eq!( /// edges.indices_of(&5), /// None @@ -176,17 +228,17 @@ impl Edges { } } + /// Returns an iterator over the `edges` in `self`. pub fn iter(&self) -> impl Iterator { self.edges.iter() } } -/// `Bins` is a sorted collection of non-overlapping -/// 1-dimensional intervals. +/// A sorted collection of non-overlapping 1-dimensional intervals. /// -/// All intervals are left-inclusive and right-exclusive. +/// **Note** that all intervals are left-closed and right-open. /// -/// # Example: +/// # Examples /// /// ``` /// use ndarray_stats::histogram::{Edges, Bins}; @@ -211,16 +263,18 @@ pub struct Bins { } impl Bins { - /// Given a collection of [`Edges`], it returns the corresponding `Bins` instance. + /// Returns a `Bins` instance where each bin corresponds to two consecutive members of the given + /// [`Edges`], consuming the edges. /// /// [`Edges`]: struct.Edges.html + #[must_use] pub fn new(edges: Edges) -> Self { Bins { edges } } - /// Returns the number of bins. + /// Returns the number of bins in `self`. /// - /// # Example: + /// # Examples /// /// ``` /// use ndarray_stats::histogram::{Edges, Bins}; @@ -233,6 +287,7 @@ impl Bins { /// 2 /// ); /// ``` + #[must_use] pub fn len(&self) -> usize { match self.edges.len() { 0 => 0, @@ -240,11 +295,38 @@ impl Bins { } } - /// Given `value`, it returns: - /// - `Some(i)`, if the `i`-th bin in `self` contains `value`; - /// - `None`, if `value` does not belong to any of the bins in `self`. + /// Returns `true` if the number of bins is zero, i.e. if the number of edges is 0 or 1. /// - /// # Example: + /// # Examples + /// + /// ``` + /// use ndarray_stats::histogram::{Edges, Bins}; + /// use noisy_float::types::{N64, n64}; + /// + /// // At least 2 edges is needed to represent 1 interval + /// let edges = Edges::from(vec![n64(0.), n64(1.), n64(3.)]); + /// let bins = Bins::new(edges); + /// assert_eq!(bins.is_empty(), false); + /// + /// // No valid interval == Empty + /// let edges = Edges::::from(vec![]); + /// let bins = Bins::new(edges); + /// assert_eq!(bins.is_empty(), true); + /// let edges = Edges::from(vec![n64(0.)]); + /// let bins = Bins::new(edges); + /// assert_eq!(bins.is_empty(), true); + /// ``` + #[must_use] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the index of the bin in `self` that contains the given `value`, + /// or returns `None` if `value` does not belong to any bins in `self`. + /// + /// # Examples + /// + /// Basic usage: /// /// ``` /// use ndarray_stats::histogram::{Edges, Bins}; @@ -252,35 +334,51 @@ impl Bins { /// let edges = Edges::from(vec![0, 2, 4, 6]); /// let bins = Bins::new(edges); /// let value = 1; + /// // The first bin [0, 2) contains `1` /// assert_eq!( /// bins.index_of(&1), /// Some(0) /// ); + /// // No bin contains 100 + /// assert_eq!( + /// bins.index_of(&100), + /// None + /// ) + /// ``` + /// + /// Chaining [`Bins::index`] and [`Bins::index_of`] to get the boundaries of the bin containing + /// the value: + /// + /// ``` + /// # use ndarray_stats::histogram::{Edges, Bins}; + /// # let edges = Edges::from(vec![0, 2, 4, 6]); + /// # let bins = Bins::new(edges); + /// # let value = 1; /// assert_eq!( - /// bins.index(bins.index_of(&1).unwrap()), - /// 0..2 + /// // using `Option::map` to avoid panic on index out-of-bounds + /// bins.index_of(&1).map(|i| bins.index(i)), + /// Some(0..2) /// ); /// ``` pub fn index_of(&self, value: &A) -> Option { self.edges.indices_of(value).map(|t| t.0) } - /// Given `value`, it returns: - /// - `Some(left_edge..right_edge)`, if there exists a bin in `self` such that - /// `left_edge <= value < right_edge`; - /// - `None`, otherwise. + /// Returns a range as the bin which contains the given `value`, or returns `None` otherwise. /// - /// # Example: + /// # Examples /// /// ``` /// use ndarray_stats::histogram::{Edges, Bins}; /// /// let edges = Edges::from(vec![0, 2, 4, 6]); /// let bins = Bins::new(edges); + /// // [0, 2) contains `1` /// assert_eq!( /// bins.range_of(&1), /// Some(0..2) /// ); + /// // `10` is not in any interval /// assert_eq!( /// bins.range_of(&10), /// None @@ -297,11 +395,13 @@ impl Bins { }) } - /// Get the `i`-th bin. + /// Returns a range as the bin at the given `index` position. /// - /// **Panics** if `index` is out of bounds. + /// # Panics /// - /// # Example: + /// Panics if `index` is out of bounds. + /// + /// # Examples /// /// ``` /// use ndarray_stats::histogram::{Edges, Bins}; @@ -313,6 +413,7 @@ impl Bins { /// 5..10 /// ); /// ``` + #[must_use] pub fn index(&self, index: usize) -> Range where A: Clone, @@ -330,7 +431,7 @@ impl Bins { #[cfg(test)] mod edges_tests { - use super::*; + use super::{Array1, Edges}; use quickcheck_macros::quickcheck; use std::collections::BTreeSet; use std::iter::FromIterator; @@ -349,7 +450,7 @@ mod edges_tests { #[quickcheck] fn check_sorted_from_array(v: Vec) -> bool { - let a = Array1::from_vec(v); + let a = Array1::from(v); let edges = Edges::from(a); let n = edges.len(); for i in 1..n { @@ -361,10 +462,10 @@ mod edges_tests { } #[quickcheck] - fn edges_are_right_exclusive(v: Vec) -> bool { + fn edges_are_right_open(v: Vec) -> bool { let edges = Edges::from(v); let view = edges.as_array_view(); - if view.len() == 0 { + if view.is_empty() { true } else { let last = view[view.len() - 1]; @@ -373,23 +474,23 @@ mod edges_tests { } #[quickcheck] - fn edges_are_left_inclusive(v: Vec) -> bool { + fn edges_are_left_closed(v: Vec) -> bool { let edges = Edges::from(v); - match edges.len() { - 1 => true, - _ => { - let view = edges.as_array_view(); - if view.len() == 0 { - true - } else { - let first = view[0]; - edges.indices_of(&first).is_some() - } + if let 1 = edges.len() { + true + } else { + let view = edges.as_array_view(); + if view.is_empty() { + true + } else { + let first = view[0]; + edges.indices_of(&first).is_some() } } } #[quickcheck] + #[allow(clippy::needless_pass_by_value)] fn edges_are_deduped(v: Vec) -> bool { let unique_elements = BTreeSet::from_iter(v.iter()); let edges = Edges::from(v.clone()); @@ -401,11 +502,12 @@ mod edges_tests { #[cfg(test)] mod bins_tests { - use super::*; + use super::{Bins, Edges}; #[test] #[should_panic] - fn get_panics_for_out_of_bound_indexes() { + #[allow(unused_must_use)] + fn get_panics_for_out_of_bounds_indexes() { let edges = Edges::from(vec![0]); let bins = Bins::new(edges); // we need at least two edges to make a valid bin! diff --git a/src/histogram/grid.rs b/src/histogram/grid.rs index 66d8e9a9..57e85061 100644 --- a/src/histogram/grid.rs +++ b/src/histogram/grid.rs @@ -1,80 +1,104 @@ -use super::bins::Bins; -use super::errors::BinsBuildError; -use super::strategies::BinsBuildingStrategy; +#![warn(missing_docs, clippy::all, clippy::pedantic)] + +use super::{bins::Bins, errors::BinsBuildError, strategies::BinsBuildingStrategy}; use itertools::izip; use ndarray::{ArrayBase, Axis, Data, Ix1, Ix2}; use std::ops::Range; -/// A `Grid` is a partition of a rectangular region of an *n*-dimensional -/// space—e.g. [*a*0, *b*0) × ⋯ × [*a**n*−1, -/// *b**n*−1)—into a collection of rectangular *n*-dimensional bins. +/// An orthogonal partition of a rectangular region in an *n*-dimensional space, e.g. +/// [*a*0, *b*0) × ⋯ × [*a**n*−1, *b**n*−1), +/// represented as a collection of rectangular *n*-dimensional bins. +/// +/// The grid is **solely determined by the Cartesian product of its projections** on each coordinate +/// axis. Therefore, each element in the product set should correspond to a sub-region in the grid. +/// +/// For example, this partition can be represented as a `Grid` struct: /// -/// The grid is **fully determined by its 1-dimensional projections** on the -/// coordinate axes. For example, this is a partition that can be represented -/// as a `Grid` struct: /// ```text -/// +---+-------+-+ -/// | | | | -/// +---+-------+-+ -/// | | | | -/// | | | | -/// | | | | -/// | | | | -/// +---+-------+-+ +/// +/// g +---+-------+---+ +/// | 3 | 4 | 5 | +/// f +---+-------+---+ +/// | | | | +/// | 0 | 1 | 2 | +/// | | | | +/// e +---+-------+---+ +/// a b c d +/// +/// R0: [a, b) × [e, f) +/// R1: [b, c) × [e, f) +/// R2: [c, d) × [e, f) +/// R3: [a, b) × [f, g) +/// R4: [b, d) × [f, g) +/// R5: [c, d) × [f, g) +/// Grid: { [a, b), [b, c), [c, d) } × { [e, f), [f, g) } == { R0, R1, R2, R3, R4, R5 } /// ``` +/// /// while the next one can't: +/// /// ```text -/// +---+-------+-+ -/// | | | | -/// | +-------+-+ -/// | | | -/// | | | -/// | | | -/// | | | -/// +---+-------+-+ +/// g +---+-----+---+ +/// | | 2 | 3 | +/// (f) | +-----+---+ +/// | 0 | | +/// | | 1 | +/// | | | +/// e +---+-----+---+ +/// a b c d +/// +/// R0: [a, b) × [e, g) +/// R1: [b, d) × [e, f) +/// R2: [b, c) × [f, g) +/// R3: [c, d) × [f, g) +/// // 'f', as long as 'R1', 'R2', or 'R3', doesn't appear on LHS +/// // [b, c) × [e, g), [c, d) × [e, g) doesn't appear on RHS +/// Grid: { [a, b), [b, c), [c, d) } × { [e, g) } != { R0, R1, R2, R3 } /// ``` /// -/// # Example: +/// # Examples +/// +/// Basic usage, building a `Grid` via [`GridBuilder`], with optimal grid layout determined by +/// a given [`strategy`], and generating a [`histogram`]: /// /// ``` /// use ndarray::{Array, array}; -/// use ndarray_stats::{HistogramExt, -/// histogram::{Histogram, Grid, GridBuilder, -/// Edges, Bins, strategies::Auto}}; -/// use noisy_float::types::{N64, n64}; +/// use ndarray_stats::{ +/// histogram::{strategies::Auto, Bins, Edges, Grid, GridBuilder}, +/// HistogramExt, +/// }; /// -/// // 1-dimensional observations, as a (n_observations, 1) 2-d matrix +/// // 1-dimensional observations, as a (n_observations, n_dimension) 2-d matrix /// let observations = Array::from_shape_vec( /// (12, 1), /// vec![1, 4, 5, 2, 100, 20, 50, 65, 27, 40, 45, 23], /// ).unwrap(); /// -/// // The optimal grid layout is inferred from the data, -/// // specifying a strategy (Auto in this case) +/// // The optimal grid layout is inferred from the data, given a chosen strategy, Auto in this case /// let grid = GridBuilder::>::from_array(&observations).unwrap().build(); -/// let expected_grid = Grid::from(vec![Bins::new(Edges::from(vec![1, 20, 39, 58, 77, 96, 115]))]); -/// assert_eq!(grid, expected_grid); /// /// let histogram = observations.histogram(grid); /// /// let histogram_matrix = histogram.counts(); -/// // Bins are left inclusive, right exclusive! +/// // Bins are left-closed, right-open! /// let expected = array![4, 3, 3, 1, 0, 1]; /// assert_eq!(histogram_matrix, expected.into_dyn()); /// ``` +/// +/// [`histogram`]: trait.HistogramExt.html +/// [`GridBuilder`]: struct.GridBuilder.html +/// [`strategy`]: strategies/index.html #[derive(Clone, Debug, Eq, PartialEq)] pub struct Grid { projections: Vec>, } impl From>> for Grid { - /// Get a `Grid` instance from a `Vec>`. + /// Converts a `Vec>` into a `Grid`, consuming the vector of bins. /// - /// The `i`-th element in `Vec>` represents the 1-dimensional - /// projection of the bin grid on the `i`-th axis. + /// The `i`-th element in `Vec>` represents the projection of the bin grid onto the + /// `i`-th axis. /// - /// Alternatively, a `Grid` can be built directly from data using a - /// [`GridBuilder`]. + /// Alternatively, a `Grid` can be built directly from data using a [`GridBuilder`]. /// /// [`GridBuilder`]: struct.GridBuilder.html fn from(projections: Vec>) -> Self { @@ -83,27 +107,99 @@ impl From>> for Grid { } impl Grid { - /// Returns `n`, the number of dimensions of the region partitioned by the grid. + /// Returns the number of dimensions of the region partitioned by the grid. + /// + /// # Examples + /// + /// ``` + /// use ndarray_stats::histogram::{Edges, Bins, Grid}; + /// + /// let edges = Edges::from(vec![0, 1]); + /// let bins = Bins::new(edges); + /// let square_grid = Grid::from(vec![bins.clone(), bins.clone()]); + /// + /// assert_eq!(square_grid.ndim(), 2usize) + /// ``` + #[must_use] pub fn ndim(&self) -> usize { self.projections.len() } - /// Returns the number of bins along each coordinate axis. + /// Returns the numbers of bins along each coordinate axis. + /// + /// # Examples + /// + /// ``` + /// use ndarray_stats::histogram::{Edges, Bins, Grid}; + /// + /// let edges_x = Edges::from(vec![0, 1]); + /// let edges_y = Edges::from(vec![-1, 0, 1]); + /// let bins_x = Bins::new(edges_x); + /// let bins_y = Bins::new(edges_y); + /// let square_grid = Grid::from(vec![bins_x, bins_y]); + /// + /// assert_eq!(square_grid.shape(), vec![1usize, 2usize]); + /// ``` + #[must_use] pub fn shape(&self) -> Vec { - self.projections.iter().map(|e| e.len()).collect() + self.projections.iter().map(Bins::len).collect() } - /// Returns the grid projections on the coordinate axes as a slice of immutable references. + /// Returns the grid projections on each coordinate axis as a slice of immutable references. + #[must_use] pub fn projections(&self) -> &[Bins] { &self.projections } - /// Returns the index of the *n*-dimensional bin containing the point, if - /// one exists. + /// Returns an `n-dimensional` index, of bins along each axis that contains the point, if one + /// exists. /// /// Returns `None` if the point is outside the grid. /// - /// **Panics** if `point.len()` does not equal `self.ndim()`. + /// # Panics + /// + /// Panics if dimensionality of the point doesn't equal the grid's. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ``` + /// use ndarray::array; + /// use ndarray_stats::histogram::{Edges, Bins, Grid}; + /// use noisy_float::types::n64; + /// + /// let edges = Edges::from(vec![n64(-1.), n64(0.), n64(1.)]); + /// let bins = Bins::new(edges); + /// let square_grid = Grid::from(vec![bins.clone(), bins.clone()]); + /// + /// // (0., -0.7) falls in 1st and 0th bin respectively + /// assert_eq!( + /// square_grid.index_of(&array![n64(0.), n64(-0.7)]), + /// Some(vec![1, 0]), + /// ); + /// // Returns `None`, as `1.` is outside the grid since bins are right-open + /// assert_eq!( + /// square_grid.index_of(&array![n64(0.), n64(1.)]), + /// None, + /// ); + /// ``` + /// + /// A panic upon dimensionality mismatch: + /// + /// ```should_panic + /// # use ndarray::array; + /// # use ndarray_stats::histogram::{Edges, Bins, Grid}; + /// # use noisy_float::types::n64; + /// # let edges = Edges::from(vec![n64(-1.), n64(0.), n64(1.)]); + /// # let bins = Bins::new(edges); + /// # let square_grid = Grid::from(vec![bins.clone(), bins.clone()]); + /// // the point has 3 dimensions, the grid expected 2 dimensions + /// assert_eq!( + /// square_grid.index_of(&array![n64(0.), n64(-0.7), n64(0.5)]), + /// Some(vec![1, 0, 1]), + /// ); + /// ``` pub fn index_of(&self, point: &ArrayBase) -> Option> where S: Data, @@ -125,12 +221,54 @@ impl Grid { } impl Grid { - /// Given `i=(i_0, ..., i_{n-1})`, an `n`-dimensional index, it returns - /// `I_{i_0}x...xI_{i_{n-1}}`, an `n`-dimensional bin, where `I_{i_j}` is - /// the `i_j`-th interval on the `j`-th projection of the grid on the coordinate axes. + /// Given an `n`-dimensional index, `i = (i_0, ..., i_{n-1})`, returns an `n`-dimensional bin, + /// `I_{i_0} x ... x I_{i_{n-1}}`, where `I_{i_j}` is the `i_j`-th interval on the `j`-th + /// projection of the grid on the coordinate axes. + /// + /// # Panics + /// + /// Panics if at least one in the index, `(i_0, ..., i_{n-1})`, is out of bounds on the + /// corresponding coordinate axis, i.e. if there exists `j` s.t. + /// `i_j >= self.projections[j].len()`. + /// + /// # Examples + /// + /// Basic usage: + /// + /// ``` + /// use ndarray::array; + /// use ndarray_stats::histogram::{Edges, Bins, Grid}; + /// + /// let edges_x = Edges::from(vec![0, 1]); + /// let edges_y = Edges::from(vec![2, 3, 4]); + /// let bins_x = Bins::new(edges_x); + /// let bins_y = Bins::new(edges_y); + /// let square_grid = Grid::from(vec![bins_x, bins_y]); + /// + /// // Query the 0-th bin on x-axis, and 1-st bin on y-axis + /// assert_eq!( + /// square_grid.index(&[0, 1]), + /// vec![0..1, 3..4], + /// ); + /// ``` /// - /// **Panics** if at least one among `(i_0, ..., i_{n-1})` is out of bounds on the respective - /// coordinate axis - i.e. if there exists `j` such that `i_j >= self.projections[j].len()`. + /// A panic upon out-of-bounds: + /// + /// ```should_panic + /// # use ndarray::array; + /// # use ndarray_stats::histogram::{Edges, Bins, Grid}; + /// # let edges_x = Edges::from(vec![0, 1]); + /// # let edges_y = Edges::from(vec![2, 3, 4]); + /// # let bins_x = Bins::new(edges_x); + /// # let bins_y = Bins::new(edges_y); + /// # let square_grid = Grid::from(vec![bins_x, bins_y]); + /// // out-of-bound on y-axis + /// assert_eq!( + /// square_grid.index(&[0, 2]), + /// vec![0..1, 3..4], + /// ); + /// ``` + #[must_use] pub fn index(&self, index: &[usize]) -> Vec> { assert_eq!( index.len(), @@ -146,12 +284,34 @@ impl Grid { } } -/// `GridBuilder`, given a [`strategy`] and some observations, returns a [`Grid`] -/// instance for [`histogram`] computation. +/// A builder used to create [`Grid`] instances for [`histogram`] computations. +/// +/// # Examples +/// +/// Basic usage, creating a `Grid` with some observations and a given [`strategy`]: +/// +/// ``` +/// use ndarray::Array; +/// use ndarray_stats::histogram::{strategies::Auto, Bins, Edges, Grid, GridBuilder}; +/// +/// // 1-dimensional observations, as a (n_observations, n_dimension) 2-d matrix +/// let observations = Array::from_shape_vec( +/// (12, 1), +/// vec![1, 4, 5, 2, 100, 20, 50, 65, 27, 40, 45, 23], +/// ).unwrap(); +/// +/// // The optimal grid layout is inferred from the data, given a chosen strategy, Auto in this case +/// let grid = GridBuilder::>::from_array(&observations).unwrap().build(); +/// // Equivalently, build a Grid directly +/// let expected_grid = Grid::from(vec![Bins::new(Edges::from(vec![1, 20, 39, 58, 77, 96, 115]))]); +/// +/// assert_eq!(grid, expected_grid); +/// ``` /// /// [`Grid`]: struct.Grid.html /// [`histogram`]: trait.HistogramExt.html /// [`strategy`]: strategies/index.html +#[allow(clippy::module_name_repetitions)] pub struct GridBuilder { bin_builders: Vec, } @@ -161,15 +321,22 @@ where A: Ord, B: BinsBuildingStrategy, { - /// Given some observations in a 2-dimensional array with shape `(n_observations, n_dimension)` - /// it returns a `GridBuilder` instance that has learned the required parameter - /// to build a [`Grid`] according to the specified [`strategy`]. + /// Returns a `GridBuilder` for building a [`Grid`] with a given [`strategy`] and some + /// observations in a 2-dimensionalarray with shape `(n_observations, n_dimension)`. /// - /// It returns `Err` if it is not possible to build a [`Grid`] given + /// # Errors + /// + /// It returns [`BinsBuildError`] if it is not possible to build a [`Grid`] given /// the observed data according to the chosen [`strategy`]. /// + /// # Examples + /// + /// See [Trait-level examples] for basic usage. + /// /// [`Grid`]: struct.Grid.html /// [`strategy`]: strategies/index.html + /// [`BinsBuildError`]: errors/enum.BinsBuildError.html + /// [Trait-level examples]: struct.GridBuilder.html#examples pub fn from_array(array: &ArrayBase) -> Result where S: Data, @@ -181,12 +348,17 @@ where Ok(Self { bin_builders }) } - /// Returns a [`Grid`] instance, built accordingly to the specified [`strategy`] - /// using the parameters inferred from observations in [`from_array`]. + /// Returns a [`Grid`] instance, with building parameters infered in [`from_array`], according + /// to the specified [`strategy`] and observations provided. + /// + /// # Examples + /// + /// See [Trait-level examples] for basic usage. /// /// [`Grid`]: struct.Grid.html /// [`strategy`]: strategies/index.html /// [`from_array`]: #method.from_array.html + #[must_use] pub fn build(&self) -> Grid { let projections: Vec<_> = self.bin_builders.iter().map(|b| b.build()).collect(); Grid::from(projections) diff --git a/src/histogram/strategies.rs b/src/histogram/strategies.rs index 0892b311..a1522109 100644 --- a/src/histogram/strategies.rs +++ b/src/histogram/strategies.rs @@ -1,49 +1,78 @@ -//! Strategies to build [`Bins`]s and [`Grid`]s (using [`GridBuilder`]) inferring -//! optimal parameters directly from data. +//! Strategies used by [`GridBuilder`] to infer optimal parameters from data for building [`Bins`] +//! and [`Grid`] instances. //! //! The docs for each strategy have been taken almost verbatim from [`NumPy`]. //! -//! Each strategy specifies how to compute the optimal number of [`Bins`] or -//! the optimal bin width. -//! For those strategies that prescribe the optimal number -//! of [`Bins`] we then compute the optimal bin width with +//! Each strategy specifies how to compute the optimal number of [`Bins`] or the optimal bin width. +//! For those strategies that prescribe the optimal number of [`Bins`], the optimal bin width is +//! computed by `bin_width = (max - min)/n`. //! -//! `bin_width = (max - min)/n` +//! Since all bins are left-closed and right-open, it is guaranteed to add an extra bin to include +//! the maximum value from the given data when necessary, so that no data is discarded. //! -//! All our bins are left-inclusive and right-exclusive: we make sure to add an extra bin -//! if it is necessary to include the maximum value of the array that has been passed as argument -//! to the `from_array` method. +//! # Strategies //! +//! Currently, the following strategies are implemented: +//! +//! - [`Auto`]: Maximum of the [`Sturges`] and [`FreedmanDiaconis`] strategies. Provides good all +//! around performance. +//! - [`FreedmanDiaconis`]: Robust (resilient to outliers) strategy that takes into account data +//! variability and data size. +//! - [`Rice`]: A strategy that does not take variability into account, only data size. Commonly +//! overestimates number of bins required. +//! - [`Sqrt`]: Square root (of data size) strategy, used by Excel and other programs +//! for its speed and simplicity. +//! - [`Sturges`]: R’s default strategy, only accounts for data size. Only optimal for gaussian data +//! and underestimates number of bins for large non-gaussian datasets. +//! +//! # Notes +//! +//! In general, successful infererence on optimal bin width and number of bins relies on +//! **variability** of data. In other word, the provided ovservations should not be empty or +//! constant. +//! +//! In addition, [`Auto`] and [`FreedmanDiaconis`] requires the [`interquartile range (IQR)`][iqr], +//! i.e. the difference between upper and lower quartiles, to be positive. +//! +//! [`GridBuilder`]: ../struct.GridBuilder.html //! [`Bins`]: ../struct.Bins.html //! [`Grid`]: ../struct.Grid.html -//! [`GridBuilder`]: ../struct.GridBuilder.html //! [`NumPy`]: https://docs.scipy.org/doc/numpy/reference/generated/numpy.histogram_bin_edges.html#numpy.histogram_bin_edges -use super::super::interpolate::Nearest; -use super::super::{Quantile1dExt, QuantileExt}; -use super::errors::BinsBuildError; -use super::{Bins, Edges}; -use ndarray::prelude::*; -use ndarray::Data; +//! [`Auto`]: struct.Auto.html +//! [`Sturges`]: struct.Sturges.html +//! [`FreedmanDiaconis`]: struct.FreedmanDiaconis.html +//! [`Rice`]: struct.Rice.html +//! [`Sqrt`]: struct.Sqrt.html +//! [iqr]: https://www.wikiwand.com/en/Interquartile_range +#![warn(missing_docs, clippy::all, clippy::pedantic)] + +use crate::{ + histogram::{errors::BinsBuildError, Bins, Edges}, + quantile::{interpolate::Nearest, Quantile1dExt, QuantileExt}, +}; +use ndarray::{prelude::*, Data}; use noisy_float::types::n64; use num_traits::{FromPrimitive, NumOps, Zero}; -/// A trait implemented by all strategies to build [`Bins`] -/// with parameters inferred from observations. +/// A trait implemented by all strategies to build [`Bins`] with parameters inferred from +/// observations. /// -/// A `BinsBuildingStrategy` is required by [`GridBuilder`] -/// to know how to build a [`Grid`]'s projections on the +/// This is required by [`GridBuilder`] to know how to build a [`Grid`]'s projections on the /// coordinate axes. /// /// [`Bins`]: ../struct.Bins.html -/// [`Grid`]: ../struct.Grid.html /// [`GridBuilder`]: ../struct.GridBuilder.html +/// [`Grid`]: ../struct.Grid.html pub trait BinsBuildingStrategy { + #[allow(missing_docs)] type Elem: Ord; - /// Given some observations in a 1-dimensional array it returns a `BinsBuildingStrategy` - /// that has learned the required parameter to build a collection of [`Bins`]. + /// Returns a strategy that has learnt the required parameter fo building [`Bins`] for given + /// 1-dimensional array, or an `Err` if it is not possible to infer the required parameter + /// with the given data and specified strategy. /// - /// It returns `Err` if it is not possible to build a collection of - /// [`Bins`] given the observed data according to the chosen strategy. + /// # Errors + /// + /// See each of the struct-level documentation for details on errors an implementor may return. /// /// [`Bins`]: ../struct.Bins.html fn from_array(array: &ArrayBase) -> Result @@ -51,17 +80,12 @@ pub trait BinsBuildingStrategy { S: Data, Self: std::marker::Sized; - /// Returns a [`Bins`] instance, built accordingly to the parameters - /// inferred from observations in [`from_array`]. + /// Returns a [`Bins`] instance, according to parameters inferred from observations. /// /// [`Bins`]: ../struct.Bins.html - /// [`from_array`]: #method.from_array.html fn build(&self) -> Bins; - /// Returns the optimal number of bins, according to the parameters - /// inferred from observations in [`from_array`]. - /// - /// [`from_array`]: #method.from_array.html + /// Returns the optimal number of bins, according to parameters inferred from observations. fn n_bins(&self) -> usize; } @@ -72,12 +96,19 @@ struct EquiSpaced { max: T, } -/// Square root (of data size) strategy, used by Excel and other programs -/// for its speed and simplicity. +/// Square root (of data size) strategy, used by Excel and other programs for its speed and +/// simplicity. /// /// Let `n` be the number of observations. Then /// /// `n_bins` = `sqrt(n)` +/// +/// # Notes +/// +/// This strategy requires the data +/// +/// - not being empty +/// - not being constant #[derive(Debug)] pub struct Sqrt { builder: EquiSpaced, @@ -86,12 +117,19 @@ pub struct Sqrt { /// A strategy that does not take variability into account, only data size. Commonly /// overestimates number of bins required. /// -/// Let `n` be the number of observations and `n_bins` the number of bins. +/// Let `n` be the number of observations and `n_bins` be the number of bins. /// /// `n_bins` = 2`n`1/3 /// /// `n_bins` is only proportional to cube root of `n`. It tends to overestimate /// the `n_bins` and it does not take into account data variability. +/// +/// # Notes +/// +/// This strategy requires the data +/// +/// - not being empty +/// - not being constant #[derive(Debug)] pub struct Rice { builder: EquiSpaced, @@ -105,24 +143,38 @@ pub struct Rice { /// is too conservative for larger, non-normal datasets. /// /// This is the default method in R’s hist method. +/// +/// # Notes +/// +/// This strategy requires the data +/// +/// - not being empty +/// - not being constant #[derive(Debug)] pub struct Sturges { builder: EquiSpaced, } -/// Robust (resilient to outliers) strategy that takes into -/// account data variability and data size. +/// Robust (resilient to outliers) strategy that takes into account data variability and data size. /// /// Let `n` be the number of observations. /// -/// `bin_width` = 2×`IQR`×`n`−1/3 +/// `bin_width` = 2 × `IQR` × `n`−1/3 /// /// The bin width is proportional to the interquartile range ([`IQR`]) and inversely proportional to -/// cube root of `n`. It can be too conservative for small datasets, but it is quite good for -/// large datasets. +/// cube root of `n`. It can be too conservative for small datasets, but it is quite good for large +/// datasets. /// /// The [`IQR`] is very robust to outliers. /// +/// # Notes +/// +/// This strategy requires the data +/// +/// - not being empty +/// - not being constant +/// - having positive [`IQR`] +/// /// [`IQR`]: https://en.wikipedia.org/wiki/Interquartile_range #[derive(Debug)] pub struct FreedmanDiaconis { @@ -135,16 +187,25 @@ enum SturgesOrFD { FreedmanDiaconis(FreedmanDiaconis), } -/// Maximum of the [`Sturges`] and [`FreedmanDiaconis`] strategies. -/// Provides good all around performance. +/// Maximum of the [`Sturges`] and [`FreedmanDiaconis`] strategies. Provides good all around +/// performance. +/// +/// A compromise to get a good value. For small datasets the [`Sturges`] value will usually be +/// chosen, while larger datasets will usually default to [`FreedmanDiaconis`]. Avoids the overly +/// conservative behaviour of [`FreedmanDiaconis`] and [`Sturges`] for small and large datasets +/// respectively. /// -/// A compromise to get a good value. For small datasets the [`Sturges`] value will usually be chosen, -/// while larger datasets will usually default to [`FreedmanDiaconis`]. Avoids the overly -/// conservative behaviour of [`FreedmanDiaconis`] and [`Sturges`] for -/// small and large datasets respectively. +/// # Notes +/// +/// This strategy requires the data +/// +/// - not being empty +/// - not being constant +/// - having positive [`IQR`] /// /// [`Sturges`]: struct.Sturges.html /// [`FreedmanDiaconis`]: struct.FreedmanDiaconis.html +/// [`IQR`]: https://en.wikipedia.org/wiki/Interquartile_range #[derive(Debug)] pub struct Auto { builder: SturgesOrFD, @@ -171,7 +232,7 @@ where fn build(&self) -> Bins { let n_bins = self.n_bins(); let mut edges: Vec = vec![]; - for i in 0..(n_bins + 1) { + for i in 0..=n_bins { let edge = self.min.clone() + T::from_usize(i).unwrap() * self.bin_width.clone(); edges.push(edge); } @@ -185,7 +246,7 @@ where max_edge = max_edge + self.bin_width.clone(); n_bins += 1; } - return n_bins; + n_bins } fn bin_width(&self) -> T { @@ -207,6 +268,11 @@ where S: Data, { let n_elems = a.len(); + // casting `n_elems: usize` to `f64` may casus off-by-one error here if `n_elems` > 2 ^ 53, + // but it's not relevant here + #[allow(clippy::cast_precision_loss)] + // casting the rounded square root from `f64` to `usize` is safe + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] let n_bins = (n_elems as f64).sqrt().round() as usize; let min = a.min()?; let max = a.max()?; @@ -248,6 +314,11 @@ where S: Data, { let n_elems = a.len(); + // casting `n_elems: usize` to `f64` may casus off-by-one error here if `n_elems` > 2 ^ 53, + // but it's not relevant here + #[allow(clippy::cast_precision_loss)] + // casting the rounded cube root from `f64` to `usize` is safe + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] let n_bins = (2. * (n_elems as f64).powf(1. / 3.)).round() as usize; let min = a.min()?; let max = a.max()?; @@ -289,6 +360,11 @@ where S: Data, { let n_elems = a.len(); + // casting `n_elems: usize` to `f64` may casus off-by-one error here if `n_elems` > 2 ^ 53, + // but it's not relevant here + #[allow(clippy::cast_precision_loss)] + // casting the rounded base-2 log from `f64` to `usize` is safe + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] let n_bins = (n_elems as f64).log2().round() as usize + 1; let min = a.min()?; let max = a.max()?; @@ -360,9 +436,11 @@ where T: Ord + Clone + FromPrimitive + NumOps + Zero, { fn compute_bin_width(n_bins: usize, iqr: T) -> T { + // casting `n_bins: usize` to `f64` may casus off-by-one error here if `n_bins` > 2 ^ 53, + // but it's not relevant here + #[allow(clippy::cast_precision_loss)] let denominator = (n_bins as f64).powf(1. / 3.); - let bin_width = T::from_usize(2).unwrap() * iqr / T::from_f64(denominator).unwrap(); - bin_width + T::from_usize(2).unwrap() * iqr / T::from_f64(denominator).unwrap() } /// The bin width (or bin length) according to the fitted strategy. @@ -438,8 +516,8 @@ where } } -/// Given a range (max, min) and the number of bins, it returns -/// the associated bin_width: +/// Returns the `bin_width`, given the two end points of a range (`max`, `min`), and the number of +/// bins, consuming endpoints /// /// `bin_width = (max - min)/n` /// @@ -448,14 +526,13 @@ fn compute_bin_width(min: T, max: T, n_bins: usize) -> T where T: Ord + Clone + FromPrimitive + NumOps + Zero, { - let range = max.clone() - min.clone(); - let bin_width = range / T::from_usize(n_bins).unwrap(); - bin_width + let range = max - min; + range / T::from_usize(n_bins).unwrap() } #[cfg(test)] mod equispaced_tests { - use super::*; + use super::EquiSpaced; #[test] fn bin_width_has_to_be_positive() { @@ -470,7 +547,7 @@ mod equispaced_tests { #[cfg(test)] mod sqrt_tests { - use super::*; + use super::{BinsBuildingStrategy, Sqrt}; use ndarray::array; #[test] @@ -490,7 +567,7 @@ mod sqrt_tests { #[cfg(test)] mod rice_tests { - use super::*; + use super::{BinsBuildingStrategy, Rice}; use ndarray::array; #[test] @@ -510,7 +587,7 @@ mod rice_tests { #[cfg(test)] mod sturges_tests { - use super::*; + use super::{BinsBuildingStrategy, Sturges}; use ndarray::array; #[test] @@ -530,7 +607,7 @@ mod sturges_tests { #[cfg(test)] mod fd_tests { - use super::*; + use super::{BinsBuildingStrategy, FreedmanDiaconis}; use ndarray::array; #[test] @@ -559,7 +636,7 @@ mod fd_tests { #[cfg(test)] mod auto_tests { - use super::*; + use super::{Auto, BinsBuildingStrategy}; use ndarray::array; #[test] diff --git a/src/lib.rs b/src/lib.rs index 66577676..4ae11004 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ //! - [partitioning]; //! - [correlation analysis] (covariance, pearson correlation); //! - [measures from information theory] (entropy, KL divergence, etc.); +//! - [measures of deviation] (count equal, L1, L2 distances, mean squared err etc.) //! - [histogram computation]. //! //! Please feel free to contribute new functionality! A roadmap can be found [here]. @@ -15,19 +16,21 @@ //! [`NumPy`] (Python) and [`StatsBase.jl`] (Julia) - any contribution bringing us closer to //! feature parity is more than welcome! //! -//! [`ndarray-stats`]: https://github.com/jturner314/ndarray-stats/ +//! [`ndarray-stats`]: https://github.com/rust-ndarray/ndarray-stats/ //! [`ndarray`]: https://github.com/rust-ndarray/ndarray //! [order statistics]: trait.QuantileExt.html //! [partitioning]: trait.Sort1dExt.html //! [summary statistics]: trait.SummaryStatisticsExt.html //! [correlation analysis]: trait.CorrelationExt.html +//! [measures of deviation]: trait.DeviationExt.html //! [measures from information theory]: trait.EntropyExt.html //! [histogram computation]: histogram/index.html -//! [here]: https://github.com/jturner314/ndarray-stats/issues/1 +//! [here]: https://github.com/rust-ndarray/ndarray-stats/issues/1 //! [`NumPy`]: https://docs.scipy.org/doc/numpy-1.14.1/reference/routines.statistics.html //! [`StatsBase.jl`]: https://juliastats.github.io/StatsBase.jl/latest/ pub use crate::correlation::CorrelationExt; +pub use crate::deviation::DeviationExt; pub use crate::entropy::EntropyExt; pub use crate::histogram::HistogramExt; pub use crate::maybe_nan::{MaybeNan, MaybeNanExt}; @@ -35,6 +38,33 @@ pub use crate::quantile::{interpolate, Quantile1dExt, QuantileExt}; pub use crate::sort::Sort1dExt; pub use crate::summary_statistics::SummaryStatisticsExt; +#[cfg(test)] +#[macro_use] +extern crate approx; + +#[macro_use] +mod multi_input_error_macros { + macro_rules! return_err_if_empty { + ($arr:expr) => { + if $arr.len() == 0 { + return Err(MultiInputError::EmptyInput); + } + }; + } + macro_rules! return_err_unless_same_shape { + ($arr_a:expr, $arr_b:expr) => { + use crate::errors::{MultiInputError, ShapeMismatch}; + if $arr_a.shape() != $arr_b.shape() { + return Err(MultiInputError::ShapeMismatch(ShapeMismatch { + first_shape: $arr_a.shape().to_vec(), + second_shape: $arr_b.shape().to_vec(), + }) + .into()); + } + }; + } +} + #[macro_use] mod private { /// This is a public type in a private module, so it can be included in @@ -69,6 +99,7 @@ mod private { } mod correlation; +mod deviation; mod entropy; pub mod errors; pub mod histogram; diff --git a/src/maybe_nan/impl_not_none.rs b/src/maybe_nan/impl_not_none.rs index e8c2755b..2ab4f075 100644 --- a/src/maybe_nan/impl_not_none.rs +++ b/src/maybe_nan/impl_not_none.rs @@ -35,9 +35,6 @@ impl PartialEq for NotNone { fn eq(&self, other: &Self) -> bool { self.deref().eq(other) } - fn ne(&self, other: &Self) -> bool { - self.deref().eq(other) - } } impl Ord for NotNone { diff --git a/src/maybe_nan/mod.rs b/src/maybe_nan/mod.rs index 8b42b73a..02cce16d 100644 --- a/src/maybe_nan/mod.rs +++ b/src/maybe_nan/mod.rs @@ -1,6 +1,7 @@ use ndarray::prelude::*; use ndarray::{s, Data, DataMut, RemoveAxis}; use noisy_float::types::{N32, N64}; +use std::mem; /// A number type that can have not-a-number values. pub trait MaybeNan: Sized { @@ -43,7 +44,7 @@ pub trait MaybeNan: Sized { /// /// This modifies the input view by moving elements as necessary. fn remove_nan_mut(mut view: ArrayViewMut1<'_, A>) -> ArrayViewMut1<'_, A> { - if view.len() == 0 { + if view.is_empty() { return view.slice_move(s![..0]); } let mut i = 0; @@ -69,6 +70,42 @@ fn remove_nan_mut(mut view: ArrayViewMut1<'_, A>) -> ArrayViewMut1< } } +/// Casts a view from one element type to another. +/// +/// # Panics +/// +/// Panics if `T` and `U` differ in size or alignment. +/// +/// # Safety +/// +/// The caller must ensure that qll elements in `view` are valid values for type `U`. +unsafe fn cast_view_mut(mut view: ArrayViewMut1<'_, T>) -> ArrayViewMut1<'_, U> { + assert_eq!(mem::size_of::(), mem::size_of::()); + assert_eq!(mem::align_of::(), mem::align_of::()); + let ptr: *mut U = view.as_mut_ptr().cast(); + let len: usize = view.len_of(Axis(0)); + let stride: isize = view.stride_of(Axis(0)); + if len <= 1 { + // We can use a stride of `0` because the stride is irrelevant for the `len == 1` case. + let stride = 0; + ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr) + } else if stride >= 0 { + let stride = stride as usize; + ArrayViewMut1::from_shape_ptr([len].strides([stride]), ptr) + } else { + // At this point, stride < 0. We have to construct the view by using the inverse of the + // stride and then inverting the axis, since `ArrayViewMut::from_shape_ptr` requires the + // stride to be nonnegative. + let neg_stride = stride.checked_neg().unwrap() as usize; + // This is safe because `ndarray` guarantees that it's safe to offset the + // pointer anywhere in the array. + let neg_ptr = ptr.offset((len - 1) as isize * stride); + let mut v = ArrayViewMut1::from_shape_ptr([len].strides([neg_stride]), neg_ptr); + v.invert_axis(Axis(0)); + v + } +} + macro_rules! impl_maybenan_for_fxx { ($fxx:ident, $Nxx:ident) => { impl MaybeNan for $fxx { @@ -102,11 +139,9 @@ macro_rules! impl_maybenan_for_fxx { fn remove_nan_mut(view: ArrayViewMut1<'_, $fxx>) -> ArrayViewMut1<'_, $Nxx> { let not_nan = remove_nan_mut(view); - // This is safe because `remove_nan_mut` has removed the NaN - // values, and `$Nxx` is a thin wrapper around `$fxx`. - unsafe { - ArrayViewMut1::from_shape_ptr(not_nan.dim(), not_nan.as_ptr() as *mut $Nxx) - } + // This is safe because `remove_nan_mut` has removed the NaN values, and `$Nxx` is + // a thin wrapper around `$fxx`. + unsafe { cast_view_mut(not_nan) } } } }; @@ -332,7 +367,7 @@ where A: 'a, F: FnMut(&'a A::NotNan), { - self.visit(|elem| { + self.for_each(|elem| { if let Some(not_nan) = elem.try_as_not_nan() { f(not_nan) } diff --git a/src/quantile/mod.rs b/src/quantile/mod.rs index ba3f5356..3fea4a65 100644 --- a/src/quantile/mod.rs +++ b/src/quantile/mod.rs @@ -477,7 +477,7 @@ where let mut results = Array::from_elem(results_shape, data.first().unwrap().clone()); Zip::from(results.lanes_mut(axis)) .and(data.lanes_mut(axis)) - .apply(|mut results, mut data| { + .for_each(|mut results, mut data| { let index_map = get_many_from_sorted_mut_unchecked(&mut data, &searched_indexes); for (result, &q) in results.iter_mut().zip(qs) { diff --git a/src/sort.rs b/src/sort.rs index 56db3715..f43a95b1 100644 --- a/src/sort.rs +++ b/src/sort.rs @@ -116,7 +116,7 @@ where self[0].clone() } else { let mut rng = thread_rng(); - let pivot_index = rng.gen_range(0, n); + let pivot_index = rng.gen_range(0..n); let partition_index = self.partition_mut(pivot_index); if i < partition_index { self.slice_axis_mut(Axis(0), Slice::from(..partition_index)) @@ -251,7 +251,7 @@ fn _get_many_from_sorted_mut_unchecked( // We pick a random pivot index: the corresponding element is the pivot value let mut rng = thread_rng(); - let pivot_index = rng.gen_range(0, n); + let pivot_index = rng.gen_range(0..n); // We partition the array with respect to the pivot value. // The pivot value moves to `array_partition_index`. diff --git a/src/summary_statistics/means.rs b/src/summary_statistics/means.rs index 89d4df9d..d5226263 100644 --- a/src/summary_statistics/means.rs +++ b/src/summary_statistics/means.rs @@ -1,9 +1,9 @@ use super::SummaryStatisticsExt; -use crate::errors::EmptyInput; -use ndarray::{ArrayBase, Data, Dimension}; +use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch}; +use ndarray::{Array, ArrayBase, Axis, Data, Dimension, Ix1, RemoveAxis}; use num_integer::IterBinomial; use num_traits::{Float, FromPrimitive, Zero}; -use std::ops::{Add, Div}; +use std::ops::{Add, AddAssign, Div, Mul}; impl SummaryStatisticsExt for ArrayBase where @@ -24,18 +24,153 @@ where } } + fn weighted_mean(&self, weights: &Self) -> Result + where + A: Copy + Div + Mul + Zero, + { + return_err_if_empty!(self); + let weighted_sum = self.weighted_sum(weights)?; + Ok(weighted_sum / weights.sum()) + } + + fn weighted_sum(&self, weights: &ArrayBase) -> Result + where + A: Copy + Mul + Zero, + { + return_err_unless_same_shape!(self, weights); + Ok(self + .iter() + .zip(weights) + .fold(A::zero(), |acc, (&d, &w)| acc + d * w)) + } + + fn weighted_mean_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ) -> Result, MultiInputError> + where + A: Copy + Div + Mul + Zero, + D: RemoveAxis, + { + return_err_if_empty!(self); + let mut weighted_sum = self.weighted_sum_axis(axis, weights)?; + let weights_sum = weights.sum(); + weighted_sum.mapv_inplace(|v| v / weights_sum); + Ok(weighted_sum) + } + + fn weighted_sum_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ) -> Result, MultiInputError> + where + A: Copy + Mul + Zero, + D: RemoveAxis, + { + if self.shape()[axis.index()] != weights.len() { + return Err(MultiInputError::ShapeMismatch(ShapeMismatch { + first_shape: self.shape().to_vec(), + second_shape: weights.shape().to_vec(), + })); + } + + // We could use `lane.weighted_sum` here, but we're avoiding 2 + // conditions and an unwrap per lane. + Ok(self.map_axis(axis, |lane| { + lane.iter() + .zip(weights) + .fold(A::zero(), |acc, (&d, &w)| acc + d * w) + })) + } + fn harmonic_mean(&self) -> Result where A: Float + FromPrimitive, { - self.map(|x| x.recip()).mean().map(|x| x.recip()) + self.map(|x| x.recip()) + .mean() + .map(|x| x.recip()) + .ok_or(EmptyInput) } fn geometric_mean(&self) -> Result where A: Float + FromPrimitive, { - self.map(|x| x.ln()).mean().map(|x| x.exp()) + self.map(|x| x.ln()) + .mean() + .map(|x| x.exp()) + .ok_or(EmptyInput) + } + + fn weighted_var(&self, weights: &Self, ddof: A) -> Result + where + A: AddAssign + Float + FromPrimitive, + { + return_err_if_empty!(self); + return_err_unless_same_shape!(self, weights); + let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); + let one = A::from_usize(1).expect("Converting 1 to `A` must not fail."); + assert!( + !(ddof < zero || ddof > one), + "`ddof` must not be less than zero or greater than one", + ); + inner_weighted_var(self, weights, ddof, zero) + } + + fn weighted_std(&self, weights: &Self, ddof: A) -> Result + where + A: AddAssign + Float + FromPrimitive, + { + Ok(self.weighted_var(weights, ddof)?.sqrt()) + } + + fn weighted_var_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ddof: A, + ) -> Result, MultiInputError> + where + A: AddAssign + Float + FromPrimitive, + D: RemoveAxis, + { + return_err_if_empty!(self); + if self.shape()[axis.index()] != weights.len() { + return Err(MultiInputError::ShapeMismatch(ShapeMismatch { + first_shape: self.shape().to_vec(), + second_shape: weights.shape().to_vec(), + })); + } + let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail."); + let one = A::from_usize(1).expect("Converting 1 to `A` must not fail."); + assert!( + !(ddof < zero || ddof > one), + "`ddof` must not be less than zero or greater than one", + ); + + // `weights` must be a view because `lane` is a view in this context. + let weights = weights.view(); + Ok(self.map_axis(axis, |lane| { + inner_weighted_var(&lane, &weights, ddof, zero).unwrap() + })) + } + + fn weighted_std_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ddof: A, + ) -> Result, MultiInputError> + where + A: AddAssign + Float + FromPrimitive, + D: RemoveAxis, + { + Ok(self + .weighted_var_axis(axis, weights, ddof)? + .mapv_into(|x| x.sqrt())) } fn kurtosis(&self) -> Result @@ -109,6 +244,30 @@ where private_impl! {} } +/// Private function for `weighted_var` without conditions and asserts. +fn inner_weighted_var( + arr: &ArrayBase, + weights: &ArrayBase, + ddof: A, + zero: A, +) -> Result +where + S: Data, + A: AddAssign + Float + FromPrimitive, + D: Dimension, +{ + let mut weight_sum = zero; + let mut mean = zero; + let mut s = zero; + for (&x, &w) in arr.iter().zip(weights.iter()) { + weight_sum += w; + let x_minus_mean = x - mean; + mean += (w / weight_sum) * x_minus_mean; + s += w * x_minus_mean * (x - mean); + } + Ok(s / (weight_sum - ddof)) +} + /// Returns a vector containing all moments of the array elements up to /// *order*, where the *p*-th moment is defined as: /// @@ -129,7 +288,7 @@ where { let n_elements = A::from_usize(a.len()).expect("Converting number of elements to `A` must not fail"); - let order = order as i32; + let order = i32::from(order); // When k=0, we are raising each element to the 0th power // No need to waste CPU cycles going through the array @@ -184,171 +343,3 @@ where } result } - -#[cfg(test)] -mod tests { - use super::SummaryStatisticsExt; - use crate::errors::EmptyInput; - use approx::assert_abs_diff_eq; - use ndarray::{array, Array, Array1}; - use ndarray_rand::RandomExt; - use noisy_float::types::N64; - use rand::distributions::Uniform; - use std::f64; - - #[test] - fn test_means_with_nan_values() { - let a = array![f64::NAN, 1.]; - assert!(a.mean().unwrap().is_nan()); - assert!(a.harmonic_mean().unwrap().is_nan()); - assert!(a.geometric_mean().unwrap().is_nan()); - } - - #[test] - fn test_means_with_empty_array_of_floats() { - let a: Array1 = array![]; - assert_eq!(a.mean(), Err(EmptyInput)); - assert_eq!(a.harmonic_mean(), Err(EmptyInput)); - assert_eq!(a.geometric_mean(), Err(EmptyInput)); - } - - #[test] - fn test_means_with_empty_array_of_noisy_floats() { - let a: Array1 = array![]; - assert_eq!(a.mean(), Err(EmptyInput)); - assert_eq!(a.harmonic_mean(), Err(EmptyInput)); - assert_eq!(a.geometric_mean(), Err(EmptyInput)); - } - - #[test] - fn test_means_with_array_of_floats() { - let a: Array1 = array![ - 0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936, - 0.22879429, 0.43997224, 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487, - 0.78995021, 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, 0.34429457, - 0.88072369, 0.17638164, 0.60819363, 0.250392, 0.69912532, 0.78855523, 0.79140914, - 0.85084218, 0.31839879, 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153, - 0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, 0.63608897, 0.84959691, - 0.43599069, 0.77867775, 0.88267754, 0.83003623, 0.67016118, 0.67547638, 0.65220036, - 0.68043427 - ]; - // Computed using NumPy - let expected_mean = 0.5475494059146699; - // Computed using SciPy - let expected_harmonic_mean = 0.21790094950226022; - // Computed using SciPy - let expected_geometric_mean = 0.4345897639796527; - - assert_abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = 1e-9); - assert_abs_diff_eq!( - a.harmonic_mean().unwrap(), - expected_harmonic_mean, - epsilon = 1e-7 - ); - assert_abs_diff_eq!( - a.geometric_mean().unwrap(), - expected_geometric_mean, - epsilon = 1e-12 - ); - } - - #[test] - fn test_central_moment_with_empty_array_of_floats() { - let a: Array1 = array![]; - for order in 0..=3 { - assert_eq!(a.central_moment(order), Err(EmptyInput)); - assert_eq!(a.central_moments(order), Err(EmptyInput)); - } - } - - #[test] - fn test_zeroth_central_moment_is_one() { - let n = 50; - let bound: f64 = 200.; - let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); - assert_eq!(a.central_moment(0).unwrap(), 1.); - } - - #[test] - fn test_first_central_moment_is_zero() { - let n = 50; - let bound: f64 = 200.; - let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); - assert_eq!(a.central_moment(1).unwrap(), 0.); - } - - #[test] - fn test_central_moments() { - let a: Array1 = array![ - 0.07820559, 0.5026185, 0.80935324, 0.39384033, 0.9483038, 0.62516215, 0.90772261, - 0.87329831, 0.60267392, 0.2960298, 0.02810356, 0.31911966, 0.86705506, 0.96884832, - 0.2222465, 0.42162446, 0.99909868, 0.47619762, 0.91696979, 0.9972741, 0.09891734, - 0.76934818, 0.77566862, 0.7692585, 0.2235759, 0.44821286, 0.79732186, 0.04804275, - 0.87863238, 0.1111003, 0.6653943, 0.44386445, 0.2133176, 0.39397086, 0.4374617, - 0.95896624, 0.57850146, 0.29301706, 0.02329879, 0.2123203, 0.62005503, 0.996492, - 0.5342986, 0.97822099, 0.5028445, 0.6693834, 0.14256682, 0.52724704, 0.73482372, - 0.1809703, - ]; - // Computed using scipy.stats.moment - let expected_moments = vec![ - 1., - 0., - 0.09339920262960291, - -0.0026849636727735186, - 0.015403769257729755, - -0.001204176487006564, - 0.002976822584939186, - ]; - for (order, expected_moment) in expected_moments.iter().enumerate() { - assert_abs_diff_eq!( - a.central_moment(order as u16).unwrap(), - expected_moment, - epsilon = 1e-8 - ); - } - } - - #[test] - fn test_bulk_central_moments() { - // Test that the bulk method is coherent with the non-bulk method - let n = 50; - let bound: f64 = 200.; - let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); - let order = 10; - let central_moments = a.central_moments(order).unwrap(); - for i in 0..=order { - assert_eq!(a.central_moment(i).unwrap(), central_moments[i as usize]); - } - } - - #[test] - fn test_kurtosis_and_skewness_is_none_with_empty_array_of_floats() { - let a: Array1 = array![]; - assert_eq!(a.skewness(), Err(EmptyInput)); - assert_eq!(a.kurtosis(), Err(EmptyInput)); - } - - #[test] - fn test_kurtosis_and_skewness() { - let a: Array1 = array![ - 0.33310096, 0.98757449, 0.9789796, 0.96738114, 0.43545674, 0.06746873, 0.23706562, - 0.04241815, 0.38961714, 0.52421271, 0.93430327, 0.33911604, 0.05112372, 0.5013455, - 0.05291507, 0.62511183, 0.20749633, 0.22132433, 0.14734804, 0.51960608, 0.00449208, - 0.4093339, 0.2237519, 0.28070469, 0.7887231, 0.92224523, 0.43454188, 0.18335111, - 0.08646856, 0.87979847, 0.25483457, 0.99975627, 0.52712442, 0.41163279, 0.85162594, - 0.52618733, 0.75815023, 0.30640695, 0.14205781, 0.59695813, 0.851331, 0.39524328, - 0.73965373, 0.4007615, 0.02133069, 0.92899207, 0.79878191, 0.38947334, 0.22042183, - 0.77768353, - ]; - // Computed using scipy.stats.kurtosis(a, fisher=False) - let expected_kurtosis = 1.821933711687523; - // Computed using scipy.stats.skew - let expected_skewness = 0.2604785422878771; - - let kurtosis = a.kurtosis().unwrap(); - let skewness = a.skewness().unwrap(); - - assert_abs_diff_eq!(kurtosis, expected_kurtosis, epsilon = 1e-12); - assert_abs_diff_eq!(skewness, expected_skewness, epsilon = 1e-8); - } -} diff --git a/src/summary_statistics/mod.rs b/src/summary_statistics/mod.rs index 8351ff82..1f8fe000 100644 --- a/src/summary_statistics/mod.rs +++ b/src/summary_statistics/mod.rs @@ -1,8 +1,8 @@ //! Summary statistics (e.g. mean, variance, etc.). -use crate::errors::EmptyInput; -use ndarray::{Data, Dimension}; +use crate::errors::{EmptyInput, MultiInputError}; +use ndarray::{Array, ArrayBase, Axis, Data, Dimension, Ix1, RemoveAxis}; use num_traits::{Float, FromPrimitive, Zero}; -use std::ops::{Add, Div}; +use std::ops::{Add, AddAssign, Div, Mul}; /// Extension trait for `ArrayBase` providing methods /// to compute several summary statistics (e.g. mean, variance, etc.). @@ -28,6 +28,100 @@ where where A: Clone + FromPrimitive + Add + Div + Zero; + /// Returns the [`arithmetic weighted mean`] x̅ of all elements in the array. Use `weighted_sum` + /// if the `weights` are normalized (they sum up to 1.0). + /// + /// ```text + /// n + /// ∑ wᵢxᵢ + /// i=1 + /// x̅ = ――――――――― + /// n + /// ∑ wᵢ + /// i=1 + /// ``` + /// + /// **Panics** if division by zero panics for type A. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` and `weights` don't have the same shape + /// + /// [`arithmetic weighted mean`] https://en.wikipedia.org/wiki/Weighted_arithmetic_mean + fn weighted_mean(&self, weights: &Self) -> Result + where + A: Copy + Div + Mul + Zero; + + /// Returns the weighted sum of all elements in the array, that is, the dot product of the + /// arrays `self` and `weights`. Equivalent to `weighted_mean` if the `weights` are normalized. + /// + /// ```text + /// n + /// x̅ = ∑ wᵢxᵢ + /// i=1 + /// ``` + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::ShapeMismatch` if `self` and `weights` don't have the same shape + fn weighted_sum(&self, weights: &Self) -> Result + where + A: Copy + Mul + Zero; + + /// Returns the [`arithmetic weighted mean`] x̅ along `axis`. Use `weighted_mean_axis ` if the + /// `weights` are normalized. + /// + /// ```text + /// n + /// ∑ wᵢxᵢ + /// i=1 + /// x̅ = ――――――――― + /// n + /// ∑ wᵢ + /// i=1 + /// ``` + /// + /// **Panics** if `axis` is out of bounds. + /// + /// The following **errors** may be returned: + /// + /// * `MultiInputError::EmptyInput` if `self` is empty + /// * `MultiInputError::ShapeMismatch` if `self` length along axis is not equal to `weights` length + /// + /// [`arithmetic weighted mean`] https://en.wikipedia.org/wiki/Weighted_arithmetic_mean + fn weighted_mean_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ) -> Result, MultiInputError> + where + A: Copy + Div + Mul + Zero, + D: RemoveAxis; + + /// Returns the weighted sum along `axis`, that is, the dot product of `weights` and each lane + /// of `self` along `axis`. Equivalent to `weighted_mean_axis` if the `weights` are normalized. + /// + /// ```text + /// n + /// x̅ = ∑ wᵢxᵢ + /// i=1 + /// ``` + /// + /// **Panics** if `axis` is out of bounds. + /// + /// The following **errors** may be returned + /// + /// * `MultiInputError::ShapeMismatch` if `self` and `weights` don't have the same shape + fn weighted_sum_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ) -> Result, MultiInputError> + where + A: Copy + Mul + Zero, + D: RemoveAxis; + /// Returns the [`harmonic mean`] `HM(X)` of all elements in the array: /// /// ```text @@ -62,6 +156,82 @@ where where A: Float + FromPrimitive; + /// Return weighted variance of all elements in the array. + /// + /// The weighted variance is computed using the [`West, D. H. D.`] incremental algorithm. + /// Equivalent to `var_axis` if the `weights` are normalized. + /// + /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the + /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. + /// + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, + /// or if `A::from_usize()` fails for zero or one. + /// + /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm + fn weighted_var(&self, weights: &Self, ddof: A) -> Result + where + A: AddAssign + Float + FromPrimitive; + + /// Return weighted standard deviation of all elements in the array. + /// + /// The weighted standard deviation is computed using the [`West, D. H. D.`] incremental + /// algorithm. Equivalent to `var_axis` if the `weights` are normalized. + /// + /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the + /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. + /// + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, + /// or if `A::from_usize()` fails for zero or one. + /// + /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm + fn weighted_std(&self, weights: &Self, ddof: A) -> Result + where + A: AddAssign + Float + FromPrimitive; + + /// Return weighted variance along `axis`. + /// + /// The weighted variance is computed using the [`West, D. H. D.`] incremental algorithm. + /// Equivalent to `var_axis` if the `weights` are normalized. + /// + /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the + /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. + /// + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, + /// or if `A::from_usize()` fails for zero or one. + /// + /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm + fn weighted_var_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ddof: A, + ) -> Result, MultiInputError> + where + A: AddAssign + Float + FromPrimitive, + D: RemoveAxis; + + /// Return weighted standard deviation along `axis`. + /// + /// The weighted standard deviation is computed using the [`West, D. H. D.`] incremental + /// algorithm. Equivalent to `var_axis` if the `weights` are normalized. + /// + /// The parameter `ddof` specifies the "delta degrees of freedom". For example, to calculate the + /// population variance, use `ddof = 0`, or to calculate the sample variance, use `ddof = 1`. + /// + /// **Panics** if `ddof` is less than zero or greater than one, or if `axis` is out of bounds, + /// or if `A::from_usize()` fails for zero or one. + /// + /// [`West, D. H. D.`]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_incremental_algorithm + fn weighted_std_axis( + &self, + axis: Axis, + weights: &ArrayBase, + ddof: A, + ) -> Result, MultiInputError> + where + A: AddAssign + Float + FromPrimitive, + D: RemoveAxis; + /// Returns the [kurtosis] `Kurt[X]` of all elements in the array: /// /// ```text diff --git a/tests/deviation.rs b/tests/deviation.rs new file mode 100644 index 00000000..54eb7af6 --- /dev/null +++ b/tests/deviation.rs @@ -0,0 +1,270 @@ +use ndarray_stats::errors::{MultiInputError, ShapeMismatch}; +use ndarray_stats::DeviationExt; + +use approx::assert_abs_diff_eq; +use ndarray::{array, Array1}; +use num_bigint::BigInt; +use num_traits::Float; + +use std::f64; + +#[test] +fn test_count_eq() -> Result<(), MultiInputError> { + let a = array![0., 0.]; + let b = array![1., 0.]; + let c = array![0., 1.]; + let d = array![1., 1.]; + + assert_eq!(a.count_eq(&a)?, 2); + assert_eq!(a.count_eq(&b)?, 1); + assert_eq!(a.count_eq(&c)?, 1); + assert_eq!(a.count_eq(&d)?, 0); + + Ok(()) +} + +#[test] +fn test_count_neq() -> Result<(), MultiInputError> { + let a = array![0., 0.]; + let b = array![1., 0.]; + let c = array![0., 1.]; + let d = array![1., 1.]; + + assert_eq!(a.count_neq(&a)?, 0); + assert_eq!(a.count_neq(&b)?, 1); + assert_eq!(a.count_neq(&c)?, 1); + assert_eq!(a.count_neq(&d)?, 2); + + Ok(()) +} + +#[test] +fn test_sq_l2_dist() -> Result<(), MultiInputError> { + let a = array![0., 1., 4., 2.]; + let b = array![1., 1., 2., 4.]; + + assert_eq!(a.sq_l2_dist(&b)?, 9.); + + Ok(()) +} + +#[test] +fn test_l2_dist() -> Result<(), MultiInputError> { + let a = array![0., 1., 4., 2.]; + let b = array![1., 1., 2., 4.]; + + assert_eq!(a.l2_dist(&b)?, 3.); + + Ok(()) +} + +#[test] +fn test_l1_dist() -> Result<(), MultiInputError> { + let a = array![0., 1., 4., 2.]; + let b = array![1., 1., 2., 4.]; + + assert_eq!(a.l1_dist(&b)?, 5.); + + Ok(()) +} + +#[test] +fn test_linf_dist() -> Result<(), MultiInputError> { + let a = array![0., 0.]; + let b = array![1., 0.]; + let c = array![1., 2.]; + + assert_eq!(a.linf_dist(&a)?, 0.); + + assert_eq!(a.linf_dist(&b)?, 1.); + assert_eq!(b.linf_dist(&a)?, 1.); + + assert_eq!(a.linf_dist(&c)?, 2.); + assert_eq!(c.linf_dist(&a)?, 2.); + + Ok(()) +} + +#[test] +fn test_mean_abs_err() -> Result<(), MultiInputError> { + let a = array![1., 1.]; + let b = array![3., 5.]; + + assert_eq!(a.mean_abs_err(&a)?, 0.); + assert_eq!(a.mean_abs_err(&b)?, 3.); + assert_eq!(b.mean_abs_err(&a)?, 3.); + + Ok(()) +} + +#[test] +fn test_mean_sq_err() -> Result<(), MultiInputError> { + let a = array![1., 1.]; + let b = array![3., 5.]; + + assert_eq!(a.mean_sq_err(&a)?, 0.); + assert_eq!(a.mean_sq_err(&b)?, 10.); + assert_eq!(b.mean_sq_err(&a)?, 10.); + + Ok(()) +} + +#[test] +fn test_root_mean_sq_err() -> Result<(), MultiInputError> { + let a = array![1., 1.]; + let b = array![3., 5.]; + + assert_eq!(a.root_mean_sq_err(&a)?, 0.); + assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 10.0.sqrt()); + assert_abs_diff_eq!(b.root_mean_sq_err(&a)?, 10.0.sqrt()); + + Ok(()) +} + +#[test] +fn test_peak_signal_to_noise_ratio() -> Result<(), MultiInputError> { + let a = array![1., 1.]; + assert!(a.peak_signal_to_noise_ratio(&a, 1.)?.is_infinite()); + + let a = array![1., 2., 3., 4., 5., 6., 7.]; + let b = array![1., 3., 3., 4., 6., 7., 8.]; + let maxv = 8.; + let expected = 20. * Float::log10(maxv) - 10. * Float::log10(a.mean_sq_err(&b)?); + let actual = a.peak_signal_to_noise_ratio(&b, maxv)?; + + assert_abs_diff_eq!(actual, expected); + + Ok(()) +} + +#[test] +fn test_deviations_with_n_by_m_ints() -> Result<(), MultiInputError> { + let a = array![[0, 1], [4, 2]]; + let b = array![[1, 1], [2, 4]]; + + assert_eq!(a.count_eq(&a)?, 4); + assert_eq!(a.count_neq(&a)?, 0); + + assert_eq!(a.sq_l2_dist(&b)?, 9); + assert_eq!(a.l2_dist(&b)?, 3.); + assert_eq!(a.l1_dist(&b)?, 5); + assert_eq!(a.linf_dist(&b)?, 2); + + assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25); + assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25); + assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5); + assert_abs_diff_eq!(a.peak_signal_to_noise_ratio(&b, 4)?, 8.519374645445623); + + Ok(()) +} + +#[test] +fn test_deviations_with_empty_receiver() { + let a: Array1 = array![]; + let b: Array1 = array![1.]; + + assert_eq!(a.count_eq(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.count_neq(&b), Err(MultiInputError::EmptyInput)); + + assert_eq!(a.sq_l2_dist(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.l2_dist(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.l1_dist(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.linf_dist(&b), Err(MultiInputError::EmptyInput)); + + assert_eq!(a.mean_abs_err(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.mean_sq_err(&b), Err(MultiInputError::EmptyInput)); + assert_eq!(a.root_mean_sq_err(&b), Err(MultiInputError::EmptyInput)); + assert_eq!( + a.peak_signal_to_noise_ratio(&b, 0.), + Err(MultiInputError::EmptyInput) + ); +} + +#[test] +fn test_deviations_do_not_panic_if_nans() -> Result<(), MultiInputError> { + let a: Array1 = array![1., f64::NAN, 3., f64::NAN]; + let b: Array1 = array![1., f64::NAN, 3., 4.]; + + assert_eq!(a.count_eq(&b)?, 2); + assert_eq!(a.count_neq(&b)?, 2); + + assert!(a.sq_l2_dist(&b)?.is_nan()); + assert!(a.l2_dist(&b)?.is_nan()); + assert!(a.l1_dist(&b)?.is_nan()); + assert_eq!(a.linf_dist(&b)?, 0.); + + assert!(a.mean_abs_err(&b)?.is_nan()); + assert!(a.mean_sq_err(&b)?.is_nan()); + assert!(a.root_mean_sq_err(&b)?.is_nan()); + assert!(a.peak_signal_to_noise_ratio(&b, 0.)?.is_nan()); + + Ok(()) +} + +#[test] +fn test_deviations_with_empty_argument() { + let a: Array1 = array![1.]; + let b: Array1 = array![]; + + let shape_mismatch_err = MultiInputError::ShapeMismatch(ShapeMismatch { + first_shape: a.shape().to_vec(), + second_shape: b.shape().to_vec(), + }); + let expected_err_usize = Err(shape_mismatch_err.clone()); + let expected_err_f64 = Err(shape_mismatch_err); + + assert_eq!(a.count_eq(&b), expected_err_usize); + assert_eq!(a.count_neq(&b), expected_err_usize); + + assert_eq!(a.sq_l2_dist(&b), expected_err_f64); + assert_eq!(a.l2_dist(&b), expected_err_f64); + assert_eq!(a.l1_dist(&b), expected_err_f64); + assert_eq!(a.linf_dist(&b), expected_err_f64); + + assert_eq!(a.mean_abs_err(&b), expected_err_f64); + assert_eq!(a.mean_sq_err(&b), expected_err_f64); + assert_eq!(a.root_mean_sq_err(&b), expected_err_f64); + assert_eq!(a.peak_signal_to_noise_ratio(&b, 0.), expected_err_f64); +} + +#[test] +fn test_deviations_with_non_copyable() -> Result<(), MultiInputError> { + let a: Array1 = array![0.into(), 1.into(), 4.into(), 2.into()]; + let b: Array1 = array![1.into(), 1.into(), 2.into(), 4.into()]; + + assert_eq!(a.count_eq(&a)?, 4); + assert_eq!(a.count_neq(&a)?, 0); + + assert_eq!(a.sq_l2_dist(&b)?, 9.into()); + assert_eq!(a.l2_dist(&b)?, 3.); + assert_eq!(a.l1_dist(&b)?, 5.into()); + assert_eq!(a.linf_dist(&b)?, 2.into()); + + assert_abs_diff_eq!(a.mean_abs_err(&b)?, 1.25); + assert_abs_diff_eq!(a.mean_sq_err(&b)?, 2.25); + assert_abs_diff_eq!(a.root_mean_sq_err(&b)?, 1.5); + assert_abs_diff_eq!( + a.peak_signal_to_noise_ratio(&b, 4.into())?, + 8.519374645445623 + ); + + Ok(()) +} + +#[test] +fn test_deviation_computation_for_mixed_ownership() { + // It's enough to check that the code compiles! + let a = array![0., 0.]; + let b = array![1., 0.]; + + let _ = a.count_eq(&b.view()); + let _ = a.count_neq(&b.view()); + let _ = a.l2_dist(&b.view()); + let _ = a.sq_l2_dist(&b.view()); + let _ = a.l1_dist(&b.view()); + let _ = a.linf_dist(&b.view()); + let _ = a.mean_abs_err(&b.view()); + let _ = a.mean_sq_err(&b.view()); + let _ = a.root_mean_sq_err(&b.view()); + let _ = a.peak_signal_to_noise_ratio(&b.view(), 10.); +} diff --git a/tests/maybe_nan.rs b/tests/maybe_nan.rs new file mode 100644 index 00000000..074352ae --- /dev/null +++ b/tests/maybe_nan.rs @@ -0,0 +1,31 @@ +use ndarray::prelude::*; +use ndarray_stats::MaybeNan; +use noisy_float::types::{n64, N64}; + +#[test] +fn remove_nan_mut_nonstandard_layout() { + fn eq_unordered(mut a: Vec, mut b: Vec) -> bool { + a.sort(); + b.sort(); + a == b + } + let a = aview1(&[1., 2., f64::NAN, f64::NAN, 3., f64::NAN, 4., 5.]); + { + let mut a = a.to_owned(); + let v = f64::remove_nan_mut(a.slice_mut(s![..;2])); + assert!(eq_unordered(v.to_vec(), vec![n64(1.), n64(3.), n64(4.)])); + } + { + let mut a = a.to_owned(); + let v = f64::remove_nan_mut(a.slice_mut(s![..;-1])); + assert!(eq_unordered( + v.to_vec(), + vec![n64(5.), n64(4.), n64(3.), n64(2.), n64(1.)], + )); + } + { + let mut a = a.to_owned(); + let v = f64::remove_nan_mut(a.slice_mut(s![..;-2])); + assert!(eq_unordered(v.to_vec(), vec![n64(5.), n64(2.)])); + } +} diff --git a/tests/quantile.rs b/tests/quantile.rs index 4b312a07..9d58071f 100644 --- a/tests/quantile.rs +++ b/tests/quantile.rs @@ -279,7 +279,7 @@ fn test_midpoint_overflow() { #[quickcheck] fn test_quantiles_mut(xs: Vec) -> bool { - let v = Array::from_vec(xs.clone()); + let v = Array::from(xs.clone()); // Unordered list of quantile indexes to look up, with a duplicate let quantile_indexes = Array::from(vec![ diff --git a/tests/sort.rs b/tests/sort.rs index af2717c4..b2bd12f1 100644 --- a/tests/sort.rs +++ b/tests/sort.rs @@ -49,7 +49,7 @@ fn test_sorted_get_many_mut(mut xs: Vec) -> bool { if n == 0 { true } else { - let mut v = Array::from_vec(xs.clone()); + let mut v = Array::from(xs.clone()); // Insert each index twice, to get a set of indexes with duplicates, not sorted let mut indexes: Vec = (0..n).into_iter().collect(); @@ -78,7 +78,7 @@ fn test_sorted_get_mut_as_sorting_algorithm(mut xs: Vec) -> bool { if n == 0 { true } else { - let mut v = Array::from_vec(xs.clone()); + let mut v = Array::from(xs.clone()); let sorted_v: Vec<_> = (0..n).map(|i| v.get_from_sorted_mut(i)).collect(); xs.sort(); xs == sorted_v diff --git a/tests/summary_statistics.rs b/tests/summary_statistics.rs new file mode 100644 index 00000000..5269e332 --- /dev/null +++ b/tests/summary_statistics.rs @@ -0,0 +1,414 @@ +use approx::{abs_diff_eq, assert_abs_diff_eq}; +use ndarray::{arr0, array, Array, Array1, Array2, Axis}; +use ndarray_rand::rand_distr::Uniform; +use ndarray_rand::RandomExt; +use ndarray_stats::{ + errors::{EmptyInput, MultiInputError, ShapeMismatch}, + SummaryStatisticsExt, +}; +use noisy_float::types::N64; +use quickcheck::{quickcheck, TestResult}; +use std::f64; + +#[test] +fn test_with_nan_values() { + let a = array![f64::NAN, 1.]; + let weights = array![1.0, f64::NAN]; + assert!(a.mean().unwrap().is_nan()); + assert!(a.weighted_mean(&weights).unwrap().is_nan()); + assert!(a.weighted_sum(&weights).unwrap().is_nan()); + assert!(a + .weighted_mean_axis(Axis(0), &weights) + .unwrap() + .into_scalar() + .is_nan()); + assert!(a + .weighted_sum_axis(Axis(0), &weights) + .unwrap() + .into_scalar() + .is_nan()); + assert!(a.harmonic_mean().unwrap().is_nan()); + assert!(a.geometric_mean().unwrap().is_nan()); + assert!(a.weighted_var(&weights, 0.0).unwrap().is_nan()); + assert!(a.weighted_std(&weights, 0.0).unwrap().is_nan()); + assert!(a + .weighted_var_axis(Axis(0), &weights, 0.0) + .unwrap() + .into_scalar() + .is_nan()); + assert!(a + .weighted_std_axis(Axis(0), &weights, 0.0) + .unwrap() + .into_scalar() + .is_nan()); +} + +#[test] +fn test_with_empty_array_of_floats() { + let a: Array1 = array![]; + let weights = array![1.0]; + assert_eq!(a.mean(), None); + assert_eq!(a.weighted_mean(&weights), Err(MultiInputError::EmptyInput)); + assert_eq!( + a.weighted_mean_axis(Axis(0), &weights), + Err(MultiInputError::EmptyInput) + ); + assert_eq!(a.harmonic_mean(), Err(EmptyInput)); + assert_eq!(a.geometric_mean(), Err(EmptyInput)); + assert_eq!( + a.weighted_var(&weights, 0.0), + Err(MultiInputError::EmptyInput) + ); + assert_eq!( + a.weighted_std(&weights, 0.0), + Err(MultiInputError::EmptyInput) + ); + assert_eq!( + a.weighted_var_axis(Axis(0), &weights, 0.0), + Err(MultiInputError::EmptyInput) + ); + assert_eq!( + a.weighted_std_axis(Axis(0), &weights, 0.0), + Err(MultiInputError::EmptyInput) + ); + + // The sum methods accept empty arrays + assert_eq!(a.weighted_sum(&array![]), Ok(0.0)); + assert_eq!(a.weighted_sum_axis(Axis(0), &array![]), Ok(arr0(0.0))); +} + +#[test] +fn test_with_empty_array_of_noisy_floats() { + let a: Array1 = array![]; + let weights = array![]; + assert_eq!(a.mean(), None); + assert_eq!(a.weighted_mean(&weights), Err(MultiInputError::EmptyInput)); + assert_eq!( + a.weighted_mean_axis(Axis(0), &weights), + Err(MultiInputError::EmptyInput) + ); + assert_eq!(a.harmonic_mean(), Err(EmptyInput)); + assert_eq!(a.geometric_mean(), Err(EmptyInput)); + assert_eq!( + a.weighted_var(&weights, N64::new(0.0)), + Err(MultiInputError::EmptyInput) + ); + assert_eq!( + a.weighted_std(&weights, N64::new(0.0)), + Err(MultiInputError::EmptyInput) + ); + assert_eq!( + a.weighted_var_axis(Axis(0), &weights, N64::new(0.0)), + Err(MultiInputError::EmptyInput) + ); + assert_eq!( + a.weighted_std_axis(Axis(0), &weights, N64::new(0.0)), + Err(MultiInputError::EmptyInput) + ); + + // The sum methods accept empty arrays + assert_eq!(a.weighted_sum(&weights), Ok(N64::new(0.0))); + assert_eq!( + a.weighted_sum_axis(Axis(0), &weights), + Ok(arr0(N64::new(0.0))) + ); +} + +#[test] +fn test_with_array_of_floats() { + let a: Array1 = array![ + 0.99889651, 0.0150731, 0.28492482, 0.83819218, 0.48413156, 0.80710412, 0.41762936, + 0.22879429, 0.43997224, 0.23831807, 0.02416466, 0.6269962, 0.47420614, 0.56275487, + 0.78995021, 0.16060581, 0.64635041, 0.34876609, 0.78543249, 0.19938356, 0.34429457, + 0.88072369, 0.17638164, 0.60819363, 0.250392, 0.69912532, 0.78855523, 0.79140914, + 0.85084218, 0.31839879, 0.63381769, 0.22421048, 0.70760302, 0.99216018, 0.80199153, + 0.19239188, 0.61356023, 0.31505352, 0.06120481, 0.66417377, 0.63608897, 0.84959691, + 0.43599069, 0.77867775, 0.88267754, 0.83003623, 0.67016118, 0.67547638, 0.65220036, + 0.68043427 + ]; + // Computed using NumPy + let expected_mean = 0.5475494059146699; + let expected_weighted_mean = 0.6782420496397121; + let expected_weighted_var = 0.04306695637838332; + // Computed using SciPy + let expected_harmonic_mean = 0.21790094950226022; + let expected_geometric_mean = 0.4345897639796527; + + assert_abs_diff_eq!(a.mean().unwrap(), expected_mean, epsilon = 1e-9); + assert_abs_diff_eq!( + a.harmonic_mean().unwrap(), + expected_harmonic_mean, + epsilon = 1e-7 + ); + assert_abs_diff_eq!( + a.geometric_mean().unwrap(), + expected_geometric_mean, + epsilon = 1e-12 + ); + + // Input array used as weights, normalized + let weights = &a / a.sum(); + assert_abs_diff_eq!( + a.weighted_sum(&weights).unwrap(), + expected_weighted_mean, + epsilon = 1e-12 + ); + assert_abs_diff_eq!( + a.weighted_var(&weights, 0.0).unwrap(), + expected_weighted_var, + epsilon = 1e-12 + ); + assert_abs_diff_eq!( + a.weighted_std(&weights, 0.0).unwrap(), + expected_weighted_var.sqrt(), + epsilon = 1e-12 + ); + + let data = a.into_shape_with_order((2, 5, 5)).unwrap(); + let weights = array![0.1, 0.5, 0.25, 0.15, 0.2]; + assert_abs_diff_eq!( + data.weighted_mean_axis(Axis(1), &weights).unwrap(), + array![ + [0.50202721, 0.53347361, 0.29086033, 0.56995637, 0.37087139], + [0.58028328, 0.50485216, 0.59349973, 0.70308937, 0.72280630] + ], + epsilon = 1e-8 + ); + assert_abs_diff_eq!( + data.weighted_mean_axis(Axis(2), &weights).unwrap(), + array![ + [0.33434378, 0.38365259, 0.56405781, 0.48676574, 0.55016179], + [0.71112376, 0.55134174, 0.45566513, 0.74228516, 0.68405851] + ], + epsilon = 1e-8 + ); + assert_abs_diff_eq!( + data.weighted_sum_axis(Axis(1), &weights).unwrap(), + array![ + [0.60243266, 0.64016833, 0.34903240, 0.68394765, 0.44504567], + [0.69633993, 0.60582259, 0.71219968, 0.84370724, 0.86736757] + ], + epsilon = 1e-8 + ); + assert_abs_diff_eq!( + data.weighted_sum_axis(Axis(2), &weights).unwrap(), + array![ + [0.40121254, 0.46038311, 0.67686937, 0.58411889, 0.66019415], + [0.85334851, 0.66161009, 0.54679815, 0.89074219, 0.82087021] + ], + epsilon = 1e-8 + ); +} + +#[test] +fn weighted_sum_dimension_zero() { + let a = Array2::::zeros((0, 20)); + assert_eq!( + a.weighted_sum_axis(Axis(0), &Array1::zeros(0)).unwrap(), + Array1::from_elem(20, 0) + ); + assert_eq!( + a.weighted_sum_axis(Axis(1), &Array1::zeros(20)).unwrap(), + Array1::from_elem(0, 0) + ); + assert_eq!( + a.weighted_sum_axis(Axis(0), &Array1::zeros(1)), + Err(MultiInputError::ShapeMismatch(ShapeMismatch { + first_shape: vec![0, 20], + second_shape: vec![1] + })) + ); + assert_eq!( + a.weighted_sum(&Array2::zeros((10, 20))), + Err(MultiInputError::ShapeMismatch(ShapeMismatch { + first_shape: vec![0, 20], + second_shape: vec![10, 20] + })) + ); +} + +#[test] +fn mean_eq_if_uniform_weights() { + fn prop(a: Vec) -> TestResult { + if a.len() < 1 { + return TestResult::discard(); + } + let a = Array1::from(a); + let weights = Array1::from_elem(a.len(), 1.0 / a.len() as f64); + let m = a.mean().unwrap(); + let wm = a.weighted_mean(&weights).unwrap(); + let ws = a.weighted_sum(&weights).unwrap(); + TestResult::from_bool( + abs_diff_eq!(m, wm, epsilon = 1e-9) && abs_diff_eq!(wm, ws, epsilon = 1e-9), + ) + } + quickcheck(prop as fn(Vec) -> TestResult); +} + +#[test] +fn mean_axis_eq_if_uniform_weights() { + fn prop(mut a: Vec) -> TestResult { + if a.len() < 24 { + return TestResult::discard(); + } + let depth = a.len() / 12; + a.truncate(depth * 3 * 4); + let weights = Array1::from_elem(depth, 1.0 / depth as f64); + let a = Array1::from(a) + .into_shape_with_order((depth, 3, 4)) + .unwrap(); + let ma = a.mean_axis(Axis(0)).unwrap(); + let wm = a.weighted_mean_axis(Axis(0), &weights).unwrap(); + let ws = a.weighted_sum_axis(Axis(0), &weights).unwrap(); + TestResult::from_bool( + abs_diff_eq!(ma, wm, epsilon = 1e-12) && abs_diff_eq!(wm, ws, epsilon = 1e12), + ) + } + quickcheck(prop as fn(Vec) -> TestResult); +} + +#[test] +fn weighted_var_eq_var_if_uniform_weight() { + fn prop(a: Vec) -> TestResult { + if a.len() < 1 { + return TestResult::discard(); + } + let a = Array1::from(a); + let weights = Array1::from_elem(a.len(), 1.0 / a.len() as f64); + let weighted_var = a.weighted_var(&weights, 0.0).unwrap(); + let var = a.var_axis(Axis(0), 0.0).into_scalar(); + TestResult::from_bool(abs_diff_eq!(weighted_var, var, epsilon = 1e-10)) + } + quickcheck(prop as fn(Vec) -> TestResult); +} + +#[test] +fn weighted_var_algo_eq_simple_algo() { + fn prop(mut a: Vec) -> TestResult { + if a.len() < 24 { + return TestResult::discard(); + } + let depth = a.len() / 12; + a.truncate(depth * 3 * 4); + let a = Array1::from(a) + .into_shape_with_order((depth, 3, 4)) + .unwrap(); + let mut success = true; + for axis in 0..3 { + let axis = Axis(axis); + + let weights = Array::random(a.len_of(axis), Uniform::new(0.0, 1.0)); + let mean = a + .weighted_mean_axis(axis, &weights) + .unwrap() + .insert_axis(axis); + let res_1_pass = a.weighted_var_axis(axis, &weights, 0.0).unwrap(); + let res_2_pass = (&a - &mean) + .mapv_into(|v| v.powi(2)) + .weighted_mean_axis(axis, &weights) + .unwrap(); + success &= abs_diff_eq!(res_1_pass, res_2_pass, epsilon = 1e-10); + } + TestResult::from_bool(success) + } + quickcheck(prop as fn(Vec) -> TestResult); +} + +#[test] +fn test_central_moment_with_empty_array_of_floats() { + let a: Array1 = array![]; + for order in 0..=3 { + assert_eq!(a.central_moment(order), Err(EmptyInput)); + assert_eq!(a.central_moments(order), Err(EmptyInput)); + } +} + +#[test] +fn test_zeroth_central_moment_is_one() { + let n = 50; + let bound: f64 = 200.; + let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); + assert_eq!(a.central_moment(0).unwrap(), 1.); +} + +#[test] +fn test_first_central_moment_is_zero() { + let n = 50; + let bound: f64 = 200.; + let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); + assert_eq!(a.central_moment(1).unwrap(), 0.); +} + +#[test] +fn test_central_moments() { + let a: Array1 = array![ + 0.07820559, 0.5026185, 0.80935324, 0.39384033, 0.9483038, 0.62516215, 0.90772261, + 0.87329831, 0.60267392, 0.2960298, 0.02810356, 0.31911966, 0.86705506, 0.96884832, + 0.2222465, 0.42162446, 0.99909868, 0.47619762, 0.91696979, 0.9972741, 0.09891734, + 0.76934818, 0.77566862, 0.7692585, 0.2235759, 0.44821286, 0.79732186, 0.04804275, + 0.87863238, 0.1111003, 0.6653943, 0.44386445, 0.2133176, 0.39397086, 0.4374617, 0.95896624, + 0.57850146, 0.29301706, 0.02329879, 0.2123203, 0.62005503, 0.996492, 0.5342986, 0.97822099, + 0.5028445, 0.6693834, 0.14256682, 0.52724704, 0.73482372, 0.1809703, + ]; + // Computed using scipy.stats.moment + let expected_moments = vec![ + 1., + 0., + 0.09339920262960291, + -0.0026849636727735186, + 0.015403769257729755, + -0.001204176487006564, + 0.002976822584939186, + ]; + for (order, expected_moment) in expected_moments.iter().enumerate() { + assert_abs_diff_eq!( + a.central_moment(order as u16).unwrap(), + expected_moment, + epsilon = 1e-8 + ); + } +} + +#[test] +fn test_bulk_central_moments() { + // Test that the bulk method is coherent with the non-bulk method + let n = 50; + let bound: f64 = 200.; + let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); + let order = 10; + let central_moments = a.central_moments(order).unwrap(); + for i in 0..=order { + assert_eq!(a.central_moment(i).unwrap(), central_moments[i as usize]); + } +} + +#[test] +fn test_kurtosis_and_skewness_is_none_with_empty_array_of_floats() { + let a: Array1 = array![]; + assert_eq!(a.skewness(), Err(EmptyInput)); + assert_eq!(a.kurtosis(), Err(EmptyInput)); +} + +#[test] +fn test_kurtosis_and_skewness() { + let a: Array1 = array![ + 0.33310096, 0.98757449, 0.9789796, 0.96738114, 0.43545674, 0.06746873, 0.23706562, + 0.04241815, 0.38961714, 0.52421271, 0.93430327, 0.33911604, 0.05112372, 0.5013455, + 0.05291507, 0.62511183, 0.20749633, 0.22132433, 0.14734804, 0.51960608, 0.00449208, + 0.4093339, 0.2237519, 0.28070469, 0.7887231, 0.92224523, 0.43454188, 0.18335111, + 0.08646856, 0.87979847, 0.25483457, 0.99975627, 0.52712442, 0.41163279, 0.85162594, + 0.52618733, 0.75815023, 0.30640695, 0.14205781, 0.59695813, 0.851331, 0.39524328, + 0.73965373, 0.4007615, 0.02133069, 0.92899207, 0.79878191, 0.38947334, 0.22042183, + 0.77768353, + ]; + // Computed using scipy.stats.kurtosis(a, fisher=False) + let expected_kurtosis = 1.821933711687523; + // Computed using scipy.stats.skew + let expected_skewness = 0.2604785422878771; + + let kurtosis = a.kurtosis().unwrap(); + let skewness = a.skewness().unwrap(); + + assert_abs_diff_eq!(kurtosis, expected_kurtosis, epsilon = 1e-12); + assert_abs_diff_eq!(skewness, expected_skewness, epsilon = 1e-8); +}