Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 93aadf5

Browse files
authored
Merge pull request #11127 from ImportanceOfBeingErnest/legend-for-scatter
Legend for Scatter
2 parents 0e2b9b1 + d9bd109 commit 93aadf5

File tree

5 files changed

+299
-6
lines changed

5 files changed

+299
-6
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ per-file-ignores =
141141
examples/lines_bars_and_markers/filled_step.py: E402
142142
examples/lines_bars_and_markers/joinstyle.py: E402
143143
examples/lines_bars_and_markers/scatter_piecharts.py: E402
144+
examples/lines_bars_and_markers/scatter_with_legend.py: E402
144145
examples/lines_bars_and_markers/span_regions.py: E402
145146
examples/lines_bars_and_markers/step_demo.py: E402
146147
examples/misc/agg_buffer.py: E402
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
:orphan:
2+
3+
Legend for scatter
4+
------------------
5+
6+
A new method for creating legends for scatter plots has been introduced.
7+
Previously, in order to obtain a legend for a :meth:`~.axes.Axes.scatter`
8+
plot, one could either plot several scatters, each with an individual label,
9+
or create proxy artists to show in the legend manually.
10+
Now, :class:`~.collections.PathCollection` provides a method
11+
:meth:`~.collections.PathCollection.legend_elements` to obtain the handles and labels
12+
for a scatter plot in an automated way. This makes creating a legend for a
13+
scatter plot as easy as::
14+
15+
scatter = plt.scatter([1,2,3], [4,5,6], c=[7,2,3])
16+
plt.legend(*scatter.legend_elements())
17+
18+
An example can be found in
19+
:ref:`automatedlegendcreation`.

