diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 0cd1d6a89d158..6e3c063a45dcb 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -76,6 +76,11 @@ Changelog - For :class:`tree.ExtraTreeRegressor`, `criterion="mse"` is deprecated, use `"squared_error"` instead which is now the default. +:mod:`sklearn.base` +................... + +- |Fix| :func:`config_context` is now threadsafe. :pr:`18736` by `Thomas Fan`_. + :mod:`sklearn.calibration` .......................... diff --git a/sklearn/_config.py b/sklearn/_config.py index feb5e86287c38..e81d50849db05 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -2,6 +2,7 @@ """ import os from contextlib import contextmanager as contextmanager +import threading _global_config = { 'assume_finite': bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)), @@ -9,6 +10,15 @@ 'print_changed_only': True, 'display': 'text', } +_threadlocal = threading.local() + + +def _get_threadlocal_config(): + """Get a threadlocal **mutable** configuration. If the configuration + does not exist, copy the default global configuration.""" + if not hasattr(_threadlocal, 'global_config'): + _threadlocal.global_config = _global_config.copy() + return _threadlocal.global_config def get_config(): @@ -24,7 +34,9 @@ def get_config(): config_context : Context manager for global scikit-learn configuration. set_config : Set global scikit-learn configuration. """ - return _global_config.copy() + # Return a copy of the threadlocal configuration so that users will + # not be able to modify the configuration with the returned dict. + return _get_threadlocal_config().copy() def set_config(assume_finite=None, working_memory=None, @@ -72,14 +84,16 @@ def set_config(assume_finite=None, working_memory=None, config_context : Context manager for global scikit-learn configuration. get_config : Retrieve current values of the global configuration. """ + local_config = _get_threadlocal_config() + if assume_finite is not None: - _global_config['assume_finite'] = assume_finite + local_config['assume_finite'] = assume_finite if working_memory is not None: - _global_config['working_memory'] = working_memory + local_config['working_memory'] = working_memory if print_changed_only is not None: - _global_config['print_changed_only'] = print_changed_only + local_config['print_changed_only'] = print_changed_only if display is not None: - _global_config['display'] = display + local_config['display'] = display @contextmanager @@ -120,8 +134,7 @@ def config_context(**new_config): Notes ----- All settings, not just those presently modified, will be returned to - their previous values when the context manager is exited. This is not - thread-safe. + their previous values when the context manager is exited. Examples -------- @@ -141,7 +154,7 @@ def config_context(**new_config): set_config : Set global scikit-learn configuration. get_config : Retrieve current values of the global configuration. """ - old_config = get_config().copy() + old_config = get_config() set_config(**new_config) try: diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index 22ec862ef24a3..6d458088a37a8 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -1,5 +1,13 @@ +import time +from concurrent.futures import ThreadPoolExecutor + +from joblib import Parallel +import joblib import pytest + from sklearn import get_config, set_config, config_context +from sklearn.utils.fixes import delayed +from sklearn.utils.fixes import parse_version def test_config_context(): @@ -76,3 +84,50 @@ def test_set_config(): # No unknown arguments with pytest.raises(TypeError): set_config(do_something_else=True) + + +def set_assume_finite(assume_finite, sleep_duration): + """Return the value of assume_finite after waiting `sleep_duration`.""" + with config_context(assume_finite=assume_finite): + time.sleep(sleep_duration) + return get_config()['assume_finite'] + + +@pytest.mark.parametrize("backend", + ["loky", "multiprocessing", "threading"]) +def test_config_threadsafe_joblib(backend): + """Test that the global config is threadsafe with all joblib backends. + Two jobs are spawned and sets assume_finite to two different values. + When the job with a duration 0.1s completes, the assume_finite value + should be the same as the value passed to the function. In other words, + it is not influenced by the other job setting assume_finite to True. + """ + + if (parse_version(joblib.__version__) < parse_version('0.12') + and backend == 'loky'): + pytest.skip('loky backend does not exist in joblib <0.12') # noqa + + assume_finites = [False, True] + sleep_durations = [0.1, 0.2] + + items = Parallel(backend=backend, n_jobs=2)( + delayed(set_assume_finite)(assume_finite, sleep_dur) + for assume_finite, sleep_dur + in zip(assume_finites, sleep_durations)) + + assert items == [False, True] + + +def test_config_threadsafe(): + """Uses threads directly to test that the global config does not change + between threads. Same test as `test_config_threadsafe_joblib` but with + `ThreadPoolExecutor`.""" + + assume_finites = [False, True] + sleep_durations = [0.1, 0.2] + + with ThreadPoolExecutor(max_workers=2) as e: + items = [output for output in + e.map(set_assume_finite, assume_finites, sleep_durations)] + + assert items == [False, True]