Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Refactoring: Removing axis parameter from scales #29988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions lib/matplotlib/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import inspect
import textwrap
from functools import wraps

import numpy as np

Expand Down Expand Up @@ -103,13 +104,61 @@ def limit_range_for_scale(self, vmin, vmax, minpos):
return vmin, vmax


def _make_axis_parameter_optional(init_func):
"""
Decorator to allow leaving out the *axis* parameter in scale constructors.

This decorator ensures backward compatibility for scale classes that
previously required an *axis* parameter. It allows constructors to be
callerd with or without the *axis* parameter.

For simplicity, this does not handle the case when *axis*
is passed as a keyword. However,
scanning GitHub, there's no evidence that that is used anywhere.

Parameters
----------
init_func : callable
The original __init__ method of a scale class.

Returns
-------
callable
A wrapped version of *init_func* that handles the optional *axis*.

Notes
-----
If the wrapped constructor defines *axis* as its first argument, the
parameter is preserved when present. Otherwise, the value `None` is injected
as the first argument.

Examples
--------
>>> from matplotlib.scale import ScaleBase
>>> class CustomScale(ScaleBase):
... @_make_axis_parameter_optional
... def __init__(self, axis, custom_param=1):
... self.custom_param = custom_param
"""
@wraps(init_func)
def wrapper(self, *args, **kwargs):
if args and isinstance(args[0], mpl.axis.Axis):
return init_func(self, *args, **kwargs)
else:
# Remove 'axis' from kwargs to avoid double assignment
kwargs.pop('axis', None)
return init_func(self, None, *args, **kwargs)
return wrapper


class LinearScale(ScaleBase):
"""
The default linear scale.
"""

name = 'linear'

@_make_axis_parameter_optional
def __init__(self, axis):
# This method is present only to prevent inheritance of the base class'
# constructor docstring, which would otherwise end up interpolated into
Expand Down Expand Up @@ -180,6 +229,7 @@ class FuncScale(ScaleBase):

name = 'function'

@_make_axis_parameter_optional
def __init__(self, axis, functions):
"""
Parameters
Expand Down Expand Up @@ -279,7 +329,8 @@ class LogScale(ScaleBase):
"""
name = 'log'

def __init__(self, axis, *, base=10, subs=None, nonpositive="clip"):
@_make_axis_parameter_optional
def __init__(self, axis=None, *, base=10, subs=None, nonpositive="clip"):
"""
Parameters
----------
Expand Down Expand Up @@ -330,6 +381,7 @@ class FuncScaleLog(LogScale):

name = 'functionlog'

@_make_axis_parameter_optional
def __init__(self, axis, functions, base=10):
"""
Parameters
Expand Down Expand Up @@ -455,7 +507,8 @@ class SymmetricalLogScale(ScaleBase):
"""
name = 'symlog'

def __init__(self, axis, *, base=10, linthresh=2, subs=None, linscale=1):
@_make_axis_parameter_optional
def __init__(self, axis=None, *, base=10, linthresh=2, subs=None, linscale=1):
self._transform = SymmetricalLogTransform(base, linthresh, linscale)
self.subs = subs

Expand Down Expand Up @@ -547,7 +600,8 @@ class AsinhScale(ScaleBase):
1024: (256, 512)
}

def __init__(self, axis, *, linear_width=1.0,
@_make_axis_parameter_optional
def __init__(self, axis=None, *, linear_width=1.0,
base=10, subs='auto', **kwargs):
"""
Parameters
Expand Down Expand Up @@ -645,7 +699,8 @@ class LogitScale(ScaleBase):
"""
name = 'logit'

def __init__(self, axis, nonpositive='mask', *,
@_make_axis_parameter_optional
def __init__(self, axis=None, nonpositive='mask', *,
one_half=r"\frac{1}{2}", use_overline=False):
r"""
Parameters
Expand Down
34 changes: 18 additions & 16 deletions lib/matplotlib/scale.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from matplotlib.axis import Axis
from matplotlib.transforms import Transform

from collections.abc import Callable, Iterable
from typing import Literal
from typing import Literal, Union
from numpy.typing import ArrayLike

class ScaleBase:
Expand All @@ -15,6 +15,7 @@ class ScaleBase:

class LinearScale(ScaleBase):
name: str
def __init__(self: ScaleBase, axis: Union[Axis, None] = None) -> None: ...

class FuncTransform(Transform):
input_dims: int
Expand Down Expand Up @@ -56,12 +57,12 @@ class LogScale(ScaleBase):
name: str
subs: Iterable[int] | None
def __init__(
self,
axis: Axis | None,
self: LogScale,
axis: Axis | None = ...,
*,
base: float = ...,
subs: Iterable[int] | None = ...,
nonpositive: Literal["clip", "mask"] = ...
base: float = 10,
subs: Union[Iterable[int], None] = None,
nonpositive: Union[Literal['clip'], Literal['mask']] = 'clip'
) -> None: ...
@property
def base(self) -> float: ...
Expand Down Expand Up @@ -104,12 +105,12 @@ class SymmetricalLogScale(ScaleBase):
subs: Iterable[int] | None
def __init__(
self,
axis: Axis | None,
axis: Axis | None = ...,
*,
base: float = ...,
linthresh: float = ...,
subs: Iterable[int] | None = ...,
linscale: float = ...
base: float = 10,
linthresh: float = 2,
subs: Union[Iterable[int], None] = None,
linscale: float = 1
) -> None: ...
@property
def base(self) -> float: ...
Expand Down Expand Up @@ -138,7 +139,7 @@ class AsinhScale(ScaleBase):
auto_tick_multipliers: dict[int, tuple[int, ...]]
def __init__(
self,
axis: Axis | None,
axis: Axis | None = ...,
*,
linear_width: float = ...,
base: float = ...,
Expand All @@ -165,14 +166,15 @@ class LogitScale(ScaleBase):
name: str
def __init__(
self,
axis: Axis | None,
nonpositive: Literal["mask", "clip"] = ...,
axis: Axis | None = ...,
nonpositive: Union[Literal['mask'], Literal['clip']] = 'mask',
*,
one_half: str = ...,
use_overline: bool = ...
one_half: str = '\\frac{1}{2}',
use_overline: bool = False
) -> None: ...
def get_transform(self) -> LogitTransform: ...

def get_scale_names() -> list[str]: ...
def scale_factory(scale: str, axis: Axis, **kwargs) -> ScaleBase: ...
def register_scale(scale_class: type[ScaleBase]) -> None: ...
def handle_axis_parameter(init_func: Callable[..., None]) -> Callable[..., None]: ...
Loading