Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit a56afb3

Browse files
authored
Merge pull request #23533 from oscargus/sankeytest
Add tests for sankey and minor fixes
2 parents 60b4104 + 0d8b319 commit a56afb3

File tree

2 files changed

+93
-16
lines changed

2 files changed

+93
-16
lines changed

lib/matplotlib/sankey.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,12 @@ def _arc(self, quadrant=0, cw=True, radius=1, center=(0, 0)):
204204
# Insignificant
205205
# [6.12303177e-17, 1.00000000e+00]])
206206
[0.00000000e+00, 1.00000000e+00]])
207-
if quadrant == 0 or quadrant == 2:
207+
if quadrant in (0, 2):
208208
if cw:
209209
vertices = ARC_VERTICES
210210
else:
211211
vertices = ARC_VERTICES[:, ::-1] # Swap x and y.
212-
elif quadrant == 1 or quadrant == 3:
212+
else: # 1, 3
213213
# Negate x.
214214
if cw:
215215
# Swap x and y.
@@ -299,15 +299,11 @@ def _add_output(self, path, angle, flow, length):
299299
else: # Vertical
300300
x += self.gap
301301
if angle == UP:
302-
sign = 1
302+
sign, quadrant = 1, 3
303303
else:
304-
sign = -1
304+
sign, quadrant = -1, 0
305305

