diff --git a/lib/matplotlib/axes/_axes.pyi b/lib/matplotlib/axes/_axes.pyi index c3eb28d2f095..c9889049a9d8 100644 --- a/lib/matplotlib/axes/_axes.pyi +++ b/lib/matplotlib/axes/_axes.pyi @@ -36,7 +36,7 @@ from collections.abc import Callable, Iterable, Sequence from typing import Any, Literal, overload import numpy as np from numpy.typing import ArrayLike -from matplotlib.typing import ColorType, MarkerType, LineStyleType +from matplotlib.typing import ColorType, MarkerType, LegendLocType, LineStyleType class Axes(_AxesBase): def get_title(self, loc: Literal["left", "center", "right"] = ...) -> str: ... @@ -58,13 +58,16 @@ class Axes(_AxesBase): @overload def legend(self) -> Legend: ... @overload - def legend(self, handles: Iterable[Artist | tuple[Artist, ...]], labels: Iterable[str], **kwargs) -> Legend: ... + def legend(self, handles: Iterable[Artist | tuple[Artist, ...]], labels: Iterable[str], + *, loc: LegendLocType | None = ..., **kwargs) -> Legend: ... @overload - def legend(self, *, handles: Iterable[Artist | tuple[Artist, ...]], **kwargs) -> Legend: ... + def legend(self, *, handles: Iterable[Artist | tuple[Artist, ...]], + loc: LegendLocType | None = ..., **kwargs) -> Legend: ... @overload - def legend(self, labels: Iterable[str], **kwargs) -> Legend: ... + def legend(self, labels: Iterable[str], + *, loc: LegendLocType | None = ..., **kwargs) -> Legend: ... @overload - def legend(self, **kwargs) -> Legend: ... + def legend(self, *, loc: LegendLocType | None = ..., **kwargs) -> Legend: ... def inset_axes( self, diff --git a/lib/matplotlib/figure.pyi b/lib/matplotlib/figure.pyi index e7c5175d8af9..5119de15c83b 100644 --- a/lib/matplotlib/figure.pyi +++ b/lib/matplotlib/figure.pyi @@ -27,7 +27,7 @@ from matplotlib.text import Text from matplotlib.transforms import Affine2D, Bbox, BboxBase, Transform from mpl_toolkits.mplot3d import Axes3D -from .typing import ColorType, HashableList +from .typing import ColorType, HashableList, LegendLocType _T = TypeVar("_T") @@ -151,13 +151,16 @@ class FigureBase(Artist): @overload def legend(self) -> Legend: ... @overload - def legend(self, handles: Iterable[Artist], labels: Iterable[str], **kwargs) -> Legend: ... + def legend(self, handles: Iterable[Artist], labels: Iterable[str], + *, loc: LegendLocType | None = ..., **kwargs) -> Legend: ... @overload - def legend(self, *, handles: Iterable[Artist], **kwargs) -> Legend: ... + def legend(self, *, handles: Iterable[Artist], + loc: LegendLocType | None = ..., **kwargs) -> Legend: ... @overload - def legend(self, labels: Iterable[str], **kwargs) -> Legend: ... + def legend(self, labels: Iterable[str], + *, loc: LegendLocType | None = ..., **kwargs) -> Legend: ... @overload - def legend(self, **kwargs) -> Legend: ... + def legend(self, *, loc: LegendLocType | None = ..., **kwargs) -> Legend: ... def text( self, diff --git a/lib/matplotlib/legend.pyi b/lib/matplotlib/legend.pyi index dde5882da69d..c03471fc54d1 100644 --- a/lib/matplotlib/legend.pyi +++ b/lib/matplotlib/legend.pyi @@ -14,12 +14,13 @@ from matplotlib.transforms import ( BboxBase, Transform, ) +from matplotlib.typing import ColorType, LegendLocType import pathlib from collections.abc import Iterable from typing import Any, Literal, overload -from .typing import ColorType + class DraggableLegend(DraggableOffsetBox): legend: Legend @@ -55,7 +56,7 @@ class Legend(Artist): handles: Iterable[Artist | tuple[Artist, ...]], labels: Iterable[str], *, - loc: str | tuple[float, float] | int | None = ..., + loc: LegendLocType | None = ..., numpoints: int | None = ..., markerscale: float | None = ..., markerfirst: bool = ..., @@ -118,7 +119,7 @@ class Legend(Artist): def get_texts(self) -> list[Text]: ... def set_alignment(self, alignment: Literal["center", "left", "right"]) -> None: ... def get_alignment(self) -> Literal["center", "left", "right"]: ... - def set_loc(self, loc: str | tuple[float, float] | int | None = ...) -> None: ... + def set_loc(self, loc: LegendLocType | None = ...) -> None: ... def set_title( self, title: str, prop: FontProperties | str | pathlib.Path | None = ... ) -> None: ... diff --git a/lib/matplotlib/typing.py b/lib/matplotlib/typing.py index df192df76b33..17f1bcf41720 100644 --- a/lib/matplotlib/typing.py +++ b/lib/matplotlib/typing.py @@ -107,3 +107,24 @@ _HT = TypeVar("_HT", bound=Hashable) HashableList: TypeAlias = list[_HT | "HashableList[_HT]"] """A nested list of Hashable values.""" + + +LegendLocType: TypeAlias = ( + Literal[ + # for simplicity, we don't distinguish the between allowed positions for + # Axes legend and figure legend. It's still better to limit the allowed + # range to the union of both rather than to accept arbitrary strings + "upper right", "upper left", "lower left", "lower right", + "right", "center left", "center right", "lower center", "upper center", + "center", + # Axes only + "best", + # Figure only + "outside upper left", "outside upper center", "outside upper right", + "outside right upper", "outside right center", "outside right lower", + "outside lower right", "outside lower center", "outside lower left", + "outside left lower", "outside left center", "outside left upper", + ] | + tuple[float, float] | + int +)