Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
011c8d2
TST Adds failing test
thomasjpfan Nov 2, 2020
59b4a8e
ENH Makes copies of global configuration per thread
thomasjpfan Nov 2, 2020
f6a1357
TST Skip loky tests for joblib version < 0.12
thomasjpfan Nov 2, 2020
bddc33f
Merge remote-tracking branch 'upstream/master' into thread_safe_config
thomasjpfan Nov 13, 2020
e7de9a0
Merge remote-tracking branch 'upstream/master' into thread_safe_config
thomasjpfan Nov 18, 2020
02e738f
Merge remote-tracking branch 'upstream/master' into thread_safe_config
thomasjpfan Nov 25, 2020
99a4ce1
ENH Small refactor
thomasjpfan Nov 26, 2020
966498d
DOC Update docstring of test
thomasjpfan Nov 26, 2020
df291d9
DOC Update names
thomasjpfan Nov 26, 2020
bba6465
Merge remote-tracking branch 'upstream/main' into thread_safe_config
thomasjpfan Apr 9, 2021
ba93b54
DOC Updates changelog and docstring about threadsafeness
thomasjpfan Apr 9, 2021
ea516c4
DOC Update test's docstring
thomasjpfan Apr 9, 2021
6fe0619
ENH Cleaner code
thomasjpfan Apr 9, 2021
91d18c1
DOC Better docstring
thomasjpfan Apr 9, 2021
8e733cb
Merge remote-tracking branch 'upstream/main' into thread_safe_config
thomasjpfan Apr 9, 2021
276c720
WIP Fixes merge error
thomasjpfan Apr 10, 2021
1b28965
CLN Less lines of code
thomasjpfan Apr 10, 2021
319f007
CLN Better copy logic
thomasjpfan Apr 10, 2021
450a456
CLN Do not need copy in config_context
thomasjpfan Apr 10, 2021
957b3c0
DOC Adds comment about mutable
thomasjpfan Apr 27, 2021
4b89c6f
Merge remote-tracking branch 'upstream/main' into thread_safe_config
thomasjpfan Apr 27, 2021
8ef81f7
DOC Adds comment for copy
thomasjpfan Apr 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats_new/v1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ Changelog
- For :class:`tree.ExtraTreeRegressor`, `criterion="mse"` is deprecated,
use `"squared_error"` instead which is now the default.

:mod:`sklearn.base`
...................

- |Fix| :func:`config_context` is now threadsafe. :pr:`18736` by `Thomas Fan`_.

:mod:`sklearn.calibration`
..........................

Expand Down
29 changes: 21 additions & 8 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,23 @@
"""
import os
from contextlib import contextmanager as contextmanager
import threading

_global_config = {
'assume_finite': bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)),
'working_memory': int(os.environ.get('SKLEARN_WORKING_MEMORY', 1024)),
'print_changed_only': True,
'display': 'text',
}
_threadlocal = threading.local()


def _get_threadlocal_config():
"""Get a threadlocal **mutable** configuration. If the configuration
does not exist, copy the default global configuration."""
if not hasattr(_threadlocal, 'global_config'):
_threadlocal.global_config = _global_config.copy()
return _threadlocal.global_config
Comment on lines +19 to +21
Copy link
Member

@jeremiedbb jeremiedbb Apr 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a nitpick. Why not put that directly in get_config ?
also I think the copy is unecessary here since there will be a copy afterwards anyway

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought of _get_threadlocal_config for returning a threadlocal mutable version of the config, while get_config returns a immutable copy of the config. This is a safe guard from someone doing the following:

config = get_config()

# Does not change the configuration
config["assume_finite"] = True

which is consistent with the current behavior on main.

I added comments to describe this reasoning into the PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But do we intend to use _get_threadlocal_config outside of get_config ? If not why don't we do

def get_config():
    if not hasattr(_threadlocal, 'global_config'):
        _threadlocal.global_config = _global_config
    return  _threadlocal.global_config.copy()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_threadlocal_config is used in set_config to update the threadlocal config.

In any case, for set_config to not mutable _global_config, _threadlocal.global_config needs to be initialize with a copy of _global_config.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_threadlocal_config is used in set_config

Yep missed that. Sorry. Forget my comments :)



def get_config():
Expand All @@ -24,7 +34,9 @@ def get_config():
config_context : Context manager for global scikit-learn configuration.
set_config : Set global scikit-learn configuration.
"""
return _global_config.copy()
# Return a copy of the threadlocal configuration so that users will
# not be able to modify the configuration with the returned dict.
return _get_threadlocal_config().copy()


def set_config(assume_finite=None, working_memory=None,
Expand Down Expand Up @@ -72,14 +84,16 @@ def set_config(assume_finite=None, working_memory=None,
config_context : Context manager for global scikit-learn configuration.
get_config : Retrieve current values of the global configuration.
"""
local_config = _get_threadlocal_config()

if assume_finite is not None:
_global_config['assume_finite'] = assume_finite
local_config['assume_finite'] = assume_finite
if working_memory is not None:
_global_config['working_memory'] = working_memory
local_config['working_memory'] = working_memory
if print_changed_only is not None:
_global_config['print_changed_only'] = print_changed_only
local_config['print_changed_only'] = print_changed_only
if display is not None:
_global_config['display'] = display
local_config['display'] = display


@contextmanager
Expand Down Expand Up @@ -120,8 +134,7 @@ def config_context(**new_config):
Notes
-----
All settings, not just those presently modified, will be returned to
their previous values when the context manager is exited. This is not
thread-safe.
their previous values when the context manager is exited.

Examples
--------
Expand All @@ -141,7 +154,7 @@ def config_context(**new_config):
set_config : Set global scikit-learn configuration.
get_config : Retrieve current values of the global configuration.
"""
old_config = get_config().copy()
old_config = get_config()
set_config(**new_config)

try:
Expand Down
55 changes: 55 additions & 0 deletions sklearn/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import time
from concurrent.futures import ThreadPoolExecutor

from joblib import Parallel
import joblib
import pytest

from sklearn import get_config, set_config, config_context
from sklearn.utils.fixes import delayed
from sklearn.utils.fixes import parse_version


def test_config_context():
Expand Down Expand Up @@ -76,3 +84,50 @@ def test_set_config():
# No unknown arguments
with pytest.raises(TypeError):
set_config(do_something_else=True)


def set_assume_finite(assume_finite, sleep_duration):
"""Return the value of assume_finite after waiting `sleep_duration`."""
with config_context(assume_finite=assume_finite):
time.sleep(sleep_duration)
return get_config()['assume_finite']


@pytest.mark.parametrize("backend",
["loky", "multiprocessing", "threading"])
def test_config_threadsafe_joblib(backend):
"""Test that the global config is threadsafe with all joblib backends.
Two jobs are spawned and sets assume_finite to two different values.
When the job with a duration 0.1s completes, the assume_finite value
should be the same as the value passed to the function. In other words,
it is not influenced by the other job setting assume_finite to True.
"""

if (parse_version(joblib.__version__) < parse_version('0.12')
and backend == 'loky'):
pytest.skip('loky backend does not exist in joblib <0.12') # noqa

assume_finites = [False, True]
sleep_durations = [0.1, 0.2]

items = Parallel(backend=backend, n_jobs=2)(
delayed(set_assume_finite)(assume_finite, sleep_dur)
for assume_finite, sleep_dur
in zip(assume_finites, sleep_durations))

assert items == [False, True]


def test_config_threadsafe():
"""Uses threads directly to test that the global config does not change
between threads. Same test as `test_config_threadsafe_joblib` but with
`ThreadPoolExecutor`."""

assume_finites = [False, True]
sleep_durations = [0.1, 0.2]

with ThreadPoolExecutor(max_workers=2) as e:
items = [output for output in
e.map(set_assume_finite, assume_finites, sleep_durations)]

assert items == [False, True]