Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Treat units info as nu, make nu a list rather than single function #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions data_prototype/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ class PatchWrapper(ProxyWrapper):
}

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
super().__init__(data, nus, xunits=self._xunits, yunits=self._yunits)
self._wrapped_instance = self._wrapped_class([0, 0], 0, 0, **kwargs)

@_stale_wrapper
def draw(self, renderer):
self._update_wrapped(self._query_and_transform(renderer, xunits=self._xunits, yunits=self._yunits))
self._update_wrapped(self._query_and_transform(renderer))
return self._wrapped_instance.draw(renderer)

def _update_wrapped(self, data):
Expand Down
91 changes: 49 additions & 42 deletions data_prototype/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Any, Protocol, Tuple, get_type_hints
from typing import Any, Protocol, get_type_hints
import inspect

import numpy as np
Expand All @@ -19,7 +19,7 @@


class _BBox(Protocol):
size: Tuple[float, float]
size: tuple[float, float]


class _Axis(Protocol):
Expand All @@ -34,10 +34,10 @@ class _Axes(Protocol):
transData: _MatplotlibTransform
transAxes: _MatplotlibTransform

def get_xlim(self) -> Tuple[float, float]:
def get_xlim(self) -> tuple[float, float]:
...

def get_ylim(self) -> Tuple[float, float]:
def get_ylim(self) -> tuple[float, float]:
...

def get_window_extent(self, renderer) -> _BBox:
Expand All @@ -48,13 +48,17 @@ class _Aritst(Protocol):
axes: _Axes


def _make_identity(k):
def identity(**kwargs):
(_,) = kwargs.values()
return _
def _make_param_name(k, func):
def wrapped(**kwargs):
(arg,) = kwargs.values()
return func(arg)

wrapped.__signature__ = inspect.Signature([inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)])
return wrapped

identity.__signature__ = inspect.Signature([inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)])
return identity

def _make_identity(k):
return _make_param_name(k, lambda x: x)


def _forwarder(forwards, cls=None):
Expand Down Expand Up @@ -109,7 +113,7 @@ def draw(self, renderer):
def _update_wrapped(self, data):
raise NotImplementedError

