@@ -1506,9 +1506,26 @@ class norm_cls(Normalize):
1506
1506
1507
1507
if init is None :
1508
1508
def init (vmin = None , vmax = None , clip = False ): pass
1509
- bound_init_signature = inspect .signature (init )
1509
+
1510
+ return _make_norm_from_scale (
1511
+ scale_cls , base_norm_cls , inspect .signature (init ))
1512
+
1513
+
1514
+ @functools .lru_cache (None )
1515
+ def _make_norm_from_scale (scale_cls , base_norm_cls , bound_init_signature ):
1516
+ """
1517
+ Helper for `make_norm_from_scale`.
1518
+
1519
+ This function is split out so that it takes a signature object as third
1520
+ argument (as signatures are picklable, contrary to arbitrary lambdas);
1521
+ caching is also used so that different unpickles reuse the same class.
1522
+ """
1510
1523
1511
1524
class Norm (base_norm_cls ):
1525
+ def __reduce__ (self ):
1526
+ return (_picklable_norm_constructor ,
1527
+ (scale_cls , base_norm_cls , bound_init_signature ),
1528
+ self .__dict__ )
1512
1529
1513
1530
def __init__ (self , * args , ** kwargs ):
1514
1531
ba = bound_init_signature .bind (* args , ** kwargs )
@@ -1518,6 +1535,10 @@ def __init__(self, *args, **kwargs):
1518
1535
self ._scale = scale_cls (axis = None , ** ba .arguments )
1519
1536
self ._trf = self ._scale .get_transform ()
1520
1537
1538
+ __init__ .__signature__ = bound_init_signature .replace (parameters = [
1539
+ inspect .Parameter ("self" , inspect .Parameter .POSITIONAL_OR_KEYWORD ),
1540
+ * bound_init_signature .parameters .values ()])
1541
+
1521
1542
def __call__ (self , value , clip = None ):
1522
1543
value , is_scalar = self .process_value (value )
1523
1544
self .autoscale_None (value )
@@ -1555,17 +1576,23 @@ def inverse(self, value):
1555
1576
.reshape (np .shape (value )))
1556
1577
return value [0 ] if is_scalar else value
1557
1578
1558
- Norm .__name__ = (f"{ scale_cls .__name__ } Norm" if base_norm_cls is Normalize
1559
- else base_norm_cls .__name__ )
1560
- Norm .__qualname__ = base_norm_cls .__qualname__
1579
+ Norm .__name__ = (
1580
+ f"{ scale_cls .__name__ } Norm" if base_norm_cls is Normalize
1581
+ else base_norm_cls .__name__ )
1582
+ Norm .__qualname__ = (
1583
+ f"{ scale_cls .__qualname__ } Norm" if base_norm_cls is Normalize
1584
+ else base_norm_cls .__qualname__ )
1561
1585
Norm .__module__ = base_norm_cls .__module__
1562
1586
Norm .__doc__ = base_norm_cls .__doc__
1563
- Norm .__init__ .__signature__ = bound_init_signature .replace (parameters = [
1564
- inspect .Parameter ("self" , inspect .Parameter .POSITIONAL_OR_KEYWORD ),
1565
- * bound_init_signature .parameters .values ()])
1587
+
1566
1588
return Norm
1567
1589
1568
1590
1591
+ def _picklable_norm_constructor (* args , ** kwargs ):
1592
+ cls = _make_norm_from_scale (* args , ** kwargs )
1593
+ return cls .__new__ (cls )
1594
+
1595
+
1569
1596
@make_norm_from_scale (
1570
1597
scale .FuncScale ,
1571
1598
init = lambda functions , vmin = None , vmax = None , clip = False : None )
0 commit comments