Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversion Node implementation of 'nu' #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions data_prototype/conversion_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

from collections.abc import Iterable, Callable, Sequence
from collections import Counter
from dataclasses import dataclass
import inspect
from functools import cached_property

from typing import Any


def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]):
for node in nodes:
input = node.evaluate(input)
return input


@dataclass
class ConversionNode:
required_keys: tuple[str, ...]
output_keys: tuple[str, ...]
trim_keys: bool

def preview_keys(self, input_keys: Iterable[str]) -> tuple[str, ...]:
if missing_keys := set(self.required_keys) - set(input_keys):
raise ValueError(f"Missing keys: {missing_keys}")
if self.trim_keys:
return tuple(sorted(set(self.output_keys)))
return tuple(sorted(set(input_keys) | set(self.output_keys)))

def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
if self.trim_keys:
return {k: input[k] for k in self.output_keys}
else:
if missing_keys := set(self.output_keys) - set(input):
raise ValueError(f"Missing keys: {missing_keys}")
return input


@dataclass
class UnionConversionNode(ConversionNode):
nodes: tuple[ConversionNode, ...]

@classmethod
def from_nodes(cls, *nodes: ConversionNode, trim_keys=False):
required = tuple(set(k for n in nodes for k in n.required_keys))
output = Counter(k for n in nodes for k in n.output_keys)
if duplicate := {k for k, v in output.items() if v > 1}:
raise ValueError(f"Duplicate keys from multiple input nodes: {duplicate}")
return cls(required, tuple(output), trim_keys, nodes)

def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
return super().evaluate({k: v for n in self.nodes for k, v in n.evaluate(input).items()})


@dataclass
class RenameConversionNode(ConversionNode):
mapping: dict[str, str]

@classmethod
def from_mapping(cls, mapping: dict[str, str], trim_keys=False):
required = tuple(mapping)
output = Counter(mapping.values())
if duplicate := {k for k, v in output.items() if v > 1}:
raise ValueError(f"Duplicate output keys in mapping: {duplicate}")
return cls(required, tuple(output), trim_keys, mapping)

def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
return super().evaluate({**input, **{out: input[inp] for (inp, out) in self.mapping.items()}})


@dataclass
class FunctionConversionNode(ConversionNode):
funcs: dict[str, Callable]

@cached_property
def _sigs(self):
return {k: (f, inspect.signature(f)) for k, f in self.funcs.items()}

@classmethod
def from_funcs(cls, funcs: dict[str, Callable], trim_keys=False):
sigs = {k: inspect.signature(f) for k, f in funcs.items()}
output = tuple(sigs)
input = []
for v in sigs.values():
input.extend(v.parameters.keys())
input = tuple(set(input))
return cls(input, output, trim_keys, funcs)

def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
return super().evaluate(
{
**input,
**{k: func(**{p: input[p] for p in sig.parameters}) for (k, (func, sig)) in self._sigs.items()},
}
)


@dataclass
class LimitKeysConversionNode(ConversionNode):
keys: set[str]

@classmethod
def from_keys(cls, keys: Sequence[str]):
return cls((), tuple(keys), trim_keys=True, keys=set(keys))

def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in input.items() if k in self.keys}
4 changes: 2 additions & 2 deletions data_prototype/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class PatchWrapper(ProxyWrapper):
"joinstyle",
}

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
super().__init__(data, converters)
self._wrapped_instance = self._wrapped_class([0, 0], 0, 0, **kwargs)

@_stale_wrapper
Expand Down
1 change: 0 additions & 1 deletion data_prototype/tests/test_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def ac():


def _verify_describe(container):

data, cache_key = container.query(IdentityTransform(), [100, 100])
desc = container.describe()

Expand Down
82 changes: 36 additions & 46 deletions data_prototype/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
from matplotlib.artist import Artist as _Artist

from data_prototype.containers import DataContainer, _MatplotlibTransform
from data_prototype.conversion_node import (
ConversionNode,
RenameConversionNode,
evaluate_pipeline,
FunctionConversionNode,
LimitKeysConversionNode,
)


