@@ -688,6 +688,34 @@ def write_png(self, fname):
688
688
bytes = True , norm = True )
689
689
PIL .Image .fromarray (im ).save (fname , format = "png" )
690
690
691
+ @staticmethod
692
+ def _normalize_image_array (A ):
693
+ # Common checks and typecasts on A for various Image subclasses.
694
+ A = cbook .safe_masked_invalid (A , copy = True )
695
+ if A .dtype != np .uint8 and not np .can_cast (A .dtype , float , "same_kind" ):
696
+ raise TypeError (f"Image data of dtype { A .dtype } cannot be "
697
+ f"converted to float" )
698
+ if A .ndim == 3 and A .shape [- 1 ] == 1 :
699
+ A = A .squeeze (- 1 ) # If just (M, N, 1), assume scalar and apply colormap.
700
+ if not (A .ndim == 2 or A .ndim == 3 and A .shape [- 1 ] in [3 , 4 ]):
701
+ raise TypeError (f"Invalid shape { A .shape } for image data" )
702
+ if A .ndim == 3 :
703
+ # If the input data has values outside the valid range (after
704
+ # normalisation), we issue a warning and then clip X to the bounds
705
+ # - otherwise casting wraps extreme values, hiding outliers and
706
+ # making reliable interpretation impossible.
707
+ high = 255 if np .issubdtype (A .dtype , np .integer ) else 1
708
+ if A .min () < 0 or high < A .max ():
709
+ _log .warning (
710
+ 'Clipping input data to the valid range for imshow with '
711
+ 'RGB data ([0..1] for floats or [0..255] for integers).'
712
+ )
713
+ A = np .clip (A , 0 , high )
714
+ # Cast unsupported integer types to uint8
715
+ if A .dtype != np .uint8 and np .issubdtype (A .dtype , np .integer ):
716
+ A = A .astype (np .uint8 )
717
+ return A
718
+
691
719
def set_data (self , A ):
692
720
"""
693
721
Set the image array.
@@ -700,38 +728,7 @@ def set_data(self, A):
700
728
"""
701
729
if isinstance (A , PIL .Image .Image ):
702
730
A = pil_to_array (A ) # Needed e.g. to apply png palette.
703
- self ._A = cbook .safe_masked_invalid (A , copy = True )
704
-
705
- if (self ._A .dtype != np .uint8 and
706
- not np .can_cast (self ._A .dtype , float , "same_kind" )):
707
- raise TypeError (f"Image data of dtype { self ._A .dtype } cannot be "
708
- "converted to float" )
709
-
710
- if self ._A .ndim == 3 and self ._A .shape [- 1 ] == 1 :
711
- # If just one dimension assume scalar and apply colormap
712
- self ._A = self ._A [:, :, 0 ]
713
-
714
- if not (self ._A .ndim == 2
715
- or self ._A .ndim == 3 and self ._A .shape [- 1 ] in [3 , 4 ]):
716
- raise TypeError (f"Invalid shape { self ._A .shape } for image data" )
717
-
718
- if self ._A .ndim == 3 :
719
- # If the input data has values outside the valid range (after
720
- # normalisation), we issue a warning and then clip X to the bounds
721
- # - otherwise casting wraps extreme values, hiding outliers and
722
- # making reliable interpretation impossible.
723
- high = 255 if np .issubdtype (self ._A .dtype , np .integer ) else 1
724
- if self ._A .min () < 0 or high < self ._A .max ():
725
- _log .warning (
726
- 'Clipping input data to the valid range for imshow with '
727
- 'RGB data ([0..1] for floats or [0..255] for integers).'
728
- )
729
- self ._A = np .clip (self ._A , 0 , high )
730
- # Cast unsupported integer types to uint8
731
- if self ._A .dtype != np .uint8 and np .issubdtype (self ._A .dtype ,
732
- np .integer ):
733
- self ._A = self ._A .astype (np .uint8 )
734
-
731
+ self ._A = self ._normalize_image_array (A )
735
732
self ._imcache = None
736
733
self .stale = True
737
734
@@ -1149,23 +1146,15 @@ def set_data(self, x, y, A):
1149
1146
(M, N) `~numpy.ndarray` or masked array of values to be
1150
1147
colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array.
1151
1148
"""
1149
+ A = self ._normalize_image_array (A )
1152
1150
x = np .array (x , np .float32 )
1153
1151
y = np .array (y , np .float32 )
1154
- A = cbook .safe_masked_invalid (A , copy = True )
1155
- if not (x .ndim == y .ndim == 1 and A .shape [0 :2 ] == y .shape + x .shape ):
1152
+ if not (x .ndim == y .ndim == 1 and A .shape [:2 ] == y .shape + x .shape ):
1156
1153
raise TypeError ("Axes don't match array shape" )
1157
- if A .ndim not in [2 , 3 ]:
1158
- raise TypeError ("Can only plot 2D or 3D data" )
1159
- if A .ndim == 3 and A .shape [2 ] not in [1 , 3 , 4 ]:
1160
- raise TypeError ("3D arrays must have three (RGB) "
1161
- "or four (RGBA) color components" )
1162
- if A .ndim == 3 and A .shape [2 ] == 1 :
1163
- A = A .squeeze (axis = - 1 )
1164
1154
self ._A = A
1165
1155
self ._Ax = x
1166
1156
self ._Ay = y
1167
1157
self ._imcache = None
1168
-
1169
1158
self .stale = True
1170
1159
1171
1160
def set_array (self , * args ):
@@ -1307,36 +1296,20 @@ def set_data(self, x, y, A):
1307
1296
- (M, N, 3): RGB array
1308
1297
- (M, N, 4): RGBA array
1309
1298
"""
1310
- A = cbook .safe_masked_invalid (A , copy = True )
1311
- if x is None :
1312
- x = np .arange (0 , A .shape [1 ]+ 1 , dtype = np .float64 )
1313
- else :
1314
- x = np .array (x , np .float64 ).ravel ()
1315
- if y is None :
1316
- y = np .arange (0 , A .shape [0 ]+ 1 , dtype = np .float64 )
1317
- else :
1318
- y = np .array (y , np .float64 ).ravel ()
1319
-
1320
- if A .shape [:2 ] != (y .size - 1 , x .size - 1 ):
1299
+ A = self ._normalize_image_array (A )
1300
+ x = np .arange (0. , A .shape [1 ] + 1 ) if x is None else np .array (x , float ).ravel ()
1301
+ y = np .arange (0. , A .shape [0 ] + 1 ) if y is None else np .array (y , float ).ravel ()
1302
+ if A .shape [:2 ] != (y .size - 1 , x .size - 1 ):
1321
1303
raise ValueError (
1322
1304
"Axes don't match array shape. Got %s, expected %s." %
1323
1305
(A .shape [:2 ], (y .size - 1 , x .size - 1 )))
1324
- if A .ndim not in [2 , 3 ]:
1325
- raise ValueError ("A must be 2D or 3D" )
1326
- if A .ndim == 3 :
1327
- if A .shape [2 ] == 1 :
1328
- A = A .squeeze (axis = - 1 )
1329
- elif A .shape [2 ] not in [3 , 4 ]:
1330
- raise ValueError ("3D arrays must have RGB or RGBA as last dim" )
1331
-
1332
1306
# For efficient cursor readout, ensure x and y are increasing.
1333
1307
if x [- 1 ] < x [0 ]:
1334
1308
x = x [::- 1 ]
1335
1309
A = A [:, ::- 1 ]
1336
1310
if y [- 1 ] < y [0 ]:
1337
1311
y = y [::- 1 ]
1338
1312
A = A [::- 1 ]
1339
-
1340
1313
self ._A = A
1341
1314
self ._Ax = x
1342
1315
self ._Ay = y
0 commit comments