1
+ import pytest
2
+ import numpy as np
3
+
1
4
from matplotlib .sankey import Sankey
5
+ from matplotlib .testing .decorators import check_figures_equal
2
6
3
7
4
8
def test_sankey ():
@@ -22,3 +26,80 @@ def show_three_decimal_places(value):
22
26
format = show_three_decimal_places )
23
27
24
28
assert s .diagrams [0 ].texts [0 ].get_text () == 'First\n 0.250'
29
+
30
+
31
+ @pytest .mark .parametrize ('kwargs, msg' , (
32
+ ({'gap' : - 1 }, "'gap' is negative" ),
33
+ ({'gap' : 1 , 'radius' : 2 }, "'radius' is greater than 'gap'" ),
34
+ ({'head_angle' : - 1 }, "'head_angle' is negative" ),
35
+ ({'tolerance' : - 1 }, "'tolerance' is negative" ),
36
+ ({'flows' : [1 , - 1 ], 'orientations' : [- 1 , 0 , 1 ]},
37
+ r"The shapes of 'flows' \(2,\) and 'orientations'" ),
38
+ ({'flows' : [1 , - 1 ], 'labels' : ['a' , 'b' , 'c' ]},
39
+ r"The shapes of 'flows' \(2,\) and 'labels'" ),
40
+ ))
41
+ def test_sankey_errors (kwargs , msg ):
42
+ with pytest .raises (ValueError , match = msg ):
43
+ Sankey (** kwargs )
44
+
45
+
46
+ @pytest .mark .parametrize ('kwargs, msg' , (
47
+ ({'trunklength' : - 1 }, "'trunklength' is negative" ),
48
+ ({'flows' : [0.2 , 0.3 ], 'prior' : 0 }, "The scaled sum of the connected" ),
49
+ ({'prior' : - 1 }, "The index of the prior diagram is negative" ),
50
+ ({'prior' : 1 }, "The index of the prior diagram is 1" ),
51
+ ({'connect' : (- 1 , 1 ), 'prior' : 0 }, "At least one of the connection" ),
52
+ ({'connect' : (2 , 1 ), 'prior' : 0 }, "The connection index to the source" ),
53
+ ({'connect' : (1 , 3 ), 'prior' : 0 }, "The connection index to this dia" ),
54
+ ({'connect' : (1 , 1 ), 'prior' : 0 , 'flows' : [- 0.2 , 0.2 ],
55
+ 'orientations' : [2 ]}, "The value of orientations" ),
56
+ ({'connect' : (1 , 1 ), 'prior' : 0 , 'flows' : [- 0.2 , 0.2 ],
57
+ 'pathlengths' : [2 ]}, "The lengths of 'flows'" ),
58
+ ))
59
+ def test_sankey_add_errors (kwargs , msg ):
60
+ sankey = Sankey ()
61
+ with pytest .raises (ValueError , match = msg ):
62
+ sankey .add (flows = [0.2 , - 0.2 ])
63
+ sankey .add (** kwargs )
64
+
65
+
66
+ def test_sankey2 ():
67
+ s = Sankey (flows = [0.25 , - 0.25 , 0.5 , - 0.5 ], labels = ['Foo' ],
68
+ orientations = [- 1 ], unit = 'Bar' )
69
+ sf = s .finish ()
70
+ assert np .all (np .equal (np .array ((0.25 , - 0.25 , 0.5 , - 0.5 )), sf [0 ].flows ))
71
+ assert sf [0 ].angles == [1 , 3 , 1 , 3 ]
72
+ assert all ([text .get_text ()[0 :3 ] == 'Foo' for text in sf [0 ].texts ])
73
+ assert all ([text .get_text ()[- 3 :] == 'Bar' for text in sf [0 ].texts ])
74
+ assert sf [0 ].text .get_text () == ''
75
+ assert np .allclose (np .array (((- 1.375 , - 0.52011255 ),
76
+ (1.375 , - 0.75506044 ),
77
+ (- 0.75 , - 0.41522509 ),
78
+ (0.75 , - 0.8599479 ))),
79
+ sf [0 ].tips )
80
+
81
+ s = Sankey (flows = [0.25 , - 0.25 , 0 , 0.5 , - 0.5 ], labels = ['Foo' ],
82
+ orientations = [- 1 ], unit = 'Bar' )
83
+ sf = s .finish ()
84
+ assert np .all (np .equal (np .array ((0.25 , - 0.25 , 0 , 0.5 , - 0.5 )), sf [0 ].flows ))
85
+ assert sf [0 ].angles == [1 , 3 , None , 1 , 3 ]
86
+ assert np .allclose (np .array (((- 1.375 , - 0.52011255 ),
87
+ (1.375 , - 0.75506044 ),
88
+ (0 , 0 ),
89
+ (- 0.75 , - 0.41522509 ),
90
+ (0.75 , - 0.8599479 ))),
91
+ sf [0 ].tips )
92
+
93
+
94
+ @check_figures_equal (extensions = ['png' ])
95
+ def test_sankey3 (fig_test , fig_ref ):
96
+ ax_test = fig_test .gca ()
97
+ s_test = Sankey (ax = ax_test , flows = [0.25 , - 0.25 , - 0.25 , 0.25 , 0.5 , - 0.5 ],
98
+ orientations = [1 , - 1 , 1 , - 1 , 0 , 0 ])
99
+ s_test .finish ()
100
+
101
+ ax_ref = fig_ref .gca ()
102
+ s_ref = Sankey (ax = ax_ref )
103
+ s_ref .add (flows = [0.25 , - 0.25 , - 0.25 , 0.25 , 0.5 , - 0.5 ],
104
+ orientations = [1 , - 1 , 1 , - 1 , 0 , 0 ])
105
+ s_ref .finish ()
0 commit comments