class _BBox(Protocol):
Expand Down Expand Up @@ -139,45 +146,26 @@ def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]
return self._cache[cache_key]
except KeyError:
...
# TODO decide if units go pre-nu or post-nu?
for x_like in xunits:
if x_like in data:
data[x_like] = ax.xaxis.convert_units(data[x_like])
for y_like in yunits:
if y_like in data:
data[y_like] = ax.xaxis.convert_units(data[y_like])

# doing the nu work here is nice because we can write it once, but we
# really want to push this computation down a layer
# TODO sort out how this interoperates with the transform stack
transformed_data = {}
for k, (nu, sig) in self._sigs.items():
to_pass = set(sig.parameters)
transformed_data[k] = nu(**{k: data[k] for k in to_pass})
# TODO units
transformed_data = evaluate_pipeline(self._converters, data)

self._cache[cache_key] = transformed_data
return transformed_data

def __init__(self, data, nus, **kwargs):
def __init__(self, data, converters: ConversionNode | list[ConversionNode] | None, **kwargs):
super().__init__(**kwargs)
self.data = data
self._cache = LFUCache(64)
# TODO make sure mutating this will invalidate the cache!
self._nus = nus or {}
for k in self.required_keys:
self._nus.setdefault(k, _make_identity(k))
desc = data.describe()
for k in self.expected_keys:
if k in desc:
self._nus.setdefault(k, _make_identity(k))
self._sigs = {k: (nu, inspect.signature(nu)) for k, nu in self._nus.items()}
if isinstance(converters, ConversionNode):
converters = [converters]
self._converters: list[ConversionNode] = converters or []
setters = list(self.expected_keys | self.required_keys)
if hasattr(self, "_wrapped_class"):
setters += [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")]
self._converters.append(LimitKeysConversionNode.from_keys(setters))
self.stale = True

# TODO add a setter
@property
def nus(self):
return dict(self._nus)


class ProxyWrapper(ProxyWrapperBase):
_privtized_methods: Tuple[str, ...] = ()
Expand All @@ -192,7 +180,7 @@ def __getattr__(self, key):
return getattr(self._wrapped_instance, key)

def __setattr__(self, key, value):
if key in ("_wrapped_instance", "data", "_cache", "_nus", "stale", "_sigs"):
if key in ("_wrapped_instance", "data", "_cache", "_converters", "stale", "_sigs"):
super().__setattr__(key, value)
elif hasattr(self, "_wrapped_instance") and hasattr(self._wrapped_instance, key):
setattr(self._wrapped_instance, key, value)
Expand All @@ -205,9 +193,12 @@ class LineWrapper(ProxyWrapper):
_privtized_methods = ("set_xdata", "set_ydata", "set_data", "get_xdata", "get_ydata", "get_data")
required_keys = {"x", "y"}

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
super().__init__(data, converters)
self._wrapped_instance = self._wrapped_class(np.array([]), np.array([]), **kwargs)
self._converters.insert(-1, RenameConversionNode.from_mapping({"x": "xdata", "y": "ydata"}))
setters = [f[4:] for f in dir(self._wrapped_class) if f.startswith("set_")]
self._converters[-1] = LimitKeysConversionNode.from_keys(setters)

@_stale_wrapper
def draw(self, renderer):
Expand All @@ -218,7 +209,6 @@ def draw(self, renderer):

def _update_wrapped(self, data):
for k, v in data.items():
k = {"x": "xdata", "y": "ydata"}.get(k, k)
getattr(self._wrapped_instance, f"set_{k}")(v)


Expand All @@ -238,8 +228,8 @@ class PathCollectionWrapper(ProxyWrapper):
"get_paths",
)

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
super().__init__(data, converters)
self._wrapped_instance = self._wrapped_class([], **kwargs)
self._wrapped_instance.set_transform(mtransforms.IdentityTransform())

