Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d567754

Browse files
authored
BUG: Ensure seed sequences are restored through pickling (#26260)
Explicity store and restore seed sequence closes #26234 --- * BUG: Ensure seed sequences are restored through pickling Explicity store and restore seed sequence closes #26234 * CLN: Simplify refactor Make more use of set and getstate to avoid changes in the pickling functions * BUG: Correct behavior for legacy pickles Add test for legacy pickles Include pickles for tests * MAINT: Correct types for pickle related functions * REF: Switch from string to type * REF: Swtich to returning bit generators Explicitly return bit generator rather than ctor
1 parent c3ce003 commit d567754

16 files changed

+167
-48
lines changed

numpy/random/_generator.pyi

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,12 @@ class Generator:
6868
def __init__(self, bit_generator: BitGenerator) -> None: ...
6969
def __repr__(self) -> str: ...
7070
def __str__(self) -> str: ...
71-
def __getstate__(self) -> dict[str, Any]: ...
72-
def __setstate__(self, state: dict[str, Any]) -> None: ...
73-
def __reduce__(self) -> tuple[Callable[[str], Generator], tuple[str], dict[str, Any]]: ...
71+
def __getstate__(self) -> None: ...
72+
def __setstate__(self, state: dict[str, Any] | None) -> None: ...
73+
def __reduce__(self) -> tuple[
74+
Callable[[BitGenerator], Generator],
75+
tuple[BitGenerator],
76+
None]: ...
7477
@property
7578
def bit_generator(self) -> BitGenerator: ...
7679
def spawn(self, n_children: int) -> list[Generator]: ...

numpy/random/_generator.pyx

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -214,17 +214,19 @@ cdef class Generator:
214214

215215
# Pickling support:
216216
def __getstate__(self):
217-
return self.bit_generator.state
217+
return None
218218

219-
def __setstate__(self, state):
220-
self.bit_generator.state = state
219+
def __setstate__(self, bit_gen):
220+
if isinstance(bit_gen, dict):
221+
# Legacy path
222+
# Prior to 2.0.x only the state of the underlying bit generator
223+
# was preserved and any seed sequence information was lost
224+
self.bit_generator.state = bit_gen
221225

222226
def __reduce__(self):
223-
ctor, name_tpl, state = self._bit_generator.__reduce__()
224-
225227
from ._pickle import __generator_ctor
226-
# Requirements of __generator_ctor are (name, ctor)
227-
return __generator_ctor, (name_tpl[0], ctor), state
228+
# Requirements of __generator_ctor are (bit_generator, )
229+
return __generator_ctor, (self._bit_generator, ), None
228230

229231
@property
230232
def bit_generator(self):

numpy/random/_pickle.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .bit_generator import BitGenerator
12
from .mtrand import RandomState
23
from ._philox import Philox
34
from ._pcg64 import PCG64, PCG64DXSM
@@ -14,27 +15,30 @@
1415
}
1516

1617

17-
def __bit_generator_ctor(bit_generator_name='MT19937'):
18+
def __bit_generator_ctor(bit_generator: str | type[BitGenerator] = 'MT19937'):
1819
"""
1920
Pickling helper function that returns a bit generator object
2021
2122
Parameters
2223
----------
23-
bit_generator_name : str
24-
String containing the name of the BitGenerator
24+
bit_generator : type[BitGenerator] or str
25+
BitGenerator class or string containing the name of the BitGenerator
2526
2627
Returns
2728
-------
28-
bit_generator : BitGenerator
29+
BitGenerator
2930
BitGenerator instance
3031
"""
31-
if bit_generator_name in BitGenerators:
32-
bit_generator = BitGenerators[bit_generator_name]
32+
if isinstance(bit_generator, type):
33+
bit_gen_class = bit_generator
34+
elif bit_generator in BitGenerators:
35+
bit_gen_class = BitGenerators[bit_generator]
3336
else:
34-
raise ValueError(str(bit_generator_name) + ' is not a known '
35-
'BitGenerator module.')
37+
raise ValueError(
38+
str(bit_generator) + ' is not a known BitGenerator module.'
39+
)
3640

37-
return bit_generator()
41+
return bit_gen_class()
3842

3943