def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]) -> Dict[str, Any]:
def _query_and_transform(self, renderer) -> dict[str, Any]:
"""
Helper to centralize the data querying and python-side transforms

Expand Down Expand Up @@ -139,38 +143,43 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]
return self._cache[cache_key]
except KeyError:
...
# TODO decide if units go pre-nu or post-nu?
for x_like in xunits:
if x_like in data:
data[x_like] = ax.xaxis.convert_units(data[x_like])
for y_like in yunits:
if y_like in data:
data[y_like] = ax.xaxis.convert_units(data[y_like])

# doing the nu work here is nice because we can write it once, but we
# really want to push this computation down a layer
# TODO sort out how this interoperates with the transform stack
transformed_data = {}
for k, (nu, sig) in self._sigs.items():
to_pass = set(sig.parameters)
transformed_data[k] = nu(**{k: data[k] for k in to_pass})
for k, nu_list in self._sigs.items():
for nu, sig in nu_list:
to_pass = set(sig.parameters)
transformed_data[k] = nu(**{k: transformed_data.get(k, data[k]) for k in to_pass})

self._cache[cache_key] = transformed_data
return transformed_data

def __init__(self, data, nus, **kwargs):
def __init__(self, data, nus, xunits: tuple[str, ...] = (), yunits: tuple[str, ...] = (), **kwargs):
super().__init__(**kwargs)
self.data = data
self._cache = LFUCache(64)
# TODO make sure mutating this will invalidate the cache!
self._nus = nus or {}
for k in self.required_keys:
self._nus.setdefault(k, _make_identity(k))
self._nus.setdefault(k, [_make_identity(k)])

desc = data.describe()
for k in self.expected_keys:
if k in desc:
self._nus.setdefault(k, _make_identity(k))
self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._nus.items()}
self._nus.setdefault(k, [_make_identity(k)])

for field in self._nus:
if inspect.isfunction(self._nus[field]):
self._nus[field] = [self._nus[field]]

for field in xunits:
self._nus[field].append(_make_param_name(field, lambda x: self.axes.xaxis.convert_units(x)))

for field in yunits:
self._nus[field].append(_make_param_name(field, lambda y: self.axes.yaxis.convert_units(y)))

self._sigs = {k: [(nu, inspect.signature(nu)) for nu in nu_list] for k, nu_list in self._nus.items()}
self.stale = True

# TODO add a setter
Expand All @@ -180,7 +189,7 @@ def nus(self):


class ProxyWrapper(ProxyWrapperBase):
_privtized_methods: Tuple[str, ...] = ()
_privtized_methods: tuple[str, ...] = ()
_wrapped_class = None
_wrapped_instance: _Aritst

Expand All @@ -206,13 +215,13 @@ class LineWrapper(ProxyWrapper):
required_keys = {"x", "y"}

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
super().__init__(data, nus, xunits=["x"], yunits=["y"])
self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs)

@_stale_wrapper
def draw(self, renderer):
self._update_wrapped(
self._query_and_transform(renderer, xunits=["x"], yunits=["y"]),
self._query_and_transform(renderer),
)
return self._wrapped_instance.draw(renderer)

Expand All @@ -239,14 +248,14 @@ class PathCollectionWrapper(ProxyWrapper):
)

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
super().__init__(data, nus, xunits=("x",), yunits=("y",))
self._wrapped_instance = self._wrapped_class([], **kwargs)
self._wrapped_instance.set_transform(mtransforms.IdentityTransform())

@_stale_wrapper
def draw(self, renderer):
self._update_wrapped(
self._query_and_transform(renderer, xunits=["x"], yunits=["y"]),
self._query_and_transform(renderer),
)
return self._wrapped_instance.draw(renderer)

Expand All @@ -272,14 +281,14 @@ def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwa
if norm is None:
raise ValueError("not sure how to do autoscaling yet")
nus["image"] = lambda image: cmap(norm(image))
super().__init__(data, nus)
super().__init__(data, nus, xunits=["xextent"], yunits=["yextent"])
kwargs.setdefault("origin", "lower")
self._wrapped_instance = self._wrapped_class(None, **kwargs)

@_stale_wrapper
def draw(self, renderer):
self._update_wrapped(
self._query_and_transform(renderer, xunits=["xextent"], yunits=["yextent"]),
self._query_and_transform(renderer),
)
return self._wrapped_instance.draw(renderer)

Expand All @@ -294,13 +303,13 @@ class StepWrapper(ProxyWrapper):
required_keys = {"edges", "density"}

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
super().__init__(data, nus, xunits=["edges"], yunits=["density"])
self._wrapped_instance = self._wrapped_class([], [1], **kwargs)

@_stale_wrapper
def draw(self, renderer):
self._update_wrapped(
self._query_and_transform(renderer, xunits=["edges"], yunits=["density"]),
self._query_and_transform(renderer),
)
return self._wrapped_instance.draw(renderer)

Expand All @@ -319,7 +328,7 @@ def __init__(self, data: DataContainer, nus=None, /, **kwargs):
@_stale_wrapper
def draw(self, renderer):
self._update_wrapped(
self._query_and_transform(renderer, xunits=[], yunits=[]),
self._query_and_transform(renderer),
)
return self._wrapped_instance.draw(renderer)

Expand All @@ -342,8 +351,8 @@ def _update_wrapped(self, data):
)
# _Artist has to go last for now because it is not (yet) MI friendly.
class MultiProxyWrapper(ProxyWrapperBase, _Artist):
_privtized_methods: Tuple[str, ...] = ()
_wrapped_instances: Dict[str, _Aritst]
_privtized_methods: tuple[str, ...] = ()
_wrapped_instances: dict[str, _Aritst]

def __setattr__(self, key, value):
attrs = set(get_type_hints(type(self)))
Expand All @@ -369,7 +378,7 @@ class ErrorbarWrapper(MultiProxyWrapper):
expected_keys = {f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"]}

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
super().__init__(data, nus, xunits=["x", "xupper", "xlower"], yunits=["y", "yupper", "ylower"])
# TODO all of the kwarg teasing apart that is needed
color = kwargs.pop("color", "k")
lw = kwargs.pop("lw", 2)
Expand All @@ -396,9 +405,7 @@ def __init__(self, data: DataContainer, nus=None, /, **kwargs):
@_stale_wrapper
def draw(self, renderer):
self._update_wrapped(
self._query_and_transform(
renderer, xunits=["x", "xupper", "xlower"], yunits=["y", "yupper", "ylower"]
),
self._query_and_transform(renderer),
)
for k, v in self._wrapped_instances.items():
v.draw(renderer)
Expand Down