examples/lines_bars_and_markers/scatter_with_legend.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,116 @@
33
Scatter plots with a legend
44
===========================
55
6-
Also demonstrates how transparency of the markers
7-
can be adjusted by giving ``alpha`` a value between
8-
0 and 1.
6+
To create a scatter plot with a legend one may use a loop and create one
7+
`~.Axes.scatter` plot per item to appear in the legend and set the ``label``
8+
accordingly.
9+
10+
The following also demonstrates how transparency of the markers
11+
can be adjusted by giving ``alpha`` a value between 0 and 1.
912
"""
1013

14+
import numpy as np
15+
np.random.seed(19680801)
1116
import matplotlib.pyplot as plt
12-
from numpy.random import rand
1317

1418

1519
fig, ax = plt.subplots()
1620
for color in ['tab:blue', 'tab:orange', 'tab:green']:
1721
n = 750
18-
x, y = rand(2, n)
19-
scale = 200.0 * rand(n)
22+
x, y = np.random.rand(2, n)
23+
scale = 200.0 * np.random.rand(n)
2024
ax.scatter(x, y, c=color, s=scale, label=color,
2125
alpha=0.3, edgecolors='none')
2226

2327
ax.legend()
2428
ax.grid(True)
2529

2630
plt.show()
31+
32+
33+
##############################################################################
34+
# .. _automatedlegendcreation:
35+
#
36+
# Automated legend creation
37+
# -------------------------
38+
#
39+
# Another option for creating a legend for a scatter is to use the
40+
# :class:`~matplotlib.collections.PathCollection`'s
41+
# :meth:`~.PathCollection.legend_elements` method.
42+
# It will automatically try to determine a useful number of legend entries
43+
# to be shown and return a tuple of handles and labels. Those can be passed
44+
# to the call to :meth:`~.axes.Axes.legend`.
45+
46+
47+
N = 45
48+
x, y = np.random.rand(2, N)
49+
c = np.random.randint(1, 5, size=N)
50+
s = np.random.randint(10, 220, size=N)
51+
52+
fig, ax = plt.subplots()
53+
54+
scatter = ax.scatter(x, y, c=c, s=s)
55+
56+
# produce a legend with the unique colors from the scatter
57+
legend1 = ax.legend(*scatter.legend_elements(),
58+
loc="lower left", title="Classes")
59+
ax.add_artist(legend1)
60+
61+
# produce a legend with a cross section of sizes from the scatter
62+
handles, labels = scatter.legend_elements(prop="sizes", alpha=0.6)
63+
legend2 = ax.legend(handles, labels, loc="upper right", title="Sizes")
64+
65+
plt.show()
66+
67+
68+
##############################################################################
69+
# Further arguments to the :meth:`~.PathCollection.legend_elements` method
70+
# can be used to steer how many legend entries are to be created and how they
71+
# should be labeled. The following shows how to use some of them.
72+
#
73+
74+
volume = np.random.rayleigh(27, size=40)
75+
amount = np.random.poisson(10, size=40)
76+
ranking = np.random.normal(size=40)
77+
price = np.random.uniform(1, 10, size=40)
78+
79+
fig, ax = plt.subplots()
80+
81+
# Because the price is much too small when being provided as size for ``s``,
82+
# we normalize it to some useful point sizes, s=0.3*(price*3)**2
83+
scatter = ax.scatter(volume, amount, c=ranking, s=0.3*(price*3)**2,
84+
vmin=-3, vmax=3, cmap="Spectral")
85+
86+
# Produce a legend for the ranking (colors). Even though there are 40 different
87+
# rankings, we only want to show 5 of them in the legend.
88+
legend1 = ax.legend(*scatter.legend_elements(num=5),
89+
loc="upper left", title="Ranking")
90+
ax.add_artist(legend1)
91+
92+
# Produce a legend for the price (sizes). Because we want to show the prices
93+
# in dollars, we use the *func* argument to supply the inverse of the function
94+
# used to calculate the sizes from above. The *fmt* ensures to show the price
95+
# in dollars. Note how we target at 5 elements here, but obtain only 4 in the
96+
# created legend due to the automatic round prices that are chosen for us.
97+
kw = dict(prop="sizes", num=5, color=scatter.cmap(0.7), fmt="$ {x:.2f}",
98+
func=lambda s: np.sqrt(s/.3)/3)
99+
legend2 = ax.legend(*scatter.legend_elements(**kw),
100+
loc="lower right", title="Price")
101+
102+
plt.show()
103+
104+
#############################################################################
105+
#
106+
# ------------
107+
#
108+
# References
109+
# """"""""""
110+
#
111+
# The usage of the following functions and methods is shown in this example:
112+
113+
import matplotlib
114+
matplotlib.axes.Axes.scatter
115+
matplotlib.pyplot.scatter
116+
matplotlib.axes.Axes.legend
117+
matplotlib.pyplot.legend
118+
matplotlib.collections.PathCollection.legend_elements

lib/matplotlib/collections.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import matplotlib as mpl
1717
from . import (_path, artist, cbook, cm, colors as mcolors, docstring,
1818
lines as mlines, path as mpath, transforms)
19+
import warnings
1920

2021