306306
tip = [x - flow / 2.0, y + sign * (length + tipheight)]
307-
if angle == UP:
308-
quadrant = 3
309-
else:
310-
quadrant = 0
311307
# Inner arc isn't needed if inner radius is zero
312308
if self.radius:
313309
path.extend(self._arc(quadrant=quadrant,
@@ -525,7 +521,7 @@ def add(self, patchlabel='', flows=None, orientations=None, labels='',
525521
if orient == 1:
526522
if is_input:
527523
angles[i] = DOWN
528-
elif not is_input:
524+
elif is_input is False:
529525
# Be specific since is_input can be None.
530526
angles[i] = UP
531527
elif orient == 0:
@@ -538,7 +534,7 @@ def add(self, patchlabel='', flows=None, orientations=None, labels='',
538534
f"but it must be -1, 0, or 1")
539535
if is_input:
540536
angles[i] = UP
541-
elif not is_input:
537+
elif is_input is False:
542538
angles[i] = DOWN
543539

544540
# Justify the lengths of the paths.
@@ -561,7 +557,7 @@ def add(self, patchlabel='', flows=None, orientations=None, labels='',
561557
if angle == DOWN and is_input:
562558
pathlengths[i] = ullength
563559
ullength += flow
564-
elif angle == UP and not is_input:
560+
elif angle == UP and is_input is False:
565561
pathlengths[i] = urlength
566562
urlength -= flow # Flow is negative for outputs.
567563
# Determine the lengths of the bottom-side arrows
@@ -571,7 +567,7 @@ def add(self, patchlabel='', flows=None, orientations=None, labels='',
571567
if angle == UP and is_input:
572568
pathlengths[n - i - 1] = lllength
573569
lllength += flow
574-
elif angle == DOWN and not is_input:
570+
elif angle == DOWN and is_input is False:
575571
pathlengths[n - i - 1] = lrlength
576572
lrlength -= flow
577573
# Determine the lengths of the left-side arrows
@@ -591,7 +587,7 @@ def add(self, patchlabel='', flows=None, orientations=None, labels='',
591587
for i, (angle, is_input, spec) in enumerate(zip(
592588
angles, are_inputs, list(zip(scaled_flows, pathlengths)))):
593589
if angle == RIGHT:
594-
if not is_input:
590+
if is_input is False:
595591
if has_right_output:
596592
pathlengths[i] = 0
597593
else:
@@ -637,7 +633,7 @@ def add(self, patchlabel='', flows=None, orientations=None, labels='',
637633
if angle == DOWN and is_input:
638634
tips[i, :], label_locations[i, :] = self._add_input(
639635
ulpath, angle, *spec)
640-
elif angle == UP and not is_input:
636+
elif angle == UP and is_input is False:
641637
tips[i, :], label_locations[i, :] = self._add_output(
642638
urpath, angle, *spec)
643639
# Add the bottom-side inputs and outputs from the middle outwards.
@@ -647,7 +643,7 @@ def add(self, patchlabel='', flows=None, orientations=None, labels='',
647643
tip, label_location = self._add_input(llpath, angle, *spec)
648644
tips[n - i - 1, :] = tip
649645
label_locations[n - i - 1, :] = label_location
650-
elif angle == DOWN and not is_input:
646+
elif angle == DOWN and is_input is False:
651647
tip, label_location = self._add_output(lrpath, angle, *spec)
652648
tips[n - i - 1, :] = tip
653649
label_locations[n - i - 1, :] = label_location
@@ -670,7 +666,7 @@ def add(self, patchlabel='', flows=None, orientations=None, labels='',
670666
has_right_output = False
671667
for i, (angle, is_input, spec) in enumerate(zip(
672668
angles, are_inputs, list(zip(scaled_flows, pathlengths)))):
673-
if angle == RIGHT and not is_input:
669+
if angle == RIGHT and is_input is False:
674670
if not has_right_output:
675671
# Make sure the upper path extends
676672
# at least as far as the lower one.

lib/matplotlib/tests/test_sankey.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import pytest
2+
import numpy as np
3+
14
from matplotlib.sankey import Sankey
5+
from matplotlib.testing.decorators import check_figures_equal
26

37

48
def test_sankey():
@@ -22,3 +26,80 @@ def show_three_decimal_places(value):
2226
format=show_three_decimal_places)
2327

2428
assert s.diagrams[0].texts[0].get_text() == 'First\n0.250'
29+
30+
31+
@pytest.mark.parametrize('kwargs, msg', (
32+
({'gap': -1}, "'gap' is negative"),
33+
({'gap': 1, 'radius': 2}, "'radius' is greater than 'gap'"),
34+
({'head_angle': -1}, "'head_angle' is negative"),
35+
({'tolerance': -1}, "'tolerance' is negative"),
36+
({'flows': [1, -1], 'orientations': [-1, 0, 1]},
37+
r"The shapes of 'flows' \(2,\) and 'orientations'"),
38+
({'flows': [1, -1], 'labels': ['a', 'b', 'c']},
39+
r"The shapes of 'flows' \(2,\) and 'labels'"),
40+
))
41+
def test_sankey_errors(kwargs, msg):
42+
with pytest.raises(ValueError, match=msg):
43+
Sankey(**kwargs)
44+
45+
46+
@pytest.mark.parametrize('kwargs, msg', (
47+
({'trunklength': -1}, "'trunklength' is negative"),
48+
({'flows': [0.2, 0.3], 'prior': 0}, "The scaled sum of the connected"),
49+
({'prior': -1}, "The index of the prior diagram is negative"),
50+
({'prior': 1}, "The index of the prior diagram is 1"),
51+
({'connect': (-1, 1), 'prior': 0}, "At least one of the connection"),
52+
({'connect': (2, 1), 'prior': 0}, "The connection index to the source"),
53+
({'connect': (1, 3), 'prior': 0}, "The connection index to this dia"),
54+
({'connect': (1, 1), 'prior': 0, 'flows': [-0.2, 0.2],
55+
'orientations': [2]}, "The value of orientations"),
56+
({'connect': (1, 1), 'prior': 0, 'flows': [-0.2, 0.2],
57+
'pathlengths': [2]}, "The lengths of 'flows'"),
58+
))
59+
def test_sankey_add_errors(kwargs, msg):
60+
sankey = Sankey()
61+
with pytest.raises(ValueError, match=msg):
62+
sankey.add(flows=[0.2, -0.2])
63+
sankey.add(**kwargs)
64+
65+
66+
def test_sankey2():
67+
s = Sankey(flows=[0.25, -0.25, 0.5, -0.5], labels=['Foo'],
68+
orientations=[-1], unit='Bar')
69+
sf = s.finish()
70+
assert np.all(np.equal(np.array((0.25, -0.25, 0.5, -0.5)), sf[0].flows))
71+
assert sf[0].angles == [1, 3, 1, 3]
72+
assert all([text.get_text()[0:3] == 'Foo' for text in sf[0].texts])
73+
assert all([text.get_text()[-3:] == 'Bar' for text in sf[0].texts])
74+
assert sf[0].text.get_text() == ''
75+
assert np.allclose(np.array(((-1.375, -0.52011255),
76+
(1.375, -0.75506044),
77+
(-0.75, -0.41522509),
78+
(0.75, -0.8599479))),
79+
sf[0].tips)
80+
81+
s = Sankey(flows=[0.25, -0.25, 0, 0.5, -0.5], labels=['Foo'],
82+
orientations=[-1], unit='Bar')
83+
sf = s.finish()
84+
assert np.all(np.equal(np.array((0.25, -0.25, 0, 0.5, -0.5)), sf[0].flows))
85+
assert sf[0].angles == [1, 3, None, 1, 3]
86+
assert np.allclose(np.array(((-1.375, -0.52011255),
87+
(1.375, -0.75506044),
88+
(0, 0),
89+
(-0.75, -0.41522509),
90+
(0.75, -0.8599479))),
91+
sf[0].tips)
92+
93+
94+
@check_figures_equal(extensions=['png'])
95+
def test_sankey3(fig_test, fig_ref):
96+
ax_test = fig_test.gca()
97+
s_test = Sankey(ax=ax_test, flows=[0.25, -0.25, -0.25, 0.25, 0.5, -0.5],
98+
orientations=[1, -1, 1, -1, 0, 0])
99+
s_test.finish()
100+
101+
ax_ref = fig_ref.gca()
102+
s_ref = Sankey(ax=ax_ref)
103+
s_ref.add(flows=[0.25, -0.25, -0.25, 0.25, 0.5, -0.5],
104+
orientations=[1, -1, 1, -1, 0, 0])
105+
s_ref.finish()

0 commit comments

Comments
 (0)