Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit bb042f6

Browse files
author
alvarosg
committed
Added parameter check, and other feedback from the PR
1 parent 509ed9f commit bb042f6

File tree

2 files changed

+104
-64
lines changed

2 files changed

+104
-64
lines changed

lib/matplotlib/cbook.py

Lines changed: 82 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2671,18 +2671,36 @@ class _FuncInfo(object):
26712671
* The direct function (direct)
26722672
* The inverse function (inverse)
26732673
* A boolean indicating whether the function
2674-
is bounded in the interval 0-1 (bounded_0_1)
2674+
is bounded in the interval 0-1 (bounded_0_1), or
2675+
a method that returns the information depending
2676+
on this
2677+
* A callable (check_params) that returns a bool specifying if a
2678+
certain combination of parameters is valid.
26752679
26762680
"""
2677-
def __init__(self, direct, inverse, bounded_0_1):
2681+
def __init__(self, direct, inverse, bounded_0_1=True, check_params=None):
26782682
self.direct = direct
26792683
self.inverse = inverse
2680-
self.bounded_0_1 = bounded_0_1
26812684

2682-
def copy(self):
2683-
return _FuncInfo(self.direct,
2684-
self.inverse,
2685-
self.bounded_0_1)
2685+
if (hasattr(bounded_0_1, '__call__')):
2686+
self._bounded_0_1 = bounded_0_1
2687+
else:
2688+
self._bounded_0_1 = lambda x: bounded_0_1
2689+
2690+
if check_params is None:
2691+
self._check_params = lambda x: True
2692+
elif (hasattr(check_params, '__call__')):
2693+
self._check_params = check_params
2694+
else:
2695+
raise ValueError("Check params must be a callable, returning "
2696+
"a boolean with the validity of the passed "
2697+
"parameters or None.")
2698+
2699+
def is_bounded_0_1(self, params=None):
2700+
return self._bounded_0_1(params)
2701+
2702+
def check_params(self, params=None):
2703+
return self._check_params(params)
26862704

26872705

26882706
class _StringFuncParser(object):
@@ -2697,40 +2715,49 @@ class _StringFuncParser(object):
26972715
_funcs['linear'] = _FuncInfo(lambda x: x,
26982716
lambda x: x,
26992717
True)
2700-
_funcs['quadratic'] = _FuncInfo(lambda x: x**2,
2701-
lambda x: x**(1. / 2),
2718+
_funcs['quadratic'] = _FuncInfo(np.square,
2719+
np.sqrt,
27022720
True)
27032721
_funcs['cubic'] = _FuncInfo(lambda x: x**3,
2704-
lambda x: x**(1. / 3),
2722+
np.cbrt,
27052723
True)
2706-
_funcs['sqrt'] = _FuncInfo(lambda x: x**(1. / 2),
2707-
lambda x: x**2,
2724+
_funcs['sqrt'] = _FuncInfo(np.sqrt,
2725+
np.square,
27082726
True)
2709-
_funcs['cbrt'] = _FuncInfo(lambda x: x**(1. / 3),
2727+
_funcs['cbrt'] = _FuncInfo(np.cbrt,
27102728
lambda x: x**3,
27112729
True)
2712-
_funcs['log10'] = _FuncInfo(lambda x: np.log10(x),
2730+
_funcs['log10'] = _FuncInfo(np.log10,
27132731
lambda x: (10**(x)),
27142732
False)
2715-
_funcs['log'] = _FuncInfo(lambda x: np.log(x),
2716-
lambda x: (np.exp(x)),
2733+
_funcs['log'] = _FuncInfo(np.log,
2734+
np.exp,
27172735
False)
2736+
_funcs['log2'] = _FuncInfo(np.log2,
2737+
lambda x: (2**x),
2738+
False)
27182739
_funcs['x**{p}'] = _FuncInfo(lambda x, p: x**p[0],
27192740
lambda x, p: x**(1. / p[0]),
27202741
True)
27212742
_funcs['root{p}(x)'] = _FuncInfo(lambda x, p: x**(1. / p[0]),
27222743
lambda x, p: x**p,
27232744
True)
2745+
_funcs['log{p}(x)'] = _FuncInfo(lambda x, p: (np.log(x) /
2746+
np.log(p[0])),
2747+
lambda x, p: p[0]**(x),
2748+
False,
2749+
lambda p: p[0] > 0)
27242750
_funcs['log10(x+{p})'] = _FuncInfo(lambda x, p: np.log10(x + p[0]),
27252751
lambda x, p: 10**x - p[0],
2726-
True)
2752+
lambda p: p[0] > 0)
27272753
_funcs['log(x+{p})'] = _FuncInfo(lambda x, p: np.log(x + p[0]),
27282754
lambda x, p: np.exp(x) - p[0],
2729-
True)
2755+
lambda p: p[0] > 0)
27302756
_funcs['log{p}(x+{p})'] = _FuncInfo(lambda x, p: (np.log(x + p[1]) /
27312757
np.log(p[0])),
27322758
lambda x, p: p[0]**(x) - p[1],
2733-
True)
2759+
lambda p: p[1] > 0,
2760+
lambda p: p[0] > 0)
27342761

27352762
def __init__(self, str_func):
27362763
"""
@@ -2749,82 +2776,88 @@ def __init__(self, str_func):
27492776
raise ValueError("The argument passed is not a string.")
27502777
self._str_func = str_func
27512778
self._key, self._params = self._get_key_params()
2752-
self._func = self.get_func()
2779+
self._func = self.func
27532780

