Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 936a08d

Browse files
committed
BUG: Avoid heap buffer overflow for stringdtype searchsorted
1 parent c35a3f4 commit 936a08d

6 files changed

Lines changed: 100 additions & 24 deletions

File tree

numpy/_core/src/common/npy_binsearch.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,21 @@
1010
extern "C" {
1111
#endif
1212

13+
typedef int (PyArray_BinSearchCompareFunc)(const void*, const void*,
14+
PyArrayObject*, PyArrayObject*);
15+
1316
typedef void (PyArray_BinSearchFunc)(const char*, const char*, char*,
1417
npy_intp, npy_intp,
1518
npy_intp, npy_intp, npy_intp,
16-
PyArrayObject*);
19+
PyArrayObject*, PyArrayObject*,
20+
PyArray_BinSearchCompareFunc*);
1721

1822
typedef int (PyArray_ArgBinSearchFunc)(const char*, const char*,
1923
const char*, char*,
2024
npy_intp, npy_intp, npy_intp,
2125
npy_intp, npy_intp, npy_intp,
22-
PyArrayObject*);
26+
PyArrayObject*, PyArrayObject*,
27+
PyArray_BinSearchCompareFunc*);
2328

2429
NPY_NO_EXPORT PyArray_BinSearchFunc* get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side);
2530
NPY_NO_EXPORT PyArray_ArgBinSearchFunc* get_argbinsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side);

numpy/_core/src/multiarray/item_selection.c

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
#include "array_coercion.h"
3636
#include "simd/simd.h"
3737

38+
#include "stringdtype/dtype.h"
39+
3840
static NPY_GCC_OPT_3 inline int
3941
npy_fasttake_impl(
4042
char *dest, char *src, const npy_intp *indices,
@@ -2061,6 +2063,16 @@ PyArray_LexSort(PyObject *sort_keys, int axis)
20612063
}
20622064

20632065

2066+
static int
2067+
binsearch_compare_default(const void *a, const void *b,
2068+
PyArrayObject *arr_a, PyArrayObject *NPY_UNUSED(arr_b))
2069+
{
2070+
PyArray_CompareFunc *compare =
2071+
PyDataType_GetArrFuncs(PyArray_DESCR(arr_a))->compare;
2072+
return compare(a, b, arr_a);
2073+
}
2074+
2075+
20642076
/*NUMPY_API
20652077
*
20662078
* Search the sorted array op1 for the location of the items in op2. The
@@ -2193,35 +2205,58 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
21932205
goto fail;
21942206
}
21952207

2208+
/*
2209+
* TODO: add a way to register per-dtype searchsorted loops and move all this
2210+
* stringdtype-specific code into a loop defined for StringDType
2211+
*/
2212+
int error = 0;
2213+
PyArray_Descr *cmp_descrs[2] = {PyArray_DESCR(ap1), PyArray_DESCR(ap2)};
2214+
npy_string_allocator *allocators[2] = {NULL, NULL};
2215+
2216+
PyArray_BinSearchCompareFunc *cmp_func = &binsearch_compare_default;
2217+
if (NPY_DTYPE(PyArray_DESCR(ap2)) == &PyArray_StringDType) {
2218+
cmp_func = &stringdtype_binsearch_compare;
2219+
}
2220+
2221+
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
2222+
2223+
if (NPY_DTYPE(PyArray_DESCR(ap2)) == &PyArray_StringDType) {
2224+
NpyString_acquire_allocators(2, cmp_descrs, allocators);
2225+
}
2226+
21962227
if (ap3 == NULL) {
21972228
/* do regular binsearch */
2198-
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
21992229
binsearch((const char *)PyArray_DATA(ap1),
22002230
(const char *)PyArray_DATA(ap2),
22012231
(char *)PyArray_DATA(ret),
22022232
PyArray_SIZE(ap1), PyArray_SIZE(ap2),
22032233
PyArray_STRIDES(ap1)[0], PyArray_ITEMSIZE(ap2),
2204-
NPY_SIZEOF_INTP, ap2);
2205-
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
2234+
NPY_SIZEOF_INTP, ap2, ap1, cmp_func);
22062235
}
22072236
else {
22082237
/* do binsearch with a sorter array */
2209-
int error = 0;
2210-
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
22112238
error = argbinsearch((const char *)PyArray_DATA(ap1),
22122239
(const char *)PyArray_DATA(ap2),
22132240
(const char *)PyArray_DATA(sorter),
22142241
(char *)PyArray_DATA(ret),
22152242
PyArray_SIZE(ap1), PyArray_SIZE(ap2),
22162243
PyArray_STRIDES(ap1)[0],
22172244
PyArray_ITEMSIZE(ap2),
2218-
PyArray_STRIDES(sorter)[0], NPY_SIZEOF_INTP, ap2);
2219-
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
2220-
if (error < 0) {
2221-
PyErr_SetString(PyExc_ValueError,
2222-
"Sorter index out of range.");
2223-
goto fail;
2224-
}
2245+
PyArray_STRIDES(sorter)[0], NPY_SIZEOF_INTP, ap2,
2246+
ap1, cmp_func);
2247+
}
2248+
2249+
if (NPY_DTYPE(PyArray_DESCR(ap2)) == &PyArray_StringDType) {
2250+
NpyString_release_allocators(2, allocators);
2251+
}
2252+
2253+
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
2254+
2255+
if (error < 0) {
2256+
PyErr_SetString(PyExc_ValueError, "Sorter index out of range.");
2257+
goto fail;
2258+
}
2259+
if (ap3 != NULL) {
22252260
Py_DECREF(ap3);
22262261
Py_DECREF(sorter);
22272262
}

