diff --git a/lib/matplotlib/scale.py b/lib/matplotlib/scale.py index 44fbe5209c4d..3126c2c82095 100644 --- a/lib/matplotlib/scale.py +++ b/lib/matplotlib/scale.py @@ -31,6 +31,7 @@ import inspect import textwrap +from functools import wraps import numpy as np @@ -103,6 +104,49 @@ def limit_range_for_scale(self, vmin, vmax, minpos): return vmin, vmax +def handle_axis_parameter(init_func): + """ + Decorator to handle the optional *axis* parameter in scale constructors. + + This decorator ensures backward compatibility for scale classes that + previously required an *axis* parameter. It allows constructors to work + seamlessly with or without the *axis* parameter. + + Parameters + ---------- + init_func : callable + The original __init__ method of a scale class. + + Returns + ------- + callable + A wrapped version of *init_func* that handles the optional *axis*. + + Notes + ----- + If the wrapped constructor defines *axis* as its first argument, the + parameter is preserved when present. Otherwise, the value `None` is injected + as the first argument. + + Examples + -------- + >>> from matplotlib.scale import ScaleBase + >>> class CustomScale(ScaleBase): + ... @handle_axis_parameter + ... def __init__(self, axis, custom_param=1): + ... self.custom_param = custom_param + """ + @wraps(init_func) + def wrapper(self, *args, **kwargs): + if args and isinstance(args[0], mpl.axis.Axis): + return init_func(self, *args, **kwargs) + else: + # Remove 'axis' from kwargs to avoid double assignment + kwargs.pop('axis', None) + return init_func(self, None, *args, **kwargs) + return wrapper + + class LinearScale(ScaleBase): """ The default linear scale. @@ -110,7 +154,8 @@ class LinearScale(ScaleBase): name = 'linear' - def __init__(self, axis): + @handle_axis_parameter + def __init__(self, axis=None): # This method is present only to prevent inheritance of the base class' # constructor docstring, which would otherwise end up interpolated into # the docstring of Axis.set_scale. @@ -180,6 +225,7 @@ class FuncScale(ScaleBase): name = 'function' + @handle_axis_parameter def __init__(self, axis, functions): """ Parameters @@ -279,7 +325,8 @@ class LogScale(ScaleBase): """ name = 'log' - def __init__(self, axis, *, base=10, subs=None, nonpositive="clip"): + @handle_axis_parameter + def __init__(self, axis=None, *, base=10, subs=None, nonpositive="clip"): """ Parameters ---------- @@ -330,6 +377,7 @@ class FuncScaleLog(LogScale): name = 'functionlog' + @handle_axis_parameter def __init__(self, axis, functions, base=10): """ Parameters @@ -455,7 +503,8 @@ class SymmetricalLogScale(ScaleBase): """ name = 'symlog' - def __init__(self, axis, *, base=10, linthresh=2, subs=None, linscale=1): + @handle_axis_parameter + def __init__(self, axis=None, *, base=10, linthresh=2, subs=None, linscale=1): self._transform = SymmetricalLogTransform(base, linthresh, linscale) self.subs = subs @@ -547,7 +596,8 @@ class AsinhScale(ScaleBase): 1024: (256, 512) } - def __init__(self, axis, *, linear_width=1.0, + @handle_axis_parameter + def __init__(self, axis=None, *, linear_width=1.0, base=10, subs='auto', **kwargs): """ Parameters @@ -645,7 +695,8 @@ class LogitScale(ScaleBase): """ name = 'logit' - def __init__(self, axis, nonpositive='mask', *, + @handle_axis_parameter + def __init__(self, axis=None, nonpositive='mask', *, one_half=r"\frac{1}{2}", use_overline=False): r""" Parameters @@ -725,7 +776,12 @@ def scale_factory(scale, axis, **kwargs): axis : `~matplotlib.axis.Axis` """ scale_cls = _api.check_getitem(_scale_mapping, scale=scale) - return scale_cls(axis, **kwargs) + try: + return scale_cls(axis, **kwargs) + except TypeError as e: + if 'unexpected keyword argument' in str(e) or 'positional argument' in str(e): + return scale_cls(**kwargs) + raise if scale_factory.__doc__: diff --git a/lib/matplotlib/scale.pyi b/lib/matplotlib/scale.pyi index 7fec8e68cc5a..d7f34457ad4f 100644 --- a/lib/matplotlib/scale.pyi +++ b/lib/matplotlib/scale.pyi @@ -2,7 +2,7 @@ from matplotlib.axis import Axis from matplotlib.transforms import Transform from collections.abc import Callable, Iterable -from typing import Literal +from typing import Literal, Union from numpy.typing import ArrayLike class ScaleBase: @@ -15,6 +15,7 @@ class ScaleBase: class LinearScale(ScaleBase): name: str + def __init__(self: ScaleBase, axis: Union[Axis, None] = None) -> None: ... class FuncTransform(Transform): input_dims: int @@ -56,12 +57,12 @@ class LogScale(ScaleBase): name: str subs: Iterable[int] | None def __init__( - self, - axis: Axis | None, + self: LogScale, + axis: Union[Axis, None] = None, *, - base: float = ..., - subs: Iterable[int] | None = ..., - nonpositive: Literal["clip", "mask"] = ... + base: float = 10, + subs: Union[Iterable[int], None] = None, + nonpositive: Union[Literal['clip'], Literal['mask']] = 'clip' ) -> None: ... @property def base(self) -> float: ... @@ -103,13 +104,13 @@ class SymmetricalLogScale(ScaleBase): name: str subs: Iterable[int] | None def __init__( - self, - axis: Axis | None, + self: SymmetricalLogScale, + axis: Union[Axis, None] = None, *, - base: float = ..., - linthresh: float = ..., - subs: Iterable[int] | None = ..., - linscale: float = ... + base: float = 10, + linthresh: float = 2, + subs: Union[Iterable[int], None] = None, + linscale: float = 1 ) -> None: ... @property def base(self) -> float: ... @@ -138,7 +139,7 @@ class AsinhScale(ScaleBase): auto_tick_multipliers: dict[int, tuple[int, ...]] def __init__( self, - axis: Axis | None, + axis: Union[Axis, None] = None, *, linear_width: float = ..., base: float = ..., @@ -164,15 +165,16 @@ class LogisticTransform(Transform): class LogitScale(ScaleBase): name: str def __init__( - self, - axis: Axis | None, - nonpositive: Literal["mask", "clip"] = ..., + self: LogitScale, + axis: Union[Axis, None] = None, + nonpositive: Union[Literal['mask'], Literal['clip']] = 'mask', *, - one_half: str = ..., - use_overline: bool = ... + one_half: str = '\\frac{1}{2}', + use_overline: bool = False ) -> None: ... def get_transform(self) -> LogitTransform: ... def get_scale_names() -> list[str]: ... def scale_factory(scale: str, axis: Axis, **kwargs) -> ScaleBase: ... def register_scale(scale_class: type[ScaleBase]) -> None: ... +def handle_axis_parameter(init_func: Callable[..., None]) -> Callable[..., None]: ...