diff --git a/data_prototype/wrappers.py b/data_prototype/wrappers.py index 2d4875b..66bbc33 100644 --- a/data_prototype/wrappers.py +++ b/data_prototype/wrappers.py @@ -1,4 +1,5 @@ from typing import List, Dict, Any, Protocol, Tuple, get_type_hints +import inspect import numpy as np @@ -46,6 +47,15 @@ class _Aritst(Protocol): axes: _Axes +def _make_identity(k): + def identity(**kwargs): + (_,) = kwargs.values() + return _ + + identity.__signature__ = inspect.Signature([inspect.Parameter(k, inspect.Parameter.POSITIONAL_OR_KEYWORD)]) + return identity + + def _forwarder(forwards, cls=None): if cls is None: return partial(_forwarder, forwards) @@ -88,6 +98,8 @@ class ProxyWrapperBase: data: DataContainer axes: _Axes stale: bool + required_keys: set = set() + expected_keys: set = set() @_stale_wrapper def draw(self, renderer): @@ -137,18 +149,34 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str] # doing the nu work here is nice because we can write it once, but we # really want to push this computation down a layer # TODO sort out how this interoperates with the transform stack - data = {k: self.nus.get(k, lambda x: x)(v) for k, v in data.items()} - self._cache[cache_key] = data - return data + transformed_data = {} + for k, (nu, sig) in self._sigs.items(): + to_pass = set(sig.parameters) + transformed_data[k] = nu(**{k: data[k] for k in to_pass}) + + self._cache[cache_key] = transformed_data + return transformed_data def __init__(self, data, nus, **kwargs): super().__init__(**kwargs) self.data = data self._cache = LFUCache(64) # TODO make sure mutating this will invalidate the cache! - self.nus = nus or {} + self._nus = nus or {} + for k in self.required_keys: + self._nus.setdefault(k, _make_identity(k)) + desc = data.describe() + for k in self.expected_keys: + if k in desc: + self._nus.setdefault(k, _make_identity(k)) + self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._nus.items()} self.stale = True + # TODO add a setter + @property + def nus(self): + return dict(self._nus) + class ProxyWrapper(ProxyWrapperBase): _privtized_methods: Tuple[str, ...] = () @@ -163,7 +191,7 @@ def __getattr__(self, key): return getattr(self._wrapped_instance, key) def __setattr__(self, key, value): - if key in ("_wrapped_instance", "data", "_cache", "nus", "stale"): + if key in ("_wrapped_instance", "data", "_cache", "_nus", "stale", "_sigs"): super().__setattr__(key, value) elif hasattr(self, "_wrapped_instance") and hasattr(self._wrapped_instance, key): setattr(self._wrapped_instance, key, value) @@ -174,6 +202,7 @@ def __setattr__(self, key, value): class LineWrapper(ProxyWrapper): _wrapped_class = _Line2D _privtized_methods = ("set_xdata", "set_ydata", "set_data", "get_xdata", "get_ydata", "get_data") + required_keys = {"x", "y"} def __init__(self, data: DataContainer, nus=None, /, **kwargs): super().__init__(data, nus) @@ -187,14 +216,16 @@ def draw(self, renderer): return self._wrapped_instance.draw(renderer) def _update_wrapped(self, data): - self._wrapped_instance.set_data(data["x"], data["y"]) + for k, v in data.items(): + k = {"x": "xdata", "y": "ydata"}.get(k, k) + getattr(self._wrapped_instance, f"set_{k}")(v) class ImageWrapper(ProxyWrapper): _wrapped_class = _AxesImage + required_keys = {"xextent", "yextent", "image"} def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwargs): - print(kwargs, nus) nus = dict(nus or {}) if cmap is not None or norm is not None: if nus is not None and "image" in nus: @@ -223,6 +254,7 @@ def _update_wrapped(self, data): class StepWrapper(ProxyWrapper): _wrapped_class = _StepPatch _privtized_methods = () # ("set_data", "get_data") + required_keys = {"edges", "density"} def __init__(self, data: DataContainer, nus=None, /, **kwargs): super().__init__(data, nus) @@ -243,10 +275,9 @@ class FormatedText(ProxyWrapper): _wrapped_class = _Text _privtized_methods = ("set_text",) - def __init__(self, data: DataContainer, format_func, nus=None, /, **kwargs): + def __init__(self, data: DataContainer, nus=None, /, **kwargs): super().__init__(data, nus) self._wrapped_instance = self._wrapped_class(text="", **kwargs) - self._format_func = format_func @_stale_wrapper def draw(self, renderer): @@ -256,7 +287,8 @@ def draw(self, renderer): return self._wrapped_instance.draw(renderer) def _update_wrapped(self, data): - self._wrapped_instance.set_text(self._format_func(**data)) + for k, v in data.items(): + getattr(self._wrapped_instance, f"set_{k}")(v) @_forwarder( @@ -296,6 +328,9 @@ def get_children(self): class ErrorbarWrapper(MultiProxyWrapper): + required_keys = {"x", "y"} + expected_keys = {f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"]} + def __init__(self, data: DataContainer, nus=None, /, **kwargs): super().__init__(data, nus) # TODO all of the kwarg teasing apart that is needed diff --git a/docs/source/_static/logo2.svg b/docs/source/_static/logo2.svg new file mode 100644 index 0000000..f2d289c --- /dev/null +++ b/docs/source/_static/logo2.svg @@ -0,0 +1,552 @@ + + + + + + + + image/svg+xml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/conf.py b/docs/source/conf.py index 38ee482..03e060d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -99,6 +99,7 @@ def matplotlib_reduced_latex_scraper(block, block_vars, gallery_conf, **kwargs): "matplotlib_animations": True, "image_srcset": ["2x"], "junit": "../test-results/sphinx-gallery/junit.xml" if CIRCLECI else "", + "run_stale_examples": True, } mathmpl_fontsize = 11.0 @@ -163,8 +164,10 @@ def matplotlib_reduced_latex_scraper(block, block_vars, gallery_conf, **kwargs): # further. For a list of options available for each theme, see the # documentation. # - -html_theme_options = {"logo": {}} +html_logo = "_static/logo2.svg" +html_theme_options = { + "logo": {"link": "index", "image_light": "images/logo2.svg", "image_dark": "images/logo_dark.svg"}, +} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/examples/animation.py b/examples/animation.py index 4468eb8..e92c176 100644 --- a/examples/animation.py +++ b/examples/animation.py @@ -61,7 +61,7 @@ def update(frame, art): lw = LineWrapper(sot_c, lw=5, color="green", label="sin(time)") fc = FormatedText( sot_c, - "ϕ={phase:.2f} ".format, + {"text": lambda phase: f"ϕ={phase:.2f}"}, x=2 * np.pi, y=1, ha="right", diff --git a/examples/mapped.py b/examples/mapped.py new file mode 100644 index 0000000..1da081c --- /dev/null +++ b/examples/mapped.py @@ -0,0 +1,71 @@ +""" +======================= +Mapping Line Properties +======================= + +Leveraging the nu functions to transform users space data to visualization data. + +""" + +import matplotlib.pyplot as plt +import numpy as np + +from matplotlib.colors import Normalize + +from data_prototype.wrappers import LineWrapper, FormatedText +from data_prototype.containers import ArrayContainer + +cmap = plt.colormaps["viridis"] +cmap.set_over("k") +cmap.set_under("r") +norm = Normalize(1, 8) + +line_nus = { + # arbitrary functions + "lw": lambda lw: min(1 + lw, 5), + # standard color mapping + "color": lambda j: cmap(norm(j)), + # categorical + "ls": lambda cat: {"A": "-", "B": ":", "C": "--"}[cat[()]], +} + +text_nus = { + "text": lambda j, cat: f"index={j[()]} class={cat[()]!r}", + "y": lambda j: j, +} + + +th = np.linspace(0, 2 * np.pi, 128) +delta = np.pi / 9 + +fig, ax = plt.subplots() + +for j in range(10): + ac = ArrayContainer( + **{ + "x": th, + "y": np.sin(th + j * delta) + j, + "j": np.asarray(j), + "lw": np.asarray(j), + "cat": np.asarray({0: "A", 1: "B", 2: "C"}[j % 3]), + } + ) + ax.add_artist( + LineWrapper( + ac, + line_nus, + ) + ) + ax.add_artist( + FormatedText( + ac, + text_nus, + x=2 * np.pi, + ha="right", + bbox={"facecolor": "gray", "alpha": 0.5}, + ) + ) +ax.set_xlim(0, np.pi * 2) +ax.set_ylim(-1.1, 10.1) + +plt.show()