@@ -356,24 +356,28 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0,
356356 if not unsampled :
357357 created_rgba_mask = False
358358
359+ if A .ndim not in (2 , 3 ):
360+ raise ValueError ("Invalid dimensions, got %s" % (A .shape ,))
361+
359362 if A .ndim == 2 :
360363 A = self .norm (A )
361- # If the image is greyscale, convert to RGBA with the
362- # correct alpha channel for resizing
363- rgba = np .empty ((A .shape [0 ], A .shape [1 ], 4 ), dtype = A .dtype )
364- rgba [..., 0 ] = A # normalized data
365- rgba [..., 1 ] = A < 0 # under data
366- # TODO, ask the norm or colormap what this threshold should be
367- rgba [..., 2 ] = A > 1 # over data
368364 if A .dtype .kind == 'f' :
365+ # If the image is greyscale, convert to RGBA with the
366+ # correct alpha channel for resizing
367+ rgba = np .empty ((A .shape [0 ], A .shape [1 ], 4 ), dtype = A .dtype )
368+ rgba [..., 0 ] = A # normalized data
369+ rgba [..., 1 ] = A < 0 # under data
370+ # TODO, ask the norm or colormap what this threshold should be
371+ rgba [..., 2 ] = A > 1 # over data
369372 rgba [..., 3 ] = ~ A .mask
373+ A = rgba
374+ output = np .zeros ((out_height , out_width , 4 ), dtype = A .dtype )
375+ alpha = 1.0
376+ created_rgba_mask = True
370377 else :
371- rgba [..., 3 ] = np .where (A .mask , 0 , np .iinfo (A .dtype ).max )
372- A = rgba
373- output = np .zeros ((out_height , out_width , 4 ), dtype = A .dtype )
374- alpha = 1.0
375- created_rgba_mask = True
376- elif A .ndim == 3 :
378+ A = self .cmap (A , alpha = self .get_alpha (), bytes = True )
379+
380+ if not created_rgba_mask :
377381 # Always convert to RGBA, even if only RGB input
378382 if A .shape [2 ] == 3 :
379383 A = _rgb_to_rgba (A )
@@ -385,8 +389,6 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0,
385389 alpha = self .get_alpha ()
386390 if alpha is None :
387391 alpha = 1.0
388- else :
389- raise ValueError ("Invalid dimensions, got %s" % (A .shape ,))
390392
391393 _image .resample (
392394 A , output , t , _interpd_ [self .get_interpolation ()],
0 commit comments