diff --git a/lib/matplotlib/scale.py b/lib/matplotlib/scale.py index 44fbe5209c4d..4517b8946b03 100644 --- a/lib/matplotlib/scale.py +++ b/lib/matplotlib/scale.py @@ -31,6 +31,7 @@ import inspect import textwrap +from functools import wraps import numpy as np @@ -103,6 +104,53 @@ def limit_range_for_scale(self, vmin, vmax, minpos): return vmin, vmax +def _make_axis_parameter_optional(init_func): + """ + Decorator to allow leaving out the *axis* parameter in scale constructors. + + This decorator ensures backward compatibility for scale classes that + previously required an *axis* parameter. It allows constructors to be + callerd with or without the *axis* parameter. + + For simplicity, this does not handle the case when *axis* + is passed as a keyword. However, + scanning GitHub, there's no evidence that that is used anywhere. + + Parameters + ---------- + init_func : callable + The original __init__ method of a scale class. + + Returns + ------- + callable + A wrapped version of *init_func* that handles the optional *axis*. + + Notes + ----- + If the wrapped constructor defines *axis* as its first argument, the + parameter is preserved when present. Otherwise, the value `None` is injected + as the first argument. + + Examples + -------- + >>> from matplotlib.scale import ScaleBase + >>> class CustomScale(ScaleBase): + ... @_make_axis_parameter_optional + ... def __init__(self, axis, custom_param=1): + ... self.custom_param = custom_param + """ + @wraps(init_func) + def wrapper(self, *args, **kwargs): + if args and isinstance(args[0], mpl.axis.Axis): + return init_func(self, *args, **kwargs) + else: + # Remove 'axis' from kwargs to avoid double assignment + axis = kwargs.pop('axis', None) + return init_func(self, axis, *args, **kwargs) + return wrapper + + class LinearScale(ScaleBase): """ The default linear scale. @@ -110,6 +158,7 @@ class LinearScale(ScaleBase): name = 'linear' + @_make_axis_parameter_optional def __init__(self, axis): # This method is present only to prevent inheritance of the base class' # constructor docstring, which would otherwise end up interpolated into @@ -180,6 +229,7 @@ class FuncScale(ScaleBase): name = 'function' + @_make_axis_parameter_optional def __init__(self, axis, functions): """ Parameters @@ -279,7 +329,8 @@ class LogScale(ScaleBase): """ name = 'log' - def __init__(self, axis, *, base=10, subs=None, nonpositive="clip"): + @_make_axis_parameter_optional + def __init__(self, axis=None, *, base=10, subs=None, nonpositive="clip"): """ Parameters ---------- @@ -330,6 +381,7 @@ class FuncScaleLog(LogScale): name = 'functionlog' + @_make_axis_parameter_optional def __init__(self, axis, functions, base=10): """ Parameters @@ -455,7 +507,8 @@ class SymmetricalLogScale(ScaleBase): """ name = 'symlog' - def __init__(self, axis, *, base=10, linthresh=2, subs=None, linscale=1): + @_make_axis_parameter_optional + def __init__(self, axis=None, *, base=10, linthresh=2, subs=None, linscale=1): self._transform = SymmetricalLogTransform(base, linthresh, linscale) self.subs = subs @@ -547,7 +600,8 @@ class AsinhScale(ScaleBase): 1024: (256, 512) } - def __init__(self, axis, *, linear_width=1.0, + @_make_axis_parameter_optional + def __init__(self, axis=None, *, linear_width=1.0, base=10, subs='auto', **kwargs): """ Parameters @@ -645,7 +699,8 @@ class LogitScale(ScaleBase): """ name = 'logit' - def __init__(self, axis, nonpositive='mask', *, + @_make_axis_parameter_optional + def __init__(self, axis=None, nonpositive='mask', *, one_half=r"\frac{1}{2}", use_overline=False): r""" Parameters diff --git a/lib/matplotlib/scale.pyi b/lib/matplotlib/scale.pyi index 7fec8e68cc5a..ba9f269b8c78 100644 --- a/lib/matplotlib/scale.pyi +++ b/lib/matplotlib/scale.pyi @@ -15,6 +15,10 @@ class ScaleBase: class LinearScale(ScaleBase): name: str + def __init__( + self, + axis: Axis | None, + ) -> None: ... class FuncTransform(Transform): input_dims: int @@ -57,7 +61,7 @@ class LogScale(ScaleBase): subs: Iterable[int] | None def __init__( self, - axis: Axis | None, + axis: Axis | None = ..., *, base: float = ..., subs: Iterable[int] | None = ..., @@ -104,7 +108,7 @@ class SymmetricalLogScale(ScaleBase): subs: Iterable[int] | None def __init__( self, - axis: Axis | None, + axis: Axis | None = ..., *, base: float = ..., linthresh: float = ..., @@ -138,7 +142,7 @@ class AsinhScale(ScaleBase): auto_tick_multipliers: dict[int, tuple[int, ...]] def __init__( self, - axis: Axis | None, + axis: Axis | None = ..., *, linear_width: float = ..., base: float = ..., @@ -165,7 +169,7 @@ class LogitScale(ScaleBase): name: str def __init__( self, - axis: Axis | None, + axis: Axis | None = ..., nonpositive: Literal["mask", "clip"] = ..., *, one_half: str = ..., @@ -176,3 +180,4 @@ class LogitScale(ScaleBase): def get_scale_names() -> list[str]: ... def scale_factory(scale: str, axis: Axis, **kwargs) -> ScaleBase: ... def register_scale(scale_class: type[ScaleBase]) -> None: ... +def _make_axis_parameter_optional(init_func: Callable[..., None]) -> Callable[..., None]: ...