2122
@cbook._define_aliases({
@@ -868,6 +869,7 @@ def draw(self, renderer):
868869
class PathCollection(_CollectionWithSizes):
869870
"""
870871
This is the most basic :class:`Collection` subclass.
872+
A :class:`PathCollection` is e.g. created by a :meth:`~.Axes.scatter` plot.
871873
"""
872874
@docstring.dedent_interpd
873875
def __init__(self, paths, sizes=None, **kwargs):
@@ -890,6 +892,133 @@ def set_paths(self, paths):
890892
def get_paths(self):
891893
return self._paths
892894

895+
def legend_elements(self, prop="colors", num="auto",
896+
fmt=None, func=lambda x: x, **kwargs):
897+
"""
898+
Creates legend handles and labels for a PathCollection. This is useful
899+
for obtaining a legend for a :meth:`~.Axes.scatter` plot. E.g.::
900+
901+
scatter = plt.scatter([1,2,3], [4,5,6], c=[7,2,3])
902+
plt.legend(*scatter.legend_elements())
903+
904+
Also see the :ref:`automatedlegendcreation` example.
905+
906+
Parameters
907+
----------
908+
prop : string, optional, default *"colors"*
909+
Can be *"colors"* or *"sizes"*. In case of *"colors"*, the legend
910+
handles will show the different colors of the collection. In case
911+
of "sizes", the legend will show the different sizes.
912+
num : int, None, "auto" (default), array-like, or `~.ticker.Locator`,
913+
optional
914+
Target number of elements to create.
915+
If None, use all unique elements of the mappable array. If an
916+
integer, target to use *num* elements in the normed range.
917+
If *"auto"*, try to determine which option better suits the nature
918+
of the data.
919+
The number of created elements may slightly deviate from *num* due
920+
to a `~.ticker.Locator` being used to find useful locations.
921+
If a list or array, use exactly those elements for the legend.
922+
Finally, a `~.ticker.Locator` can be provided.
923+
fmt : string, `~matplotlib.ticker.Formatter`, or None (default)
924+
The format or formatter to use for the labels. If a string must be
925+
a valid input for a `~.StrMethodFormatter`. If None (the default),
926+
use a `~.ScalarFormatter`.
927+
func : function, default *lambda x: x*
928+
Function to calculate the labels. Often the size (or color)
929+
argument to :meth:`~.Axes.scatter` will have been pre-processed
930+
by the user using a function *s = f(x)* to make the markers
931+
visible; e.g. *size = np.log10(x)*. Providing the inverse of this
932+
function here allows that pre-processing to be inverted, so that
933+
the legend labels have the correct values;
934+
e.g. *func = np.exp(x, 10)*.
935+
kwargs : further parameters
936+
Allowed kwargs are *color* and *size*. E.g. it may be useful to
937+
set the color of the markers if *prop="sizes"* is used; similarly
938+
to set the size of the markers if *prop="colors"* is used.
939+
Any further parameters are passed onto the `.Line2D` instance.
940+
This may be useful to e.g. specify a different *markeredgecolor* or
941+
*alpha* for the legend handles.
942+
943+
Returns
944+
-------
945+
tuple (handles, labels)
946+
with *handles* being a list of `.Line2D` objects
947+
and *labels* a matching list of strings.
948+
"""
949+
handles = []
950+
labels = []
951+
hasarray = self.get_array() is not None
952+
if fmt is None:
953+
fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True)
954+
elif isinstance(fmt, str):
955+
fmt = mpl.ticker.StrMethodFormatter(fmt)
956+
fmt.create_dummy_axis()
957+
958+
if prop == "colors":
959+
if not hasarray:
960+
warnings.warn("Collection without array used. Make sure to "
961+
"specify the values to be colormapped via the "
962+
"`c` argument.")
963+
return handles, labels
964+
u = np.unique(self.get_array())
965+
size = kwargs.pop("size", mpl.rcParams["lines.markersize"])
966+
elif prop == "sizes":
967+
u = np.unique(self.get_sizes())
968+
color = kwargs.pop("color", "k")
969+
else:
970+
raise ValueError("Valid values for `prop` are 'colors' or "
971+
f"'sizes'. You supplied '{prop}' instead.")
972+
973+
fmt.set_bounds(func(u).min(), func(u).max())
974+
if num == "auto":
975+
num = 9
976+
if len(u) <= num:
977+
num = None
978+
if num is None:
979+
values = u
980+
label_values = func(values)
981+
else:
982+
if prop == "colors":
983+
arr = self.get_array()
984+
elif prop == "sizes":
985+
arr = self.get_sizes()
986+
if isinstance(num, mpl.ticker.Locator):
987+
loc = num
988+
elif np.iterable(num):
989+
loc = mpl.ticker.FixedLocator(num)
990+
else:
991+
num = int(num)
992+
loc = mpl.ticker.MaxNLocator(nbins=num, min_n_ticks=num-1,
993+
steps=[1, 2, 2.5, 3, 5, 6, 8, 10])
994+
label_values = loc.tick_values(func(arr).min(), func(arr).max())
995+
cond = ((label_values >= func(arr).min()) &
996+
(label_values <= func(arr).max()))
997+
label_values = label_values[cond]
998+
xarr = np.linspace(arr.min(), arr.max(), 256)
999+
values = np.interp(label_values, func(xarr), xarr)
1000+
1001+
kw = dict(markeredgewidth=self.get_linewidths()[0],
1002+
alpha=self.get_alpha())
1003+
kw.update(kwargs)
1004+
1005+
for val, lab in zip(values, label_values):
1006+
if prop == "colors":
1007+
color = self.cmap(self.norm(val))
1008+
elif prop == "sizes":
1009+
size = np.sqrt(val)
1010+
if np.isclose(size, 0.0):
1011+
continue
1012+
h = mlines.Line2D([0], [0], ls="", color=color, ms=size,
1013+
marker=self.get_paths()[0], **kw)
1014+
handles.append(h)
1015+
if hasattr(fmt, "set_locs"):
1016+
fmt.set_locs(label_values)
1017+
l = fmt(lab)
1018+
labels.append(l)
1019+
1020+
return handles, labels
1021+
8931022

8941023
class PolyCollection(_CollectionWithSizes):
8951024
@docstring.dedent_interpd

lib/matplotlib/tests/test_collections.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,55 @@ def test_scatter_post_alpha():
669669
# this needs to be here to update internal state
670670
fig.canvas.draw()
671671
sc.set_alpha(.1)
672+
673+
674+
def test_pathcollection_legend_elements():
675+
np.random.seed(19680801)
676+
x, y = np.random.rand(2, 10)
677+
y = np.random.rand(10)
678+
c = np.random.randint(0, 5, size=10)
679+
s = np.random.randint(10, 300, size=10)
680+
681+
fig, ax = plt.subplots()
682+
sc = ax.scatter(x, y, c=c, s=s, cmap="jet", marker="o", linewidths=0)
683+
684+
h, l = sc.legend_elements(fmt="{x:g}")
685+
assert len(h) == 5
686+
assert_array_equal(np.array(l).astype(float), np.arange(5))
687+
colors = np.array([line.get_color() for line in h])
688+
colors2 = sc.cmap(np.arange(5)/4)
689+
assert_array_equal(colors, colors2)
690+
l1 = ax.legend(h, l, loc=1)
691+
692+
h2, lab2 = sc.legend_elements(num=9)
693+
assert len(h2) == 9
694+
l2 = ax.legend(h2, lab2, loc=2)
695+
696+
h, l = sc.legend_elements(prop="sizes", alpha=0.5, color="red")
697+
alpha = np.array([line.get_alpha() for line in h])
698+
assert_array_equal(alpha, 0.5)
699+
color = np.array([line.get_markerfacecolor() for line in h])
700+
assert_array_equal(color, "red")
701+
l3 = ax.legend(h, l, loc=4)
702+
703+
h, l = sc.legend_elements(prop="sizes", num=4, fmt="{x:.2f}",
704+
func=lambda x: 2*x)
705+
actsizes = [line.get_markersize() for line in h]
706+
labeledsizes = np.sqrt(np.array(l).astype(float)/2)
707+
assert_array_almost_equal(actsizes, labeledsizes)
708+
l4 = ax.legend(h, l, loc=3)
709+
710+
import matplotlib.ticker as mticker
711+
loc = mticker.MaxNLocator(nbins=9, min_n_ticks=9-1,
712+
steps=[1, 2, 2.5, 3, 5, 6, 8, 10])
713+
h5, lab5 = sc.legend_elements(num=loc)
714+
assert len(h2) == len(h5)
715+
716+
levels = [-1, 0, 55.4, 260]
717+
h6, lab6 = sc.legend_elements(num=levels, prop="sizes", fmt="{x:g}")
718+
assert_array_equal(np.array(lab6).astype(float), levels[2:])
719+
720+
for l in [l1, l2, l3, l4]:
721+
ax.add_artist(l)
722+
723+
fig.canvas.draw()

0 commit comments

Comments
 (0)