1010from numpy import ma
1111
1212from matplotlib import rcParams
13- from matplotlib import artist as martist
14- from matplotlib import colors as mcolors
15- from matplotlib import cm
13+ import matplotlib .artist as martist
14+ import matplotlib .colors as mcolors
15+ import matplotlib .cm as cm
16+ import matplotlib .cbook as cbook
1617
1718# For clarity, names from _image are given explicitly in this module:
18- from matplotlib import _image
19- from matplotlib import _png
19+ import matplotlib . _image as _image
20+ import matplotlib . _png as _png
2021
2122# For user convenience, the names from _image are also imported into
2223# the image namespace:
@@ -238,7 +239,8 @@ def draw(self, renderer, *args, **kwargs):
238239 clippath , affine )
239240
240241 def contains (self , mouseevent ):
241- """Test whether the mouse event occured within the image.
242+ """
243+ Test whether the mouse event occured within the image.
242244 """
243245 if callable (self ._contains ): return self ._contains (self ,mouseevent )
244246 # TODO: make sure this is consistent with patch and patch
@@ -271,18 +273,17 @@ def write_png(self, fname, noscale=False):
271273 rows , cols , buffer = im .as_rgba_str ()
272274 _png .write_png (buffer , cols , rows , fname )
273275
274- def set_data (self , A , shape = None ):
276+ def set_data (self , A ):
275277 """
276278 Set the image array
277279
278- ACCEPTS: numpy/PIL Image A"""
280+ ACCEPTS: numpy/PIL Image A
281+ """
279282 # check if data is PIL Image without importing Image
280283 if hasattr (A ,'getpixel' ):
281284 self ._A = pil_to_array (A )
282- elif ma .isMA (A ):
283- self ._A = A
284285 else :
285- self ._A = np . asarray (A ) # assume array
286+ self ._A = cbook . safe_masked_invalid (A )
286287
287288 if self ._A .dtype != np .uint8 and not np .can_cast (self ._A .dtype , np .float ):
288289 raise TypeError ("Image data can not convert to float" )
@@ -310,7 +311,8 @@ def set_array(self, A):
310311
311312
312313 def set_extent (self , extent ):
313- """extent is data axes (left, right, bottom, top) for making image plots
314+ """
315+ extent is data axes (left, right, bottom, top) for making image plots
314316 """
315317 self ._extent = extent
316318
@@ -375,7 +377,8 @@ def get_extent(self):
375377 return (- 0.5 , numcols - 0.5 , - 0.5 , numrows - 0.5 )
376378
377379 def set_filternorm (self , filternorm ):
378- """Set whether the resize filter norms the weights -- see
380+ """
381+ Set whether the resize filter norms the weights -- see
379382 help for imshow
380383
381384 ACCEPTS: 0 or 1
@@ -390,7 +393,8 @@ def get_filternorm(self):
390393 return self ._filternorm
391394
392395 def set_filterrad (self , filterrad ):
393- """Set the resize filter radius only applicable to some
396+ """
397+ Set the resize filter radius only applicable to some
394398 interpolation schemes -- see help for imshow
395399
396400 ACCEPTS: positive float
@@ -405,9 +409,11 @@ def get_filterrad(self):
405409
406410
407411class NonUniformImage (AxesImage ):
408- def __init__ (self , ax ,
409- ** kwargs
410- ):
412+ def __init__ (self , ax , ** kwargs ):
413+ """
414+ kwargs are identical to those for AxesImage, except
415+ that 'interpolation' defaults to 'nearest'
416+ """
411417 interp = kwargs .pop ('interpolation' , 'nearest' )
412418 AxesImage .__init__ (self , ax ,
413419 ** kwargs )
@@ -434,10 +440,19 @@ def make_image(self, magnification=1.0):
434440 return im
435441
436442 def set_data (self , x , y , A ):
443+ """
444+ Set the grid for the pixel centers, and the pixel values.
445+
446+ *x* and *y* are 1-D ndarrays of lengths N and M, respectively,
447+ specifying pixel centers
448+
449+ *A* is an (M,N) ndarray or masked array of values to be
450+ colormapped, or a (M,N,3) RGB array, or a (M,N,4) RGBA
451+ array.
452+ """
437453 x = np .asarray (x ,np .float32 )
438454 y = np .asarray (y ,np .float32 )
439- if not ma .isMA (A ):
440- A = np .asarray (A )
455+ A = cbook .safe_masked_invalid (A )
441456 if len (x .shape ) != 1 or len (y .shape ) != 1 \
442457 or A .shape [0 :2 ] != (y .shape [0 ], x .shape [0 ]):
443458 raise TypeError ("Axes don't match array shape" )
@@ -567,8 +582,7 @@ def draw(self, renderer, *args, **kwargs):
567582
568583
569584 def set_data (self , x , y , A ):
570- if not ma .isMA (A ):
571- A = np .asarray (A )
585+ A = cbook .safe_masked_invalid (A )
572586 if x is None :
573587 x = np .arange (0 , A .shape [1 ]+ 1 , dtype = np .float64 )
574588 else :
@@ -666,6 +680,19 @@ def get_extent(self):
666680 return (- 0.5 + self .ox , numcols - 0.5 + self .ox ,
667681 - 0.5 + self .oy , numrows - 0.5 + self .oy )
668682
683+ def set_data (self , A ):
684+ """
685+ Set the image array
686+
687+ """
688+ cm .ScalarMappable .set_array (self , cbook .safe_masked_invalid (A ))
689+
690+ def set_array (self , A ):
691+ """
692+ Deprecated; use set_data for consistency with other image types.
693+ """
694+ self .set_data (A )
695+
669696 def make_image (self , magnification = 1.0 ):
670697 if self ._A is None :
671698 raise RuntimeError ('You must first set the image array' )
0 commit comments