diff --git a/ci/codespell-ignore-words.txt b/ci/codespell-ignore-words.txt index e138f26e216a..8e5163842c51 100644 --- a/ci/codespell-ignore-words.txt +++ b/ci/codespell-ignore-words.txt @@ -1,5 +1,6 @@ aas ABD +aother axises coo curvelinear diff --git a/lib/matplotlib/_data_containers/__init__.py b/lib/matplotlib/_data_containers/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/lib/matplotlib/_data_containers/_helpers.py b/lib/matplotlib/_data_containers/_helpers.py new file mode 100644 index 000000000000..72c0f18deb3b --- /dev/null +++ b/lib/matplotlib/_data_containers/_helpers.py @@ -0,0 +1,49 @@ +from .description import Desc, desc_like +from .conversion_edge import Graph, TransformEdge + + +def containerize_draw(draw_func): + def draw(self, renderer, *, graph=None): + if graph is None: + graph = Graph([]) + + implicit_graph = _get_graph(self.axes) + return draw_func(self, renderer, graph=graph+implicit_graph) + + return draw + + +def _get_graph(ax): + if ax is None: + return Graph([]) + desc: Desc = Desc(("N",), coordinates="data") + xy: dict[str, Desc] = {"x": desc, "y": desc} + implicit_graph = Graph( + [ + TransformEdge( + "data", + xy, + desc_like(xy, coordinates="axes"), + transform=ax.transData - ax.transAxes, + ), + TransformEdge( + "axes", + desc_like(xy, coordinates="axes"), + desc_like(xy, coordinates="display"), + transform=ax.transAxes, + ), + TransformEdge( + "dpi", + desc_like(xy, coordinates="display_inches"), + desc_like(xy, coordinates="display"), + transform=ax.figure.dpi_scale_trans, + ), + ], + aliases=(("parent", "axes"),), + ) + return implicit_graph + + +def check_container(artist, container_cls, operation="This operation"): + if not isinstance(artist._container, container_cls): + raise TypeError(f"{operation} is not available with a custom container class") diff --git a/lib/matplotlib/_data_containers/containers.py b/lib/matplotlib/_data_containers/containers.py new file mode 100644 index 000000000000..cd487a4b2c59 --- /dev/null +++ b/lib/matplotlib/_data_containers/containers.py @@ -0,0 +1,416 @@ +from __future__ import annotations + +from typing import ( + Protocol, + Optional, + Any, + Union, +) +from collections.abc import Callable, MutableMapping +import uuid + +from cachetools import LFUCache # type: ignore[import-untyped] + +import numpy as np +import pandas as pd + +from .description import Desc, desc_like + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .conversion_edge import Graph + + +class _MatplotlibTransform(Protocol): + def transform(self, verts): ... + + def __sub__(self, other) -> "_MatplotlibTransform": ... + + +class DataContainer(Protocol): + def query( + self, + graph: Graph, + parent_coordinates: str = "axes", + /, + ) -> tuple[dict[str, Any], Union[str, int]]: + """ + Query the data container for data. + + We are given the data limits and the screen size so that we have an + estimate of how finely (or not) we need to sample the data we wrapping. + + Parameters + ---------- + coord_transform : matplotlib.transform.Transform + Must go from axes fraction space -> data space + + size : 2 integers + xpixels, ypixels + + The size in screen / render units that we have to fill. + + Returns + ------- + data : dict[str, Any] + The values are really array-likes + + cache_key : str + This is a key that clients can use to cache down-stream + computations on this data. + """ + ... + + def describe(self) -> dict[str, Desc]: + """ + Describe the data a query will return + + Returns + ------- + dict[str, Desc] + """ + ... + + +class NoNewKeys(ValueError): ... + + +class ArrayContainer: + def __init__(self, coordinates: dict[str, str] | None = None, /, **data): + coordinates = coordinates or {} + self._data = data + self._cache_key = str(uuid.uuid4()) + self._desc = { + k: ( + Desc(v.shape, coordinates.get(k, "auto")) + if hasattr(v, "shape") + else Desc((), coordinates.get(k, "auto")) + ) + for k, v in data.items() + } + + def query( + self, + graph: Graph, + parent_coordinates: str = "axes", + ) -> tuple[dict[str, Any], Union[str, int]]: + return dict(self._data), self._cache_key + + def describe(self) -> dict[str, Desc]: + return dict(self._desc) + + def update(self, **data): + # TODO check that this is still consistent with desc! + if not all(k in self._data for k in data): + raise NoNewKeys( + f"The keys that currently exist are {set(self._data)}. You " + f"tried to add {set(data) - set(self._data)!r}." + ) + self._data.update(data) + self._cache_key = str(uuid.uuid4()) + + +class RandomContainer: + def __init__(self, **shapes): + self._desc = {k: Desc(s) for k, s in shapes.items()} + + def query( + self, + graph: Graph, + parent_coordinates: str = "axes", + ) -> tuple[dict[str, Any], Union[str, int]]: + return {k: np.random.randn(*d.shape) for k, d in self._desc.items()}, str( + uuid.uuid4() + ) + + def describe(self) -> dict[str, Desc]: + return dict(self._desc) + + +class FuncContainer: + def __init__( + self, + # TODO: is this really the best spelling?! + xfuncs: Optional[ + dict[str, tuple[tuple[Union[str, int], ...], Callable[[Any], Any]]] + ] = None, + yfuncs: Optional[ + dict[str, tuple[tuple[Union[str, int], ...], Callable[[Any], Any]]] + ] = None, + xyfuncs: Optional[ + dict[str, tuple[tuple[Union[str, int], ...], Callable[[Any, Any], Any]]] + ] = None, + ): + """ + A container that wraps several functions. They are split into 3 categories: + + - functions that are offered x-like values as input + - functions that are offered y-like values as input + - functions that are offered both x and y like values as two inputs + + In addition to the callable, the user needs to provide a spelling of + what the (relative) shapes will be in relation to each other. For now this + is a list of integers and strings, where the strings are "generic" values. + + For example if two functions report shapes: ``{'bins':[N], 'edges': [N + 1]`` + then when called, *edges* will always have one more entry than bins. + + Parameters + ---------- + xfuncs, yfuncs, xyfuncs : dict[str, tuple[shape, func]] + + """ + # TODO validate no collisions + self._desc: dict[str, Desc] = {} + + def _split(input_dict): + out = {} + for k, (shape, func) in input_dict.items(): + self._desc[k] = Desc(shape) + out[k] = func + return out + + self._xfuncs = _split(xfuncs) if xfuncs is not None else {} + self._yfuncs = _split(yfuncs) if yfuncs is not None else {} + self._xyfuncs = _split(xyfuncs) if xyfuncs is not None else {} + self._cache: MutableMapping[Union[str, int], Any] = LFUCache(64) + + def _query_hash(self, coord_transform, size): + # TODO find a better way to compute the hash key, this is not sentative to + # scale changes, only limit changes + data_bounds = tuple(coord_transform.transform([[0, 0], [1, 1]]).flatten()) + hash_key = hash((data_bounds, size)) + return hash_key + + def query( + self, + graph: Graph, + parent_coordinates: str = "axes", + ) -> tuple[dict[str, Any], Union[str, int]]: + # hash_key = self._query_hash(coord_transform, size) + # if hash_key in self._cache: + # return self._cache[hash_key], hash_key + + desc = Desc(("N",)) + xy = {"x": desc, "y": desc} + data_lim = graph.evaluator( + desc_like(xy, coordinates="data"), + desc_like(xy, coordinates=parent_coordinates), + ).inverse + + screen_size = graph.evaluator( + desc_like(xy, coordinates=parent_coordinates), + desc_like(xy, coordinates="display"), + ) + + screen_dims = screen_size.evaluate({"x": [0, 1], "y": [0, 1]}) + xpix, ypix = np.ceil(np.abs(np.diff(screen_dims["x"]))), np.ceil( + np.abs(np.diff(screen_dims["y"])) + ) + + x_data = data_lim.evaluate( + { + "x": np.linspace(0, 1, int(xpix) * 2), + "y": np.zeros(int(xpix) * 2), + } + )["x"] + y_data = data_lim.evaluate( + { + "x": np.zeros(int(ypix) * 2), + "y": np.linspace(0, 1, int(ypix) * 2), + } + )["y"] + + hash_key = str(uuid.uuid4()) + ret = self._cache[hash_key] = dict( + **{k: f(x_data) for k, f in self._xfuncs.items()}, + **{k: f(y_data) for k, f in self._yfuncs.items()}, + **{k: f(x_data, y_data) for k, f in self._xyfuncs.items()}, + ) + return ret, hash_key + + def describe(self) -> dict[str, Desc]: + return dict(self._desc) + + +class HistContainer: + def __init__(self, raw_data, num_bins: int): + self._raw_data = raw_data + self._num_bins = num_bins + self._desc = { + "edges": Desc((num_bins + 1 + 2,)), + "density": Desc((num_bins + 2,)), + } + self._full_range = (raw_data.min(), raw_data.max()) + self._cache: MutableMapping[Union[str, int], Any] = LFUCache(64) + + def query( + self, + graph: Graph, + parent_coordinates: str = "axes", + ) -> tuple[dict[str, Any], Union[str, int]]: + dmin, dmax = self._full_range + + desc = Desc(("N",)) + xy = {"x": desc, "y": desc} + data_lim = graph.evaluator( + desc_like(xy, coordinates="data"), + desc_like(xy, coordinates=parent_coordinates), + ).inverse + + pts = data_lim.evaluate({"x": (0, 1), "y": (0, 1)}) + xmin, xmax = pts["x"] + ymin, ymax = pts["y"] + + xmin, xmax = np.clip([xmin, xmax], dmin, dmax) + hash_key = hash((xmin, xmax)) + if hash_key in self._cache: + return self._cache[hash_key], hash_key + # TODO this gives an artifact with high lw + edges_in = [] + if dmin < xmin: + edges_in.append(np.array([dmin])) + edges_in.append(np.linspace(xmin, xmax, self._num_bins)) + if xmax < dmax: + edges_in.append(np.array([dmax])) + + density, edges = np.histogram( + self._raw_data, + bins=np.concatenate(edges_in), + density=True, + ) + ret = self._cache[hash_key] = {"edges": edges, "density": density} + return ret, hash_key + + def describe(self) -> dict[str, Desc]: + return dict(self._desc) + + +class SeriesContainer: + _data: pd.Series + _index_name: str + _hash_key: str + + def __init__(self, series: pd.Series, *, index_name: str, col_name: str): + # TODO make a copy? + self._data = series + self._index_name = index_name + self._col_name = col_name + self._desc = { + index_name: Desc((len(series),)), + col_name: Desc((len(series),)), + } + self._hash_key = str(uuid.uuid4()) + + def query( + self, + graph: Graph, + parent_coordinates: str = "axes", + ) -> tuple[dict[str, Any], Union[str, int]]: + return { + self._index_name: self._data.index.values, + self._col_name: self._data.values, + }, self._hash_key + + def describe(self) -> dict[str, Desc]: + return dict(self._desc) + + +class DataFrameContainer: + _data: pd.DataFrame + + def __init__( + self, + df: pd.DataFrame, + *, + col_names: Union[Callable[[str], str], dict[str, str]], + index_name: Optional[str] = None, + ): + # TODO make a copy? + self._data = df + self._index_name = index_name + + if callable(col_names): + # TODO cache the function so we can replace the dataframe later? + self._col_name_dict = {k: col_names(k) for k in df.columns} + else: + self._col_name_dict = dict(col_names) + + self._desc: dict[str, Desc] = {} + if self._index_name is not None: + self._desc[self._index_name] = Desc((len(df),)) + for col, out in self._col_name_dict.items(): + self._desc[out] = Desc((len(df),)) + + self._hash_key = str(uuid.uuid4()) + + def query( + self, + graph: Graph, + parent_coordinates: str = "axes", + ) -> tuple[dict[str, Any], Union[str, int]]: + ret: dict[str, Any] = {} + if self._index_name is not None: + ret[self._index_name] = self._data.index.values + for col, out in self._col_name_dict.items(): + ret[out] = self._data[col].values + + return ret, self._hash_key + + def describe(self) -> dict[str, Desc]: + return dict(self._desc) + + +class ReNamer: + def __init__(self, data: DataContainer, mapping: dict[str, str]): + # TODO: check all the asked for key exist + self._data = data + self._mapping = mapping + + def query( + self, + graph: Graph, + parent_coordinates: str = "axes", + ) -> tuple[dict[str, Any], Union[str, int]]: + base, cache_key = self._data.query(graph, parent_coordinates) + return {v: base[k] for k, v in self._mapping.items()}, cache_key + + def describe(self): + base = self._data.describe() + return {v: base[k] for k, v in self._mapping.items()} + + +class DataUnion: + def __init__(self, *data: DataContainer): + # TODO check no collisions + self._datas = data + + def query( + self, + graph: Graph, + parent_coordinates: str = "axes", + ) -> tuple[dict[str, Any], Union[str, int]]: + cache_keys = [] + ret = {} + for data in self._datas: + base, cache_key = data.query(graph, parent_coordinates) + ret.update(base) + cache_keys.append(cache_key) + return ret, hash(tuple(cache_keys)) + + def describe(self): + return {k: v for d in self._datas for k, v in d.describe().items()} + + +class WebServiceContainer: + def query( + self, + graph: Graph, + parent_coordinates: str = "axes", + ) -> tuple[dict[str, Any], Union[str, int]]: + def hit_some_database(): + return {}, "1" + + data, etag = hit_some_database() + return data, etag diff --git a/lib/matplotlib/_data_containers/conversion_edge.py b/lib/matplotlib/_data_containers/conversion_edge.py new file mode 100644 index 000000000000..8002d9879b9a --- /dev/null +++ b/lib/matplotlib/_data_containers/conversion_edge.py @@ -0,0 +1,401 @@ +from __future__ import annotations + +from collections.abc import Sequence +from collections.abc import Callable +from dataclasses import dataclass +from queue import PriorityQueue +from typing import Any +import numpy as np + +from .description import Desc, desc_like, ShapeSpec + +from matplotlib.transforms import Transform + + +@dataclass +class Edge: + name: str + input: dict[str, Desc] + output: dict[str, Desc] + weight: float = 1 + invertable: bool = True + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + return input + + @property + def inverse(self) -> "Edge": + return Edge(self.name + "_r", self.output, self.input, self.weight) + + +@dataclass +class SequenceEdge(Edge): + edges: Sequence[Edge] = () + + @classmethod + def from_edges( + cls, + name: str, + edges: Sequence[Edge], + output: dict[str, Desc], + weight: float | None = None, + ): + input: dict[str, Desc] = {} + intermediates: dict[str, Desc] = {} + invertable = True + edge_sum: float = 0 + for edge in edges: + edge_sum += edge.weight + input |= {k: v for k, v in edge.input.items() if k not in intermediates} + intermediates |= edge.output + if not edge.invertable: + invertable = False + + if weight is None: + weight = edge_sum + + return cls(name, input, output, weight, invertable, edges) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + for edge in self.edges: + input |= edge.evaluate({k: input[k] for k in edge.input}) + return {k: input[k] for k in self.output} + + @property + def inverse(self) -> "SequenceEdge": + return SequenceEdge.from_edges( + self.name + "_r", + [e.inverse for e in self.edges[::-1]], + self.input, + self.weight, + ) + + +@dataclass +class CoordinateEdge(Edge): + """Change coordinates without changing values""" + + @classmethod + def from_coords( + cls, name: str, input: dict[str, Desc | str], output: str, weight: float = 1 + ): + # dtype/shape is reductive here, but I like the idea of being able to just + # supply only the input/output coordinates for many things + # could also see lowering default weight for this edge, but just defaulting + # everything to 1 for now + inp = { + k: v if isinstance(v, Desc) else Desc(("N",), v) for k, v in input.items() + } + outp = {k: desc_like(v, coordinates=output) for k, v in inp.items()} + + return cls(name, inp, outp, weight) + + @property + def inverse(self) -> Edge: + return Edge(f"{self.name}_r", self.output, self.input, self.weight) + + +@dataclass +class DefaultEdge(Edge): + """Provide default values with a high weight""" + + weight = 1e6 + value: Any = None + + @classmethod + def from_default_value( + cls, + name: str, + key: str, + output: Desc, + value: Any, + weight=1e6, + ) -> "DefaultEdge": + return cls(name, {}, {key: output}, weight, invertable=False, value=value) + + @classmethod + def from_rc( + cls, rc_name: str, key: str | None = None, coordinates: str = "display" + ): + from matplotlib import rcParams + + if key is None: + key = rc_name.split(".")[-1] + scalar = Desc((), coordinates) + return cls.from_default_value(f"{rc_name}_rc", key, scalar, rcParams[rc_name]) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + return {k: self.value for k in self.output} + + +@dataclass +class FuncEdge(Edge): + # TODO: more explicit callable boundaries? + func: Callable = lambda: {} + inverse_func: Callable | None = None + + @classmethod + def from_func( + cls, + name: str, + func: Callable, + input: str | dict[str, Desc], + output: str | dict[str, Desc], + weight: float = 1, + inverse: Callable | None = None, + ): + # dtype/shape is reductive here, but I like the idea of being able to just + # supply a function and the input/output coordinates for many things + if isinstance(input, str): + import inspect + + input_vars = inspect.signature(func).parameters.keys() + input = {k: Desc(("N",), input) for k in input_vars} + if isinstance(output, str): + output = {k: Desc(("N",), output) for k in input.keys()} + + return cls(name, input, output, weight, inverse is not None, func, inverse) + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + res = self.func(**{k: input[k] for k in self.input}) + + if isinstance(res, dict): + # TODO: more sanity checks here? + # How forgiving do we _really_ wish to be? + return res + elif isinstance(res, tuple): + if len(res) != len(self.output): + if len(self.output) == 1: + return {k: res for k in self.output} + raise RuntimeError( + f"Expected {len(self.output)} return values," + f"got {len(res)} in {self.name}" + ) + return {k: v for k, v in zip(self.output, res)} + elif len(self.output) == 1: + return {k: res for k in self.output} + raise RuntimeError("Output of function does not match expected output") + + @property + def inverse(self) -> "FuncEdge": + if self.inverse_func is None: + raise RuntimeError("Trying to invert a non-invertable edge") + + return FuncEdge.from_func( + self.name + "_r", + self.inverse_func, + self.output, + self.input, + self.weight, + self.func, + ) + + +@dataclass +class TransformEdge(Edge): + transform: Transform | Callable[[], Transform] | None = None + + # TODO: helper for common cases/validation? + + def evaluate(self, input: dict[str, Any]) -> dict[str, Any]: + # TODO: ensure ordering? + # Stacking and unstacking at every step seems inefficient, + # especially if initially given as stacked + if self.transform is None: + return input + elif isinstance(self.transform, Callable): # type: ignore[arg-type] + trf = self.transform() # type: ignore[operator] + else: + trf = self.transform + inp = np.stack([input[k] for k in self.input], axis=-1) + outp = trf.transform(inp) + return {k: v for k, v in zip(self.output, outp.T)} + + @property + def inverse(self) -> "TransformEdge": + if self.transform is None: + raise RuntimeError("Trying to invert a non-invertable edge") + + if isinstance(self.transform, Callable): # type: ignore[arg-type] + return TransformEdge( + self.name + "_r", + self.output, + self.input, + self.weight, + True, + lambda: self.transform().inverted(), # type: ignore[misc,operator] + ) + + return TransformEdge( + self.name + "_r", + self.output, + self.input, + self.weight, + True, + self.transform.inverted(), # type: ignore[union-attr] + ) + + +class Graph: + def __init__( + self, edges: Sequence[Edge], aliases: tuple[tuple[str, str], ...] = () + ): + self._edges = tuple(edges) + self._aliases = aliases + + self._subgraphs: list[tuple[set[str], list[Edge]]] = [] + for edge in self._edges: + keys = set(edge.input) | set(edge.output) + + overlapping = [] + + for n, (sub_keys, sub_edges) in enumerate(self._subgraphs): + if keys & sub_keys: + overlapping.append(n) + + if not overlapping: + self._subgraphs.append((keys, [edge])) + elif len(overlapping) == 1: + s = self._subgraphs[overlapping[0]][0] + s |= keys + self._subgraphs[overlapping[0]][1].append(edge) + else: + edges_combined = [edge] + for n in overlapping: + keys |= self._subgraphs[n][0] + edges_combined.extend(self._subgraphs[n][1]) + for n in overlapping[::-1]: + self._subgraphs.pop(n) + self._subgraphs.append((keys, edges_combined)) + + def _resolve_alias(self, coord: str) -> str: + while True: + for coa, cob in self._aliases: + if coord == coa: + coord = cob + break + else: + break + return coord + + def evaluator(self, input: dict[str, Desc], output: dict[str, Desc]) -> Edge: + out_edges = [] + + for sub_keys, sub_edges in self._subgraphs: + if not (sub_keys & set(output) or sub_keys & set(input)): + continue + + output_subset = {k: v for k, v in output.items() if k in sub_keys} + sub_edges = sorted(sub_edges, key=lambda x: x.weight) + + @dataclass + class Node: + weight: float + desc: dict[str, Desc] + prev_node: Node | None = None + edge: Edge | None = None + + def __le__(self, other): + return self.weight <= other.weight + + def __lt__(self, other): + return self.weight < other.weight + + def __ge__(self, other): + return self.weight >= other.weight + + def __gt__(self, other): + return self.weight > other.weight + + @property + def edges(self): + if self.prev_node is None: + return [self.edge] + return self.prev_node.edges + [self.edge] + + q: PriorityQueue[Node] = PriorityQueue() + q.put(Node(0, input)) + + best: Node = Node(np.inf, {}) + while not q.empty(): + n = q.get() + if n.weight > best.weight: + continue + if Desc.compatible(n.desc, output_subset, aliases=self._aliases): + if n.weight < best.weight: + best = n + continue + for e in sub_edges: + if e in n.edges: + continue + if Desc.compatible(n.desc, e.input, aliases=self._aliases): + d = n.desc | e.output + w = n.weight + e.weight + + q.put(Node(w, d, n, e)) + if np.isinf(best.weight): + raise NotImplementedError( + "This may be possible, but is not a simple case already considered" + ) + + edges: list[Edge] = [] + n = best + while n.prev_node is not None: + if n.edge is not None: + edges.insert(0, n.edge) + n = n.prev_node + if len(edges) == 0: + continue + elif len(edges) == 1: + out_edges.append(edges[0]) + else: + out_edges.append(SequenceEdge.from_edges("eval", edges, output_subset)) + + found_outputs = set(input) + for out in out_edges: + found_outputs |= set(out.output) + if missing := set(output) - found_outputs: + raise RuntimeError(f"Could not find path to resolve all outputs: {missing}") + + if len(out_edges) == 0: + return Edge("noop", input, output) + if len(out_edges) == 1: + return out_edges[0] + return SequenceEdge.from_edges("eval", out_edges, output) + + def __add__(self, other: Graph) -> Graph: + aself = {k: v for k, v in self._aliases} + aother = {k: v for k, v in other._aliases} + aliases = tuple((aself | aother).items()) + return Graph(self._edges + other._edges, aliases) + + def cache_key(self): + """A cache key representing the graph. + + Current implementation is a new UUID, that is to say uncachable. + """ + import uuid + + return str(uuid.uuid4()) + + +def coord_and_default( + key: str, + shape: ShapeSpec = (), + coordinates: str = "display", + default_value: Any = None, + default_rc: str | None = None, +): + if default_rc is not None: + if default_value is not None: + raise ValueError( + "Only one of 'default_value' and 'default_rc' may be specified" + ) + def_edge = DefaultEdge.from_rc(default_rc, key, coordinates) + else: + scalar = Desc((), coordinates) + def_edge = DefaultEdge.from_default_value( + f"{key}_def", key, scalar, default_value + ) + coord_edge = CoordinateEdge.from_coords(key, {key: Desc(shape)}, coordinates) + return coord_edge, def_edge diff --git a/lib/matplotlib/_data_containers/description.py b/lib/matplotlib/_data_containers/description.py new file mode 100644 index 000000000000..095faab7f6da --- /dev/null +++ b/lib/matplotlib/_data_containers/description.py @@ -0,0 +1,162 @@ +from dataclasses import dataclass +from typing import TypeAlias, Union, overload + + +ShapeSpec: TypeAlias = tuple[Union[str, int], ...] + + +@dataclass(frozen=True) +class Desc: + # TODO: sort out how to actually spell this. We need to know: + # - what the number of dimensions is (1d vs 2d vs ...) + # - is this a fixed size dimension (e.g. 2 for xextent) + # - is this a variable size depending on the query (e.g. N) + # - what is the relative size to the other variable values (N vs N+1) + # We are probably going to have to implement a DSL for this (😞) + shape: ShapeSpec + coordinates: str = "auto" + + @staticmethod + def validate_shapes( + specification: dict[str, ShapeSpec | "Desc"], + actual: dict[str, ShapeSpec | "Desc"], + *, + broadcast: bool = False, + ) -> None: + """Validate specified shape relationships against a provided set of shapes. + + Shapes provided are tuples of int | str. If a specification calls for an int, + the exact size is expected. + If it is a str, it must be a single capital letter optionally followed by ``+`` + or ``-`` an integer value. + The same letter used in the specification must represent the same value in all + appearances. The value may, however, be a variable (with an offset) in the + actual shapes (which does not need to have the same letter). + + Shapes may be provided as raw tuples or as ``Desc`` objects. + + Parameters + ---------- + specification: dict[str, ShapeSpec | "Desc"] + The desired shape relationships + actual: dict[str, ShapeSpec | "Desc"] + The shapes to test for compliance + + Keyword Parameters + ------------------ + broadcast: bool + Whether to allow broadcasted shapes to pass (i.e. actual shapes with a ``1`` + will not cause exceptions regardless of what the specified shape value is) + + Raises + ------ + KeyError: + If a required field from the specification is missing in the provided actual + values. + ValueError: + If shapes are incompatible in any other way + """ + specvars: dict[str, int | tuple[str, int]] = {} + for fieldname in specification: + spec = specification[fieldname] + if fieldname not in actual: + raise KeyError( + f"Actual is missing {fieldname!r}, required by specification." + ) + desc = actual[fieldname] + if isinstance(spec, Desc): + spec = spec.shape + if isinstance(desc, Desc): + desc = desc.shape + if not broadcast: + if len(spec) != len(desc): + raise ValueError( + f"{fieldname!r} shape {desc} incompatible with specification " + f"{spec}." + ) + elif len(desc) > len(spec): + raise ValueError( + f"{fieldname!r} shape {desc} incompatible with specification " + f"{spec}." + ) + for speccomp, desccomp in zip(spec[::-1], desc[::-1]): + if broadcast and desccomp == 1: + continue + if isinstance(speccomp, str): + specv, specoff = speccomp[0], int(speccomp[1:] or 0) + entry: tuple[str, int] | int + + if isinstance(desccomp, str): + descv, descoff = desccomp[0], int(desccomp[1:] or 0) + entry = (descv, descoff - specoff) + else: + entry = desccomp - specoff + + if specv in specvars and entry != specvars[specv]: + raise ValueError(f"Found two incompatible values for {specv!r}") + + specvars[specv] = entry + elif speccomp != desccomp: + raise ValueError( + f"{fieldname!r} shape {desc} incompatible with specification " + f"{spec}" + ) + return None + + @staticmethod + def compatible( + a: dict[str, "Desc"], + b: dict[str, "Desc"], + aliases: tuple[tuple[str, str], ...] = (), + ) -> bool: + """Determine if ``a`` is a valid input for ``b``. + + Note: ``a`` _may_ have additional keys. + """ + + def resolve_aliases(coord): + while True: + for coa, cob in aliases: + if coord == coa: + coord = cob + break + else: + break + return coord + + try: + Desc.validate_shapes(b, a) # type: ignore[arg-type] + except (KeyError, ValueError): + return False + for k, v in b.items(): + if resolve_aliases(a[k].coordinates) != resolve_aliases(v.coordinates): + return False + return True + + +@overload +def desc_like(desc: Desc, shape=None, coordinates=None) -> Desc: ... + + +@overload +def desc_like( + desc: dict[str, Desc], shape=None, coordinates=None +) -> dict[str, Desc]: ... + + +def desc_like(desc, shape=None, coordinates=None): + if isinstance(desc, dict): + return {k: desc_like(v, shape, coordinates) for k, v in desc.items()} + if shape is None: + shape = desc.shape + if coordinates is None: + coordinates = desc.coordinates + return Desc(shape, coordinates) + + +# Monkey patch mpl_data_containers for Desc isinstance checks +try: + from mpl_data_containers import description + description.Desc = Desc +except ImportError: + pass diff --git a/lib/matplotlib/_data_containers/meson.build b/lib/matplotlib/_data_containers/meson.build new file mode 100644 index 000000000000..9607203ae74f --- /dev/null +++ b/lib/matplotlib/_data_containers/meson.build @@ -0,0 +1,13 @@ +python_sources = [ + '__init__.py', + 'containers.py', + 'conversion_edge.py', + 'description.py', + '_helpers.py', +] + +typing_sources = [ +] + +py3.install_sources(python_sources, typing_sources, + subdir: 'matplotlib/_data_containers') diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 2bdb6ffd6a3f..4c70f5514880 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -6352,7 +6352,7 @@ def imshow(self, X, cmap=None, norm=None, *, aspect=None, `~matplotlib.pyplot.imshow` expects RGB images adopting the straight (unassociated) alpha representation. """ - im = mimage.AxesImage(self, cmap=cmap, norm=norm, colorizer=colorizer, + im = mimage.AxesImage(self, A=X, cmap=cmap, norm=norm, colorizer=colorizer, interpolation=interpolation, origin=origin, extent=extent, filternorm=filternorm, filterrad=filterrad, resample=resample, @@ -6366,7 +6366,6 @@ def imshow(self, X, cmap=None, norm=None, *, aspect=None, if aspect is not None: self.set_aspect(aspect) - im.set_data(X) im.set_alpha(alpha) if im.get_clip_path() is None: # image does not already have clipping set, clip to Axes patch diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index 2e416486baf4..7a8d21e72bf8 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -28,6 +28,7 @@ except ImportError: from numpy import VisibleDeprecationWarning + import matplotlib from matplotlib import _api, _c_internal_utils, mlab diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index ceae9fc308a0..e109d0151cfd 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -17,12 +17,188 @@ import numpy as np +from ._data_containers.description import Desc, desc_like + import matplotlib as mpl from . import (_api, _path, artist, cbook, colorizer as mcolorizer, colors as mcolors, _docstring, hatch as mhatch, lines as mlines, path as mpath, transforms) +from ._data_containers._helpers import _get_graph, check_container from ._enums import JoinStyle, CapStyle + +class CollectionContainer(): + def __init__( + self, + x: np.array, + y: np.array, + ): + self.x = x + self.y = y + self.paths = None + + def describe(self): + return { + "x": Desc(("N",), "data"), + "y": Desc(("N",), "data"), + # Colors are weird because it could look like (N, 3) or (N, 4), + # But also accepts strings or cmapped data at this level... + "transforms": Desc(("N", 3, 3), "data"), + "paths": Desc(("N",), "path"), + } + + def query(self, graph, parent_coordinates="axes"): + transforms = np.eye(3)[np.newaxis, :, :] + d = { + "x": self.x, + "y": self.y, + "transforms": transforms, + "paths": self.paths, + } + return d, "" + # TODO hash + + +class SizedCollectionContainer(CollectionContainer): + def __init__( + self, + x: np.array, + y: np.array, + sizes: np.array, + factor: float = 1.0, + ): + super().__init__(x, y) + self.sizes = np.atleast_1d(sizes) + self.factor = factor + + def query(self, graph, parent_coordinates="axes"): + desc = Desc(("N",)) + dpi_eval = graph.evaluator( + desc_like({"x": desc, "y": desc}, coordinates="display_inches"), + desc_like({"x": desc, "y": desc}, coordinates="display"), + ) + dpi = dpi_eval.evaluate({"x": [1], "y": [1]})["x"][0] + + d, hash = super().query(graph, parent_coordinates) + transforms = np.zeros((len(self.sizes), 3, 3)) + scale = np.sqrt(self.sizes) * dpi / 72.0 * self.factor + transforms[:, 0, 0] = scale + transforms[:, 1, 1] = scale + transforms[:, 2, 2] = 1.0 + d["transforms"] = transforms + + return d, hash + + +class RegularPolyCollectionContainer(SizedCollectionContainer): + def __init__( + self, + x: np.array, + y: np.array, + sizes: np.array, + rotation: float, + ): + factor = np.pi ** (-1/2) + super().__init__(x, y, sizes, factor) + self.rotation = rotation + + def query(self, graph, parent_coordinates="axes"): + d, hash = super().query(graph, parent_coordinates) + for i, t in enumerate(d["transforms"]): + d["transforms"][i, :, :] = ( + transforms.Affine2D(t) + .rotate(-self.rotation) + .get_matrix() + ) + + return d, hash + + +class EllipseCollectionContainer(CollectionContainer): + def __init__( + self, + x: np.array, + y: np.array, + widths: np.array, + heights: np.array, + angles: np.array, + units: str, + ): + super().__init__(x, y) + self.widths = np.atleast_1d(widths) + self.heights = np.atleast_1d(heights) + self.angles = np.atleast_1d(angles) + self.units = units + + def query(self, graph, parent_coordinates="axes"): + desc = Desc(("N",)) + dpi_eval = graph.evaluator( + desc_like({"x": desc, "y": desc}, coordinates="display_inches"), + desc_like({"x": desc, "y": desc}, coordinates="display"), + ) + dpi = dpi_eval.evaluate({"x": [1], "y": [1]})["x"][0] + + d, hash = super().query(graph, parent_coordinates) + + # TODO: this section is verbose and likely to be useful elsewhere + # Consider moving to one or more helper methods + # For reference, this was originally from FuncContainer, with modifications + desc = Desc(("N",)) + xy = {"x": desc, "y": desc} + data_lim = graph.evaluator( + desc_like(xy, coordinates="data"), + desc_like(xy, coordinates=parent_coordinates), + ).inverse + + screen_size = graph.evaluator( + desc_like(xy, coordinates=parent_coordinates), + desc_like(xy, coordinates="display"), + ) + + screen_dims = screen_size.evaluate({"x": [0, 1], "y": [0, 1]}) + xpix, ypix = np.ceil(np.abs(np.diff(screen_dims["x"]))), np.ceil( + np.abs(np.diff(screen_dims["y"])) + ) + data_dims = data_lim.evaluate({"x": [0, 1], "y": [0, 1]}) + xdata, ydata = np.abs(np.diff(data_dims["x"])), np.abs(np.diff(data_dims["y"])) + + if self.units == 'xy': + sc = 1 + elif self.units == 'x': + sc = xpix / xdata + elif self.units == 'y': + sc = ypix / ydata + elif self.units == 'inches': + sc = dpi + elif self.units == 'points': + sc = dpi / 72.0 + elif self.units == 'width': + sc = xpix + elif self.units == 'height': + sc = ypix + elif self.units == 'dots': + sc = 1.0 + else: + raise ValueError(f'Unrecognized units: {self._units!r}') + + + transforms = np.zeros((len(self.widths), 3, 3)) + widths = self.widths * sc + heights = self.heights * sc + sin_angle = np.sin(self.angles) + cos_angle = np.cos(self.angles) + transforms[:, 0, 0] = widths * cos_angle + transforms[:, 0, 1] = heights * -sin_angle + transforms[:, 1, 0] = widths * sin_angle + transforms[:, 1, 1] = heights * cos_angle + transforms[:, 2, 2] = 1.0 + + d["transforms"] = transforms + + return d, hash + + + # "color" is excluded; it is a compound setter, and its docstring differs # in LineCollection. @_api.define_aliases({ @@ -164,6 +340,10 @@ def __init__(self, *, """ super().__init__(self._get_colorizer(cmap, norm, colorizer)) + + self._container = self._init_container() + self.__query = None + # list of un-scaled dash patterns # this is needed scaling the dash pattern by linewidth self._us_linestyles = [(0, None)] @@ -202,27 +382,48 @@ def __init__(self, *, self._joinstyle = None if offsets is not None: - offsets = np.asanyarray(offsets, float) - # Broadcast (2,) -> (1, 2) but nothing else. - if offsets.shape == (2,): - offsets = offsets[None, :] + self.set_offsets(offsets) - self._offsets = offsets self._offset_transform = offset_transform self._path_effects = None self._internal_update(kwargs) - self._paths = None + + def set_container(self, container): + self._container = container + self.stale = True + + def get_container(self): + return self._container + + def _init_container(self): + return CollectionContainer( + x=np.array([]), + y=np.array([]), + ) + + @property + def _query(self): + if self.__query is not None: + return self.__query + return self._container.query(_get_graph(self.axes))[0] + + def _cache_query(self): + self.__query = self._container.query(_get_graph(self.axes))[0] + def get_paths(self): - return self._paths + check_container(self, CollectionContainer, "'get_paths'") + return self._container.paths def set_paths(self, paths): - self._paths = paths + check_container(self, CollectionContainer, "'set_paths'") + self._container.paths = paths self.stale = True def get_transforms(self): - return self._transforms + q = self._query + return q["transforms"] def get_offset_transform(self): """Return the `.Transform` instance used by this artist offset.""" @@ -260,6 +461,7 @@ def get_datalim(self, transData): # for the limits (i.e. for scatter) # # 3. otherwise return a null Bbox. + q = self._query transform = self.get_transform() offset_trf = self.get_offset_transform() @@ -298,7 +500,7 @@ def get_datalim(self, transData): offset_trf.get_affine().frozen()) # NOTE: None is the default case where no offsets were passed in - if self._offsets is not None: + if len(q["x"]): # this is for collections that have their paths (shapes) # in physical, axes-relative, or figure-relative units # (i.e. like scatter). We can't uniquely set limits based on @@ -358,6 +560,7 @@ def _prepare_points(self): def draw(self, renderer): if not self.get_visible(): return + self._cache_query() renderer.open_group(self.__class__.__name__, self.get_gid()) self.update_scalarmappable() @@ -627,20 +830,26 @@ def set_offsets(self, offsets): ---------- offsets : (N, 2) or (2,) array-like """ + check_container(self, CollectionContainer, "'set_offsets'") offsets = np.asanyarray(offsets) if offsets.shape == (2,): # Broadcast (2,) -> (1, 2) but nothing else. offsets = offsets[None, :] - cstack = (np.ma.column_stack if isinstance(offsets, np.ma.MaskedArray) - else np.column_stack) - self._offsets = cstack( - (np.asanyarray(self.convert_xunits(offsets[:, 0]), float), - np.asanyarray(self.convert_yunits(offsets[:, 1]), float))) + + self._container.x = np.asanyarray(self.convert_xunits(offsets[:, 0]), float) + self._container.y = np.asanyarray(self.convert_yunits(offsets[:, 1]), float) self.stale = True def get_offsets(self): """Return the offsets for the collection.""" # Default to zeros in the no-offset (None) case - return np.zeros((1, 2)) if self._offsets is None else self._offsets + q = self._query + if len(q["x"]) == 0: + return np.zeros((1,2)) + cstack = (np.ma.column_stack if + isinstance(q["x"], np.ma.MaskedArray) + or isinstance(q["y"], np.ma.MaskedArray) + else np.column_stack) + return cstack([q["x"], q["y"]]) def _get_default_linewidth(self): # This may be overridden in a subclass. @@ -1080,6 +1289,17 @@ class _CollectionWithSizes(Collection): """ _factor = 1.0 + def __init__(self, sizes=None, **kwargs): + super().__init__(**kwargs) + self.set_sizes(sizes) + + def _init_container(self): + return SizedCollectionContainer( + x=np.array([]), + y=np.array([]), + sizes=np.array([]), + ) + def get_sizes(self): """ Return the sizes ('areas') of the elements in the collection. @@ -1089,7 +1309,8 @@ def get_sizes(self): array The 'area' of each element. """ - return self._sizes + check_container(self, CollectionContainer, "'get_sizes'") + return self._container.sizes def set_sizes(self, sizes, dpi=72.0): """ @@ -1103,23 +1324,12 @@ def set_sizes(self, sizes, dpi=72.0): dpi : float, default: 72 The dpi of the canvas. """ + check_container(self, CollectionContainer, "'set_sizes'") if sizes is None: - self._sizes = np.array([]) - self._transforms = np.empty((0, 3, 3)) - else: - self._sizes = np.asarray(sizes) - self._transforms = np.zeros((len(self._sizes), 3, 3)) - scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor - self._transforms[:, 0, 0] = scale - self._transforms[:, 1, 1] = scale - self._transforms[:, 2, 2] = 1.0 + sizes = np.array([]) + self._container.sizes = np.atleast_1d(sizes) self.stale = True - @artist.allow_rasterization - def draw(self, renderer): - self.set_sizes(self._sizes, self.get_figure(root=True).dpi) - super().draw(renderer) - class PathCollection(_CollectionWithSizes): r""" @@ -1146,7 +1356,7 @@ def __init__(self, paths, sizes=None, **kwargs): self.stale = True def get_paths(self): - return self._paths + return self._container.paths def legend_elements(self, prop="colors", num="auto", fmt=None, func=lambda x: x, **kwargs): @@ -1336,7 +1546,7 @@ def set_verts(self, verts, closed=True): # No need to do anything fancy if the path isn't closed. if not closed: - self._paths = [mpath.Path(xy) for xy in verts] + self._container.paths = [mpath.Path(xy) for xy in verts] return # Fast path for arrays @@ -1347,16 +1557,16 @@ def set_verts(self, verts, closed=True): template_path = mpath.Path(verts_pad[0], closed=True) codes = template_path.codes _make_path = mpath.Path._fast_from_codes_and_verts - self._paths = [_make_path(xy, codes, internals_from=template_path) - for xy in verts_pad] + self._container.paths = [_make_path(xy, codes, internals_from=template_path) + for xy in verts_pad] return - self._paths = [] + self._container.paths = [] for xy in verts: if len(xy): - self._paths.append(mpath.Path._create_closed(xy)) + self._container.paths.append(mpath.Path._create_closed(xy)) else: - self._paths.append(mpath.Path(xy)) + self._container.paths.append(mpath.Path(xy)) set_paths = set_verts @@ -1365,7 +1575,7 @@ def set_verts_and_codes(self, verts, codes): if len(verts) != len(codes): raise ValueError("'codes' must be a 1D list or array " "with the same length of 'verts'") - self._paths = [mpath.Path(xy, cds) if len(xy) else mpath.Path(xy) + self._container.paths = [mpath.Path(xy, cds) if len(xy) else mpath.Path(xy) for xy, cds in zip(verts, codes)] self.stale = True @@ -1653,29 +1863,24 @@ def __init__(self, offset_transform=ax.transData, ) """ - super().__init__(**kwargs) - self.set_sizes(sizes) + super().__init__(sizes=sizes, **kwargs) + self._container.rotation = rotation self._numsides = numsides - self._paths = [self._path_generator(numsides)] - self._rotation = rotation self.set_transform(transforms.IdentityTransform()) - + self._container.paths = [self._path_generator(numsides)] + + def _init_container(self): + return RegularPolyCollectionContainer( + x=np.array([]), + y=np.array([]), + sizes=np.array([]), + rotation=0.0, + ) def get_numsides(self): return self._numsides def get_rotation(self): - return self._rotation - - @artist.allow_rasterization - def draw(self, renderer): - self.set_sizes(self._sizes, self.get_figure(root=True).dpi) - self._transforms = [ - transforms.Affine2D(x).rotate(-self._rotation).get_matrix() - for x in self._transforms - ] - # Explicitly not super().draw, because set_sizes must be called before - # updating self._transforms. - Collection.draw(self, renderer) + return self._collection.rotation class StarPolygonCollection(RegularPolyCollection): @@ -1756,9 +1961,9 @@ def set_segments(self, segments): if segments is None: return - self._paths = [mpath.Path(seg) if isinstance(seg, np.ma.MaskedArray) - else mpath.Path(np.asarray(seg, float)) - for seg in segments] + self._container.paths = [mpath.Path(seg) if isinstance(seg, np.ma.MaskedArray) + else mpath.Path(np.asarray(seg, float)) + for seg in segments] self.stale = True set_verts = set_segments # for compatibility with PolyCollection @@ -1774,7 +1979,7 @@ def get_segments(self): """ segments = [] - for path in self._paths: + for path in self._container.paths: vertices = [ vertex for vertex, _ @@ -1869,7 +2074,7 @@ def _get_inverse_paths_linestyles(self): if ls == (0, None) else (path, mlines._get_inverse_dash_pattern(*ls)) for (path, ls) in - zip(self._paths, itertools.cycle(self._linestyles))] + zip(self._container.paths, itertools.cycle(self._linestyles))] return zip(*path_patterns) @@ -2074,7 +2279,7 @@ def __init__(self, sizes, **kwargs): super().__init__(**kwargs) self.set_sizes(sizes) self.set_transform(transforms.IdentityTransform()) - self._paths = [mpath.Path.unit_circle()] + self._container.paths = [mpath.Path.unit_circle()] class EllipseCollection(Collection): @@ -2105,83 +2310,63 @@ def __init__(self, widths, heights, angles, *, units='points', **kwargs): self.set_widths(widths) self.set_heights(heights) self.set_angles(angles) - self._units = units + self._container.units = units self.set_transform(transforms.IdentityTransform()) - self._transforms = np.empty((0, 3, 3)) - self._paths = [mpath.Path.unit_circle()] + self._container.paths = [mpath.Path.unit_circle()] - def _set_transforms(self): - """Calculate transforms immediately before drawing.""" + def _init_container(self): + return EllipseCollectionContainer( + x=np.array([]), + y=np.array([]), + widths=np.array([]), + heights=np.array([]), + angles=np.array([]), + units="xy" + ) - ax = self.axes - fig = self.get_figure(root=False) - if self._units == 'xy': - sc = 1 - elif self._units == 'x': - sc = ax.bbox.width / ax.viewLim.width - elif self._units == 'y': - sc = ax.bbox.height / ax.viewLim.height - elif self._units == 'inches': - sc = fig.dpi - elif self._units == 'points': - sc = fig.dpi / 72.0 - elif self._units == 'width': - sc = ax.bbox.width - elif self._units == 'height': - sc = ax.bbox.height - elif self._units == 'dots': - sc = 1.0 - else: - raise ValueError(f'Unrecognized units: {self._units!r}') - - self._transforms = np.zeros((len(self._widths), 3, 3)) - widths = self._widths * sc - heights = self._heights * sc - sin_angle = np.sin(self._angles) - cos_angle = np.cos(self._angles) - self._transforms[:, 0, 0] = widths * cos_angle - self._transforms[:, 0, 1] = heights * -sin_angle - self._transforms[:, 1, 0] = widths * sin_angle - self._transforms[:, 1, 1] = heights * cos_angle - self._transforms[:, 2, 2] = 1.0 - - _affine = transforms.Affine2D - if self._units == 'xy': - m = ax.transData.get_affine().get_matrix().copy() - m[:2, 2:] = 0 - self.set_transform(_affine(m)) + def set_angles(self, angles): + """Set the angles of the first axes, degrees CCW from the x-axis.""" + check_container(self, EllipseCollectionContainer, "'set_angles'") + self._container.angles = np.deg2rad(angles).ravel() + self.stale = True def set_widths(self, widths): """Set the lengths of the first axes (e.g., major axis).""" - self._widths = 0.5 * np.asarray(widths).ravel() + check_container(self, EllipseCollectionContainer, "'set_widths'") + self._container.widths = 0.5 * np.asarray(widths).ravel() self.stale = True def set_heights(self, heights): """Set the lengths of second axes (e.g., minor axes).""" - self._heights = 0.5 * np.asarray(heights).ravel() - self.stale = True - - def set_angles(self, angles): - """Set the angles of the first axes, degrees CCW from the x-axis.""" - self._angles = np.deg2rad(angles).ravel() + check_container(self, EllipseCollectionContainer, "'set_heights'") + self._container.heights = 0.5 * np.asarray(heights).ravel() self.stale = True def get_widths(self): """Get the lengths of the first axes (e.g., major axis).""" - return self._widths * 2 + check_container(self, EllipseCollectionContainer, "'get_widths'") + return self._container.widths * 2 def get_heights(self): - """Set the lengths of second axes (e.g., minor axes).""" - return self._heights * 2 + """Get the lengths of second axes (e.g., minor axes).""" + check_container(self, EllipseCollectionContainer, "'get_heights'") + return self._container.heights * 2 def get_angles(self): """Get the angles of the first axes, degrees CCW from the x-axis.""" - return np.rad2deg(self._angles) + check_container(self, EllipseCollectionContainer, "'get_angles'") + return np.rad2deg(self._container.angles) @artist.allow_rasterization def draw(self, renderer): - self._set_transforms() + if ( + isinstance(self._container, EllipseCollectionContainer) + and self._container.units == "xy" + ): + m = self.axes.transData.get_affine().get_matrix().copy() + m[:2, 2:] = 0 + self.set_transform(transforms.Affine2D(m)) super().draw(renderer) @@ -2242,7 +2427,7 @@ def determine_facecolor(patch): def set_paths(self, patches): paths = [p.get_transform().transform_path(p.get_path()) for p in patches] - self._paths = paths + self._container.paths = paths class TriMesh(Collection): @@ -2265,12 +2450,12 @@ def __init__(self, triangulation, **kwargs): self._bbox.update_from_data_xy(xy) def get_paths(self): - if self._paths is None: + if self._container.paths is None: self.set_paths() - return self._paths + return self._container.paths def set_paths(self): - self._paths = self.convert_mesh_to_paths(self._triangulation) + self._container.paths = self.convert_mesh_to_paths(self._triangulation) @staticmethod def convert_mesh_to_paths(tri): @@ -2288,6 +2473,7 @@ def convert_mesh_to_paths(tri): def draw(self, renderer): if not self.get_visible(): return + self._cache_query() renderer.open_group(self.__class__.__name__, gid=self.get_gid()) transform = self.get_transform() @@ -2508,12 +2694,12 @@ def __init__(self, coordinates, *, antialiased=True, shading='flat', self.set_mouseover(False) def get_paths(self): - if self._paths is None: + if self._container.paths is None: self.set_paths() - return self._paths + return self._container.paths def set_paths(self): - self._paths = self._convert_mesh_to_paths(self._coordinates) + self._container.paths = self._convert_mesh_to_paths(self._coordinates) self.stale = True def get_datalim(self, transData): @@ -2523,6 +2709,7 @@ def get_datalim(self, transData): def draw(self, renderer): if not self.get_visible(): return + self._cache_query() renderer.open_group(self.__class__.__name__, self.get_gid()) transform = self.get_transform() offset_trf = self.get_offset_transform() diff --git a/lib/matplotlib/contour.py b/lib/matplotlib/contour.py index dfc39ed664f9..333666d2c9de 100644 --- a/lib/matplotlib/contour.py +++ b/lib/matplotlib/contour.py @@ -455,7 +455,7 @@ def add_label_near(self, x, y, inline=True, inline_spacing=5, idx_level_min, idx_vtx_min, proj = self._find_nearest_contour( (x, y), self.labelIndiceList) - path = self._paths[idx_level_min] + path = self._container.paths[idx_level_min] level = self.labelIndiceList.index(idx_level_min) label_width = self._get_nth_label_width(level) rotation, path = self._split_path_and_get_label_rotation( @@ -464,7 +464,7 @@ def add_label_near(self, x, y, inline=True, inline_spacing=5, self.labelCValueList[idx_level_min]) if inline: - self._paths[idx_level_min] = path + self._container.paths[idx_level_min] = path def pop_label(self, index=-1): """Defaults to removing last label, but any index can be supplied""" @@ -481,7 +481,7 @@ def labels(self, inline, inline_spacing): trans = self.get_transform() label_width = self._get_nth_label_width(idx) additions = [] - for subpath in self._paths[icon]._iter_connected_components(): + for subpath in self._container.paths[icon]._iter_connected_components(): screen_xys = trans.transform(subpath.vertices) # Check if long enough for a label if self.print_label(screen_xys, label_width): @@ -497,7 +497,7 @@ def labels(self, inline, inline_spacing): # After looping over all segments on a contour, replace old path by new one # if inlining. if inline: - self._paths[icon] = Path.make_compound_path(*additions) + self._container.paths[icon] = Path.make_compound_path(*additions) def remove(self): super().remove() @@ -757,8 +757,8 @@ def __init__(self, ax, *args, self.norm._changed() self._process_colors() - if self._paths is None: - self._paths = self._make_paths_from_contour_generator() + if self._container.paths is None: + self._container.paths = self._make_paths_from_contour_generator() if self.filled: if linewidths is not None: @@ -839,7 +839,7 @@ def legend_elements(self, variable_name='x', str_format=str): if self.filled: lowers, uppers = self._get_lowers_and_uppers() - n_levels = len(self._paths) + n_levels = len(self._container.paths) for idx in range(n_levels): artists.append(mpatches.Rectangle( (0, 0), 1, 1, @@ -905,15 +905,15 @@ def _process_args(self, *args, **kwargs): # pathcodes. However, kinds can also be None; in which case all paths in that # list are codeless (this case is normalized above). These lists are used to # construct paths, which then get concatenated. - self._paths = [Path.make_compound_path(*map(Path, segs, kinds)) + self._container.paths = [Path.make_compound_path(*map(Path, segs, kinds)) for segs, kinds in zip(allsegs, allkinds)] return kwargs def _make_paths_from_contour_generator(self): """Compute ``paths`` using C extension.""" - if self._paths is not None: - return self._paths + if self._container.paths is not None: + return self._container.paths cg = self._contour_generator empty_path = Path(np.empty((0, 2))) vertices_and_codes = ( @@ -1180,13 +1180,13 @@ def _find_nearest_contour(self, xy, indices=None): raise ValueError("Method does not support filled contours") if indices is None: - indices = range(len(self._paths)) + indices = range(len(self._container.paths)) d2min = np.inf idx_level_min = idx_vtx_min = proj_min = None for idx_level in indices: - path = self._paths[idx_level] + path = self._container.paths[idx_level] idx_vtx_start = 0 for subpath in path._iter_connected_components(): if not len(subpath.vertices): @@ -1249,7 +1249,8 @@ def find_nearest_contour(self, x, y, indices=None, pixel=True): if i_level is not None: cc_cumlens = np.cumsum( - [*map(len, self._paths[i_level]._iter_connected_components())]) + [*map(len, self._container.paths[i_level]._iter_connected_components())] + ) segment = cc_cumlens.searchsorted(i_vtx, "right") index = i_vtx if segment == 0 else i_vtx - cc_cumlens[segment - 1] d2 = (xmin-x)**2 + (ymin-y)**2 @@ -1258,7 +1259,7 @@ def find_nearest_contour(self, x, y, indices=None, pixel=True): @artist.allow_rasterization def draw(self, renderer): - paths = self._paths + paths = self._container.paths n_paths = len(paths) if not self.filled or all(hatch is None for hatch in self.hatches): super().draw(renderer) @@ -1268,15 +1269,22 @@ def draw(self, renderer): if edgecolors.size == 0: edgecolors = ("none",) for idx in range(n_paths): - with self._cm_set( - paths=[paths[idx]], - hatch=self.hatches[idx % len(self.hatches)], - array=[self.get_array()[idx]], - linewidths=[self.get_linewidths()[idx % len(self.get_linewidths())]], - linestyles=[self.get_linestyles()[idx % len(self.get_linestyles())]], - edgecolors=edgecolors[idx % len(edgecolors)], - ): - super().draw(renderer) + contour = mcoll.PathCollection(paths=[paths[idx]]) + contour.update_from(self) + contour.set_linewidths( + [self.get_linewidths()[idx % len(self.get_linewidths())]] + ) + contour.set_linestyles( + [self.get_linestyles()[idx % len(self.get_linestyles())]] + ) + contour.set_edgecolors(edgecolors[idx % len(edgecolors)]) + contour.set_hatch(self.hatches[idx % len(self.hatches)]) + contour.set_array([self.get_array()[idx]]) + contour.set_norm(self.norm) + contour.set_cmap(self.cmap) + + contour.set_transform(self.get_transform()) + contour.draw(renderer) @_docstring.interpd diff --git a/lib/matplotlib/image.py b/lib/matplotlib/image.py index c1846f92608c..0e11dbfb7aba 100644 --- a/lib/matplotlib/image.py +++ b/lib/matplotlib/image.py @@ -6,6 +6,7 @@ import math import os import logging +from dataclasses import dataclass from pathlib import Path import warnings @@ -13,6 +14,7 @@ import PIL.Image import PIL.PngImagePlugin + import matplotlib as mpl from matplotlib import _api, cbook # For clarity, names from _image are given explicitly in this module @@ -20,6 +22,8 @@ # For user convenience, the names from _image are also imported into # the image namespace from matplotlib._image import * # noqa: F401, F403 +from ._data_containers.description import Desc +from ._data_containers._helpers import _get_graph, check_container import matplotlib.artist as martist import matplotlib.colorizer as mcolorizer from matplotlib.backend_bases import FigureCanvasBase @@ -28,6 +32,7 @@ Affine2D, BboxBase, Bbox, BboxTransform, BboxTransformTo, IdentityTransform, TransformedBbox) + _log = logging.getLogger(__name__) # map interpolation strings to module constants @@ -230,6 +235,71 @@ def _rgb_to_rgba(A): return rgba +@dataclass +class ImageContainer: + x: np.ndarray + y: np.ndarray + image: np.ndarray + + def describe(self): + imshape = list(self.image.shape) + imshape[:2] = ("M", "N") + + return { + "x": Desc((2,), "data"), + "y": Desc((2,), "data"), + "image": Desc(tuple(imshape), "data"), + } + + def query(self, graph, parent_coordinates="axes"): + return { + "x": self.x, + "y": self.y, + "image": self.image, + }, "" + # TODO hash + + +@dataclass +class NonUniformImageContainer(ImageContainer): + def describe(self): + imshape = list(self.image.shape) + imshape[:2] = ("M", "N") + + return { + "x": Desc(("M",), "data"), + "y": Desc(("N",), "data"), + "image": Desc(tuple(imshape), "data"), + } + + +@dataclass +class PcolorImageContainer(ImageContainer): + def describe(self): + imshape = list(self.image.shape) + imshape[:2] = ("M", "N") + + return { + "x": Desc(("M+1",), "data"), + "y": Desc(("N+1",), "data"), + "image": Desc(tuple(imshape), "data"), + } + + +@dataclass +class FigureImageContainer(ImageContainer): + def describe(self): + imshape = list(self.image.shape) + imshape[:2] = ("M", "N") + + return { + "x": Desc((), "data"), + "y": Desc((), "data"), + "image": Desc(tuple(imshape), "data"), + } + + + class _ImageBase(mcolorizer.ColorizingArtist): """ Base class for images. @@ -272,10 +342,51 @@ def __init__(self, ax, self.set_resample(resample) self.axes = ax + self._container = ImageContainer( + np.array([0.,1.]), + np.array([0.,1.]), + np.array([[]]), + ) + self.__query = None self._imcache = None self._internal_update(kwargs) + @property + def _query(self): + if self.__query is not None: + return self.__query + return self._container.query(_get_graph(self.axes))[0] + + def _cache_query(self): + self.__query = self._container.query(_get_graph(self.axes))[0] + + @property + def _image_array(self): + return self._query["image"] + + @property + def _A(self): + return self._image_array + + @_A.setter + def _A(self, val): + if val is None: + # This case is needed for the transition because + # ColorizingArtist sets `_A = None` during init + return + check_container(self, ImageContainer, "Setting _A") + self._container.image = self._normalize_image_array(val) + self._imcache = None + self.stale = True + + def set_container(self, container): + self._container = container + self.stale = True + + def get_container(self): + return self._container + def __str__(self): try: shape = self.get_shape() @@ -285,7 +396,11 @@ def __str__(self): def __getstate__(self): # Save some space on the pickle by not saving the cache. - return {**super().__getstate__(), "_imcache": None} + return { + **super().__getstate__(), + "_imcache": None, + "_ImageBase__query": None, + } def get_size(self): """Return the size of the image as tuple (numrows, numcols).""" @@ -295,10 +410,7 @@ def get_shape(self): """ Return the shape of the image as tuple (numrows, numcols, channels). """ - if self._A is None: - raise RuntimeError('You must first set the image array') - - return self._A.shape + return self._image_array.shape def set_alpha(self, alpha): """ @@ -388,6 +500,8 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0, "Your Artist's draw method must filter before " "this method is called.") + A = np.ma.asanyarray(A) + clipped_bbox = Bbox.intersection(out_bbox, clip_bbox) if clipped_bbox is None: @@ -596,6 +710,8 @@ def draw(self, renderer): if not self.get_visible(): self.stale = False return + # Update the cached version of the query + self._cache_query() # for empty images, there is nothing to draw! if self.get_array().size == 0: self.stale = False @@ -688,12 +804,16 @@ def set_data(self, A): ---------- A : array-like or `PIL.Image.Image` """ + check_container(self, ImageContainer, "'set_data'") if isinstance(A, PIL.Image.Image): A = pil_to_array(A) # Needed e.g. to apply png palette. - self._A = self._normalize_image_array(A) + self._container.image = self._normalize_image_array(A) self._imcache = None self.stale = True + def get_array(self): + return self._image_array + def set_array(self, A): """ Retained for backwards compatibility - use set_data instead. @@ -874,6 +994,7 @@ class AxesImage(_ImageBase): def __init__(self, ax, *, + A=None, cmap=None, norm=None, colorizer=None, @@ -887,8 +1008,6 @@ def __init__(self, ax, **kwargs ): - self._extent = extent - super().__init__( ax, cmap=cmap, @@ -903,21 +1022,31 @@ def __init__(self, ax, **kwargs ) + if A is not None: + self.set_data(A) + self.set_extent(extent) + elif extent is not None: + self.set_extent(extent) + def get_window_extent(self, renderer=None): - x0, x1, y0, y1 = self._extent + x0, x1, y0, y1 = self.get_extent() bbox = Bbox.from_extents([x0, y0, x1, y1]) return bbox.transformed(self.get_transform()) def make_image(self, renderer, magnification=1.0, unsampled=False): + q = self._query + x1, x2 = q["x"] + y1, y2 = q["y"] + + A = q["image"] + # docstring inherited trans = self.get_transform() - # image is created in the canvas coordinate. - x1, x2, y1, y2 = self.get_extent() bbox = Bbox(np.array([[x1, y1], [x2, y2]])) transformed_bbox = TransformedBbox(bbox, trans) clip = ((self.get_clip_box() or self.axes.bbox) if self.get_clip_on() else self.get_figure(root=True).bbox) - return self._make_image(self._A, bbox, transformed_bbox, clip, + return self._make_image(A, bbox, transformed_bbox, clip, magnification, unsampled=unsampled) def _check_unsampled_image(self): @@ -945,6 +1074,16 @@ def set_extent(self, extent, **kwargs): state is not changed, so a subsequent call to `.Axes.autoscale_view` will redo the autoscaling in accord with `~.Axes.dataLim`. """ + check_container(self, ImageContainer, "'set_extent'") + + if extent is None: + sz = self.get_size() + numrows, numcols = sz + if self.origin == 'upper': + extent = (-0.5, numcols-0.5, numrows-0.5, -0.5) + else: + extent = (-0.5, numcols-0.5, -0.5, numrows-0.5) + (xmin, xmax), (ymin, ymax) = self.axes._process_unit_info( [("x", [extent[0], extent[1]]), ("y", [extent[2], extent[3]])], @@ -961,7 +1100,15 @@ def set_extent(self, extent, **kwargs): ymax, self.convert_yunits) extent = [xmin, xmax, ymin, ymax] - self._extent = extent + self._container.x[:] = extent[:2] + self._container.y[:] = extent[2:] + self._update_autolims(xmin, xmax, ymin, ymax) + + def set_container(self, container): + super().set_container(container) + self._update_autolims(*self.get_extent()) + + def _update_autolims(self, xmin, xmax, ymin, ymax): corners = (xmin, ymin), (xmax, ymax) self.axes.update_datalim(corners) self.sticky_edges.x[:] = [xmin, xmax] @@ -974,15 +1121,10 @@ def set_extent(self, extent, **kwargs): def get_extent(self): """Return the image extent as tuple (left, right, bottom, top).""" - if self._extent is not None: - return self._extent - else: - sz = self.get_size() - numrows, numcols = sz - if self.origin == 'upper': - return (-0.5, numcols-0.5, numrows-0.5, -0.5) - else: - return (-0.5, numcols-0.5, -0.5, numrows-0.5) + q = self._query + x = q["x"] + y = q["y"] + return x[0], x[-1], y[0], y[-1] def get_cursor_data(self, event): """ @@ -1033,8 +1175,17 @@ def __init__(self, ax, *, interpolation='nearest', **kwargs): **kwargs All other keyword arguments are identical to those of `.AxesImage`. """ + if "A" in kwargs: + raise RuntimeError( + "'NonUniformImage' does not support setting array in init" + ) super().__init__(ax, **kwargs) self.set_interpolation(interpolation) + self._container = NonUniformImageContainer( + np.array([0.,1.]), + np.array([0.,1.]), + np.array([[np.nan]]), + ) def _check_unsampled_image(self): """Return False. Do not use unsampled image.""" @@ -1042,11 +1193,15 @@ def _check_unsampled_image(self): def make_image(self, renderer, magnification=1.0, unsampled=False): # docstring inherited - if self._A is None: - raise RuntimeError('You must first set the image array') if unsampled: raise ValueError('unsampled not supported on NonUniformImage') - A = self._A + + q = self._query + Ax = q["x"] + Ay = q["y"] + + A = q["image"] + if A.ndim == 2: if A.dtype != np.uint8: A = self.to_rgba(A, bytes=True) @@ -1072,8 +1227,8 @@ def make_image(self, renderer, magnification=1.0, unsampled=False): [(l, y) for y in np.linspace(b, t, height)])[:, 1] if self._interpolation == "nearest": - x_mid = (self._Ax[:-1] + self._Ax[1:]) / 2 - y_mid = (self._Ay[:-1] + self._Ay[1:]) / 2 + x_mid = (Ax[:-1] + Ax[1:]) / 2 + y_mid = (Ay[:-1] + Ay[1:]) / 2 x_int = x_mid.searchsorted(x_pix) y_int = y_mid.searchsorted(y_pix) # The following is equal to `A[y_int[:, None], x_int[None, :]]`, @@ -1086,16 +1241,16 @@ def make_image(self, renderer, magnification=1.0, unsampled=False): else: # self._interpolation == "bilinear" # Use np.interp to compute x_int/x_float has similar speed. x_int = np.clip( - self._Ax.searchsorted(x_pix) - 1, 0, len(self._Ax) - 2) + Ax.searchsorted(x_pix) - 1, 0, len(Ax) - 2) y_int = np.clip( - self._Ay.searchsorted(y_pix) - 1, 0, len(self._Ay) - 2) + Ay.searchsorted(y_pix) - 1, 0, len(Ay) - 2) idx_int = np.add.outer(y_int * A.shape[1], x_int) x_frac = np.clip( - np.divide(x_pix - self._Ax[x_int], np.diff(self._Ax)[x_int], + np.divide(x_pix - Ax[x_int], np.diff(Ax)[x_int], dtype=np.float32), # Downcasting helps with speed. 0, 1) y_frac = np.clip( - np.divide(y_pix - self._Ay[y_int], np.diff(self._Ay)[y_int], + np.divide(y_pix - Ay[y_int], np.diff(Ay)[y_int], dtype=np.float32), 0, 1) f00 = np.outer(1 - y_frac, 1 - x_frac) @@ -1127,14 +1282,15 @@ def set_data(self, x, y, A): (M, N) `~numpy.ndarray` or masked array of values to be colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array. """ + check_container(self, NonUniformImageContainer, "'set_data'") A = self._normalize_image_array(A) x = np.array(x, np.float32) y = np.array(y, np.float32) if not (x.ndim == y.ndim == 1 and A.shape[:2] == y.shape + x.shape): raise TypeError("Axes don't match array shape") - self._A = A - self._Ax = x - self._Ay = y + self._container.image = A + self._container.x = x + self._container.y = y self._imcache = None self.stale = True @@ -1153,11 +1309,6 @@ def set_interpolation(self, s): 'bilinear interpolations are supported') super().set_interpolation(s) - def get_extent(self): - if self._A is None: - raise RuntimeError('Must set data first') - return self._Ax[0], self._Ax[-1], self._Ay[0], self._Ay[-1] - def set_filternorm(self, filternorm): pass @@ -1165,24 +1316,29 @@ def set_filterrad(self, filterrad): pass def set_norm(self, norm): - if self._A is not None: - raise RuntimeError('Cannot change colors after loading data') + #if self._A is not None: + # raise RuntimeError('Cannot change colors after loading data') super().set_norm(norm) def set_cmap(self, cmap): - if self._A is not None: - raise RuntimeError('Cannot change colors after loading data') + #if self._A is not None: + # raise RuntimeError('Cannot change colors after loading data') super().set_cmap(cmap) def get_cursor_data(self, event): # docstring inherited + q = self._query + Ax = q["x"] + Ay = q["y"] + A = q["image"] + x, y = event.xdata, event.ydata - if (x < self._Ax[0] or x > self._Ax[-1] or - y < self._Ay[0] or y > self._Ay[-1]): + if (x < Ax[0] or x > Ax[-1] or + y < Ay[0] or y > Ay[-1]): return None - j = np.searchsorted(self._Ax, x) - 1 - i = np.searchsorted(self._Ay, y) - 1 - return self._A[i, j] + j = np.searchsorted(Ax, x) - 1 + i = np.searchsorted(Ay, y) - 1 + return A[i, j] class PcolorImage(AxesImage): @@ -1229,18 +1385,27 @@ def __init__(self, ax, """ super().__init__(ax, norm=norm, cmap=cmap, colorizer=colorizer) self._internal_update(kwargs) + self._container = PcolorImageContainer( + np.array([0.,1.]), + np.array([0.,1.]), + np.array([[np.nan]]), + ) if A is not None: self.set_data(x, y, A) def make_image(self, renderer, magnification=1.0, unsampled=False): # docstring inherited - if self._A is None: - raise RuntimeError('You must first set the image array') if unsampled: - raise ValueError('unsampled not supported on PColorImage') + raise ValueError('unsampled not supported on PcolorImage') + + q = self._query + Ax = q["x"] + Ay = q["y"] + + A = q["image"] if self._imcache is None: - A = self.to_rgba(self._A, bytes=True) + A = self.to_rgba(A, bytes=True) self._imcache = np.pad(A, [(1, 1), (1, 1), (0, 0)], "constant") padded_A = self._imcache bg = mcolors.to_rgba(self.axes.patch.get_facecolor(), 0) @@ -1257,8 +1422,8 @@ def make_image(self, renderer, magnification=1.0, unsampled=False): x_pix = np.linspace(vl.x0, vl.x1, width) y_pix = np.linspace(vl.y0, vl.y1, height) - x_int = self._Ax.searchsorted(x_pix) - y_int = self._Ay.searchsorted(y_pix) + x_int = Ax.searchsorted(x_pix) + y_int = Ay.searchsorted(y_pix) im = ( # See comment in NonUniformImage.make_image re: performance. padded_A.view(np.uint32).ravel()[ np.add.outer(y_int * padded_A.shape[1], x_int)] @@ -1286,6 +1451,7 @@ def set_data(self, x, y, A): - (M, N, 3): RGB array - (M, N, 4): RGBA array """ + check_container(self, PcolorImageContainer, "'set_data'") A = self._normalize_image_array(A) x = np.arange(0., A.shape[1] + 1) if x is None else np.array(x, float).ravel() y = np.arange(0., A.shape[0] + 1) if y is None else np.array(y, float).ravel() @@ -1300,9 +1466,9 @@ def set_data(self, x, y, A): if y[-1] < y[0]: y = y[::-1] A = A[::-1] - self._A = A - self._Ax = x - self._Ay = y + self._container.image = A + self._container.x = x + self._container.y = y self._imcache = None self.stale = True @@ -1311,13 +1477,18 @@ def set_array(self, *args): def get_cursor_data(self, event): # docstring inherited + q = self._query + Ax = q["x"] + Ay = q["y"] + A = q["image"] + x, y = event.xdata, event.ydata - if (x < self._Ax[0] or x > self._Ax[-1] or - y < self._Ay[0] or y > self._Ay[-1]): + if (x < Ax[0] or x > Ax[-1] or + y < Ay[0] or y > Ay[-1]): return None - j = np.searchsorted(self._Ax, x) - 1 - i = np.searchsorted(self._Ay, y) - 1 - return self._A[i, j] + j = np.searchsorted(Ax, x) - 1 + i = np.searchsorted(Ay, y) - 1 + return A[i, j] class FigureImage(_ImageBase): @@ -1351,40 +1522,48 @@ def __init__(self, fig, origin=origin ) self.set_figure(fig) - self.ox = offsetx - self.oy = offsety + self._container = FigureImageContainer( + np.array(offsetx), + np.array(offsety), + np.array([[]]), + ) self._internal_update(kwargs) self.magnification = 1.0 def get_extent(self): """Return the image extent as tuple (left, right, bottom, top).""" - numrows, numcols = self.get_size() - return (-0.5 + self.ox, numcols-0.5 + self.ox, - -0.5 + self.oy, numrows-0.5 + self.oy) + q = self._query + ox = q["x"] + oy = q["y"] + A = q["image"] + + numrows, numcols, *_ = A.shape + return (-0.5 + ox, numcols-0.5 + ox, + -0.5 + oy, numrows-0.5 + oy) def make_image(self, renderer, magnification=1.0, unsampled=False): # docstring inherited + q = self._query + ox = q["x"] + oy = q["y"] + A = q["image"] + fig = self.get_figure(root=True) fac = renderer.dpi/fig.dpi # fac here is to account for pdf, eps, svg backends where # figure.dpi is set to 72. This means we need to scale the # image (using magnification) and offset it appropriately. - bbox = Bbox([[self.ox/fac, self.oy/fac], - [(self.ox/fac + self._A.shape[1]), - (self.oy/fac + self._A.shape[0])]]) + bbox = Bbox([[ox/fac, oy/fac], + [(ox/fac + A.shape[1]), + (oy/fac + A.shape[0])]]) width, height = fig.get_size_inches() width *= renderer.dpi height *= renderer.dpi clip = Bbox([[0, 0], [width, height]]) return self._make_image( - self._A, bbox, bbox, clip, magnification=magnification / fac, + A, bbox, bbox, clip, magnification=magnification / fac, unsampled=unsampled, round_to_pixel_border=False) - def set_data(self, A): - """Set the image array.""" - super().set_data(A) - self.stale = True - class BboxImage(_ImageBase): """ @@ -1460,6 +1639,12 @@ def __init__(self, bbox, ) self.bbox = bbox + self._container = ImageContainer( + np.asarray([]), # Unused for BboxImage, kept for container hierarchy + np.asarray([]), # Unused for BboxImage, kept for container hierarchy + np.asarray([[]]), + ) + def get_window_extent(self, renderer=None): if isinstance(self.bbox, BboxBase): return self.bbox @@ -1480,6 +1665,9 @@ def contains(self, mouseevent): def make_image(self, renderer, magnification=1.0, unsampled=False): # docstring inherited + q = self._query + A = q["image"] + width, height = renderer.get_canvas_width_height() bbox_in = self.get_window_extent(renderer).frozen() bbox_in._points /= [width, height] @@ -1487,7 +1675,7 @@ def make_image(self, renderer, magnification=1.0, unsampled=False): clip = Bbox([[0, 0], [width, height]]) self._transform = BboxTransformTo(clip) return self._make_image( - self._A, + A, bbox_in, bbox_out, clip, magnification, unsampled=unsampled) diff --git a/lib/matplotlib/lines.py b/lib/matplotlib/lines.py index 7c374843b5c1..13f4cbe662d4 100644 --- a/lib/matplotlib/lines.py +++ b/lib/matplotlib/lines.py @@ -4,6 +4,7 @@ import copy +from dataclasses import dataclass from numbers import Integral, Number, Real import logging @@ -18,6 +19,8 @@ from .path import Path from .transforms import Bbox, BboxTransformTo, TransformedPath from ._enums import JoinStyle, CapStyle +from ._data_containers._helpers import containerize_draw, _get_graph, check_container +from ._data_containers.description import Desc # Imported here for backward compatibility, even though they don't # really belong. @@ -227,6 +230,26 @@ def _slice_or_none(in_v, slc): raise ValueError(f"markevery={markevery!r} is not a recognized value") +@dataclass +class LineContainer: + x: np.ndarray + y: np.ndarray + + def describe(self): + + return { + "x": Desc(("N",), "data"), + "y": Desc(("N",), "data"), + } + + def query(self, graph, parent_coordinates="axes"): + return { + "x": self.x, + "y": self.y, + }, "" + # TODO hash + + @_docstring.interpd @_api.define_aliases({ "antialiased": ["aa"], @@ -335,6 +358,9 @@ def __init__(self, xdata, ydata, *, """ super().__init__() + self._container = self._init_container() + self.__query = None + # Convert sequences to NumPy arrays. if not np.iterable(xdata): raise RuntimeError('xdata must be a sequence') @@ -413,13 +439,8 @@ def __init__(self, xdata, ydata, *, not isinstance(self._picker, bool)): self._pickradius = self._picker - self._xorig = np.asarray([]) - self._yorig = np.asarray([]) self._invalidx = True self._invalidy = True - self._x = None - self._y = None - self._xy = None self._path = None self._transformed_path = None self._subslice = False @@ -427,6 +448,51 @@ def __init__(self, xdata, ydata, *, self.set_data(xdata, ydata) + def set_container(self, container): + self._container = container + self.stale = True + + def get_container(self): + return self._container + + def _init_container(self): + return LineContainer( + x=np.array([]), + y=np.array([]), + ) + + @property + def _xorig(self): + return self._query["x"] + + @property + def _x(self): + xconv = self.convert_xunits(self._xorig) + return _to_unmasked_float_array(xconv).ravel() + + @property + def _yorig(self): + return self._query["y"] + + @property + def _y(self): + yconv = self.convert_yunits(self._yorig) + return _to_unmasked_float_array(yconv).ravel() + + @property + def _xy(self): + x, y = self._x, self._y + return np.column_stack(np.broadcast_arrays(x, y)).astype(float) + + @property + def _query(self): + if self.__query is not None: + return self.__query + return self._container.query(_get_graph(self.axes))[0] + + def _cache_query(self): + self.__query = self._container.query(_get_graph(self.axes))[0] + def contains(self, mouseevent): """ Test whether *mouseevent* occurred on the line. @@ -683,9 +749,6 @@ def recache(self, always=False): else: y = self._y - self._xy = np.column_stack(np.broadcast_arrays(x, y)).astype(float) - self._x, self._y = self._xy.T # views - self._subslice = False if (self.axes and len(x) > self._subslice_optim_min_size @@ -744,12 +807,15 @@ def set_transform(self, t): super().set_transform(t) @allow_rasterization - def draw(self, renderer): + @containerize_draw + def draw(self, renderer, *, graph=None): # docstring inherited if not self.get_visible(): return + self._cache_query() + if self._invalidy or self._invalidx: self.recache() self.ind_offset = 0 # Needed for contains() method. @@ -1297,9 +1363,11 @@ def set_xdata(self, x): set_data set_ydata """ + check_container(self, LineContainer, "'set_xdata'") if not np.iterable(x): raise RuntimeError('x must be a sequence') - self._xorig = copy.copy(x) + self._container.x = copy.copy(x) + self.__query = None self._invalidx = True self.stale = True @@ -1316,9 +1384,11 @@ def set_ydata(self, y): set_data set_xdata """ + check_container(self, LineContainer, "'set_ydata'") if not np.iterable(y): raise RuntimeError('y must be a sequence') - self._yorig = copy.copy(y) + self._container.y = copy.copy(y) + self.__query = None self._invalidy = True self.stale = True diff --git a/lib/matplotlib/meson.build b/lib/matplotlib/meson.build index c0bfdb227e2e..e799b4338ef8 100644 --- a/lib/matplotlib/meson.build +++ b/lib/matplotlib/meson.build @@ -168,3 +168,4 @@ subdir('style') subdir('testing') subdir('tests') subdir('tri') +subdir('_data_containers') diff --git a/lib/matplotlib/quiver.py b/lib/matplotlib/quiver.py index 9ffcec5117d9..43ab1e37602b 100644 --- a/lib/matplotlib/quiver.py +++ b/lib/matplotlib/quiver.py @@ -980,10 +980,11 @@ def __init__(self, ax, *args, # Make a collection barb_size = self._length ** 2 / 4 # Empirically determined super().__init__( - [], (barb_size,), offsets=xy, offset_transform=transform, **kwargs) + [], (barb_size,), offsets=None, offset_transform=transform, **kwargs) self.set_transform(transforms.IdentityTransform()) self.set_UVC(u, v, c) + self.set_offsets(xy) # Call after super init/set_UVC because it references UVC def _find_tails(self, mag, rounding=True, half=5, full=10, flag=50): """ diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index c062e8c12b9c..964c57055e67 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -430,9 +430,9 @@ def test_EllipseCollection_setter_getter(): offset_transform=ax.transData, ) - assert_array_almost_equal(ec._widths, np.array(widths).ravel() * 0.5) - assert_array_almost_equal(ec._heights, np.array(heights).ravel() * 0.5) - assert_array_almost_equal(ec._angles, np.deg2rad(angles).ravel()) + assert_array_almost_equal(ec._container.widths, np.array(widths).ravel() * 0.5) + assert_array_almost_equal(ec._container.heights, np.array(heights).ravel() * 0.5) + assert_array_almost_equal(ec._container.angles, np.deg2rad(angles).ravel()) assert_array_almost_equal(ec.get_widths(), widths) assert_array_almost_equal(ec.get_heights(), heights) @@ -837,16 +837,16 @@ def test_collection_set_verts_array(): verts = np.arange(80, dtype=np.double).reshape(10, 4, 2) col_arr = PolyCollection(verts) col_list = PolyCollection(list(verts)) - assert len(col_arr._paths) == len(col_list._paths) - for ap, lp in zip(col_arr._paths, col_list._paths): + assert len(col_arr.get_paths()) == len(col_list.get_paths()) + for ap, lp in zip(col_arr.get_paths(), col_list.get_paths()): assert np.array_equal(ap._vertices, lp._vertices) assert np.array_equal(ap._codes, lp._codes) verts_tuple = np.empty(10, dtype=object) verts_tuple[:] = [tuple(tuple(y) for y in x) for x in verts] col_arr_tuple = PolyCollection(verts_tuple) - assert len(col_arr._paths) == len(col_arr_tuple._paths) - for ap, atp in zip(col_arr._paths, col_arr_tuple._paths): + assert len(col_arr.get_paths()) == len(col_arr_tuple.get_paths()) + for ap, atp in zip(col_arr.get_paths(), col_arr_tuple.get_paths()): assert np.array_equal(ap._vertices, atp._vertices) assert np.array_equal(ap._codes, atp._codes) diff --git a/lib/matplotlib/tests/test_image.py b/lib/matplotlib/tests/test_image.py index 9b598fbf7193..bdf24126a301 100644 --- a/lib/matplotlib/tests/test_image.py +++ b/lib/matplotlib/tests/test_image.py @@ -818,7 +818,9 @@ def test_setdata_xya(image_cls, x, y, a): im = image_cls(ax) im.set_data(x, y, a) x[0] = y[0] = a[0, 0] = 9.9 - assert im._A[0, 0] == im._Ax[0] == im._Ay[0] == 0, 'value changed' + Ax = im._container.x + Ay = im._container.y + assert im._A[0, 0] == Ax[0] == Ay[0] == 0, 'value changed' im.set_data(x, y, a.reshape((*a.shape, -1))) # Just a smoketest. @@ -1697,8 +1699,8 @@ def test_axesimage_get_shape(): # generate dummy image to test get_shape method ax = plt.gca() im = AxesImage(ax) - with pytest.raises(RuntimeError, match="You must first set the image array"): - im.get_shape() + # Initial behavior is an empty 2D array + assert im.get_shape() == (1, 0) z = np.arange(12, dtype=float).reshape((4, 3)) im.set_data(z) assert im.get_shape() == (4, 3)