@@ -313,38 +313,45 @@ def load_frame(self, frame_size):
313313
314314# Tools used for pickling.
315315
316- def _getattribute (obj , name ):
317- top = obj
318- for subpath in name .split ('.' ):
319- if subpath == '<locals>' :
320- raise AttributeError ("Can't get local attribute {!r} on {!r}"
321- .format (name , top ))
322- try :
323- parent = obj
324- obj = getattr (obj , subpath )
325- except AttributeError :
326- raise AttributeError ("Can't get attribute {!r} on {!r}"
327- .format (name , top )) from None
328- return obj , parent
316+ def _getattribute (obj , dotted_path ):
317+ for subpath in dotted_path :
318+ obj = getattr (obj , subpath )
319+ return obj
329320
330321def whichmodule (obj , name ):
331322 """Find the module an object belong to."""
323+ dotted_path = name .split ('.' )
332324 module_name = getattr (obj , '__module__' , None )
333- if module_name is not None :
334- return module_name
335- # Protect the iteration by using a list copy of sys.modules against dynamic
336- # modules that trigger imports of other modules upon calls to getattr.
337- for module_name , module in sys .modules .copy ().items ():
338- if (module_name == '__main__'
339- or module_name == '__mp_main__' # bpo-42406
340- or module is None ):
341- continue
342- try :
343- if _getattribute (module , name )[0 ] is obj :
344- return module_name
345- except AttributeError :
346- pass
347- return '__main__'
325+ if module_name is None and '<locals>' not in dotted_path :
326+ # Protect the iteration by using a list copy of sys.modules against dynamic
327+ # modules that trigger imports of other modules upon calls to getattr.
328+ for module_name , module in sys .modules .copy ().items ():
329+ if (module_name == '__main__'
330+ or module_name == '__mp_main__' # bpo-42406
331+ or module is None ):
332+ continue
333+ try :
334+ if _getattribute (module , dotted_path ) is obj :
335+ return module_name
336+ except AttributeError :
337+ pass
338+ module_name = '__main__'
339+ elif module_name is None :
340+ module_name = '__main__'
341+
342+ try :
343+ __import__ (module_name , level = 0 )
344+ module = sys .modules [module_name ]
345+ if _getattribute (module , dotted_path ) is obj :
346+ return module_name
347+ except (ImportError , KeyError , AttributeError ):
348+ raise PicklingError (
349+ "Can't pickle %r: it's not found as %s.%s" %
350+ (obj , module_name , name )) from None
351+
352+ raise PicklingError (
353+ "Can't pickle %r: it's not the same object as %s.%s" %
354+ (obj , module_name , name ))
348355
349356def encode_long (x ):
350357 r"""Encode a long to a two's complement little-endian binary string.
@@ -1074,24 +1081,10 @@ def save_global(self, obj, name=None):
10741081
10751082 if name is None :
10761083 name = getattr (obj , '__qualname__' , None )
1077- if name is None :
1078- name = obj .__name__
1084+ if name is None :
1085+ name = obj .__name__
10791086
10801087 module_name = whichmodule (obj , name )
1081- try :
1082- __import__ (module_name , level = 0 )
1083- module = sys .modules [module_name ]
1084- obj2 , parent = _getattribute (module , name )
1085- except (ImportError , KeyError , AttributeError ):
1086- raise PicklingError (
1087- "Can't pickle %r: it's not found as %s.%s" %
1088- (obj , module_name , name )) from None
1089- else :
1090- if obj2 is not obj :
1091- raise PicklingError (
1092- "Can't pickle %r: it's not the same object as %s.%s" %
1093- (obj , module_name , name ))
1094-
10951088 if self .proto >= 2 :
10961089 code = _extension_registry .get ((module_name , name ))
10971090 if code :
@@ -1103,10 +1096,7 @@ def save_global(self, obj, name=None):
11031096 else :
11041097 write (EXT4 + pack ("<i" , code ))
11051098 return
1106- lastname = name .rpartition ('.' )[2 ]
1107- if parent is module :
1108- name = lastname
1109- # Non-ASCII identifiers are supported only with protocols >= 3.
1099+
11101100 if self .proto >= 4 :
11111101 self .save (module_name )
11121102 self .save (name )
@@ -1616,7 +1606,16 @@ def find_class(self, module, name):
16161606 module = _compat_pickle .IMPORT_MAPPING [module ]
16171607 __import__ (module , level = 0 )
16181608 if self .proto >= 4 :
1619- return _getattribute (sys .modules [module ], name )[0 ]
1609+ module = sys .modules [module ]
1610+ dotted_path = name .split ('.' )
1611+ if '<locals>' in dotted_path :
1612+ raise AttributeError (
1613+ f"Can't get local attribute { name !r} on { module !r} " )
1614+ try :
1615+ return _getattribute (module , dotted_path )
1616+ except AttributeError :
1617+ raise AttributeError (
1618+ f"Can't get attribute { name !r} on { module !r} " ) from None
16201619 else :
16211620 return getattr (sys .modules [module ], name )
16221621
0 commit comments