4044
def __generator_ctor(bit_generator_name="MT19937",
@@ -44,8 +48,9 @@ def __generator_ctor(bit_generator_name="MT19937",
4448
4549
Parameters
4650
----------
47-
bit_generator_name : str
48-
String containing the core BitGenerator's name
51+
bit_generator_name : str or BitGenerator
52+
String containing the core BitGenerator's name or a
53+
BitGenerator instance
4954
bit_generator_ctor : callable, optional
5055
Callable function that takes bit_generator_name as its only argument
5156
and returns an instantized bit generator.
@@ -55,6 +60,9 @@ def __generator_ctor(bit_generator_name="MT19937",
5560
rg : Generator
5661
Generator using the named core BitGenerator
5762
"""
63+
if isinstance(bit_generator_name, BitGenerator):
64+
return Generator(bit_generator_name)
65+
# Legacy path that uses a bit generator name and ctor
5866
return Generator(bit_generator_ctor(bit_generator_name))
5967

6068

@@ -76,5 +84,6 @@ def __randomstate_ctor(bit_generator_name="MT19937",
7684
rs : RandomState
7785
Legacy RandomState using the named core BitGenerator
7886
"""
79-
87+
if isinstance(bit_generator_name, BitGenerator):
88+
return RandomState(bit_generator_name)
8089
return RandomState(bit_generator_ctor(bit_generator_name))

numpy/random/bit_generator.pyi

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,17 @@ class SeedSequence(ISpawnableSeedSequence):
9292
class BitGenerator(abc.ABC):
9393
lock: Lock
9494
def __init__(self, seed: None | _ArrayLikeInt_co | SeedSequence = ...) -> None: ...
95-
def __getstate__(self) -> dict[str, Any]: ...
96-
def __setstate__(self, state: dict[str, Any]) -> None: ...
95+
def __getstate__(self) -> tuple[dict[str, Any], ISeedSequence]: ...
96+
def __setstate__(
97+
self, state_seed_seq: dict[str, Any] | tuple[dict[str, Any], ISeedSequence]
98+
) -> None: ...
9799
def __reduce__(
98100
self,
99-
) -> tuple[Callable[[str], BitGenerator], tuple[str], tuple[dict[str, Any]]]: ...
101+
) -> tuple[
102+
Callable[[str], BitGenerator],
103+
tuple[str],
104+
tuple[dict[str, Any], ISeedSequence]
105+
]: ...
100106
@abc.abstractmethod
101107
@property
102108
def state(self) -> Mapping[str, Any]: ...

numpy/random/bit_generator.pyx

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -537,14 +537,27 @@ cdef class BitGenerator():
537537

538538
# Pickling support:
539539
def __getstate__(self):
540-
return self.state
540+
return self.state, self._seed_seq
541541

542-
def __setstate__(self, state):
543-
self.state = state
542+
def __setstate__(self, state_seed_seq):
543+
544+
if isinstance(state_seed_seq, dict):
545+
# Legacy path
546+
# Prior to 2.0.x only the state of the underlying bit generator
547+
# was preserved and any seed sequence information was lost
548+
self.state = state_seed_seq
549+
else:
550+
self._seed_seq = state_seed_seq[1]
551+
self.state = state_seed_seq[0]
544552

545553
def __reduce__(self):
546554
from ._pickle import __bit_generator_ctor
547-
return __bit_generator_ctor, (self.state['bit_generator'],), self.state
555+
556+
return (
557+
__bit_generator_ctor,
558+
(type(self), ),
559+
(self.state, self._seed_seq)
560+
)
548561

549562
@property
550563
def state(self):

numpy/random/meson.build

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ py.install_sources(
139139
'tests/data/philox-testset-2.csv',
140140
'tests/data/sfc64-testset-1.csv',
141141
'tests/data/sfc64-testset-2.csv',
142+
'tests/data/sfc64_np126.pkl.gz',
143+
'tests/data/generator_pcg64_np126.pkl.gz',
144+
'tests/data/generator_pcg64_np121.pkl.gz',
142145
],
143146
subdir: 'numpy/random/tests/data'
144147
)

numpy/random/mtrand.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class RandomState:
7373
def __str__(self) -> str: ...
7474
def __getstate__(self) -> dict[str, Any]: ...
7575
def __setstate__(self, state: dict[str, Any]) -> None: ...
76-
def __reduce__(self) -> tuple[Callable[[str], RandomState], tuple[str], dict[str, Any]]: ...
76+
def __reduce__(self) -> tuple[Callable[[BitGenerator], RandomState], tuple[BitGenerator], dict[str, Any]]: ...
7777
def seed(self, seed: None | _ArrayLikeFloat_co = ...) -> None: ...
7878
@overload
7979
def get_state(self, legacy: Literal[False] = ...) -> dict[str, Any]: ...

numpy/random/mtrand.pyx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,13 @@ cdef class RandomState:
205205
self.set_state(state)
206206

207207
def __reduce__(self):
208-
ctor, name_tpl, _ = self._bit_generator.__reduce__()
209-
210208
from ._pickle import __randomstate_ctor
211-
return __randomstate_ctor, (name_tpl[0], ctor), self.get_state(legacy=False)
209+
# The third argument containing the state is required here since
210+
# RandomState contains state information in addition to the state
211+
# contained in the bit generator that described the gaussian
212+
# generator. This argument is passed to __setstate__ after the
213+
# Generator is created.
214+
return __randomstate_ctor, (self._bit_generator, ), self.get_state(legacy=False)
212215

213216
cdef _initialize_bit_generator(self, bit_generator):
214217
self._bit_generator = bit_generator
Binary file not shown.
Binary file not shown.
290 Bytes
Binary file not shown.

numpy/random/tests/test_direct.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,24 @@ def test_pickle(self):
298298
aa = pickle.loads(pickle.dumps(ss))
299299
assert_equal(ss.state, aa.state)
300300

301+
def test_pickle_preserves_seed_sequence(self):
302+
# GH 26234
303+
# Add explicit test that bit generators preserve seed sequences
304+
import pickle
305+
306+
bit_generator = self.bit_generator(*self.data1['seed'])
307+
ss = bit_generator.seed_seq
308+
bg_plk = pickle.loads(pickle.dumps(bit_generator))
309+
ss_plk = bg_plk.seed_seq
310+
assert_equal(ss.state, ss_plk.state)
311+
assert_equal(ss.pool, ss_plk.pool)
312+
313+
bit_generator.seed_seq.spawn(10)
314+
bg_plk = pickle.loads(pickle.dumps(bit_generator))
315+
ss_plk = bg_plk.seed_seq
316+
assert_equal(ss.state, ss_plk.state)
317+
assert_equal(ss.n_children_spawned, ss_plk.n_children_spawned)
318+
301319
def test_invalid_state_type(self):
302320
bit_generator = self.bit_generator(*self.data1['seed'])
303321
with pytest.raises(TypeError):
@@ -349,8 +367,9 @@ def test_getstate(self):
349367
bit_generator = self.bit_generator(*self.data1['seed'])
350368
state = bit_generator.state
351369
alt_state = bit_generator.__getstate__()
352-
assert_state_equal(state, alt_state)
353-
370+
assert isinstance(alt_state, tuple)
371+
assert_state_equal(state, alt_state[0])
372+
assert isinstance(alt_state[1], SeedSequence)
354373

355374
class TestPhilox(Base):
356375
@classmethod
@@ -413,6 +432,7 @@ def test_advange_large(self):
413432
assert state["state"] == advanced_state
414433

415434

435+
416436
class TestPCG64DXSM(Base):
417437
@classmethod
418438
def setup_class(cls):
@@ -502,6 +522,29 @@ def setup_class(cls):
502522
cls.invalid_init_types = [(3.2,), ([None],), (1, None)]
503523
cls.invalid_init_values = [(-1,)]
504524

525+
def test_legacy_pickle(self):
526+
# Pickling format was changed in 2.0.x
527+
import gzip
528+
import pickle
529+
530+
expected_state = np.array(
531+
[
532+
9957867060933711493,
533+
532597980065565856,
534+
14769588338631205282,
535+
13
536+
],
537+
dtype=np.uint64
538+
)
539+
540+
base_path = os.path.split(os.path.abspath(__file__))[0]
541+
pkl_file = os.path.join(base_path, "data", f"sfc64_np126.pkl.gz")
542+
with gzip.open(pkl_file) as gz:
543+
sfc = pickle.load(gz)
544+
545+
assert isinstance(sfc, SFC64)
546+
assert_equal(sfc.state["state"]["state"], expected_state)
547+
505548

506549
class TestDefaultRNG:
507550
def test_seed(self):

numpy/random/tests/test_generator_mt19937.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os.path
12
import sys
23
import hashlib
34

@@ -2738,10 +2739,50 @@ def test_generator_ctor_old_style_pickle():
27382739
rg = np.random.Generator(np.random.PCG64DXSM(0))
27392740
rg.standard_normal(1)
27402741
# Directly call reduce which is used in pickling
2741-
ctor, args, state_a = rg.__reduce__()
2742+
ctor, (bit_gen, ), _ = rg.__reduce__()
27422743
# Simulate unpickling an old pickle that only has the name
2743-
assert args[:1] == ("PCG64DXSM",)
2744-
b = ctor(*args[:1])
2745-
b.bit_generator.state = state_a
2744+
assert bit_gen.__class__.__name__ == "PCG64DXSM"
2745+
print(ctor)
2746+
b = ctor(*("PCG64DXSM",))
2747+
print(b)
2748+
b.bit_generator.state = bit_gen.state
27462749
state_b = b.bit_generator.state
2747-
assert state_a == state_b
2750+
assert bit_gen.state == state_b
2751+
2752+
2753+
def test_pickle_preserves_seed_sequence():
2754+
# GH 26234
2755+
# Add explicit test that bit generators preserve seed sequences
2756+
import pickle
2757+
2758+
rg = np.random.Generator(np.random.PCG64DXSM(20240411))
2759+
ss = rg.bit_generator.seed_seq
2760+
rg_plk = pickle.loads(pickle.dumps(rg))
2761+
ss_plk = rg_plk.bit_generator.seed_seq
2762+
assert_equal(ss.state, ss_plk.state)
2763+
assert_equal(ss.pool, ss_plk.pool)
2764+
2765+
rg.bit_generator.seed_seq.spawn(10)
2766+
rg_plk = pickle.loads(pickle.dumps(rg))
2767+
ss_plk = rg_plk.bit_generator.seed_seq
2768+
assert_equal(ss.state, ss_plk.state)
2769+
2770+
2771+
@pytest.mark.parametrize("version", [121, 126])
2772+
def test_legacy_pickle(version):
2773+
# Pickling format was changes in 1.22.x and in 2.0.x
2774+
import pickle
2775+
import gzip
2776+
2777+
base_path = os.path.split(os.path.abspath(__file__))[0]
2778+
pkl_file = os.path.join(
2779+
base_path, "data", f"generator_pcg64_np{version}.pkl.gz"
2780+
)
2781+
with gzip.open(pkl_file) as gz:
2782+
rg = pickle.load(gz)
2783+
state = rg.bit_generator.state['state']
2784+
2785+
assert isinstance(rg, Generator)
2786+
assert isinstance(rg.bit_generator, np.random.PCG64)
2787+
assert state['state'] == 35399562948360463058890781895381311971
2788+
assert state['inc'] == 87136372517582989555478159403783844777

numpy/random/tests/test_randomstate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,8 +2052,8 @@ def test_randomstate_ctor_old_style_pickle():
20522052
# Directly call reduce which is used in pickling
20532053
ctor, args, state_a = rs.__reduce__()
20542054
# Simulate unpickling an old pickle that only has the name
2055-
assert args[:1] == ("MT19937",)
2056-
b = ctor(*args[:1])
2055+
assert args[0].__class__.__name__ == "MT19937"
2056+
b = ctor(*("MT19937",))
20572057
b.set_state(state_a)
20582058
state_b = b.get_state(legacy=False)
20592059

numpy/typing/tests/data/pass/random.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -911,9 +911,7 @@
911911

912912
def_gen.__str__()
913913
def_gen.__repr__()
914-
def_gen_state: dict[str, Any]
915-
def_gen_state = def_gen.__getstate__()
916-
def_gen.__setstate__(def_gen_state)
914+
def_gen.__setstate__(dict(def_gen.bit_generator.state))
917915

918916
# RandomState
919917
random_st: np.random.RandomState = np.random.RandomState()

numpy/typing/tests/data/reveal/random.pyi

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -953,9 +953,7 @@ assert_type(def_gen.shuffle(D_2D, axis=1), None)
953953
assert_type(np.random.Generator(pcg64), np.random.Generator)
954954
assert_type(def_gen.__str__(), str)
955955
assert_type(def_gen.__repr__(), str)
956-
def_gen_state = def_gen.__getstate__()
957-
assert_type(def_gen_state, dict[str, Any])
958-
assert_type(def_gen.__setstate__(def_gen_state), None)
956+
assert_type(def_gen.__setstate__(dict(def_gen.bit_generator.state)), None)
959957

960958
# RandomState
961959
random_st: np.random.RandomState = np.random.RandomState()

0 commit comments

Comments
 (0)