-
-
Notifications
You must be signed in to change notification settings - Fork 7.9k
Refactoring: Removing axis parameter from scales #29988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d8fb083
1eea948
93bf34b
d241761
075c4b3
5bef6c2
5cceb7a
d18d7a1
538dbd2
851b25c
1a516b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -31,6 +31,7 @@ | |||||||||||||||
|
||||||||||||||||
import inspect | ||||||||||||||||
import textwrap | ||||||||||||||||
from functools import wraps | ||||||||||||||||
|
||||||||||||||||
import numpy as np | ||||||||||||||||
|
||||||||||||||||
|
@@ -103,14 +104,58 @@ | |||||||||||||||
return vmin, vmax | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
def handle_axis_parameter(init_func): | ||||||||||||||||
""" | ||||||||||||||||
Decorator to handle the optional *axis* parameter in scale constructors. | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
||||||||||||||||
This decorator ensures backward compatibility for scale classes that | ||||||||||||||||
previously required an *axis* parameter. It allows constructors to work | ||||||||||||||||
seamlessly with or without the *axis* parameter. | ||||||||||||||||
Comment on lines
+112
to
+113
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
||||||||||||||||
Parameters | ||||||||||||||||
---------- | ||||||||||||||||
init_func : callable | ||||||||||||||||
The original __init__ method of a scale class. | ||||||||||||||||
|
||||||||||||||||
Returns | ||||||||||||||||
------- | ||||||||||||||||
callable | ||||||||||||||||
A wrapped version of *init_func* that handles the optional *axis*. | ||||||||||||||||
|
||||||||||||||||
Notes | ||||||||||||||||
----- | ||||||||||||||||
If the wrapped constructor defines *axis* as its first argument, the | ||||||||||||||||
parameter is preserved when present. Otherwise, the value `None` is injected | ||||||||||||||||
as the first argument. | ||||||||||||||||
|
||||||||||||||||
Examples | ||||||||||||||||
-------- | ||||||||||||||||
>>> from matplotlib.scale import ScaleBase | ||||||||||||||||
>>> class CustomScale(ScaleBase): | ||||||||||||||||
... @handle_axis_parameter | ||||||||||||||||
... def __init__(self, axis, custom_param=1): | ||||||||||||||||
... self.custom_param = custom_param | ||||||||||||||||
""" | ||||||||||||||||
@wraps(init_func) | ||||||||||||||||
def wrapper(self, *args, **kwargs): | ||||||||||||||||
if args and isinstance(args[0], mpl.axis.Axis): | ||||||||||||||||
return init_func(self, *args, **kwargs) | ||||||||||||||||
else: | ||||||||||||||||
# Remove 'axis' from kwargs to avoid double assignment | ||||||||||||||||
kwargs.pop('axis', None) | ||||||||||||||||
return init_func(self, None, *args, **kwargs) | ||||||||||||||||
return wrapper | ||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class LinearScale(ScaleBase): | ||||||||||||||||
""" | ||||||||||||||||
The default linear scale. | ||||||||||||||||
""" | ||||||||||||||||
|
||||||||||||||||
name = 'linear' | ||||||||||||||||
|
||||||||||||||||
def __init__(self, axis): | ||||||||||||||||
@handle_axis_parameter | ||||||||||||||||
def __init__(self, axis=None): | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there an advantage in making |
||||||||||||||||
# This method is present only to prevent inheritance of the base class' | ||||||||||||||||
# constructor docstring, which would otherwise end up interpolated into | ||||||||||||||||
# the docstring of Axis.set_scale. | ||||||||||||||||
|
@@ -180,6 +225,7 @@ | |||||||||||||||
|
||||||||||||||||
name = 'function' | ||||||||||||||||
|
||||||||||||||||
@handle_axis_parameter | ||||||||||||||||
def __init__(self, axis, functions): | ||||||||||||||||
""" | ||||||||||||||||
Parameters | ||||||||||||||||
|
@@ -279,7 +325,8 @@ | |||||||||||||||
""" | ||||||||||||||||
name = 'log' | ||||||||||||||||
|
||||||||||||||||
def __init__(self, axis, *, base=10, subs=None, nonpositive="clip"): | ||||||||||||||||
@handle_axis_parameter | ||||||||||||||||
def __init__(self, axis=None, *, base=10, subs=None, nonpositive="clip"): | ||||||||||||||||
""" | ||||||||||||||||
Parameters | ||||||||||||||||
---------- | ||||||||||||||||
|
@@ -330,6 +377,7 @@ | |||||||||||||||
|
||||||||||||||||
name = 'functionlog' | ||||||||||||||||
|
||||||||||||||||
@handle_axis_parameter | ||||||||||||||||
def __init__(self, axis, functions, base=10): | ||||||||||||||||
""" | ||||||||||||||||
Parameters | ||||||||||||||||
|
@@ -455,7 +503,8 @@ | |||||||||||||||
""" | ||||||||||||||||
name = 'symlog' | ||||||||||||||||
|
||||||||||||||||
def __init__(self, axis, *, base=10, linthresh=2, subs=None, linscale=1): | ||||||||||||||||
@handle_axis_parameter | ||||||||||||||||
def __init__(self, axis=None, *, base=10, linthresh=2, subs=None, linscale=1): | ||||||||||||||||
self._transform = SymmetricalLogTransform(base, linthresh, linscale) | ||||||||||||||||
self.subs = subs | ||||||||||||||||
|
||||||||||||||||
|
@@ -547,7 +596,8 @@ | |||||||||||||||
1024: (256, 512) | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
def __init__(self, axis, *, linear_width=1.0, | ||||||||||||||||
@handle_axis_parameter | ||||||||||||||||
def __init__(self, axis=None, *, linear_width=1.0, | ||||||||||||||||
base=10, subs='auto', **kwargs): | ||||||||||||||||
""" | ||||||||||||||||
Parameters | ||||||||||||||||
|
@@ -645,7 +695,8 @@ | |||||||||||||||
""" | ||||||||||||||||
name = 'logit' | ||||||||||||||||
|
||||||||||||||||
def __init__(self, axis, nonpositive='mask', *, | ||||||||||||||||
@handle_axis_parameter | ||||||||||||||||
def __init__(self, axis=None, nonpositive='mask', *, | ||||||||||||||||
one_half=r"\frac{1}{2}", use_overline=False): | ||||||||||||||||
r""" | ||||||||||||||||
Parameters | ||||||||||||||||
|
@@ -725,7 +776,12 @@ | |||||||||||||||
axis : `~matplotlib.axis.Axis` | ||||||||||||||||
""" | ||||||||||||||||
scale_cls = _api.check_getitem(_scale_mapping, scale=scale) | ||||||||||||||||
return scale_cls(axis, **kwargs) | ||||||||||||||||
try: | ||||||||||||||||
return scale_cls(axis, **kwargs) | ||||||||||||||||
except TypeError as e: | ||||||||||||||||
if 'unexpected keyword argument' in str(e) or 'positional argument' in str(e): | ||||||||||||||||
return scale_cls(**kwargs) | ||||||||||||||||
raise | ||||||||||||||||
Comment on lines
+781
to
+784
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure this is safe? Basing logic on the exception message only can be brittle. Minimally, I'd want an explicit comment on what we catch here. But I'd rather split this out into a separte PR as we need tests. |
||||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
if scale_factory.__doc__: | ||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,7 +2,7 @@ from matplotlib.axis import Axis | |||||
from matplotlib.transforms import Transform | ||||||
|
||||||
from collections.abc import Callable, Iterable | ||||||
from typing import Literal | ||||||
from typing import Literal, Union | ||||||
from numpy.typing import ArrayLike | ||||||
|
||||||
class ScaleBase: | ||||||
|
@@ -15,6 +15,7 @@ class ScaleBase: | |||||
|
||||||
class LinearScale(ScaleBase): | ||||||
name: str | ||||||
def __init__(self: ScaleBase, axis: Union[Axis, None] = None) -> None: ... | ||||||
|
||||||
class FuncTransform(Transform): | ||||||
input_dims: int | ||||||
|
@@ -56,12 +57,12 @@ class LogScale(ScaleBase): | |||||
name: str | ||||||
subs: Iterable[int] | None | ||||||
def __init__( | ||||||
self, | ||||||
axis: Axis | None, | ||||||
self: LogScale, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't type |
||||||
axis: Union[Axis, None] = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
*, | ||||||
base: float = ..., | ||||||
subs: Iterable[int] | None = ..., | ||||||
nonpositive: Literal["clip", "mask"] = ... | ||||||
base: float = 10, | ||||||
subs: Union[Iterable[int], None] = None, | ||||||
nonpositive: Union[Literal['clip'], Literal['mask']] = 'clip' | ||||||
) -> None: ... | ||||||
@property | ||||||
def base(self) -> float: ... | ||||||
|
@@ -103,13 +104,13 @@ class SymmetricalLogScale(ScaleBase): | |||||
name: str | ||||||
subs: Iterable[int] | None | ||||||
def __init__( | ||||||
self, | ||||||
axis: Axis | None, | ||||||
self: SymmetricalLogScale, | ||||||
axis: Union[Axis, None] = None, | ||||||
*, | ||||||
base: float = ..., | ||||||
linthresh: float = ..., | ||||||
subs: Iterable[int] | None = ..., | ||||||
linscale: float = ... | ||||||
base: float = 10, | ||||||
linthresh: float = 2, | ||||||
subs: Union[Iterable[int], None] = None, | ||||||
linscale: float = 1 | ||||||
) -> None: ... | ||||||
@property | ||||||
def base(self) -> float: ... | ||||||
|
@@ -138,7 +139,7 @@ class AsinhScale(ScaleBase): | |||||
auto_tick_multipliers: dict[int, tuple[int, ...]] | ||||||
def __init__( | ||||||
self, | ||||||
axis: Axis | None, | ||||||
axis: Union[Axis, None] = None, | ||||||
*, | ||||||
linear_width: float = ..., | ||||||
base: float = ..., | ||||||
|
@@ -164,15 +165,16 @@ class LogisticTransform(Transform): | |||||
class LogitScale(ScaleBase): | ||||||
name: str | ||||||
def __init__( | ||||||
self, | ||||||
axis: Axis | None, | ||||||
nonpositive: Literal["mask", "clip"] = ..., | ||||||
self: LogitScale, | ||||||
axis: Union[Axis, None] = None, | ||||||
nonpositive: Union[Literal['mask'], Literal['clip']] = 'mask', | ||||||
*, | ||||||
one_half: str = ..., | ||||||
use_overline: bool = ... | ||||||
one_half: str = '\\frac{1}{2}', | ||||||
use_overline: bool = False | ||||||
) -> None: ... | ||||||
def get_transform(self) -> LogitTransform: ... | ||||||
|
||||||
def get_scale_names() -> list[str]: ... | ||||||
def scale_factory(scale: str, axis: Axis, **kwargs) -> ScaleBase: ... | ||||||
def register_scale(scale_class: type[ScaleBase]) -> None: ... | ||||||
def handle_axis_parameter(init_func: Callable[..., None]) -> Callable[..., None]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be very explicit about the naming, and keep it internal. I don't expect that possible downstream Child classes will need to support optional axis parameters.