Expand All @@ -262,17 +252,17 @@ class ImageWrapper(ProxyWrapper):
_wrapped_class = _AxesImage
required_keys = {"xextent", "yextent", "image"}

def __init__(self, data: DataContainer, nus=None, /, cmap=None, norm=None, **kwargs):
nus = dict(nus or {})
def __init__(self, data: DataContainer, converters=None, /, cmap=None, norm=None, **kwargs):
converters = converters or []
if cmap is not None or norm is not None:
if nus is not None and "image" in nus:
if converters is not None and "image" in converters:
raise ValueError("Conflicting input")
if cmap is None:
cmap = mpl.colormaps["viridis"]
if norm is None:
raise ValueError("not sure how to do autoscaling yet")
nus["image"] = lambda image: cmap(norm(image))
super().__init__(data, nus)
converters.append(FunctionConversionNode.from_funcs({"image": lambda image: cmap(norm(image))}))
super().__init__(data, converters)
kwargs.setdefault("origin", "lower")
self._wrapped_instance = self._wrapped_class(None, **kwargs)

Expand All @@ -293,8 +283,8 @@ class StepWrapper(ProxyWrapper):
_privtized_methods = () # ("set_data", "get_data")
required_keys = {"edges", "density"}

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
super().__init__(data, converters)
self._wrapped_instance = self._wrapped_class([], [1], **kwargs)

@_stale_wrapper
Expand All @@ -312,8 +302,8 @@ class FormatedText(ProxyWrapper):
_wrapped_class = _Text
_privtized_methods = ("set_text",)

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
super().__init__(data, converters)
self._wrapped_instance = self._wrapped_class(text="", **kwargs)

@_stale_wrapper
Expand Down Expand Up @@ -368,8 +358,8 @@ class ErrorbarWrapper(MultiProxyWrapper):
required_keys = {"x", "y"}
expected_keys = {f"{axis}{dirc}" for axis in ["x", "y"] for dirc in ["upper", "lower"]}

def __init__(self, data: DataContainer, nus=None, /, **kwargs):
super().__init__(data, nus)
def __init__(self, data: DataContainer, converters=None, /, **kwargs):
super().__init__(data, converters)
# TODO all of the kwarg teasing apart that is needed
color = kwargs.pop("color", "k")
lw = kwargs.pop("lw", 2)
Expand Down
6 changes: 2 additions & 4 deletions examples/2Dfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from data_prototype.wrappers import ImageWrapper
from data_prototype.containers import FuncContainer

import matplotlib as mpl
from matplotlib.colors import Normalize


Expand All @@ -25,9 +24,8 @@
"image": (("N", "M"), lambda x, y: np.sin(x).reshape(1, -1) * np.cos(y).reshape(-1, 1)),
},
)
cmap = mpl.colormaps["viridis"]
norm = Normalize(-1, 1)
im = ImageWrapper(fc, {"image": lambda image: cmap(norm(image))})
norm = Normalize(vmin=-1, vmax=1)
im = ImageWrapper(fc, norm=norm)

fig, ax = plt.subplots()
ax.add_artist(im)
Expand Down
7 changes: 4 additions & 3 deletions examples/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from matplotlib.animation import FuncAnimation

from data_prototype.containers import _MatplotlibTransform, Desc
from data_prototype.conversion_node import FunctionConversionNode

from data_prototype.wrappers import LineWrapper, FormatedText

Expand Down Expand Up @@ -63,9 +64,9 @@ def update(frame, art):
lw = LineWrapper(sot_c, lw=5, color="green", label="sin(time)")
fc = FormatedText(
sot_c,
{"text": lambda phase: f"ϕ={phase:.2f}"},
x=2 * np.pi,
y=1,
FunctionConversionNode.from_funcs(
{"text": lambda phase: f"ϕ={phase:.2f}", "x": lambda: 2 * np.pi, "y": lambda: 1}
),
ha="right",
)
fig, ax = plt.subplots()
Expand Down
Loading