@@ -688,6 +688,34 @@ def write_png(self, fname):
688688 bytes = True , norm = True )
689689 PIL .Image .fromarray (im ).save (fname , format = "png" )
690690
691+ @staticmethod
692+ def _normalize_image_array (A ):
693+ # Common checks and typecasts on A for various Image subclasses.
694+ A = cbook .safe_masked_invalid (A , copy = True )
695+ if A .dtype != np .uint8 and not np .can_cast (A .dtype , float , "same_kind" ):
696+ raise TypeError (f"Image data of dtype { A .dtype } cannot be "
697+ f"converted to float" )
698+ if A .ndim == 3 and A .shape [- 1 ] == 1 :
699+ A = A .squeeze (- 1 ) # If just (M, N, 1), assume scalar and apply colormap.
700+ if not (A .ndim == 2 or A .ndim == 3 and A .shape [- 1 ] in [3 , 4 ]):
701+ raise TypeError (f"Invalid shape { A .shape } for image data" )
702+ if A .ndim == 3 :
703+ # If the input data has values outside the valid range (after
704+ # normalisation), we issue a warning and then clip X to the bounds
705+ # - otherwise casting wraps extreme values, hiding outliers and
706+ # making reliable interpretation impossible.
707+ high = 255 if np .issubdtype (A .dtype , np .integer ) else 1
708+ if A .min () < 0 or high < A .max ():
709+ _log .warning (
710+ 'Clipping input data to the valid range for imshow with '
711+ 'RGB data ([0..1] for floats or [0..255] for integers).'
712+ )
713+ A = np .clip (A , 0 , high )
714+ # Cast unsupported integer types to uint8
715+ if A .dtype != np .uint8 and np .issubdtype (A .dtype , np .integer ):
716+ A = A .astype (np .uint8 )
717+ return A
718+
691719 def set_data (self , A ):
692720 """
693721 Set the image array.
@@ -700,38 +728,7 @@ def set_data(self, A):
700728 """
701729 if isinstance (A , PIL .Image .Image ):
702730 A = pil_to_array (A ) # Needed e.g. to apply png palette.
703- self ._A = cbook .safe_masked_invalid (A , copy = True )
704-
705- if (self ._A .dtype != np .uint8 and
706- not np .can_cast (self ._A .dtype , float , "same_kind" )):
707- raise TypeError (f"Image data of dtype { self ._A .dtype } cannot be "
708- "converted to float" )
709-
710- if self ._A .ndim == 3 and self ._A .shape [- 1 ] == 1 :
711- # If just one dimension assume scalar and apply colormap
712- self ._A = self ._A [:, :, 0 ]
713-
714- if not (self ._A .ndim == 2
715- or self ._A .ndim == 3 and self ._A .shape [- 1 ] in [3 , 4 ]):
716- raise TypeError (f"Invalid shape { self ._A .shape } for image data" )
717-
718- if self ._A .ndim == 3 :
719- # If the input data has values outside the valid range (after
720- # normalisation), we issue a warning and then clip X to the bounds
721- # - otherwise casting wraps extreme values, hiding outliers and
722- # making reliable interpretation impossible.
723- high = 255 if np .issubdtype (self ._A .dtype , np .integer ) else 1
724- if self ._A .min () < 0 or high < self ._A .max ():
725- _log .warning (
726- 'Clipping input data to the valid range for imshow with '
727- 'RGB data ([0..1] for floats or [0..255] for integers).'
728- )
729- self ._A = np .clip (self ._A , 0 , high )
730- # Cast unsupported integer types to uint8
731- if self ._A .dtype != np .uint8 and np .issubdtype (self ._A .dtype ,
732- np .integer ):
733- self ._A = self ._A .astype (np .uint8 )
734-
731+ self ._A = self ._normalize_image_array (A )
735732 self ._imcache = None
736733 self .stale = True
737734
@@ -1149,23 +1146,15 @@ def set_data(self, x, y, A):
11491146 (M, N) `~numpy.ndarray` or masked array of values to be
11501147 colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array.
11511148 """
1149+ A = self ._normalize_image_array (A )
11521150 x = np .array (x , np .float32 )
11531151 y = np .array (y , np .float32 )
1154- A = cbook .safe_masked_invalid (A , copy = True )
1155- if not (x .ndim == y .ndim == 1 and A .shape [0 :2 ] == y .shape + x .shape ):
1152+ if not (x .ndim == y .ndim == 1 and A .shape [:2 ] == y .shape + x .shape ):
11561153 raise TypeError ("Axes don't match array shape" )
1157- if A .ndim not in [2 , 3 ]:
1158- raise TypeError ("Can only plot 2D or 3D data" )
1159- if A .ndim == 3 and A .shape [2 ] not in [1 , 3 , 4 ]:
1160- raise TypeError ("3D arrays must have three (RGB) "
1161- "or four (RGBA) color components" )
1162- if A .ndim == 3 and A .shape [2 ] == 1 :
1163- A = A .squeeze (axis = - 1 )
11641154 self ._A = A
11651155 self ._Ax = x
11661156 self ._Ay = y
11671157 self ._imcache = None
1168-
11691158 self .stale = True
11701159
11711160 def set_array (self , * args ):
@@ -1307,36 +1296,20 @@ def set_data(self, x, y, A):
13071296 - (M, N, 3): RGB array
13081297 - (M, N, 4): RGBA array
13091298 """
1310- A = cbook .safe_masked_invalid (A , copy = True )
1311- if x is None :
1312- x = np .arange (0 , A .shape [1 ]+ 1 , dtype = np .float64 )
1313- else :
1314- x = np .array (x , np .float64 ).ravel ()
1315- if y is None :
1316- y = np .arange (0 , A .shape [0 ]+ 1 , dtype = np .float64 )
1317- else :
1318- y = np .array (y , np .float64 ).ravel ()
1319-
1320- if A .shape [:2 ] != (y .size - 1 , x .size - 1 ):
1299+ A = self ._normalize_image_array (A )
1300+ x = np .arange (0. , A .shape [1 ] + 1 ) if x is None else np .array (x , float ).ravel ()
1301+ y = np .arange (0. , A .shape [0 ] + 1 ) if y is None else np .array (y , float ).ravel ()
1302+ if A .shape [:2 ] != (y .size - 1 , x .size - 1 ):
13211303 raise ValueError (
13221304 "Axes don't match array shape. Got %s, expected %s." %
13231305 (A .shape [:2 ], (y .size - 1 , x .size - 1 )))
1324- if A .ndim not in [2 , 3 ]:
1325- raise ValueError ("A must be 2D or 3D" )
1326- if A .ndim == 3 :
1327- if A .shape [2 ] == 1 :
1328- A = A .squeeze (axis = - 1 )
1329- elif A .shape [2 ] not in [3 , 4 ]:
1330- raise ValueError ("3D arrays must have RGB or RGBA as last dim" )
1331-
13321306 # For efficient cursor readout, ensure x and y are increasing.
13331307 if x [- 1 ] < x [0 ]:
13341308 x = x [::- 1 ]
13351309 A = A [:, ::- 1 ]
13361310 if y [- 1 ] < y [0 ]:
13371311 y = y [::- 1 ]
13381312 A = A [::- 1 ]
1339-
13401313 self ._A = A
13411314 self ._Ax = x
13421315 self ._Ay = y
0 commit comments