Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Refactoring: Removing axis parameter from scales #29988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
68 changes: 62 additions & 6 deletions lib/matplotlib/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import inspect
import textwrap
from functools import wraps

import numpy as np

Expand Down Expand Up @@ -103,14 +104,58 @@
return vmin, vmax


def handle_axis_parameter(init_func):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def handle_axis_parameter(init_func):
def _make_axis_parameter_optional(init_func):

Let's be very explicit about the naming, and keep it internal. I don't expect that possible downstream Child classes will need to support optional axis parameters.

"""
Decorator to handle the optional *axis* parameter in scale constructors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Decorator to handle the optional *axis* parameter in scale constructors.
Decorator to allow leaving out the *axis* parameter in scale constructors.


This decorator ensures backward compatibility for scale classes that
previously required an *axis* parameter. It allows constructors to work
seamlessly with or without the *axis* parameter.
Comment on lines +112 to +113
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
previously required an *axis* parameter. It allows constructors to work
seamlessly with or without the *axis* parameter.
previously required an *axis* parameter. It allows constructors to be
callerd with or without the *axis* parameter.
For simplicity, this does not handle the case when *axis* is passed as a keyword. Howver,
scanning GitHub, there's no evidence that that is used anywhere.


Parameters
----------
init_func : callable
The original __init__ method of a scale class.

Returns
-------
callable
A wrapped version of *init_func* that handles the optional *axis*.

Notes
-----
If the wrapped constructor defines *axis* as its first argument, the
parameter is preserved when present. Otherwise, the value `None` is injected
as the first argument.

Examples
--------
>>> from matplotlib.scale import ScaleBase
>>> class CustomScale(ScaleBase):
... @handle_axis_parameter
... def __init__(self, axis, custom_param=1):
... self.custom_param = custom_param
"""
@wraps(init_func)
def wrapper(self, *args, **kwargs):
if args and isinstance(args[0], mpl.axis.Axis):
return init_func(self, *args, **kwargs)
else:
# Remove 'axis' from kwargs to avoid double assignment
kwargs.pop('axis', None)
return init_func(self, None, *args, **kwargs)
return wrapper


class LinearScale(ScaleBase):
"""
The default linear scale.
"""

name = 'linear'

