diff --git a/lib/matplotlib/axes/__init__.pyi b/lib/matplotlib/axes/__init__.pyi index 0c27be62e370..7df38b8bde9e 100644 --- a/lib/matplotlib/axes/__init__.pyi +++ b/lib/matplotlib/axes/__init__.pyi @@ -1,10 +1,13 @@ from typing import TypeVar -from ._axes import * -from ._axes import Axes as Subplot +from ._axes import Axes as Axes + _T = TypeVar("_T") +# Backcompat. +Subplot = Axes + class _SubplotBaseMeta(type): def __instancecheck__(self, obj) -> bool: ... diff --git a/lib/matplotlib/axes/_base.pyi b/lib/matplotlib/axes/_base.pyi index 1f929b6c90c5..c63d81992389 100644 --- a/lib/matplotlib/axes/_base.pyi +++ b/lib/matplotlib/axes/_base.pyi @@ -6,6 +6,7 @@ from matplotlib import cbook from matplotlib.artist import Artist from matplotlib.axis import XAxis, YAxis, Tick from matplotlib.backend_bases import RendererBase, MouseButton, MouseEvent +from matplotlib.cbook import CallbackRegistry from matplotlib.container import Container from matplotlib.collections import Collection from matplotlib.cm import ScalarMappable @@ -25,9 +26,11 @@ from cycler import Cycler import numpy as np from numpy.typing import ArrayLike -from typing import Any, Literal, overload +from typing import Any, Literal, TypeVar, overload from matplotlib.typing import ColorType +_T = TypeVar("_T", bound=Artist) + class _axis_method_wrapper: attr_name: str method_name: str @@ -53,6 +56,11 @@ class _AxesBase(martist.Artist): transData: Transform ignore_existing_data_limits: bool axison: bool + containers: list[Container] + callbacks: CallbackRegistry + child_axes: list[_AxesBase] + legend_: Legend | None + title: Text _projection_init: Any def __init__( @@ -125,8 +133,7 @@ class _AxesBase(martist.Artist): def clear(self) -> None: ... def cla(self) -> None: ... - # Could be made generic, but comments indicate it may be temporary anyway - class ArtistList(Sequence[Artist]): + class ArtistList(Sequence[_T]): def __init__( self, axes: _AxesBase, @@ -135,40 +142,40 @@ class _AxesBase(martist.Artist): invalid_types: type | Iterable[type] | None = ..., ) -> None: ... def __len__(self) -> int: ... - def __iter__(self) -> Iterator[Artist]: ... + def __iter__(self) -> Iterator[_T]: ... @overload - def __getitem__(self, key: int) -> Artist: ... + def __getitem__(self, key: int) -> _T: ... @overload - def __getitem__(self, key: slice) -> list[Artist]: ... + def __getitem__(self, key: slice) -> list[_T]: ... @overload - def __add__(self, other: _AxesBase.ArtistList) -> list[Artist]: ... + def __add__(self, other: _AxesBase.ArtistList[_T]) -> list[_T]: ... @overload def __add__(self, other: list[Any]) -> list[Any]: ... @overload def __add__(self, other: tuple[Any]) -> tuple[Any]: ... @overload - def __radd__(self, other: _AxesBase.ArtistList) -> list[Artist]: ... + def __radd__(self, other: _AxesBase.ArtistList[_T]) -> list[_T]: ... @overload def __radd__(self, other: list[Any]) -> list[Any]: ... @overload def __radd__(self, other: tuple[Any]) -> tuple[Any]: ... @property - def artists(self) -> _AxesBase.ArtistList: ... + def artists(self) -> _AxesBase.ArtistList[Artist]: ... @property - def collections(self) -> _AxesBase.ArtistList: ... + def collections(self) -> _AxesBase.ArtistList[Collection]: ... @property - def images(self) -> _AxesBase.ArtistList: ... + def images(self) -> _AxesBase.ArtistList[AxesImage]: ... @property - def lines(self) -> _AxesBase.ArtistList: ... + def lines(self) -> _AxesBase.ArtistList[Line2D]: ... @property - def patches(self) -> _AxesBase.ArtistList: ... + def patches(self) -> _AxesBase.ArtistList[Patch]: ... @property - def tables(self) -> _AxesBase.ArtistList: ... + def tables(self) -> _AxesBase.ArtistList[Table]: ... @property - def texts(self) -> _AxesBase.ArtistList: ... + def texts(self) -> _AxesBase.ArtistList[Text]: ... def get_facecolor(self) -> ColorType: ... def set_facecolor(self, color: ColorType | None) -> None: ... @overload diff --git a/lib/matplotlib/ticker.pyi b/lib/matplotlib/ticker.pyi index 2ef1c9f53f1d..f026b4943c94 100644 --- a/lib/matplotlib/ticker.pyi +++ b/lib/matplotlib/ticker.pyi @@ -19,7 +19,7 @@ class _DummyAxis: class TickHelper: axis: None | Axis | _DummyAxis | _AxisWrapper - def set_axis(self, axis: Axis | _DummyAxis | None) -> None: ... + def set_axis(self, axis: Axis | _DummyAxis | _AxisWrapper | None) -> None: ... def create_dummy_axis(self, **kwargs) -> None: ... class Formatter(TickHelper):