|
5 | 5 | from matplotlib import verbose, __version__, rcParams |
6 | 6 | from matplotlib.backend_bases import RendererBase, GraphicsContextBase,\ |
7 | 7 | FigureManagerBase, FigureCanvasBase |
| 8 | +from matplotlib.cbook import is_string_like |
8 | 9 | from matplotlib.colors import rgb2hex |
9 | 10 | from matplotlib.figure import Figure |
10 | 11 | from matplotlib.font_manager import findfont, FontProperties |
@@ -458,23 +459,36 @@ class FigureCanvasSVG(FigureCanvasBase): |
458 | 459 | 'svgz': 'Scalable Vector Graphics'} |
459 | 460 |
|
460 | 461 | def print_svg(self, filename, *args, **kwargs): |
461 | | - svgwriter = codecs.open(filename, 'w', 'utf-8') |
462 | | - return self._print_svg(filename, svgwriter) |
463 | | - |
| 462 | + if is_string_like(filename): |
| 463 | + fh_to_close = svgwriter = codecs.open(filename, 'w', 'utf-8') |
| 464 | + elif hasattr(filename, 'write') and callable(filename.write): |
| 465 | + svgwriter = codecs.EncodedFile(filename, 'utf-8') |
| 466 | + fh_to_close = None |
| 467 | + else: |
| 468 | + raise ValueError("filename must be a path or a file-like object") |
| 469 | + return self._print_svg(filename, svgwriter, fh_to_close) |
| 470 | + |
464 | 471 | def print_svgz(self, filename, *args, **kwargs): |
465 | | - gzipwriter = gzip.GzipFile(filename, 'w') |
466 | | - svgwriter = codecs.EncodedFile(gzipwriter, 'utf-8') |
467 | | - return self._print_svg(filename, svgwriter) |
| 472 | + if is_string_like(filename): |
| 473 | + gzipwriter = gzip.GzipFile(filename, 'w') |
| 474 | + fh_to_close = svgwriter = codecs.EncodedFile(gzipwriter, 'utf-8') |
| 475 | + elif hasattr(filename, 'write') and callable(filename.write): |
| 476 | + fh_to_close = gzipwriter = gzip.GzipFile(fileobj=filename, mode='w') |
| 477 | + svgwriter = codecs.EncodedFile(gzipwriter, 'utf-8') |
| 478 | + else: |
| 479 | + raise ValueError("filename must be a path or a file-like object") |
| 480 | + return self._print_svg(filename, svgwriter, fh_to_close) |
468 | 481 |
|
469 | | - def _print_svg(self, filename, svgwriter): |
| 482 | + def _print_svg(self, filename, svgwriter, fh_to_close=None): |
470 | 483 | self.figure.dpi.set(72) |
471 | 484 | width, height = self.figure.get_size_inches() |
472 | 485 | w, h = width*72, height*72 |
473 | 486 |
|
474 | 487 | renderer = RendererSVG(w, h, svgwriter, filename) |
475 | 488 | self.figure.draw(renderer) |
476 | 489 | renderer.finish() |
477 | | - svgwriter.close() |
| 490 | + if fh_to_close is not None: |
| 491 | + svgwriter.close() |
478 | 492 |
|
479 | 493 | def get_default_filetype(self): |
480 | 494 | return 'svg' |
|
0 commit comments