numpy/_core/src/multiarray/stringdtype/dtype.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,15 @@ _compare(void *a, void *b, PyArray_StringDTypeObject *descr_a,
512512
return NpyString_cmp(&s_a, &s_b);
513513
}
514514

515+
NPY_NO_EXPORT int
516+
stringdtype_binsearch_compare(const void *a, const void *b,
517+
PyArrayObject *arr_a, PyArrayObject *arr_b)
518+
{
519+
return _compare((void *)a, (void *)b,
520+
(PyArray_StringDTypeObject *)PyArray_DESCR(arr_a),
521+
(PyArray_StringDTypeObject *)PyArray_DESCR(arr_b));
522+
}
523+
515524
int
516525
_sort_compare(const void *a, const void *b, void *context)
517526
{

numpy/_core/src/multiarray/stringdtype/dtype.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ NPY_NO_EXPORT int
2525
_compare(void *a, void *b, PyArray_StringDTypeObject *descr_a,
2626
PyArray_StringDTypeObject *descr_b);
2727

28+
NPY_NO_EXPORT int
29+
stringdtype_binsearch_compare(const void *a, const void *b,
30+
PyArrayObject *arr_a, PyArrayObject *arr_b);
31+
2832
NPY_NO_EXPORT int
2933
init_string_na_object(PyObject *mod);
3034

numpy/_core/src/npysort/binsearch.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ template <class Tag, side_t side>
6060
static void
6161
binsearch(const char *arr, const char *key, char *ret, npy_intp arr_len,
6262
npy_intp key_len, npy_intp arr_str, npy_intp key_str,
63-
npy_intp ret_str, PyArrayObject *)
63+
npy_intp ret_str, PyArrayObject *, PyArrayObject *,
64+
PyArray_BinSearchCompareFunc *)
6465
{
6566
using T = typename Tag::type;
6667
auto cmp = side_to_cmp<Tag, side>::value;
@@ -177,7 +178,7 @@ static int
177178
argbinsearch(const char *arr, const char *key, const char *sort, char *ret,
178179
npy_intp arr_len, npy_intp key_len, npy_intp arr_str,
179180
npy_intp key_str, npy_intp sort_str, npy_intp ret_str,
180-
PyArrayObject *)
181+
PyArrayObject *, PyArrayObject *, PyArray_BinSearchCompareFunc *)
181182
{
182183
using T = typename Tag::type;
183184
auto cmp = side_to_cmp<Tag, side>::value;
@@ -244,10 +245,10 @@ template <side_t side>
244245
static void
245246
npy_binsearch(const char *arr, const char *key, char *ret, npy_intp arr_len,
246247
npy_intp key_len, npy_intp arr_str, npy_intp key_str,
247-
npy_intp ret_str, PyArrayObject *cmp)
248+
npy_intp ret_str, PyArrayObject *key_arr, PyArrayObject *arr_arr,
249+
PyArray_BinSearchCompareFunc *compare)
248250
{
249251
using Cmp = typename side_to_generic_cmp<side>::type;
250-
PyArray_CompareFunc *compare = PyDataType_GetArrFuncs(PyArray_DESCR(cmp))->compare;
251252
npy_intp min_idx = 0;
252253
npy_intp max_idx = arr_len;
253254
const char *last_key = key;
@@ -258,7 +259,8 @@ npy_binsearch(const char *arr, const char *key, char *ret, npy_intp arr_len,
258259
* gives the search a big boost when keys are sorted, but slightly
259260
* slows down things for purely random ones.
260261
*/
261-
if (Cmp{}(compare(last_key, key, cmp), 0)) {
262+
/* last_key and key are both elements of the key array */
263+
if (Cmp{}(compare(last_key, key, key_arr, key_arr), 0)) {
262264
max_idx = arr_len;
263265
}
264266
else {
@@ -272,7 +274,8 @@ npy_binsearch(const char *arr, const char *key, char *ret, npy_intp arr_len,
272274
const npy_intp mid_idx = min_idx + ((max_idx - min_idx) >> 1);
273275
const char *arr_ptr = arr + mid_idx * arr_str;
274276

275-
if (Cmp{}(compare(arr_ptr, key, cmp), 0)) {
277+
/* arr_ptr belongs to the haystack, key to the key array */
278+
if (Cmp{}(compare(arr_ptr, key, arr_arr, key_arr), 0)) {
276279
min_idx = mid_idx + 1;
277280
}
278281
else {
@@ -288,10 +291,10 @@ static int
288291
npy_argbinsearch(const char *arr, const char *key, const char *sort, char *ret,
289292
npy_intp arr_len, npy_intp key_len, npy_intp arr_str,
290293
npy_intp key_str, npy_intp sort_str, npy_intp ret_str,
291-
PyArrayObject *cmp)
294+
PyArrayObject *key_arr, PyArrayObject *arr_arr,
295+
PyArray_BinSearchCompareFunc *compare)
292296
{
293297
using Cmp = typename side_to_generic_cmp<side>::type;
294-
PyArray_CompareFunc *compare = PyDataType_GetArrFuncs(PyArray_DESCR(cmp))->compare;
295298
npy_intp min_idx = 0;
296299
npy_intp max_idx = arr_len;
297300
const char *last_key = key;
@@ -302,7 +305,8 @@ npy_argbinsearch(const char *arr, const char *key, const char *sort, char *ret,
302305
* gives the search a big boost when keys are sorted, but slightly
303306
* slows down things for purely random ones.
304307
*/
305-
if (Cmp{}(compare(last_key, key, cmp), 0)) {
308+
/* last_key and key are both elements of the key array */
309+
if (Cmp{}(compare(last_key, key, key_arr, key_arr), 0)) {
306310
max_idx = arr_len;
307311
}
308312
else {
@@ -323,7 +327,8 @@ npy_argbinsearch(const char *arr, const char *key, const char *sort, char *ret,
323327

324328
arr_ptr = arr + sort_idx * arr_str;
325329

326-
if (Cmp{}(compare(arr_ptr, key, cmp), 0)) {
330+
/* arr_ptr belongs to the haystack, key to the key array */
331+
if (Cmp{}(compare(arr_ptr, key, arr_arr, key_arr), 0)) {
327332
min_idx = mid_idx + 1;
328333
}
329334
else {

numpy/_core/tests/test_stringdtype.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,24 @@ def test_sort(strings, arr_sorted):
492492
test_sort(strings, arr_sorted)
493493

494494

495+
def test_searchsorted_gh31533():
496+
n = 100_000
497+
# all > 15 bytes -> arena
498+
values = [f"{i:020d}" for i in range(n)]
499+
haystack = np.array(values, dtype="T")
500+
# a handful of needles -> tiny arena
501+
needle_values = values[:: n // 23]
502+
expected = np.searchsorted(
503+
np.array(values, dtype="U20"), np.array(needle_values, dtype="U20")
504+
)
505+
506+
needles = np.array(needle_values, dtype="T")
507+
assert_array_equal(np.searchsorted(haystack, needles), expected)
508+
assert_array_equal(
509+
np.searchsorted(haystack, needles, sorter=np.arange(n)), expected
510+
)
511+
512+
495513
@pytest.mark.parametrize(
496514
"strings",
497515
[

0 commit comments

Comments
 (0)