66import math
77import os
88import logging
9+ from dataclasses import dataclass
910from pathlib import Path
1011import warnings
1112
1213import numpy as np
1314import PIL .Image
1415import PIL .PngImagePlugin
1516
17+ from mpl_data_containers .description import Desc
18+ from mpl_data_containers .conversion_edge import Graph
19+
1620import matplotlib as mpl
1721from matplotlib import _api , cbook
1822# For clarity, names from _image are given explicitly in this module
2832 Affine2D , BboxBase , Bbox , BboxTransform , BboxTransformTo ,
2933 IdentityTransform , TransformedBbox )
3034
35+
3136_log = logging .getLogger (__name__ )
3237
3338# map interpolation strings to module constants
@@ -230,6 +235,31 @@ def _rgb_to_rgba(A):
230235 return rgba
231236
232237
238+ @dataclass
239+ class ImageContainer :
240+ x : np .ndarray
241+ y : np .ndarray
242+ image : np .ndarray
243+
244+ def describe (self ):
245+ imshape = list (self .image .shape )
246+ imshape [:2 ] = ("M" , "N" )
247+
248+ return {
249+ "x" : Desc ((2 ,), "data" ),
250+ "y" : Desc ((2 ,), "data" ),
251+ "image" : Desc (tuple (imshape ), "data" ),
252+ }
253+
254+ def query (self , graph , parent_coordinates = "axes" ):
255+ return {
256+ "x" : self .x ,
257+ "y" : self .y ,
258+ "image" : self .image ,
259+ }, ""
260+ # TODO hash
261+
262+
233263class _ImageBase (mcolorizer .ColorizingArtist ):
234264 """
235265 Base class for images.
@@ -272,10 +302,38 @@ def __init__(self, ax,
272302 self .set_resample (resample )
273303 self .axes = ax
274304
305+ self ._container = ImageContainer (
306+ np .array ([0. ,1. ]),
307+ np .array ([0. ,1. ]),
308+ np .array ([[]]),
309+ )
275310 self ._imcache = None
276311
277312 self ._internal_update (kwargs )
278313
314+ @property
315+ def _image_array (self ):
316+ return self ._container .query (self ._get_graph ())[0 ]["image" ]
317+
318+ @property
319+ def _A (self ):
320+ return self ._image_array
321+
322+ @_A .setter
323+ def _A (self , val ):
324+ return
325+
326+ def set_container (self , container ):
327+ self ._container = container
328+ self .stale = True
329+
330+ def get_container (self ):
331+ return self ._container
332+
333+ def _get_graph (self ):
334+ # TODO actually fill out graph
335+ return Graph ([])
336+
279337 def __str__ (self ):
280338 try :
281339 shape = self .get_shape ()
@@ -295,10 +353,7 @@ def get_shape(self):
295353 """
296354 Return the shape of the image as tuple (numrows, numcols, channels).
297355 """
298- if self ._A is None :
299- raise RuntimeError ('You must first set the image array' )
300-
301- return self ._A .shape
356+ return self ._image_array .shape
302357
303358 def set_alpha (self , alpha ):
304359 """
@@ -388,6 +443,8 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0,
388443 "Your Artist's draw method must filter before "
389444 "this method is called." )
390445
446+ A = np .ma .asanyarray (A )
447+
391448 clipped_bbox = Bbox .intersection (out_bbox , clip_bbox )
392449
393450 if clipped_bbox is None :
@@ -688,12 +745,18 @@ def set_data(self, A):
688745 ----------
689746 A : array-like or `PIL.Image.Image`
690747 """
748+ if not isinstance (self ._container , ImageContainer ):
749+ raise TypeError ("Cannot use 'set_data' on custom container types" )
691750 if isinstance (A , PIL .Image .Image ):
692751 A = pil_to_array (A ) # Needed e.g. to apply png palette.
693- self ._A = self ._normalize_image_array (A )
752+ # self._A = self._normalize_image_array(A)
753+ self ._container .image = self ._normalize_image_array (A )
694754 self ._imcache = None
695755 self .stale = True
696756
757+ def get_array (self ):
758+ return self ._image_array
759+
697760 def set_array (self , A ):
698761 """
699762 Retained for backwards compatibility - use set_data instead.
@@ -874,6 +937,7 @@ class AxesImage(_ImageBase):
874937
875938 def __init__ (self , ax ,
876939 * ,
940+ A = None ,
877941 cmap = None ,
878942 norm = None ,
879943 colorizer = None ,
@@ -887,8 +951,6 @@ def __init__(self, ax,
887951 ** kwargs
888952 ):
889953
890- self ._extent = extent
891-
892954 super ().__init__ (
893955 ax ,
894956 cmap = cmap ,
@@ -903,21 +965,31 @@ def __init__(self, ax,
903965 ** kwargs
904966 )
905967
968+ if A is not None :
969+ self .set_data (A )
970+ self .set_extent (extent )
971+ elif extent is not None :
972+ self .set_extent (extent )
973+
906974 def get_window_extent (self , renderer = None ):
907- x0 , x1 , y0 , y1 = self ._extent
975+ x0 , x1 , y0 , y1 = self .get_extent ()
908976 bbox = Bbox .from_extents ([x0 , y0 , x1 , y1 ])
909977 return bbox .transformed (self .get_transform ())
910978
911979 def make_image (self , renderer , magnification = 1.0 , unsampled = False ):
980+ q , _ = self ._container .query (self ._get_graph ())
981+ x1 , x2 = q ["x" ]
982+ y1 , y2 = q ["y" ]
983+
984+ A = q ["image" ]
985+
912986 # docstring inherited
913987 trans = self .get_transform ()
914- # image is created in the canvas coordinate.
915- x1 , x2 , y1 , y2 = self .get_extent ()
916988 bbox = Bbox (np .array ([[x1 , y1 ], [x2 , y2 ]]))
917989 transformed_bbox = TransformedBbox (bbox , trans )
918990 clip = ((self .get_clip_box () or self .axes .bbox ) if self .get_clip_on ()
919991 else self .get_figure (root = True ).bbox )
920- return self ._make_image (self . _A , bbox , transformed_bbox , clip ,
992+ return self ._make_image (A , bbox , transformed_bbox , clip ,
921993 magnification , unsampled = unsampled )
922994
923995 def _check_unsampled_image (self ):
@@ -945,6 +1017,17 @@ def set_extent(self, extent, **kwargs):
9451017 state is not changed, so a subsequent call to `.Axes.autoscale_view`
9461018 will redo the autoscaling in accord with `~.Axes.dataLim`.
9471019 """
1020+ if not isinstance (self ._container , ImageContainer ):
1021+ raise TypeError ("Cannot use 'set_data' on custom container types" )
1022+
1023+ if extent is None :
1024+ sz = self .get_size ()
1025+ numrows , numcols = sz
1026+ if self .origin == 'upper' :
1027+ extent = (- 0.5 , numcols - 0.5 , numrows - 0.5 , - 0.5 )
1028+ else :
1029+ extent = (- 0.5 , numcols - 0.5 , - 0.5 , numrows - 0.5 )
1030+
9481031 (xmin , xmax ), (ymin , ymax ) = self .axes ._process_unit_info (
9491032 [("x" , [extent [0 ], extent [1 ]]),
9501033 ("y" , [extent [2 ], extent [3 ]])],
@@ -961,7 +1044,15 @@ def set_extent(self, extent, **kwargs):
9611044 ymax , self .convert_yunits )
9621045 extent = [xmin , xmax , ymin , ymax ]
9631046
964- self ._extent = extent
1047+ self ._container .x [:] = extent [:2 ]
1048+ self ._container .y [:] = extent [2 :]
1049+ self ._update_autolims (xmin , xmax , ymin , ymax )
1050+
1051+ def set_container (self , container ):
1052+ super ().set_container (container )
1053+ self ._update_autolims (* self .get_extent ())
1054+
1055+ def _update_autolims (self , xmin , xmax , ymin , ymax ):
9651056 corners = (xmin , ymin ), (xmax , ymax )
9661057 self .axes .update_datalim (corners )
9671058 self .sticky_edges .x [:] = [xmin , xmax ]
@@ -974,15 +1065,10 @@ def set_extent(self, extent, **kwargs):
9741065
9751066 def get_extent (self ):
9761067 """Return the image extent as tuple (left, right, bottom, top)."""
977- if self ._extent is not None :
978- return self ._extent
979- else :
980- sz = self .get_size ()
981- numrows , numcols = sz
982- if self .origin == 'upper' :
983- return (- 0.5 , numcols - 0.5 , numrows - 0.5 , - 0.5 )
984- else :
985- return (- 0.5 , numcols - 0.5 , - 0.5 , numrows - 0.5 )
1068+ q , _ = self ._container .query (self ._get_graph ())
1069+ x = q ["x" ]
1070+ y = q ["y" ]
1071+ return x [0 ], x [- 1 ], y [0 ], y [- 1 ]
9861072
9871073 def get_cursor_data (self , event ):
9881074 """
@@ -1033,6 +1119,10 @@ def __init__(self, ax, *, interpolation='nearest', **kwargs):
10331119 **kwargs
10341120 All other keyword arguments are identical to those of `.AxesImage`.
10351121 """
1122+ if "A" in kwargs :
1123+ raise RuntimeError (
1124+ "'NonUniformImage' does not support setting array in init"
1125+ )
10361126 super ().__init__ (ax , ** kwargs )
10371127 self .set_interpolation (interpolation )
10381128
0 commit comments