diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index 32ddabbb5a73..be5d3151be8a 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -2661,3 +2661,259 @@ def __exit__(self, exc_type, exc_value, traceback): os.rmdir(path) except OSError: pass + + +class _FuncInfo(object): + """ + Class used to store a function. + + """ + + def __init__(self, function, inverse, bounded_0_1=True, check_params=None): + """ + Parameters + ---------- + + function : callable + A callable implementing the function receiving the variable as + first argument and any additional parameters in a list as second + argument. + inverse : callable + A callable implementing the inverse function receiving the variable + as first argument and any additional parameters in a list as + second argument. It must satisfy 'inverse(function(x, p), p) == x'. + bounded_0_1: bool or callable + A boolean indicating whether the function is bounded in the [0,1] + interval, or a callable taking a list of values for the additional + parameters, and returning a boolean indicating whether the function + is bounded in the [0,1] interval for that combination of + parameters. Default True. + check_params: callable or None + A callable taking a list of values for the additional parameters + and returning a boolean indicating whether that combination of + parameters is valid. It is only required if the function has + additional parameters and some of them are restricted. + Default None. + + """ + + self.function = function + self.inverse = inverse + + if callable(bounded_0_1): + self._bounded_0_1 = bounded_0_1 + else: + self._bounded_0_1 = lambda x: bounded_0_1 + + if check_params is None: + self._check_params = lambda x: True + elif callable(check_params): + self._check_params = check_params + else: + raise ValueError("Invalid 'check_params' argument.") + + def is_bounded_0_1(self, params=None): + """ + Returns a boolean indicating if the function is bounded in the [0,1] + interval for a particular set of additional parameters. + + Parameters + ---------- + + params : list + The list of additional parameters. Default None. + + Returns + ------- + + out : bool + True if the function is bounded in the [0,1] interval for + parameters 'params'. Otherwise False. + + """ + + return self._bounded_0_1(params) + + def check_params(self, params=None): + """ + Returns a boolean indicating if the set of additional parameters is + valid. + + Parameters + ---------- + + params : list + The list of additional parameters. Default None. + + Returns + ------- + + out : bool + True if 'params' is a valid set of additional parameters for the + function. Otherwise False. + + """ + + return self._check_params(params) + + +class _StringFuncParser(object): + """ + A class used to convert predefined strings into + _FuncInfo objects, or to directly obtain _FuncInfo + properties. + + """ + + _funcs = {} + _funcs['linear'] = _FuncInfo(lambda x: x, + lambda x: x, + True) + _funcs['quadratic'] = _FuncInfo(np.square, + np.sqrt, + True) + _funcs['cubic'] = _FuncInfo(lambda x: x**3, + lambda x: x**(1. / 3), + True) + _funcs['sqrt'] = _FuncInfo(np.sqrt, + np.square, + True) + _funcs['cbrt'] = _FuncInfo(lambda x: x**(1. / 3), + lambda x: x**3, + True) + _funcs['log10'] = _FuncInfo(np.log10, + lambda x: (10**(x)), + False) + _funcs['log'] = _FuncInfo(np.log, + np.exp, + False) + _funcs['log2'] = _FuncInfo(np.log2, + lambda x: (2**x), + False) + _funcs['x**{p}'] = _FuncInfo(lambda x, p: x**p[0], + lambda x, p: x**(1. / p[0]), + True) + _funcs['root{p}(x)'] = _FuncInfo(lambda x, p: x**(1. / p[0]), + lambda x, p: x**p, + True) + _funcs['log{p}(x)'] = _FuncInfo(lambda x, p: (np.log(x) / + np.log(p[0])), + lambda x, p: p[0]**(x), + False, + lambda p: p[0] > 0) + _funcs['log10(x+{p})'] = _FuncInfo(lambda x, p: np.log10(x + p[0]), + lambda x, p: 10**x - p[0], + lambda p: p[0] > 0) + _funcs['log(x+{p})'] = _FuncInfo(lambda x, p: np.log(x + p[0]), + lambda x, p: np.exp(x) - p[0], + lambda p: p[0] > 0) + _funcs['log{p}(x+{p})'] = _FuncInfo(lambda x, p: (np.log(x + p[1]) / + np.log(p[0])), + lambda x, p: p[0]**(x) - p[1], + lambda p: p[1] > 0, + lambda p: p[0] > 0) + + def __init__(self, str_func): + """ + Parameters + ---------- + str_func : string + String to be parsed. + + """ + + if not isinstance(str_func, six.string_types): + raise ValueError("'%s' must be a string." % str_func) + self._str_func = six.text_type(str_func) + self._key, self._params = self._get_key_params() + self._func = self._parse_func() + + def _parse_func(self): + """ + Parses the parameters to build a new _FuncInfo object, + replacing the relevant parameters if necessary in the lambda + functions. + + """ + + func = self._funcs[self._key] + + if not self._params: + func = _FuncInfo(func.function, func.inverse, + func.is_bounded_0_1()) + else: + m = func.function + function = (lambda x, m=m: m(x, self._params)) + + m = func.inverse + inverse = (lambda x, m=m: m(x, self._params)) + + is_bounded_0_1 = func.is_bounded_0_1(self._params) + + func = _FuncInfo(function, inverse, + is_bounded_0_1) + return func + + @property + def func_info(self): + """ + Returns the _FuncInfo object. + + """ + return self._func + + @property + def function(self): + """ + Returns the callable for the direct function. + + """ + return self._func.function + + @property + def inverse(self): + """ + Returns the callable for the inverse function. + + """ + return self._func.inverse + + @property + def is_bounded_0_1(self): + """ + Returns a boolean indicating if the function is bounded + in the [0-1 interval]. + + """ + return self._func.is_bounded_0_1() + + def _get_key_params(self): + str_func = self._str_func + # Checking if it comes with parameters + regex = '\{(.*?)\}' + params = re.findall(regex, str_func) + + for i, param in enumerate(params): + try: + params[i] = float(param) + except ValueError: + raise ValueError("Parameter %i is '%s', which is " + "not a number." % + (i, param)) + + str_func = re.sub(regex, '{p}', str_func) + + try: + func = self._funcs[str_func] + except (ValueError, KeyError): + raise ValueError("'%s' is an invalid string. The only strings " + "recognized as functions are %s." % + (str_func, list(self._funcs))) + + # Checking that the parameters are valid + if not func.check_params(params): + raise ValueError("%s are invalid values for the parameters " + "in %s." % + (params, str_func)) + + return str_func, params diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index 65136c77bbc6..bc69e7c1d3a9 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -517,3 +517,64 @@ def test_flatiter(): assert 0 == next(it) assert 1 == next(it) + + +class TestFuncParser(object): + x_test = np.linspace(0.01, 0.5, 3) + validstrings = ['linear', 'quadratic', 'cubic', 'sqrt', 'cbrt', + 'log', 'log10', 'log2', 'x**{1.5}', 'root{2.5}(x)', + 'log{2}(x)', + 'log(x+{0.5})', 'log10(x+{0.1})', 'log{2}(x+{0.1})', + 'log{2}(x+{0})'] + results = [(lambda x: x), + np.square, + (lambda x: x**3), + np.sqrt, + (lambda x: x**(1. / 3)), + np.log, + np.log10, + np.log2, + (lambda x: x**1.5), + (lambda x: x**(1 / 2.5)), + (lambda x: np.log2(x)), + (lambda x: np.log(x + 0.5)), + (lambda x: np.log10(x + 0.1)), + (lambda x: np.log2(x + 0.1)), + (lambda x: np.log2(x))] + + bounded_list = [True, True, True, True, True, + False, False, False, True, True, + False, + True, True, True, + False] + + @pytest.mark.parametrize("string, func", + zip(validstrings, results), + ids=validstrings) + def test_values(self, string, func): + func_parser = cbook._StringFuncParser(string) + f = func_parser.function + assert_array_almost_equal(f(self.x_test), func(self.x_test)) + + @pytest.mark.parametrize("string", validstrings, ids=validstrings) + def test_inverse(self, string): + func_parser = cbook._StringFuncParser(string) + f = func_parser.func_info + fdir = f.function + finv = f.inverse + assert_array_almost_equal(finv(fdir(self.x_test)), self.x_test) + + @pytest.mark.parametrize("string", validstrings, ids=validstrings) + def test_get_inverse(self, string): + func_parser = cbook._StringFuncParser(string) + finv1 = func_parser.inverse + finv2 = func_parser.func_info.inverse + assert_array_almost_equal(finv1(self.x_test), finv2(self.x_test)) + + @pytest.mark.parametrize("string, bounded", + zip(validstrings, bounded_list), + ids=validstrings) + def test_bounded(self, string, bounded): + func_parser = cbook._StringFuncParser(string) + b = func_parser.is_bounded_0_1 + assert_array_equal(b, bounded)