From 80d0959e4131f695fb94e814ed73b5339c98bb23 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Sat, 9 Apr 2022 23:12:33 +0200 Subject: [PATCH] Fix pickling of globally available, dynamically generated norm classes. --- lib/matplotlib/colors.py | 23 ++++++++++++++++++++--- lib/matplotlib/tests/test_pickle.py | 5 +++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index ec206b926a4d..2e519149527f 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -43,6 +43,7 @@ from collections.abc import Sized, Sequence import copy import functools +import importlib import inspect import io import itertools @@ -1528,9 +1529,22 @@ def _make_norm_from_scale(scale_cls, base_norm_cls, bound_init_signature): class Norm(base_norm_cls): def __reduce__(self): + cls = type(self) + # If the class is toplevel-accessible, it is possible to directly + # pickle it "by name". This is required to support norm classes + # defined at a module's toplevel, as the inner base_norm_cls is + # otherwise unpicklable (as it gets shadowed by the generated norm + # class). If either import or attribute access fails, fall back to + # the general path. + try: + if cls is getattr(importlib.import_module(cls.__module__), + cls.__qualname__): + return (_create_empty_object_of_class, (cls,), vars(self)) + except (ImportError, AttributeError): + pass return (_picklable_norm_constructor, (scale_cls, base_norm_cls, bound_init_signature), - self.__dict__) + vars(self)) def __init__(self, *args, **kwargs): ba = bound_init_signature.bind(*args, **kwargs) @@ -1603,11 +1617,14 @@ def autoscale_None(self, A): return Norm -def _picklable_norm_constructor(*args): - cls = _make_norm_from_scale(*args) +def _create_empty_object_of_class(cls): return cls.__new__(cls) +def _picklable_norm_constructor(*args): + return _create_empty_object_of_class(_make_norm_from_scale(*args)) + + @make_norm_from_scale( scale.FuncScale, init=lambda functions, vmin=None, vmax=None, clip=False: None) diff --git a/lib/matplotlib/tests/test_pickle.py b/lib/matplotlib/tests/test_pickle.py index 7cd23ea5c0eb..f4e35fb19b87 100644 --- a/lib/matplotlib/tests/test_pickle.py +++ b/lib/matplotlib/tests/test_pickle.py @@ -221,6 +221,11 @@ def test_mpl_toolkits(): assert type(pickle.loads(pickle.dumps(ax))) == parasite_axes.HostAxes +def test_standard_norm(): + assert type(pickle.loads(pickle.dumps(mpl.colors.LogNorm()))) \ + == mpl.colors.LogNorm + + def test_dynamic_norm(): logit_norm_instance = mpl.colors.make_norm_from_scale( mpl.scale.LogitScale, mpl.colors.Normalize)()