def __init__(self, axis):
@handle_axis_parameter
def __init__(self, axis=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an advantage in making axis optional here? That should be handled be handled by the decorator. If there's no particular reason, then I'd prefer to keep the signature formally unchanged for now.

# This method is present only to prevent inheritance of the base class'
# constructor docstring, which would otherwise end up interpolated into
# the docstring of Axis.set_scale.
Expand Down Expand Up @@ -180,6 +225,7 @@

name = 'function'

@handle_axis_parameter
def __init__(self, axis, functions):
"""
Parameters
Expand Down Expand Up @@ -279,7 +325,8 @@
"""
name = 'log'

def __init__(self, axis, *, base=10, subs=None, nonpositive="clip"):
@handle_axis_parameter
def __init__(self, axis=None, *, base=10, subs=None, nonpositive="clip"):
"""
Parameters
----------
Expand Down Expand Up @@ -330,6 +377,7 @@

name = 'functionlog'

@handle_axis_parameter
def __init__(self, axis, functions, base=10):
"""
Parameters
Expand Down Expand Up @@ -455,7 +503,8 @@
"""
name = 'symlog'

def __init__(self, axis, *, base=10, linthresh=2, subs=None, linscale=1):
@handle_axis_parameter
def __init__(self, axis=None, *, base=10, linthresh=2, subs=None, linscale=1):
self._transform = SymmetricalLogTransform(base, linthresh, linscale)
self.subs = subs

Expand Down Expand Up @@ -547,7 +596,8 @@
1024: (256, 512)
}

def __init__(self, axis, *, linear_width=1.0,
@handle_axis_parameter
def __init__(self, axis=None, *, linear_width=1.0,
base=10, subs='auto', **kwargs):
"""
Parameters
Expand Down Expand Up @@ -645,7 +695,8 @@
"""
name = 'logit'

def __init__(self, axis, nonpositive='mask', *,
@handle_axis_parameter
def __init__(self, axis=None, nonpositive='mask', *,
one_half=r"\frac{1}{2}", use_overline=False):
r"""
Parameters
Expand Down Expand Up @@ -725,7 +776,12 @@
axis : `~matplotlib.axis.Axis`
"""
scale_cls = _api.check_getitem(_scale_mapping, scale=scale)
return scale_cls(axis, **kwargs)
try:
return scale_cls(axis, **kwargs)
except TypeError as e:
if 'unexpected keyword argument' in str(e) or 'positional argument' in str(e):
return scale_cls(**kwargs)
raise

Check warning on line 784 in lib/matplotlib/scale.py

View check run for this annotation

Codecov / codecov/patch

lib/matplotlib/scale.py#L784

Added line #L784 was not covered by tests
Comment on lines +781 to +784
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this is safe? Basing logic on the exception message only can be brittle. Minimally, I'd want an explicit comment on what we catch here. But I'd rather split this out into a separte PR as we need tests.



if scale_factory.__doc__:
Expand Down
38 changes: 20 additions & 18 deletions lib/matplotlib/scale.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from matplotlib.axis import Axis
from matplotlib.transforms import Transform

from collections.abc import Callable, Iterable
from typing import Literal
from typing import Literal, Union
from numpy.typing import ArrayLike

class ScaleBase:
Expand All @@ -15,6 +15,7 @@ class ScaleBase:

class LinearScale(ScaleBase):
name: str
def __init__(self: ScaleBase, axis: Union[Axis, None] = None) -> None: ...

class FuncTransform(Transform):
input_dims: int
Expand Down Expand Up @@ -56,12 +57,12 @@ class LogScale(ScaleBase):
name: str
subs: Iterable[int] | None
def __init__(
self,
axis: Axis | None,
self: LogScale,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't type self. It's implicitly understood by typecheckers.

axis: Union[Axis, None] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
axis: Union[Axis, None] = None,
axis: Axis | None = ...,

Axis | None is equivalent to Union[Axis, None], but the preferred spelling. You don't repeat default values in stub files. Instead, you place an ellipsis.

*,
base: float = ...,
subs: Iterable[int] | None = ...,
nonpositive: Literal["clip", "mask"] = ...
base: float = 10,
subs: Union[Iterable[int], None] = None,
nonpositive: Union[Literal['clip'], Literal['mask']] = 'clip'
) -> None: ...
@property
def base(self) -> float: ...
Expand Down Expand Up @@ -103,13 +104,13 @@ class SymmetricalLogScale(ScaleBase):
name: str
subs: Iterable[int] | None
def __init__(
self,
axis: Axis | None,
self: SymmetricalLogScale,
axis: Union[Axis, None] = None,
*,
base: float = ...,
linthresh: float = ...,
subs: Iterable[int] | None = ...,
linscale: float = ...
base: float = 10,
linthresh: float = 2,
subs: Union[Iterable[int], None] = None,
linscale: float = 1
) -> None: ...
@property
def base(self) -> float: ...
Expand Down Expand Up @@ -138,7 +139,7 @@ class AsinhScale(ScaleBase):
auto_tick_multipliers: dict[int, tuple[int, ...]]
def __init__(
self,
axis: Axis | None,
axis: Union[Axis, None] = None,
*,
linear_width: float = ...,
base: float = ...,
Expand All @@ -164,15 +165,16 @@ class LogisticTransform(Transform):
class LogitScale(ScaleBase):
name: str
def __init__(
self,
axis: Axis | None,
nonpositive: Literal["mask", "clip"] = ...,
self: LogitScale,
axis: Union[Axis, None] = None,
nonpositive: Union[Literal['mask'], Literal['clip']] = 'mask',
*,
one_half: str = ...,
use_overline: bool = ...
one_half: str = '\\frac{1}{2}',
use_overline: bool = False
) -> None: ...
def get_transform(self) -> LogitTransform: ...

def get_scale_names() -> list[str]: ...
def scale_factory(scale: str, axis: Axis, **kwargs) -> ScaleBase: ...
def register_scale(scale_class: type[ScaleBase]) -> None: ...
def handle_axis_parameter(init_func: Callable[..., None]) -> Callable[..., None]: ...
Loading