From 952fac697b8ff1b32e5a2f11274a24e01273c9dd Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Mon, 18 Dec 2023 16:44:47 -0700 Subject: [PATCH] BUG: avoid seg fault from OOB access in RandomState.set_state() --- numpy/random/mtrand.pyx | 19 ++++++++++--------- numpy/random/tests/test_random.py | 5 +++++ 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/numpy/random/mtrand.pyx b/numpy/random/mtrand.pyx index 53066f6ab98b..0236658c13bf 100644 --- a/numpy/random/mtrand.pyx +++ b/numpy/random/mtrand.pyx @@ -368,15 +368,16 @@ cdef class RandomState: else: if not isinstance(state, (tuple, list)): raise TypeError('state must be a dict or a tuple.') - if state[0] != 'MT19937': - raise ValueError('set_state can only be used with legacy MT19937' - 'state instances.') - st = {'bit_generator': state[0], - 'state': {'key': state[1], 'pos': state[2]}} - if len(state) > 3: - st['has_gauss'] = state[3] - st['gauss'] = state[4] - value = st + with cython.boundscheck(True): + if state[0] != 'MT19937': + raise ValueError('set_state can only be used with legacy ' + 'MT19937 state instances.') + st = {'bit_generator': state[0], + 'state': {'key': state[1], 'pos': state[2]}} + if len(state) > 3: + st['has_gauss'] = state[3] + st['gauss'] = state[4] + value = st self._aug_state.gauss = st.get('gauss', 0.0) self._aug_state.has_gauss = st.get('has_gauss', 0) diff --git a/numpy/random/tests/test_random.py b/numpy/random/tests/test_random.py index 65cf7431ed23..c98584aeda9d 100644 --- a/numpy/random/tests/test_random.py +++ b/numpy/random/tests/test_random.py @@ -147,6 +147,11 @@ def test_negative_binomial(self): # arguments without truncation. self.prng.negative_binomial(0.5, 0.5) + def test_set_invalid_state(self): + # gh-25402 + with pytest.raises(IndexError): + self.prng.set_state(()) + class TestRandint: