@@ -688,50 +688,50 @@ def write_png(self, fname):
688
688
bytes = True , norm = True )
689
689
PIL .Image .fromarray (im ).save (fname , format = "png" )
690
690
691
- def set_data (self , A ):
691
+ @staticmethod
692
+ def _normalize_image_array (A ):
692
693
"""
693
- Set the image array.
694
-
695
- Note that this function does *not* update the normalization used.
696
-
697
- Parameters
698
- ----------
699
- A : array-like or `PIL.Image.Image`
694
+ Check validity of image-like input *A* and normalize it to a format suitable for
695
+ Image subclasses.
700
696
"""
701
- if isinstance (A , PIL .Image .Image ):
702
- A = pil_to_array (A ) # Needed e.g. to apply png palette.
703
- self ._A = cbook .safe_masked_invalid (A , copy = True )
704
-
705
- if (self ._A .dtype != np .uint8 and
706
- not np .can_cast (self ._A .dtype , float , "same_kind" )):
707
- raise TypeError (f"Image data of dtype { self ._A .dtype } cannot be "
708
- "converted to float" )
709
-
710
- if self ._A .ndim == 3 and self ._A .shape [- 1 ] == 1 :
711
- # If just one dimension assume scalar and apply colormap
712
- self ._A = self ._A [:, :, 0 ]
713
-
714
- if not (self ._A .ndim == 2
715
- or self ._A .ndim == 3 and self ._A .shape [- 1 ] in [3 , 4 ]):
716
- raise TypeError (f"Invalid shape { self ._A .shape } for image data" )
717
-
718
- if self ._A .ndim == 3 :
697
+ A = cbook .safe_masked_invalid (A , copy = True )
698
+ if A .dtype != np .uint8 and not np .can_cast (A .dtype , float , "same_kind" ):
699
+ raise TypeError (f"Image data of dtype { A .dtype } cannot be "
700
+ f"converted to float" )
701
+ if A .ndim == 3 and A .shape [- 1 ] == 1 :
702
+ A = A .squeeze (- 1 ) # If just (M, N, 1), assume scalar and apply colormap.
703
+ if not (A .ndim == 2 or A .ndim == 3 and A .shape [- 1 ] in [3 , 4 ]):
704
+ raise TypeError (f"Invalid shape { A .shape } for image data" )
705
+ if A .ndim == 3 :
719
706
# If the input data has values outside the valid range (after
720
707
# normalisation), we issue a warning and then clip X to the bounds
721
708
# - otherwise casting wraps extreme values, hiding outliers and
722
709
# making reliable interpretation impossible.
723
- high = 255 if np .issubdtype (self . _A .dtype , np .integer ) else 1
724
- if self . _A . min () < 0 or high < self . _A .max ():
710
+ high = 255 if np .issubdtype (A .dtype , np .integer ) else 1
711
+ if A . min () < 0 or high < A .max ():
725
712
_log .warning (
726
713
'Clipping input data to the valid range for imshow with '
727
714
'RGB data ([0..1] for floats or [0..255] for integers).'
728
715
)
729
- self . _A = np .clip (self . _A , 0 , high )
716
+ A = np .clip (A , 0 , high )
730
717
# Cast unsupported integer types to uint8
731
- if self . _A . dtype != np .uint8 and np .issubdtype (self . _A . dtype ,
732
- np .integer ):
733
- self . _A = self . _A . astype ( np . uint8 )
718
+ if A . dtype != np .uint8 and np .issubdtype (A . dtype , np . integer ):
719
+ A = A . astype ( np .uint8 )
720
+ return A
734
721
722
+ def set_data (self , A ):
723
+ """
724
+ Set the image array.
725
+
726
+ Note that this function does *not* update the normalization used.
727
+
728
+ Parameters
729
+ ----------
730
+ A : array-like or `PIL.Image.Image`
731
+ """
732
+ if isinstance (A , PIL .Image .Image ):
733
+ A = pil_to_array (A ) # Needed e.g. to apply png palette.
734
+ self ._A = self ._normalize_image_array (A )
735
735
self ._imcache = None
736
736
self .stale = True
737
737
@@ -1149,23 +1149,15 @@ def set_data(self, x, y, A):
1149
1149
(M, N) `~numpy.ndarray` or masked array of values to be
1150
1150
colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array.
1151
1151
"""
1152
+ A = self ._normalize_image_array (A )
1152
1153
x = np .array (x , np .float32 )
1153
1154
y = np .array (y , np .float32 )
1154
- A = cbook .safe_masked_invalid (A , copy = True )
1155
- if not (x .ndim == y .ndim == 1 and A .shape [0 :2 ] == y .shape + x .shape ):
1155
+ if not (x .ndim == y .ndim == 1 and A .shape [:2 ] == y .shape + x .shape ):
1156
1156
raise TypeError ("Axes don't match array shape" )
1157
- if A .ndim not in [2 , 3 ]:
1158
- raise TypeError ("Can only plot 2D or 3D data" )
1159
- if A .ndim == 3 and A .shape [2 ] not in [1 , 3 , 4 ]:
1160
- raise TypeError ("3D arrays must have three (RGB) "
1161
- "or four (RGBA) color components" )
1162
- if A .ndim == 3 and A .shape [2 ] == 1 :
1163
- A = A .squeeze (axis = - 1 )
1164
1157
self ._A = A
1165
1158
self ._Ax = x
1166
1159
self ._Ay = y
1167
1160
self ._imcache = None
1168
-
1169
1161
self .stale = True
1170
1162
1171
1163
def set_array (self , * args ):
@@ -1307,36 +1299,20 @@ def set_data(self, x, y, A):
1307
1299
- (M, N, 3): RGB array
1308
1300
- (M, N, 4): RGBA array
1309
1301
"""
1310
- A = cbook .safe_masked_invalid (A , copy = True )
1311
- if x is None :
1312
- x = np .arange (0 , A .shape [1 ]+ 1 , dtype = np .float64 )
1313
- else :
1314
- x = np .array (x , np .float64 ).ravel ()
1315
- if y is None :
1316
- y = np .arange (0 , A .shape [0 ]+ 1 , dtype = np .float64 )
1317
- else :
1318
- y = np .array (y , np .float64 ).ravel ()
1319
-
1320
- if A .shape [:2 ] != (y .size - 1 , x .size - 1 ):
1302
+ A = self ._normalize_image_array (A )
1303
+ x = np .arange (0. , A .shape [1 ] + 1 ) if x is None else np .array (x , float ).ravel ()
1304
+ y = np .arange (0. , A .shape [0 ] + 1 ) if y is None else np .array (y , float ).ravel ()
1305
+ if A .shape [:2 ] != (y .size - 1 , x .size - 1 ):
1321
1306
raise ValueError (
1322
1307
"Axes don't match array shape. Got %s, expected %s." %
1323
1308
(A .shape [:2 ], (y .size - 1 , x .size - 1 )))
1324
- if A .ndim not in [2 , 3 ]:
1325
- raise ValueError ("A must be 2D or 3D" )
1326
- if A .ndim == 3 :
1327
- if A .shape [2 ] == 1 :
1328
- A = A .squeeze (axis = - 1 )
1329
- elif A .shape [2 ] not in [3 , 4 ]:
1330
- raise ValueError ("3D arrays must have RGB or RGBA as last dim" )
1331
-
1332
1309
# For efficient cursor readout, ensure x and y are increasing.
1333
1310
if x [- 1 ] < x [0 ]:
1334
1311
x = x [::- 1 ]
1335
1312
A = A [:, ::- 1 ]
1336
1313
if y [- 1 ] < y [0 ]:
1337
1314
y = y [::- 1 ]
1338
1315
A = A [::- 1 ]
1339
-
1340
1316
self ._A = A
1341
1317
self ._Ax = x
1342
1318
self ._Ay = y
0 commit comments