Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0ab3699

Browse files
sebergogrisel
andauthored
MAINT: Adapt sklearn for NumPy default integer change (#27041)
Co-authored-by: Olivier Grisel <[email protected]>
1 parent c36ab3a commit 0ab3699

File tree

4 files changed

+57
-26
lines changed

4 files changed

+57
-26
lines changed

sklearn/cluster/_hdbscan/hdbscan.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _brute_mst(mutual_reachability, min_samples):
124124
# Compute the minimum spanning tree for the sparse graph
125125
sparse_min_spanning_tree = csgraph.minimum_spanning_tree(mutual_reachability)
126126
rows, cols = sparse_min_spanning_tree.nonzero()
127-
mst = np.core.records.fromarrays(
127+
mst = np.rec.fromarrays(
128128
[rows, cols, sparse_min_spanning_tree.data],
129129
dtype=MST_edge_dtype,
130130
)

sklearn/utils/_random.pxd

-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,6 @@ cdef enum:
1616
# 32-bit signed integers (i.e. 2^31 - 1).
1717
RAND_R_MAX = 2147483647
1818

19-
cpdef sample_without_replacement(cnp.int_t n_population,
20-
cnp.int_t n_samples,
21-
method=*,
22-
random_state=*)
2319

2420
# rand_r replacement using a 32bit XorShift generator
2521
# See http://www.jstatsoft.org/v08/i14/paper for details

sklearn/utils/_random.pyx

+55-20
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,17 @@ from . import check_random_state
1919
cdef UINT32_t DEFAULT_SEED = 1
2020

2121

22-
cpdef _sample_without_replacement_check_input(cnp.int_t n_population,
23-
cnp.int_t n_samples):
22+
# Compatibility type to always accept the default int type used by NumPy, both
23+
# before and after NumPy 2. On Windows, `long` does not always match `cnp.inp_t`.
24+
# See the comments in the `sample_without_replacement` Python function for more
25+
# details.
26+
ctypedef fused default_int:
27+
cnp.intp_t
28+
long
29+
30+
31+
cpdef _sample_without_replacement_check_input(default_int n_population,
32+
default_int n_samples):
2433
""" Check that input are consistent for sample_without_replacement"""
2534
if n_population < 0:
2635
raise ValueError('n_population should be greater than 0, got %s.'
@@ -33,8 +42,8 @@ cpdef _sample_without_replacement_check_input(cnp.int_t n_population,
3342

3443

3544
cpdef _sample_without_replacement_with_tracking_selection(
36-
cnp.int_t n_population,
37-
cnp.int_t n_samples,
45+
default_int n_population,
46+
default_int n_samples,
3847
random_state=None):
3948
r"""Sample integers without replacement.
4049
@@ -76,9 +85,9 @@ cpdef _sample_without_replacement_with_tracking_selection(
7685
"""
7786
_sample_without_replacement_check_input(n_population, n_samples)
7887
79-
cdef cnp.int_t i
80-
cdef cnp.int_t j
81-
cdef cnp.int_t[::1] out = np.empty((n_samples, ), dtype=int)
88+
cdef default_int i
89+
cdef default_int j
90+
cdef default_int[::1] out = np.empty((n_samples, ), dtype=int)
8291
8392
rng = check_random_state(random_state)
8493
rng_randint = rng.randint
@@ -97,8 +106,8 @@ cpdef _sample_without_replacement_with_tracking_selection(
97106
return np.asarray(out)
98107
99108
100-
cpdef _sample_without_replacement_with_pool(cnp.int_t n_population,
101-
cnp.int_t n_samples,
109+
cpdef _sample_without_replacement_with_pool(default_int n_population,
110+
default_int n_samples,
102111
random_state=None):
103112
"""Sample integers without replacement.
104113

@@ -131,10 +140,10 @@ cpdef _sample_without_replacement_with_pool(cnp.int_t n_population,
131140
"""
132141
_sample_without_replacement_check_input(n_population, n_samples)
133142
134-
cdef cnp.int_t i
135-
cdef cnp.int_t j
136-
cdef cnp.int_t[::1] out = np.empty((n_samples,), dtype=int)
137-
cdef cnp.int_t[::1] pool = np.empty((n_population,), dtype=int)
143+
cdef default_int i
144+
cdef default_int j
145+
cdef default_int[::1] out = np.empty((n_samples,), dtype=int)
146+
cdef default_int[::1] pool = np.empty((n_population,), dtype=int)
138147
139148
rng = check_random_state(random_state)
140149
rng_randint = rng.randint
@@ -154,8 +163,8 @@ cpdef _sample_without_replacement_with_pool(cnp.int_t n_population,
154163
155164
156165
cpdef _sample_without_replacement_with_reservoir_sampling(
157-
cnp.int_t n_population,
158-
cnp.int_t n_samples,
166+
default_int n_population,
167+
default_int n_samples,
159168
random_state=None
160169
):
161170
"""Sample integers without replacement.
@@ -191,9 +200,9 @@ cpdef _sample_without_replacement_with_reservoir_sampling(
191200
"""
192201
_sample_without_replacement_check_input(n_population, n_samples)
193202
194-
cdef cnp.int_t i
195-
cdef cnp.int_t j
196-
cdef cnp.int_t[::1] out = np.empty((n_samples, ), dtype=int)
203+
cdef default_int i
204+
cdef default_int j
205+
cdef default_int[::1] out = np.empty((n_samples, ), dtype=int)
197206
198207
rng = check_random_state(random_state)
199208
rng_randint = rng.randint
@@ -213,8 +222,8 @@ cpdef _sample_without_replacement_with_reservoir_sampling(
213222
return np.asarray(out)
214223
215224
216-
cpdef sample_without_replacement(cnp.int_t n_population,
217-
cnp.int_t n_samples,
225+
cdef _sample_without_replacement(default_int n_population,
226+
default_int n_samples,
218227
method="auto",
219228
random_state=None):
220229
"""Sample integers without replacement.
@@ -303,6 +312,32 @@ cpdef sample_without_replacement(cnp.int_t n_population,
303312
% (all_methods, method))
304313
305314
315+
def sample_without_replacement(
316+
object n_population, object n_samples, method="auto", random_state=None):
317+
cdef:
318+
cnp.intp_t n_pop_intp, n_samples_intp
319+
long n_pop_long, n_samples_long
320+
321+
# On most platforms `np.int_ is np.intp`. However, before NumPy 2 the
322+
# default integer `np.int_` was a long which is 32bit on 64bit windows
323+
# while `intp` is 64bit on 64bit platforms and 32bit on 32bit ones.
324+
if np.int_ is np.intp:
325+
# Branch always taken on NumPy >=2 (or when not on 64bit windows).
326+
# Cython has different rules for conversion of values to integers.
327+
# For NumPy <1.26.2 AND Cython 3, this first branch requires `int()`
328+
# called explicitly to allow e.g. floats.
329+
n_pop_intp = int(n_population)
330+
n_samples_intp = int(n_samples)
331+
return _sample_without_replacement(
332+
n_pop_intp, n_samples_intp, method, random_state)
333+
else:
334+
# Branch taken on 64bit windows with Numpy<2.0 where `long` is 32bit
335+
n_pop_long = n_population
336+
n_samples_long = n_samples
337+
return _sample_without_replacement(
338+
n_pop_long, n_samples_long, method, random_state)
339+
340+
306341
def _our_rand_r_py(seed):
307342
"""Python utils to test the our_rand_r function"""
308343
cdef UINT32_t my_seed = seed

sklearn/utils/estimator_checks.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3470,7 +3470,7 @@ def param_filter(p):
34703470
type,
34713471
}
34723472
# Any numpy numeric such as np.int32.
3473-
allowed_types.update(np.core.numerictypes.allTypes.values())
3473+
allowed_types.update(np.sctypeDict.values())
34743474

34753475
allowed_value = (
34763476
type(init_param.default) in allowed_types

0 commit comments

Comments
 (0)