diff --git a/numpy/__init__.py b/numpy/__init__.py index 0fcd5097d2d1..076f61065ff8 100644 --- a/numpy/__init__.py +++ b/numpy/__init__.py @@ -107,8 +107,21 @@ from __future__ import division, absolute_import, print_function import sys +import warnings +# Disallow reloading numpy. Doing that does nothing to change previously +# loaded modules, which would need to be reloaded separately, but it does +# change the identity of the warnings and sentinal classes defined below +# with dire consequences when checking for identity. +if '_is_loaded' in globals(): + raise RuntimeError('Reloading numpy is not supported') +_is_loaded = True + +# Define some global warnings and the _NoValue sentinal. Defining them here +# means that their identity will change if numpy is reloaded, hence if that is +# to be allowed they should be moved into their own, non-reloadable module. +# Note that these should be defined (or imported) before the other imports. class ModuleDeprecationWarning(DeprecationWarning): """Module deprecation warning. @@ -135,9 +148,8 @@ class VisibleDeprecationWarning(UserWarning): class _NoValue: """Special keyword value. - This class may be used as the default value assigned to a - deprecated keyword in order to check if it has been given a user - defined value. + This class may be used as the default value assigned to a deprecated + keyword in order to check if it has been given a user defined value. """ pass @@ -155,11 +167,8 @@ class _NoValue: except NameError: __NUMPY_SETUP__ = False - if __NUMPY_SETUP__: - import sys as _sys - _sys.stderr.write('Running from numpy source directory.\n') - del _sys + sys.stderr.write('Running from numpy source directory.\n') else: try: from numpy.__config__ import show as show_config @@ -206,7 +215,7 @@ def pkgload(*packages, **options): from .compat import long # Make these accessible from numpy name-space - # but not imported in from numpy import * + # but not imported in from numpy import * if sys.version_info[0] >= 3: from builtins import bool, int, float, complex, object, str unicode = str @@ -222,8 +231,8 @@ def pkgload(*packages, **options): __all__.extend(lib.__all__) __all__.extend(['linalg', 'fft', 'random', 'ctypeslib', 'ma']) + # Filter annoying Cython warnings that serve no good purpose. - import warnings warnings.filterwarnings("ignore", message="numpy.dtype size changed") warnings.filterwarnings("ignore", message="numpy.ufunc size changed") warnings.filterwarnings("ignore", message="numpy.ndarray size changed") diff --git a/numpy/tests/test_reloading.py b/numpy/tests/test_reloading.py new file mode 100644 index 000000000000..141e11f6cfb0 --- /dev/null +++ b/numpy/tests/test_reloading.py @@ -0,0 +1,26 @@ +from __future__ import division, absolute_import, print_function + +import sys + +import numpy as np +from numpy.testing import assert_raises, assert_, run_module_suite + +if sys.version_info[:2] >= (3, 4): + from importlib import reload +else: + from imp import reload + +def test_reloading_exception(): + # gh-7844. Also check that relevant globals retain their identity. + _NoValue = np._NoValue + VisibleDeprecationWarning = np.VisibleDeprecationWarning + ModuleDeprecationWarning = np.ModuleDeprecationWarning + + assert_raises(RuntimeError, reload, np) + assert_(_NoValue is np._NoValue) + assert_(ModuleDeprecationWarning is np.ModuleDeprecationWarning) + assert_(VisibleDeprecationWarning is np.VisibleDeprecationWarning) + + +if __name__ == "__main__": + run_module_suite()