2754-
def get_func(self):
2781+
@property
2782+
def func(self):
27552783
"""
27562784
Returns the _FuncInfo object, replacing the relevant parameters if
27572785
necessary in the lambda functions.
27582786
27592787
"""
27602788

2761-
func = self._funcs[self._key].copy()
2762-
if len(self._params) > 0:
2789+
func = self._funcs[self._key]
2790+
if self._params:
27632791
m = func.direct
2764-
func.direct = (lambda x, m=m: m(x, self._params))
2792+
direct = (lambda x, m=m: m(x, self._params))
2793+
27652794
m = func.inverse
2766-
func.inverse = (lambda x, m=m: m(x, self._params))
2795+
inverse = (lambda x, m=m: m(x, self._params))
2796+
2797+
is_bounded_0_1 = func.is_bounded_0_1(self._params)
2798+
2799+
func = _FuncInfo(direct, inverse,
2800+
is_bounded_0_1)
2801+
else:
2802+
func = _FuncInfo(func.direct, func.inverse,
2803+
func.is_bounded_0_1())
27672804
return func
27682805

2769-
def get_directfunc(self):
2806+
@property
2807+
def directfunc(self):
27702808
"""
27712809
Returns the callable for the direct function.
27722810
27732811
"""
27742812
return self._func.direct
27752813

2776-
def get_invfunc(self):
2814+
@property
2815+
def invfunc(self):
27772816
"""
27782817
Returns the callable for the inverse function.
27792818
27802819
"""
27812820
return self._func.inverse
27822821

2822+
@property
27832823
def is_bounded_0_1(self):
27842824
"""
27852825
Returns a boolean indicating if the function is bounded
27862826
in the [0-1 interval].
27872827
27882828
"""
2789-
return self._func.bounded_0_1
2829+
return self._func.is_bounded_0_1()
27902830

27912831
def _get_key_params(self):
27922832
str_func = six.text_type(self._str_func)
27932833
# Checking if it comes with parameters
27942834
regex = '\{(.*?)\}'
27952835
params = re.findall(regex, str_func)
27962836

2797-
if len(params) > 0:
2837+
if params:
27982838
for i in range(len(params)):
27992839
try:
28002840
params[i] = float(params[i])
28012841
except:
2802-
raise ValueError("'p' in parametric function strings must"
2842+
raise ValueError("Error with parameter number %i: '%s'. "
2843+
"'p' in parametric function strings must "
28032844
" be replaced by a number that is not "
2804-
"zero, e.g. 'log10(x+{0.1})'.")
2845+
"zero, e.g. 'log10(x+{0.1})'." %
2846+
(i, params[i]))
28052847

2806-
if params[i] == 0:
2807-
raise ValueError("'p' in parametric function strings must"
2808-
" be replaced by a number that is not "
2809-
"zero.")
28102848
str_func = re.sub(regex, '{p}', str_func)
28112849

