Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix _expand_kwargs and its use in plot_series
  • Loading branch information
TLouf committed Sep 25, 2021
commit 68fa0f9cd522d1bfb3244ad8ee128c820ff3df04
111 changes: 69 additions & 42 deletions geopandas/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _flatten_multi_geoms(geoms, prefix="Multi"):
component_index = exploded_geoms.index.codes[1]
component_index = (component_index == 0).cumsum() - 1

else:
else:
not_empty = ~geoms.is_empty
geo_series = geoms
component_index = np.arange(len(geoms))
Expand All @@ -70,37 +70,31 @@ def _expand_kwargs(kwargs, multiindex):
`multiindex` should be list_like
"""
import matplotlib
from matplotlib.colors import is_color_like
from typing import Iterable

single_value_kwargs = ["hatch", "marker"]
mpl = matplotlib.__version__
if mpl >= LooseVersion("3.4") or (mpl > LooseVersion("3.3.2") and "+" in mpl):
if not (mpl >= LooseVersion("3.4") or (mpl > LooseVersion("3.3.2") and "+" in mpl)):
# alpha is supported as array argument with matplotlib 3.4+
scalar_kwargs = ["marker", "path_effects"]
else:
scalar_kwargs = ["marker", "alpha", "path_effects"]

pick_first_kwargs = ["hatch", "marker"]
for att, value in kwargs.items():
if "color" in att: # color(s), edgecolor(s), facecolor(s)
if is_color_like(value):
continue
elif "linestyle" in att: # linestyle(s)
if "linestyle" in att: # linestyle(s)
# A single linestyle can be 2-tuple of a number and an iterable.
if (
isinstance(value, tuple)
and len(value) == 2
and isinstance(value[1], Iterable)
):
continue
elif att in scalar_kwargs:
# For these attributes, only a single value is allowed, so never expand.
continue

if pd.api.types.is_list_like(value):
kwargs[att] = np.take(value, multiindex, axis=0)
if att in pick_first_kwargs:
kwargs[att] = kwargs[att][0]
if att in single_value_kwargs:
kwargs[att] = np.take(value, multiindex[0], axis=0)
else:
kwargs[att] = np.take(value, multiindex, axis=0)


def _PolygonPatch(polygon, **kwargs):
Expand Down Expand Up @@ -160,10 +154,10 @@ def _plot_polygon_collection(

geoms, multiindex = _flatten_multi_geoms(geoms)
poly_patches = geoms.apply(_PolygonPatch)
_expand_kwargs(kwargs, multiindex)

if isinstance(values, pd.Categorical):
# This should never be entered when called through `plot_series`.
# This should never be entered when called through `plot_series`.
_expand_kwargs(kwargs, multiindex)
values = values[multiindex]
codes = values.codes
ucodes = np.unique(codes)
Expand All @@ -176,13 +170,15 @@ def _plot_polygon_collection(
_expand_kwargs(cat_kwargs, cat_idx[[0]])
cat_patches = np.take(poly_patches, cat_idx, axis=0)
collection = PatchCollection(cat_patches, label=cat, **cat_kwargs)
collection.set_facecolor(color.get(cat, 'none'))
collection.set_facecolor(color.get(cat, "none"))
_ = ax.add_collection(collection, autolim=True)

else:
# Add to kwargs for easier checking below.
if color is not None:
kwargs["color"] = color
_expand_kwargs(kwargs, multiindex)

collection = PatchCollection(poly_patches, **kwargs)

if values is not None:
Expand Down Expand Up @@ -228,10 +224,10 @@ def _plot_linestring_collection(

geoms, multiindex = _flatten_multi_geoms(geoms)
segments = [np.array(linestring.coords)[:, :2] for linestring in geoms]
_expand_kwargs(kwargs, multiindex)

if isinstance(values, pd.Categorical):
# This should never be entered when called through `plot_series`.
# This should never be entered when called through `plot_series`.
_expand_kwargs(kwargs, multiindex)
values = values[multiindex]
codes = values.codes
ucodes = np.unique(codes)
Expand All @@ -243,13 +239,15 @@ def _plot_linestring_collection(
_expand_kwargs(cat_kwargs, cat_idx[[0]])
cat_segments = np.take(segments, cat_idx, axis=0)
collection = LineCollection(cat_segments, label=cat, **cat_kwargs)
collection.set_color(color.get(cat, 'none'))
collection.set_color(color.get(cat, "none"))
_ = ax.add_collection(collection, autolim=True)

else:
# Add to kwargs for easier checking below.
if color is not None:
kwargs["color"] = color
_expand_kwargs(kwargs, multiindex)

collection = LineCollection(segments, **kwargs)

if values is not None:
Expand All @@ -267,7 +265,13 @@ def _plot_linestring_collection(


def _plot_point_collection(
ax, geoms, values=None, color=None, cmap=None, norm=None, **kwargs,
ax,
geoms,
values=None,
color=None,
cmap=None,
norm=None,
**kwargs,
):
"""
Plots a collection of Point and MultiPoint geometries to `ax`
Expand All @@ -287,14 +291,14 @@ def _plot_point_collection(
kwargs.pop("edgecolor", None)
if kwargs.get("markersize") is not None:
# We square to match the units.
kwargs["s"] = kwargs.pop("markersize")**2
kwargs["s"] = kwargs.pop("markersize") ** 2

geoms, multiindex = _flatten_multi_geoms(geoms)
x, y = geoms.x, geoms.y
_expand_kwargs(kwargs, multiindex)

if isinstance(values, pd.Categorical):
# This should never be entered when called through `plot_series`.
# This should never be entered when called through `plot_series`.
_expand_kwargs(kwargs, multiindex)
values = values[multiindex]
codes = values.codes
ucodes = np.unique(codes)
Expand All @@ -307,7 +311,7 @@ def _plot_point_collection(
cat_x = np.take(x, cat_idx, axis=0)
cat_y = np.take(y, cat_idx, axis=0)
collection = ax.scatter(cat_x, cat_y, label=cat, **cat_kwargs)
collection.set_color(color.get(cat, 'none'))
collection.set_color(color.get(cat, "none"))

else:
# Add to kwargs for easier checking below.
Expand All @@ -319,14 +323,23 @@ def _plot_point_collection(
if values is not None:
kwargs["c"] = np.take(values, multiindex, axis=0)

_expand_kwargs(kwargs, multiindex)

_ = ax.scatter(x, y, cmap=cmap, norm=norm, **kwargs)


plot_point_collection = deprecated(_plot_point_collection)


def plot_series(
s, cmap=None, norm=None, color=None, ax=None, figsize=None, aspect="auto", **style_kwds
s,
cmap=None,
norm=None,
color=None,
ax=None,
figsize=None,
aspect="auto",
**style_kwds,
):
"""
Plot a GeoSeries.
Expand Down Expand Up @@ -436,6 +449,7 @@ def plot_series(
vmin = style_kwds.get("vmin", values.min())
vmax = style_kwds.get("vmax", values.max())
from matplotlib.colors import Normalize

norm = Normalize(vmin=vmin, vmax=vmax)

# decompose GeometryCollections
Expand All @@ -450,34 +464,45 @@ def plot_series(
)
point_idx = np.asarray((geom_types == "Point") | (geom_types == "MultiPoint"))

if color is not None:
style_kwds["color"] = color

# plot all Polygons and all MultiPolygon components in the same collection
polys = geoms[poly_idx]
if not polys.empty:
# color overrides both face and edgecolor. As we want people to be
# able to use edgecolor as well, pass color to facecolor
facecolor = style_kwds.pop("facecolor", None)
if color is not None:
facecolor = color
polys_style_kwds = style_kwds.copy()
_expand_kwargs(polys_style_kwds, multiindex[poly_idx])
# `color` overrides both `facecolor` and `edgecolor`. As we want users
# to be able to use `edgecolor` as well, pass `color` to `facecolor`:
default_fc = polys_style_kwds.pop("color", None)
if polys_style_kwds.get("facecolor") is None:
polys_style_kwds["facecolor"] = default_fc

values_ = values[poly_idx] if cmap else None
_plot_polygon_collection(
ax, polys, values_, facecolor=facecolor, cmap=cmap, norm=norm, **style_kwds
ax, polys, values_, cmap=cmap, norm=norm, **polys_style_kwds
)

# plot all LineStrings and MultiLineString components in same collection
lines = geoms[line_idx]
if not lines.empty:
lines_style_kwds = style_kwds.copy()
_expand_kwargs(lines_style_kwds, multiindex[line_idx])

values_ = values[line_idx] if cmap else None
_plot_linestring_collection(
ax, lines, values_, color=color, cmap=cmap, norm=norm, **style_kwds
ax, lines, values_, cmap=cmap, norm=norm, **lines_style_kwds
)

# plot all Points in the same collection
points = geoms[point_idx]
if not points.empty:
pts_style_kwds = style_kwds.copy()
_expand_kwargs(pts_style_kwds, multiindex[point_idx])

values_ = values[point_idx] if cmap else None
_plot_point_collection(
ax, points, values_, color=color, cmap=cmap, norm=norm, **style_kwds
ax, points, values_, cmap=cmap, norm=norm, **pts_style_kwds
)

plt.draw()
Expand Down Expand Up @@ -661,6 +686,7 @@ def plot_dataframe(
ax = style_kwds.pop("axes")

from matplotlib.colors import is_color_like

if column is not None and is_color_like(color):
warnings.warn(
"Only specify one of 'column' or 'color'. Using 'color'.", UserWarning
Expand Down Expand Up @@ -834,6 +860,7 @@ def plot_dataframe(
# `vmin` and `vmax`, or based on the `values` array.
if norm is None:
from matplotlib.colors import Normalize

# If categorical vmin and vmax cannot be None, so the fact that
# value.min() is a string does not matter.
mn = values[~np.isnan(values)].min() if vmin is None else vmin
Expand All @@ -844,14 +871,13 @@ def plot_dataframe(
# category:
if categorical:
from matplotlib import cm

n_cmap = cm.ScalarMappable(norm=norm, cmap=cmap)
cat_colors = n_cmap.to_rgba(np.unique(values.codes))
color = {
cat: cat_color
for cat, cat_color in zip(values.categories, cat_colors)
cat: cat_color for cat, cat_color in zip(values.categories, cat_colors)
}

# decompose GeometryCollections
geoms, multiindex = _flatten_multi_geoms(df.geometry, prefix="Geom")
values = values[multiindex]
nan_idx = np.take(nan_idx, multiindex, axis=0)
Expand Down Expand Up @@ -884,7 +910,6 @@ def plot_dataframe(
subset = values[lines_mask]
if not lines.empty:
lines_style_kwds = style_kwds.copy()
# _expand_kwargs(lines_style_kwds, np.where(lines_mask)[0])
_expand_kwargs(lines_style_kwds, multiindex[lines_mask])
_plot_linestring_collection(
ax, lines, subset, color=color, cmap=cmap, norm=norm, **lines_style_kwds
Expand All @@ -903,7 +928,7 @@ def plot_dataframe(

if missing_kwds is not None and not geoms[nan_idx].empty:
if color:
missing_kwds["color"] = missing_kwds.get('color', color)
missing_kwds["color"] = missing_kwds.get("color", color)
merged_kwds = style_kwds.copy()
merged_kwds.update(missing_kwds)
merged_kwds["label"] = merged_kwds.get("label", "NaN")
Expand All @@ -922,6 +947,7 @@ def plot_dataframe(

else:
from matplotlib import cm

if cax is not None:
legend_kwds.setdefault("cax", cax)
else:
Expand Down Expand Up @@ -956,13 +982,14 @@ def geo(self, *args, **kwargs):


def _legend_with_poly_wrapper(fun):
'''
"""
Decorator for ax.legend that enables `PatchCollection` objects plotted by
`_plot_polygon_collection` to be rendered correctly in the legend.
'''
"""
from matplotlib.legend_handler import HandlerPolyCollection
from matplotlib.collections import PatchCollection

def legend(*args, **kwargs):
kwargs['handler_map'] = {PatchCollection: HandlerPolyCollection()}
kwargs["handler_map"] = {PatchCollection: HandlerPolyCollection()}
return fun(*args, **kwargs)
return legend