diff --git a/docs/source/api/layouts/figure.rst b/docs/source/api/layouts/figure.rst index a2d5e5758..817284e18 100644 --- a/docs/source/api/layouts/figure.rst +++ b/docs/source/api/layouts/figure.rst @@ -37,6 +37,7 @@ Methods Figure.add_animations Figure.clear Figure.close + Figure.export Figure.remove_animation Figure.render Figure.show diff --git a/fastplotlib/layouts/_figure.py b/fastplotlib/layouts/_figure.py index 17bb28095..d330c6928 100644 --- a/fastplotlib/layouts/_figure.py +++ b/fastplotlib/layouts/_figure.py @@ -128,6 +128,7 @@ def __init__( # if controller instances have been specified for each subplot if controllers is not None: + # one controller for all subplots if isinstance(controllers, pygfx.Controller): controllers = [controllers] * len(self) @@ -579,6 +580,52 @@ def clear(self): for subplot in self: subplot.clear() + def export(self, uri: str | Path | bytes, **kwargs): + """ + Use ``imageio`` for writing the current Figure to a file, or return a byte string. + Must have ``imageio`` installed. + + Parameters + ---------- + uri: str | Path | bytes + + kwargs: passed to imageio.v3.imwrite, see: https://imageio.readthedocs.io/en/stable/_autosummary/imageio.v3.imwrite.html + + Returns + ------- + None | bytes + see https://imageio.readthedocs.io/en/stable/_autosummary/imageio.v3.imwrite.html + """ + try: + import imageio.v3 as iio + except ModuleNotFoundError: + raise ImportError( + "imageio is required to use Figure.export(). Install it using pip or conda:\n" + "pip install imageio\n" + "conda install -c conda-forge imageio\n" + ) + else: + snapshot = self.renderer.snapshot() + remove_alpha = True + + # image formats that support alpha channel: + # https://en.wikipedia.org/wiki/Alpha_compositing#Image_formats_supporting_alpha_channels + alpha_support = [".png", ".exr", ".tiff", ".tif", ".gif", ".jxl", ".svg"] + + if isinstance(uri, str): + if any([uri.endswith(ext) for ext in alpha_support]): + remove_alpha = False + + elif isinstance(uri, Path): + if uri.suffix in alpha_support: + remove_alpha = False + + if remove_alpha: + # remove alpha channel if it's not supported + snapshot = snapshot[..., :-1].shape + + return iio.imwrite(uri, snapshot, **kwargs) + def _get_iterator(self): return product(range(self.shape[0]), range(self.shape[1]))