28122850
try:
28132851
func = self._funcs[str_func]
2814-
except KeyError:
2815-
raise ValueError("%s: invalid function. The only strings "
2816-
"recognized as functions are %s." %
2817-
(str_func, self.funcs.keys()))
28182852
except:
2819-
raise ValueError("Invalid function. The only strings recognized "
2820-
"as functions are %s." %
2821-
(self.funcs.keys()))
2822-
if len(params) > 0:
2823-
func.direct(0.5, params)
2824-
try:
2825-
func.direct(0.5, params)
2826-
except:
2827-
raise ValueError("Invalid parameters set for '%s'." %
2828-
(str_func))
2853+
raise ValueError("%s: invalid string. The only strings "
2854+
"recognized as functions are %s." %
2855+
(str_func, self._funcs.keys()))
2856+
2857+
# Checking that the parameters are valid
2858+
if not func.check_params(params):
2859+
raise ValueError("%s: are invalid values for the parameters "
2860+
"in %s." %
2861+
(params, str_func))
28292862

28302863
return str_func, params

lib/matplotlib/tests/test_cbook.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -522,52 +522,59 @@ def test_flatiter():
522522
class TestFuncParser(object):
523523
x_test = np.linspace(0.01, 0.5, 3)
524524
validstrings = ['linear', 'quadratic', 'cubic', 'sqrt', 'cbrt',
525-
'log', 'log10', 'x**{1.5}', 'root{2.5}(x)',
526-
'log(x+{0.5})', 'log10(x+{0.1})', 'log{2}(x+{0.1})']
525+
'log', 'log10', 'log2', 'x**{1.5}', 'root{2.5}(x)',
526+
'log{2}(x)',
527+
'log(x+{0.5})', 'log10(x+{0.1})', 'log{2}(x+{0.1})',
528+
'log{2}(x+{0})']
527529
results = [(lambda x: x),
528-
(lambda x: x**2),
530+
np.square,
529531
(lambda x: x**3),
530-
(lambda x: x**(1. / 2)),
531-
(lambda x: x**(1. / 3)),
532-
(lambda x: np.log(x)),
533-
(lambda x: np.log10(x)),
532+
np.sqrt,
533+
np.cbrt,
534+
np.log,
535+
np.log10,
536+
np.log2,
534537
(lambda x: x**1.5),
535538
(lambda x: x**(1 / 2.5)),
539+
(lambda x: np.log2(x)),
536540
(lambda x: np.log(x + 0.5)),
537541
(lambda x: np.log10(x + 0.1)),
538-
(lambda x: np.log2(x + 0.1))]
542+
(lambda x: np.log2(x + 0.1)),
543+
(lambda x: np.log2(x))]
539544

540545
bounded_list = [True, True, True, True, True,
541-
False, False, True, True,
542-
True, True, True]
546+
False, False, False, True, True,
547+
False,
548+
True, True, True,
549+
False]
543550

544551
@pytest.mark.parametrize("string, func",
545552
zip(validstrings, results),
546553
ids=validstrings)
547554
def test_values(self, string, func):
548555
func_parser = cbook._StringFuncParser(string)
549-
f = func_parser.get_directfunc()
556+
f = func_parser.directfunc
550557
assert_array_almost_equal(f(self.x_test), func(self.x_test))
551558

552559
@pytest.mark.parametrize("string", validstrings, ids=validstrings)
553560
def test_inverse(self, string):
554561
func_parser = cbook._StringFuncParser(string)
555-
f = func_parser.get_func()
562+
f = func_parser.func
556563
fdir = f.direct
557564
finv = f.inverse
558565
assert_array_almost_equal(finv(fdir(self.x_test)), self.x_test)
559566

560567
@pytest.mark.parametrize("string", validstrings, ids=validstrings)
561568
def test_get_invfunc(self, string):
562569
func_parser = cbook._StringFuncParser(string)
563-
finv1 = func_parser.get_invfunc()
564-
finv2 = func_parser.get_func().inverse
570+
finv1 = func_parser.invfunc
571+
finv2 = func_parser.func.inverse
565572
assert_array_almost_equal(finv1(self.x_test), finv2(self.x_test))
566573

567574
@pytest.mark.parametrize("string, bounded",
568575
zip(validstrings, bounded_list),
569576
ids=validstrings)
570577
def test_bounded(self, string, bounded):
571578
func_parser = cbook._StringFuncParser(string)
572-
b = func_parser.is_bounded_0_1()
579+
b = func_parser.is_bounded_0_1
573580
assert_array_equal(b, bounded)

0 commit comments

Comments
 (0)