Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@amcandio
Copy link

Hi! I noticed that np.searchsorted can be optimized.

This PR optimizes C++ binsearch implementation used by np.searchsorted for u64 and adds benchmarks (I haven't reworked the other binsearch implementations yet).

The main idea is to express the binary search in terms of a range [base, base+legth] where the length halves on each iteration. This makes each key to only need a single base pointer for intermediate computations. The PR uses the ret array to store these intermediate computations so no extra memory is needed (base eventually becomes the result after the last iteration).

The lengths used are only dependent on the initial length of the array, which allows us to batch each intermediate step of the algorithm for all keys. Basically, each key is processed against the same set of lengths.

This ends up being faster for two reasons:

  1. Cache locallity of pivots. In early iterations each key is compared against the same set of pivots. For example, in the first iteration all keys are compared against the median. In the second iteration, all keys end up being compared against 1st and 3rd quartiles.
  2. Independent calculations for out-of-order execution. In the single-key version, step i+1 depends on computation of step i. Meaning that step i+1 must wait for step i to complete before proceeding. When batching multiple keys, we compute each step for all keys before continuing on the next step. All the computations at a given step are independent across different keys. Meaning that the CPU can execute multiple keys out-of-order in parallel.

A numpy vectorized version

We could also implement this algorithm in Python by just relying on numpy array broadcasting:

import numpy as np

def searchsorted_np(arr, keys, side="left"):
    N = arr.shape[0]
    K = keys.shape[0]

    result = np.zeros(K, dtype=np.intp)
    length = N

    if side == "left":
        cmp = np.less
    elif side == "right":
        cmp = np.less_equal

    while length > 1:
        half = length >> 1
        mid = result + half
        mid_val = arr[mid]
        move = cmp(mid_val, keys)
        result += move.astype(np.intp) * half
        length -= half

    mid_val = arr[result]
    move = cmp(mid_val, keys)
    result += move.astype(np.intp)

    return result

A quick benchmark shows this approach to be significantly faster than current implementation when quering multiple keys:

import numpy as np
import timeit

np.random.seed(42)

N = 5_000_000
K = 100_000

arr = np.sort(np.random.randint(0, 10_000_000, size=N))
keys = np.random.randint(0, 10_000_000, size=K)

setup = """
import numpy as np
from __main__ import arr, keys, searchsorted_np
"""
number = 10

stmt_np = "np.searchsorted(arr, keys, side='left')"
stmt_my = "searchsorted_np(arr, keys, side='left')"

time_np = timeit.timeit(stmt=stmt_np, setup=setup, number=number)
time_my = timeit.timeit(stmt=stmt_my, setup=setup, number=number)

print((np.searchsorted(arr, keys, side='left') == searchsorted_np(arr, keys, side='left')).all())
print(f"np.searchsorted: {time_np/number:.4f}s per run")
print(f"searchsorted_np: {time_my/number:.4f}s per run")
True
np.searchsorted: 0.1160s per run
searchsorted_np: 0.0079s per run

Which is 15x faster (numpy beats numpy!). Although it does not beat the C++ implementation from the PR:

True
np.searchsorted: 0.0052s per run
searchsorted_np: 0.0081s per run

Drawbacks

The main drawback of this approach is that we do length = ceil(length / 2) on each iteration. The PR algorithm always does exactly ceil(log(lenght)) iterations whereas the current implementation might only require floor(log(lenght)) in some cases. This extra overhead is what most-likely explains why PR's implementation is ~40ns slower for the single-key&big-array benchmarks.

Results

| Change   | Before [d9e70f6e] <main>   | After [d1efa8aa] <batched-binary-search>   |   Ratio | Benchmark (Parameter)                                                                           |
|----------|----------------------------|--------------------------------------------|---------|-------------------------------------------------------------------------------------------------|
| +        | 903±4ns                    | 949±10ns                                   |    1.05 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 1, 'random', 18122022)       |
| +        | 905±6ns                    | 947±10ns                                   |    1.05 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 1, 'random', 42)             |
| +        | 899±3ns                    | 936±4ns                                    |    1.04 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 1, 'ordered', 42)            |
| -        | 1.06±0.01μs                | 1.03±0.01μs                                |    0.97 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 2, 'ordered', 18122022)      |
| -        | 1.04±0.01μs                | 1.01±0.01μs                                |    0.97 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 2, 'random', 18122022)       |
| -        | 3.02±0.1μs                 | 1.25±0.02μs                                |    0.41 | bench_searchsorted.SearchSortedInt64.time_searchsorted(100, 100, 'ordered', 18122022)           |
| -        | 3.07±0.09μs                | 1.27±0.01μs                                |    0.41 | bench_searchsorted.SearchSortedInt64.time_searchsorted(100, 100, 'ordered', 42)                 |
| -        | 3.41±0.1μs                 | 1.26±0.01μs                                |    0.37 | bench_searchsorted.SearchSortedInt64.time_searchsorted(100, 100, 'random', 18122022)            |
| -        | 3.46±0.1μs                 | 1.26±0.01μs                                |    0.36 | bench_searchsorted.SearchSortedInt64.time_searchsorted(100, 100, 'random', 42)                  |
| -        | 5.70±0.1μs                 | 1.66±0.03μs                                |    0.29 | bench_searchsorted.SearchSortedInt64.time_searchsorted(10000, 100, 'ordered', 18122022)         |
| -        | 5.78±0.2μs                 | 1.66±0.01μs                                |    0.29 | bench_searchsorted.SearchSortedInt64.time_searchsorted(10000, 100, 'ordered', 42)               |
| -        | 439±20ms                   | 122±5ms                                    |    0.28 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 100000, 'ordered', 18122022) |
| -        | 440±10ms                   | 120±2ms                                    |    0.27 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 100000, 'ordered', 42)       |
| -        | 6.42±0.2μs                 | 1.69±0.07μs                                |    0.26 | bench_searchsorted.SearchSortedInt64.time_searchsorted(10000, 100, 'random', 18122022)          |
| -        | 6.44±0.1μs                 | 1.68±0.05μs                                |    0.26 | bench_searchsorted.SearchSortedInt64.time_searchsorted(10000, 100, 'random', 42)                |
| -        | 565±40ms                   | 146±30ms                                   |    0.26 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 100000, 'random', 42)        |
| -        | 8.04±0.4ms                 | 1.99±0.04ms                                |    0.25 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000, 100000, 'ordered', 18122022)    |
| -        | 7.98±0.3ms                 | 1.97±0.01ms                                |    0.25 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000, 100000, 'ordered', 42)          |
| -        | 525±30ms                   | 127±20ms                                   |    0.24 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 100000, 'random', 18122022)  |
| -        | 2.21±0.09ms                | 450±8μs                                    |    0.2  | bench_searchsorted.SearchSortedInt64.time_searchsorted(100, 100000, 'ordered', 18122022)        |
| -        | 2.29±0.09ms                | 453±2μs                                    |    0.2  | bench_searchsorted.SearchSortedInt64.time_searchsorted(100, 100000, 'ordered', 42)              |
| -        | 11.2±0.6μs                 | 2.16±0.03μs                                |    0.19 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000, 100, 'ordered', 18122022)       |
| -        | 4.92±0.08ms                | 881±3μs                                    |    0.18 | bench_searchsorted.SearchSortedInt64.time_searchsorted(10000, 100000, 'ordered', 18122022)      |
| -        | 4.99±0.1ms                 | 883±9μs                                    |    0.18 | bench_searchsorted.SearchSortedInt64.time_searchsorted(10000, 100000, 'ordered', 42)            |
| -        | 12.3±0.5μs                 | 2.14±0μs                                   |    0.17 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000, 100, 'ordered', 42)             |
| -        | 12.6±0.8μs                 | 2.09±0.01μs                                |    0.17 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000, 100, 'random', 18122022)        |
| -        | 2.91±0.1ms                 | 459±10μs                                   |    0.16 | bench_searchsorted.SearchSortedInt64.time_searchsorted(100, 100000, 'random', 18122022)         |
| -        | 2.98±0.02ms                | 464±10μs                                   |    0.16 | bench_searchsorted.SearchSortedInt64.time_searchsorted(100, 100000, 'random', 42)               |
| -        | 5.74±0.1ms                 | 892±2μs                                    |    0.16 | bench_searchsorted.SearchSortedInt64.time_searchsorted(10000, 100000, 'random', 18122022)       |
| -        | 5.81±0.2ms                 | 895±5μs                                    |    0.15 | bench_searchsorted.SearchSortedInt64.time_searchsorted(10000, 100000, 'random', 42)             |
| -        | 14.8±0.8μs                 | 2.12±0.02μs                                |    0.14 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000, 100, 'random', 42)              |
| -        | 24.4±0.2μs                 | 3.28±0.09μs                                |    0.13 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 100, 'ordered', 18122022)    |
| -        | 27.8±0.2μs                 | 3.31±0.04μs                                |    0.12 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 100, 'ordered', 42)          |
| -        | 29.3±0.2μs                 | 3.21±0.02μs                                |    0.11 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 100, 'random', 18122022)     |
| -        | 31.9±0.6μs                 | 3.28±0.04μs                                |    0.1  | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000000, 100, 'random', 42)           |
| -        | 28.6±10ms                  | 2.05±0.05ms                                |    0.07 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000, 100000, 'random', 18122022)     |
| -        | 32.1±20ms                  | 1.97±0.02ms                                |    0.06 | bench_searchsorted.SearchSortedInt64.time_searchsorted(1000000, 100000, 'random', 42)           |

Copy link
Contributor

@eendebakpt eendebakpt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice optimization! I have a few minor comments, but this looks good overall.

@amcandio
Copy link
Author

Nice optimization! I have a few minor comments, but this looks good overall.

Thanks for the comments! I addressed them in the last commit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants