diff --git a/lib/matplotlib/style/core.py b/lib/matplotlib/style/core.py index 8bd0d13b44ed..a67bc30c0859 100644 --- a/lib/matplotlib/style/core.py +++ b/lib/matplotlib/style/core.py @@ -39,43 +39,71 @@ def is_style_file(filename): return STYLE_FILE_PATTERN.match(filename) is not None -def use(name): - """Use matplotlib style settings from a known style sheet or from a file. +def use(style): + """Use matplotlib style settings from a style specification. Parameters ---------- - name : str or list of str - Name of style or path/URL to a style file. For a list of available - style names, see `style.available`. If given a list, each style is - applied from first to last in the list. + style : str, dict, or list + A style specification. Valid options are: + + +------+-------------------------------------------------------------+ + | str | The name of a style or a path/URL to a style file. For a | + | | list of available style names, see `style.available`. | + +------+-------------------------------------------------------------+ + | dict | Dictionary with valid key/value pairs for | + | | `matplotlib.rcParams`. | + +------+-------------------------------------------------------------+ + | list | A list of style specifiers (str or dict) applied from first | + | | to last in the list. | + +------+-------------------------------------------------------------+ + + """ - if cbook.is_string_like(name): - name = [name] + if cbook.is_string_like(style) or hasattr(style, 'keys'): + # If name is a single str or dict, make it a single element list. + styles = [style] + else: + styles = style + + for style in styles: + if not cbook.is_string_like(style): + mpl.rcParams.update(style) + continue - for style in name: if style in library: mpl.rcParams.update(library[style]) else: try: rc = rc_params_from_file(style, use_default_template=False) mpl.rcParams.update(rc) - except: + except IOError: msg = ("'%s' not found in the style library and input is " "not a valid URL or path. See `style.available` for " "list of available styles.") - raise ValueError(msg % style) + raise IOError(msg % style) @contextlib.contextmanager -def context(name, after_reset=False): +def context(style, after_reset=False): """Context manager for using style settings temporarily. Parameters ---------- - name : str or list of str - Name of style or path/URL to a style file. For a list of available - style names, see `style.available`. If given a list, each style is - applied from first to last in the list. + style : str, dict, or list + A style specification. Valid options are: + + +------+-------------------------------------------------------------+ + | str | The name of a style or a path/URL to a style file. For a | + | | list of available style names, see `style.available`. | + +------+-------------------------------------------------------------+ + | dict | Dictionary with valid key/value pairs for | + | | `matplotlib.rcParams`. | + +------+-------------------------------------------------------------+ + | list | A list of style specifiers (str or dict) applied from first | + | | to last in the list. | + +------+-------------------------------------------------------------+ + after_reset : bool If True, apply style after resetting settings to their defaults; otherwise, apply style on top of the current settings. @@ -83,9 +111,16 @@ def context(name, after_reset=False): initial_settings = mpl.rcParams.copy() if after_reset: mpl.rcdefaults() - use(name) - yield - mpl.rcParams.update(initial_settings) + try: + use(style) + except: + # Restore original settings before raising errors during the update. + mpl.rcParams.update(initial_settings) + raise + else: + yield + finally: + mpl.rcParams.update(initial_settings) def load_base_library(): diff --git a/lib/matplotlib/tests/test_style.py b/lib/matplotlib/tests/test_style.py index d8e71f5cc4d2..005064dd2421 100644 --- a/lib/matplotlib/tests/test_style.py +++ b/lib/matplotlib/tests/test_style.py @@ -2,17 +2,20 @@ unicode_literals) import os +import sys import shutil import tempfile from contextlib import contextmanager +from nose import SkipTest +from nose.tools import assert_raises + import matplotlib as mpl from matplotlib import style from matplotlib.style.core import USER_LIBRARY_PATHS, STYLE_EXTENSION import six - PARAM = 'image.cmap' VALUE = 'pink' DUMMY_SETTINGS = {PARAM: VALUE} @@ -68,6 +71,70 @@ def test_context(): assert mpl.rcParams[PARAM] == 'gray' +def test_context_with_dict(): + original_value = 'gray' + other_value = 'blue' + mpl.rcParams[PARAM] = original_value + with style.context({PARAM: other_value}): + assert mpl.rcParams[PARAM] == other_value + assert mpl.rcParams[PARAM] == original_value + + +def test_context_with_dict_after_namedstyle(): + # Test dict after style name where dict modifies the same parameter. + original_value = 'gray' + other_value = 'blue' + mpl.rcParams[PARAM] = original_value + with temp_style('test', DUMMY_SETTINGS): + with style.context(['test', {PARAM: other_value}]): + assert mpl.rcParams[PARAM] == other_value + assert mpl.rcParams[PARAM] == original_value + + +def test_context_with_dict_before_namedstyle(): + # Test dict before style name where dict modifies the same parameter. + original_value = 'gray' + other_value = 'blue' + mpl.rcParams[PARAM] = original_value + with temp_style('test', DUMMY_SETTINGS): + with style.context([{PARAM: other_value}, 'test']): + assert mpl.rcParams[PARAM] == VALUE + assert mpl.rcParams[PARAM] == original_value + + +def test_context_with_union_of_dict_and_namedstyle(): + # Test dict after style name where dict modifies the a different parameter. + original_value = 'gray' + other_param = 'text.usetex' + other_value = True + d = {other_param: other_value} + mpl.rcParams[PARAM] = original_value + mpl.rcParams[other_param] = (not other_value) + with temp_style('test', DUMMY_SETTINGS): + with style.context(['test', d]): + assert mpl.rcParams[PARAM] == VALUE + assert mpl.rcParams[other_param] == other_value + assert mpl.rcParams[PARAM] == original_value + assert mpl.rcParams[other_param] == (not other_value) + + +def test_context_with_badparam(): + if sys.version_info[:2] >= (2, 7): + from collections import OrderedDict + else: + m = "Test can only be run in Python >= 2.7 as it requires OrderedDict" + raise SkipTest(m) + + original_value = 'gray' + other_value = 'blue' + d = OrderedDict([(PARAM, original_value), ('badparam', None)]) + with style.context({PARAM: other_value}): + assert mpl.rcParams[PARAM] == other_value + x = style.context([d]) + assert_raises(KeyError, x.__enter__) + assert mpl.rcParams[PARAM] == other_value + + if __name__ == '__main__': from numpy import testing testing.run_module_suite()