From 65b295bc822b7daf2b22bb3ce9c459c033091468 Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Fri, 4 Nov 2022 17:21:31 -0500 Subject: [PATCH 1/2] Treat units info as nu, make nu a list rather than single function --- data_prototype/patches.py | 4 +- data_prototype/wrappers.py | 91 ++++++++++++++++++++------------------ 2 files changed, 50 insertions(+), 45 deletions(-) diff --git a/data_prototype/patches.py b/data_prototype/patches.py index 94a50af..acf98f1 100644 --- a/data_prototype/patches.py +++ b/data_prototype/patches.py @@ -45,12 +45,12 @@ class PatchWrapper(ProxyWrapper): } def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + super().__init__(data, nus, xunits=self._xunits, yunits=self._yunits) self._wrapped_instance = self._wrapped_class([0, 0], 0, 0, **kwargs) @_stale_wrapper def draw(self, renderer): - self._update_wrapped(self._query_and_transform(renderer, xunits=self._xunits, yunits=self._yunits)) + self._update_wrapped(self._query_and_transform(renderer)) return self._wrapped_instance.draw(renderer) def _update_wrapped(self, data): diff --git a/data_prototype/wrappers.py b/data_prototype/wrappers.py index 83ab843..e7ccf26 100644 --- a/data_prototype/wrappers.py +++ b/data_prototype/wrappers.py @@ -1,9 +1,10 @@ -from typing import List, Dict, Any, Protocol, Tuple, get_type_hints +from typing import Any, Protocol, get_type_hints import inspect import numpy as np from cachetools import LFUCache +from collections.abc import Sequence from functools import partial, wraps import matplotlib as mpl @@ -19,7 +20,7 @@ class _BBox(Protocol): - size: Tuple[float, float] + size: tuple[float, float] class _Axis(Protocol): @@ -34,10 +35,10 @@ class _Axes(Protocol): transData: _MatplotlibTransform transAxes: _MatplotlibTransform - def get_xlim(self) -> Tuple[float, float]: + def get_xlim(self) -> tuple[float, float]: ... - def get_ylim(self) -> Tuple[float, float]: + def get_ylim(self) -> tuple[float, float]: ... def get_window_extent(self, renderer) -> _BBox: @@ -47,15 +48,16 @@ def get_window_extent(self, renderer) -> _BBox: class _Aritst(Protocol): axes: _Axes +def _make_param_name(k, func): + def wrapped(**kwargs): + (arg,) = kwargs.values() + return func(arg) -def _make_identity(k): - def identity(**kwargs): - (_,) = kwargs.values() - return _ - - identity.__signature__ = inspect.Signature([inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)]) - return identity + wrapped.__signature__ = inspect.Signature([inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)]) + return wrapped +def _make_identity(k): + return _make_param_name(k, lambda x: x) def _forwarder(forwards, cls=None): if cls is None: @@ -109,7 +111,7 @@ def draw(self, renderer): def _update_wrapped(self, data): raise NotImplementedError - def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]) -> Dict[str, Any]: + def _query_and_transform(self, renderer) -> dict[str, Any]: """ Helper to centralize the data querying and python-side transforms @@ -139,38 +141,43 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str] return self._cache[cache_key] except KeyError: ... - # TODO decide if units go pre-nu or post-nu? - for x_like in xunits: - if x_like in data: - data[x_like] = ax.xaxis.convert_units(data[x_like]) - for y_like in yunits: - if y_like in data: - data[y_like] = ax.xaxis.convert_units(data[y_like]) - # doing the nu work here is nice because we can write it once, but we # really want to push this computation down a layer # TODO sort out how this interoperates with the transform stack transformed_data = {} - for k, (nu, sig) in self._sigs.items(): - to_pass = set(sig.parameters) - transformed_data[k] = nu(**{k: data[k] for k in to_pass}) + for k, nu_list in self._sigs.items(): + for nu, sig in nu_list: + to_pass = set(sig.parameters) + transformed_data[k] = nu(**{k: transformed_data.get(k, data[k]) for k in to_pass}) self._cache[cache_key] = transformed_data return transformed_data - def __init__(self, data, nus, **kwargs): + def __init__(self, data, nus, xunits: tuple[str, ...] = (), yunits: tuple[str, ...] = (), **kwargs): super().__init__(**kwargs) self.data = data self._cache = LFUCache(64) # TODO make sure mutating this will invalidate the cache! self._nus = nus or {} for k in self.required_keys: - self._nus.setdefault(k, _make_identity(k)) + self._nus.setdefault(k, [_make_identity(k)]) + desc = data.describe() for k in self.expected_keys: if k in desc: - self._nus.setdefault(k, _make_identity(k)) - self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._nus.items()} + self._nus.setdefault(k, [_make_identity(k)]) + + for field in self._nus: + if inspect.isfunction(self._nus[field]): + self._nus[field] = [self._nus[field]] + + for field in xunits: + self._nus[field].append(_make_param_name(field, lambda x: self.axes.xaxis.convert_units(x))) + + for field in yunits: + self._nus[field].append(_make_param_name(field, lambda y: self.axes.yaxis.convert_units(y))) + + self._sigs = {k: [(nu, inspect.signature(nu)) for nu in nu_list] for k, nu_list in self._nus.items()} self.stale = True # TODO add a setter @@ -180,7 +187,7 @@ def nus(self): class ProxyWrapper(ProxyWrapperBase): - _privtized_methods: Tuple[str, ...] = () + _privtized_methods: tuple[str, ...] = () _wrapped_class = None _wrapped_instance: _Aritst @@ -206,13 +213,13 @@ class LineWrapper(ProxyWrapper): required_keys = {"x", "y"} def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + super().__init__(data, nus, xunits=["x"], yunits=["y"]) self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs) @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform(renderer, xunits=["x"], yunits=["y"]), + self._query_and_transform(renderer), ) return self._wrapped_instance.draw(renderer) @@ -239,14 +246,14 @@ class PathCollectionWrapper(ProxyWrapper): ) def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + super().__init__(data, nus, xunits = ("x",), yunits = ("y",)) self._wrapped_instance = self._wrapped_class([], **kwargs) self._wrapped_instance.set_transform(mtransforms.IdentityTransform()) @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform(renderer, xunits=["x"], yunits=["y"]), + self._query_and_transform(renderer), ) return self._wrapped_instance.draw(renderer) @@ -272,14 +279,14 @@ def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwa if norm is None: raise ValueError("not sure how to do autoscaling yet") nus["image"] = lambda image: cmap(norm(image)) - super().__init__(data, nus) + super().__init__(data, nus, xunits=["xextent"], yunits=["yextent"]) kwargs.setdefault("origin", "lower") self._wrapped_instance = self._wrapped_class(None, **kwargs) @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform(renderer, xunits=["xextent"], yunits=["yextent"]), + self._query_and_transform(renderer), ) return self._wrapped_instance.draw(renderer) @@ -294,13 +301,13 @@ class StepWrapper(ProxyWrapper): required_keys = {"edges", "density"} def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + super().__init__(data, nus, xunits=["edges"], yunits=["density"]) self._wrapped_instance = self._wrapped_class([], [1], **kwargs) @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform(renderer, xunits=["edges"], yunits=["density"]), + self._query_and_transform(renderer), ) return self._wrapped_instance.draw(renderer) @@ -319,7 +326,7 @@ def __init__(self, data: DataContainer, nus=None, /, **kwargs): @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform(renderer, xunits=[], yunits=[]), + self._query_and_transform(renderer), ) return self._wrapped_instance.draw(renderer) @@ -342,8 +349,8 @@ def _update_wrapped(self, data): ) # _Artist has to go last for now because it is not (yet) MI friendly. class MultiProxyWrapper(ProxyWrapperBase, _Artist): - _privtized_methods: Tuple[str, ...] = () - _wrapped_instances: Dict[str, _Aritst] + _privtized_methods: tuple[str, ...] = () + _wrapped_instances: dict[str, _Aritst] def __setattr__(self, key, value): attrs = set(get_type_hints(type(self))) @@ -369,7 +376,7 @@ class ErrorbarWrapper(MultiProxyWrapper): expected_keys = {f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"]} def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus) + super().__init__(data, nus, xunits=["x", "xupper", "xlower"], yunits=["y", "yupper", "ylower"]) # TODO all of the kwarg teasing apart that is needed color = kwargs.pop("color", "k") lw = kwargs.pop("lw", 2) @@ -396,9 +403,7 @@ def __init__(self, data: DataContainer, nus=None, /, **kwargs): @_stale_wrapper def draw(self, renderer): self._update_wrapped( - self._query_and_transform( - renderer, xunits=["x", "xupper", "xlower"], yunits=["y", "yupper", "ylower"] - ), + self._query_and_transform(renderer), ) for k, v in self._wrapped_instances.items(): v.draw(renderer) From baebd64896bf1752dc69d03e9b598c47553233c2 Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Wed, 8 Feb 2023 19:09:02 -0600 Subject: [PATCH 2/2] STY: blacken --- data_prototype/wrappers.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/data_prototype/wrappers.py b/data_prototype/wrappers.py index e7ccf26..68a2b8b 100644 --- a/data_prototype/wrappers.py +++ b/data_prototype/wrappers.py @@ -4,7 +4,6 @@ import numpy as np from cachetools import LFUCache -from collections.abc import Sequence from functools import partial, wraps import matplotlib as mpl @@ -48,6 +47,7 @@ def get_window_extent(self, renderer) -> _BBox: class _Aritst(Protocol): axes: _Axes + def _make_param_name(k, func): def wrapped(**kwargs): (arg,) = kwargs.values() @@ -56,9 +56,11 @@ def wrapped(**kwargs): wrapped.__signature__ = inspect.Signature([inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)]) return wrapped + def _make_identity(k): return _make_param_name(k, lambda x: x) + def _forwarder(forwards, cls=None): if cls is None: return partial(_forwarder, forwards) @@ -246,7 +248,7 @@ class PathCollectionWrapper(ProxyWrapper): ) def __init__(self, data: DataContainer, nus=None, /, **kwargs): - super().__init__(data, nus, xunits = ("x",), yunits = ("y",)) + super().__init__(data, nus, xunits=("x",), yunits=("y",)) self._wrapped_instance = self._wrapped_class([], **kwargs) self._wrapped_instance.set_transform(mtransforms.IdentityTransform())