3434import re
3535import io
3636import codecs
37+ import _compat_pickle
3738
3839__all__ = ["PickleError" , "PicklingError" , "UnpicklingError" , "Pickler" ,
3940 "Unpickler" , "dump" , "dumps" , "load" , "loads" ]
@@ -171,12 +172,11 @@ def __init__(self, value):
171172
172173__all__ .extend ([x for x in dir () if re .match ("[A-Z][A-Z0-9_]+$" ,x )])
173174
174-
175175# Pickling machinery
176176
177177class _Pickler :
178178
179- def __init__ (self , file , protocol = None ):
179+ def __init__ (self , file , protocol = None , * , fix_imports = True ):
180180 """This takes a binary file for writing a pickle data stream.
181181
182182 The optional protocol argument tells the pickler to use the
@@ -193,6 +193,10 @@ def __init__(self, file, protocol=None):
193193 bytes argument. It can thus be a file object opened for binary
194194 writing, a io.BytesIO instance, or any other custom object that
195195 meets this interface.
196+
197+ If fix_imports is True and protocol is less than 3, pickle will try to
198+ map the new Python 3.x names to the old module names used in Python
199+ 2.x, so that the pickle data stream is readable with Python 2.x.
196200 """
197201 if protocol is None :
198202 protocol = DEFAULT_PROTOCOL
@@ -208,6 +212,7 @@ def __init__(self, file, protocol=None):
208212 self .proto = int (protocol )
209213 self .bin = protocol >= 1
210214 self .fast = 0
215+ self .fix_imports = fix_imports and protocol < 3
211216
212217 def clear_memo (self ):
213218 """Clears the pickler's "memo".
@@ -698,6 +703,11 @@ def save_global(self, obj, name=None, pack=struct.pack):
698703 write (GLOBAL + bytes (module , "utf-8" ) + b'\n ' +
699704 bytes (name , "utf-8" ) + b'\n ' )
700705 else :
706+ if self .fix_imports :
707+ if (module , name ) in _compat_pickle .REVERSE_NAME_MAPPING :
708+ module , name = _compat_pickle .REVERSE_NAME_MAPPING [(module , name )]
709+ if module in _compat_pickle .REVERSE_IMPORT_MAPPING :
710+ module = _compat_pickle .REVERSE_IMPORT_MAPPING [module ]
701711 try :
702712 write (GLOBAL + bytes (module , "ascii" ) + b'\n ' +
703713 bytes (name , "ascii" ) + b'\n ' )
@@ -766,7 +776,8 @@ def whichmodule(func, funcname):
766776
767777class _Unpickler :
768778
769- def __init__ (self , file , * , encoding = "ASCII" , errors = "strict" ):
779+ def __init__ (self , file , * , fix_imports = True ,
780+ encoding = "ASCII" , errors = "strict" ):
770781 """This takes a binary file for reading a pickle data stream.
771782
772783 The protocol version of the pickle is detected automatically, so no
@@ -779,15 +790,21 @@ def __init__(self, file, *, encoding="ASCII", errors="strict"):
779790 reading, a BytesIO object, or any other custom object that
780791 meets this interface.
781792
782- Optional keyword arguments are encoding and errors, which are
783- used to decode 8-bit string instances pickled by Python 2.x.
784- These default to 'ASCII' and 'strict', respectively.
793+ Optional keyword arguments are *fix_imports*, *encoding* and *errors*,
794+ which are used to control compatiblity support for pickle stream
795+ generated by Python 2.x. If *fix_imports* is True, pickle will try to
796+ map the old Python 2.x names to the new names used in Python 3.x. The
797+ *encoding* and *errors* tell pickle how to decode 8-bit string
798+ instances pickled by Python 2.x; these default to 'ASCII' and
799+ 'strict', respectively.
785800 """
786801 self .readline = file .readline
787802 self .read = file .read
788803 self .memo = {}
789804 self .encoding = encoding
790805 self .errors = errors
806+ self .proto = 0
807+ self .fix_imports = fix_imports
791808
792809 def load (self ):
793810 """Read a pickled object representation from the open file.
@@ -838,6 +855,7 @@ def load_proto(self):
838855 proto = ord (self .read (1 ))
839856 if not 0 <= proto <= HIGHEST_PROTOCOL :
840857 raise ValueError ("unsupported pickle protocol: %d" % proto )
858+ self .proto = proto
841859 dispatch [PROTO [0 ]] = load_proto
842860
843861 def load_persid (self ):
@@ -1088,7 +1106,12 @@ def get_extension(self, code):
10881106 self .append (obj )
10891107
10901108 def find_class (self , module , name ):
1091- # Subclasses may override this
1109+ # Subclasses may override this.
1110+ if self .proto < 3 and self .fix_imports :
1111+ if (module , name ) in _compat_pickle .NAME_MAPPING :
1112+ module , name = _compat_pickle .NAME_MAPPING [(module , name )]
1113+ if module in _compat_pickle .IMPORT_MAPPING :
1114+ module = _compat_pickle .IMPORT_MAPPING [module ]
10921115 __import__ (module , level = 0 )
10931116 mod = sys .modules [module ]
10941117 klass = getattr (mod , name )
@@ -1327,27 +1350,28 @@ def decode_long(data):
13271350
13281351# Shorthands
13291352
1330- def dump (obj , file , protocol = None ):
1331- Pickler (file , protocol ).dump (obj )
1353+ def dump (obj , file , protocol = None , * , fix_imports = True ):
1354+ Pickler (file , protocol , fix_imports = fix_imports ).dump (obj )
13321355
1333- def dumps (obj , protocol = None ):
1356+ def dumps (obj , protocol = None , * , fix_imports = True ):
13341357 f = io .BytesIO ()
1335- Pickler (f , protocol ).dump (obj )
1358+ Pickler (f , protocol , fix_imports = fix_imports ).dump (obj )
13361359 res = f .getvalue ()
13371360 assert isinstance (res , bytes_types )
13381361 return res
13391362
1340- def load (file , * , encoding = "ASCII" , errors = "strict" ):
1341- return Unpickler (file , encoding = encoding , errors = errors ).load ()
1363+ def load (file , * , fix_imports = True , encoding = "ASCII" , errors = "strict" ):
1364+ return Unpickler (file , fix_imports = fix_imports ,
1365+ encoding = encoding , errors = errors ).load ()
13421366
1343- def loads (s , * , encoding = "ASCII" , errors = "strict" ):
1367+ def loads (s , * , fix_imports = True , encoding = "ASCII" , errors = "strict" ):
13441368 if isinstance (s , str ):
13451369 raise TypeError ("Can't load pickle from unicode string" )
13461370 file = io .BytesIO (s )
1347- return Unpickler (file , encoding = encoding , errors = errors ).load ()
1371+ return Unpickler (file , fix_imports = fix_imports ,
1372+ encoding = encoding , errors = errors ).load ()
13481373
13491374# Doctest
1350-
13511375def _test ():
13521376 import doctest
13531377 return doctest .testmod ()
0 commit comments