@@ -127,21 +127,140 @@ def changed(self):
127127 def make_image (self , magnification = 1.0 ):
128128 raise RuntimeError ('The make_image method must be overridden.' )
129129
130+
131+ def _get_unsampled_image (self , A , image_extents , viewlim ):
132+ """
133+ convert numpy array A with given extents ([x1, x2, y1, y2] in
134+ data coordinate) into the Image, given the vielim (should be a
135+ bbox instance). Image will be clipped if the extents is
136+ significantly larger than the viewlim.
137+ """
138+ xmin , xmax , ymin , ymax = image_extents
139+ dxintv = xmax - xmin
140+ dyintv = ymax - ymin
141+
142+ # the viewport scale factor
143+ sx = dxintv / viewlim .width
144+ sy = dyintv / viewlim .height
145+ numrows , numcols = A .shape [:2 ]
146+ if sx > 2 :
147+ x0 = (viewim .x0 - xmin )/ dxintv * numcols
148+ ix0 = max (0 , int (x0 - self ._filterrad ))
149+ x1 = (viewlim .x1 - xmin )/ dxintv * numcols
150+ ix1 = min (numcols , int (x1 + self ._filterrad ))
151+ xslice = slice (ix0 , ix1 )
152+ xmin_old = xmin
153+ xmin = xmin_old + ix0 * dxintv / numcols
154+ xmax = xmin_old + ix1 * dxintv / numcols
155+ dxintv = xmax - xmin
156+ sx = dxintv / viewlim .width
157+ else :
158+ xslice = slice (0 , numcols )
159+
160+ if sy > 2 :
161+ y0 = (viewlim .y0 - ymin )/ dyintv * numrows
162+ iy0 = max (0 , int (y0 - self ._filterrad ))
163+ y1 = (viewlim .y1 - ymin )/ dyintv * numrows
164+ iy1 = min (numrows , int (y1 + self ._filterrad ))
165+ if self .origin == 'upper' :
166+ yslice = slice (numrows - iy1 , numrows - iy0 )
167+ else :
168+ yslice = slice (iy0 , iy1 )
169+ ymin_old = ymin
170+ ymin = ymin_old + iy0 * dyintv / numrows
171+ ymax = ymin_old + iy1 * dyintv / numrows
172+ dyintv = ymax - ymin
173+ sy = dyintv / self .axes .viewLim .height
174+ else :
175+ yslice = slice (0 , numrows )
176+
177+ if xslice != self ._oldxslice or yslice != self ._oldyslice :
178+ self ._imcache = None
179+ self ._oldxslice = xslice
180+ self ._oldyslice = yslice
181+
182+ if self ._imcache is None :
183+ if self ._A .dtype == np .uint8 and len (self ._A .shape ) == 3 :
184+ im = _image .frombyte (self ._A [yslice ,xslice ,:], 0 )
185+ im .is_grayscale = False
186+ else :
187+ if self ._rgbacache is None :
188+ x = self .to_rgba (self ._A , self ._alpha )
189+ self ._rgbacache = x
190+ else :
191+ x = self ._rgbacache
192+ im = _image .fromarray (x [yslice ,xslice ], 0 )
193+ if len (self ._A .shape ) == 2 :
194+ im .is_grayscale = self .cmap .is_gray ()
195+ else :
196+ im .is_grayscale = False
197+ self ._imcache = im
198+
199+ if self .origin == 'upper' :
200+ im .flipud_in ()
201+ else :
202+ im = self ._imcache
203+
204+ return im , xmin , ymin , dxintv , dyintv , sx , sy
205+
206+
207+ def _draw_unsampled_image (self , renderer , gc ):
208+ """
209+ draw unsampled image. The renderer should support a draw_image method
210+ with scale parameter.
211+ """
212+ im , xmin , ymin , dxintv , dyintv , sx , sy = \
213+ self ._get_unsampled_image (self ._A , self .get_extent (), self .axes .viewLim )
214+
215+ if im is None : return # I'm not if this check is required. -JJL
216+
217+ transData = self .axes .transData
218+ xx1 , yy1 = transData .transform_point ((xmin , ymin ))
219+ xx2 , yy2 = transData .transform_point ((xmin + dxintv , ymin + dyintv ))
220+
221+ fc = self .axes .patch .get_facecolor ()
222+ bg = mcolors .colorConverter .to_rgba (fc , 0 )
223+ im .set_bg ( * bg )
224+
225+ # image input dimensions
226+ im .reset_matrix ()
227+ numrows , numcols = im .get_size ()
228+
229+ im .resize (numcols , numrows ) # just to create im.bufOut that is required by backends. There may be better solution -JJL
230+
231+ sx = (xx2 - xx1 )/ numcols
232+ sy = (yy2 - yy1 )/ numrows
233+ im ._url = self .get_url ()
234+ renderer .draw_image (gc , xx1 , yy1 , im , sx , sy )
235+
236+
237+ def _check_unsampled_image (self , renderer ):
238+ """
239+ return True if the image is better to be drawn unsampled.
240+ The derived class needs to override it.
241+ """
242+ return False
243+
130244 @allow_rasterization
131245 def draw (self , renderer , * args , ** kwargs ):
132246 if not self .get_visible (): return
133247 if (self .axes .get_xscale () != 'linear' or
134248 self .axes .get_yscale () != 'linear' ):
135249 warnings .warn ("Images are not supported on non-linear axes." )
136- im = self .make_image (renderer .get_image_magnification ())
137- if im is None :
138- return
139- im ._url = self .get_url ()
250+
140251 l , b , widthDisplay , heightDisplay = self .axes .bbox .bounds
141252 gc = renderer .new_gc ()
142253 gc .set_clip_rectangle (self .axes .bbox .frozen ())
143254 gc .set_clip_path (self .get_clip_path ())
144- renderer .draw_image (gc , l , b , im )
255+
256+ if self ._check_unsampled_image (renderer ):
257+ self ._draw_unsampled_image (renderer , gc )
258+ else :
259+ im = self .make_image (renderer .get_image_magnification ())
260+ if im is None :
261+ return
262+ im ._url = self .get_url ()
263+ renderer .draw_image (gc , l , b , im )
145264 gc .restore ()
146265
147266 def contains (self , mouseevent ):
@@ -338,71 +457,8 @@ def make_image(self, magnification=1.0):
338457 if self ._A is None :
339458 raise RuntimeError ('You must first set the image array or the image attribute' )
340459
341- xmin , xmax , ymin , ymax = self .get_extent ()
342- dxintv = xmax - xmin
343- dyintv = ymax - ymin
344-
345- # the viewport scale factor
346- sx = dxintv / self .axes .viewLim .width
347- sy = dyintv / self .axes .viewLim .height
348- numrows , numcols = self ._A .shape [:2 ]
349- if sx > 2 :
350- x0 = (self .axes .viewLim .x0 - xmin )/ dxintv * numcols
351- ix0 = max (0 , int (x0 - self ._filterrad ))
352- x1 = (self .axes .viewLim .x1 - xmin )/ dxintv * numcols
353- ix1 = min (numcols , int (x1 + self ._filterrad ))
354- xslice = slice (ix0 , ix1 )
355- xmin_old = xmin
356- xmin = xmin_old + ix0 * dxintv / numcols
357- xmax = xmin_old + ix1 * dxintv / numcols
358- dxintv = xmax - xmin
359- sx = dxintv / self .axes .viewLim .width
360- else :
361- xslice = slice (0 , numcols )
362-
363- if sy > 2 :
364- y0 = (self .axes .viewLim .y0 - ymin )/ dyintv * numrows
365- iy0 = max (0 , int (y0 - self ._filterrad ))
366- y1 = (self .axes .viewLim .y1 - ymin )/ dyintv * numrows
367- iy1 = min (numrows , int (y1 + self ._filterrad ))
368- if self .origin == 'upper' :
369- yslice = slice (numrows - iy1 , numrows - iy0 )
370- else :
371- yslice = slice (iy0 , iy1 )
372- ymin_old = ymin
373- ymin = ymin_old + iy0 * dyintv / numrows
374- ymax = ymin_old + iy1 * dyintv / numrows
375- dyintv = ymax - ymin
376- sy = dyintv / self .axes .viewLim .height
377- else :
378- yslice = slice (0 , numrows )
379-
380- if xslice != self ._oldxslice or yslice != self ._oldyslice :
381- self ._imcache = None
382- self ._oldxslice = xslice
383- self ._oldyslice = yslice
384-
385- if self ._imcache is None :
386- if self ._A .dtype == np .uint8 and len (self ._A .shape ) == 3 :
387- im = _image .frombyte (self ._A [yslice ,xslice ,:], 0 )
388- im .is_grayscale = False
389- else :
390- if self ._rgbacache is None :
391- x = self .to_rgba (self ._A , self ._alpha )
392- self ._rgbacache = x
393- else :
394- x = self ._rgbacache
395- im = _image .fromarray (x [yslice ,xslice ], 0 )
396- if len (self ._A .shape ) == 2 :
397- im .is_grayscale = self .cmap .is_gray ()
398- else :
399- im .is_grayscale = False
400- self ._imcache = im
401-
402- if self .origin == 'upper' :
403- im .flipud_in ()
404- else :
405- im = self ._imcache
460+ im , xmin , ymin , dxintv , dyintv , sx , sy = \
461+ self ._get_unsampled_image (self ._A , self .get_extent (), self .axes .viewLim )
406462
407463 fc = self .axes .patch .get_facecolor ()
408464 bg = mcolors .colorConverter .to_rgba (fc , 0 )
@@ -435,6 +491,15 @@ def make_image(self, magnification=1.0):
435491 return im
436492
437493
494+ def _check_unsampled_image (self , renderer ):
495+ """
496+ return True if the image is better to be drawn unsampled.
497+ """
498+ if renderer .option_scale_image () and self .get_interpolation () == "nearest" :
499+ return True
500+ else :
501+ return False
502+
438503 def set_extent (self , extent ):
439504 """
440505 extent is data axes (left, right, bottom, top) for making image plots
0 commit comments