diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 61068b6b..63402fe0 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -102,6 +102,7 @@ builtin_code_type = type(float.__new__.__code__) _extract_code_globals_cache = weakref.WeakKeyDictionary() +_extract_candidate_submodule_names_cache = weakref.WeakKeyDictionary() def _get_or_create_tracker_id(class_def): @@ -337,53 +338,63 @@ def _extract_code_globals(co): return out_names -def _find_imported_submodules(code, top_level_dependencies): - """Find currently imported submodules used by a function. - - Submodules used by a function need to be detected and referenced for the - function to work correctly at depickling time. Because submodules can be - referenced as attribute of their parent package (``package.submodule``), we - need a special introspection technique that does not rely on GLOBAL-related - opcodes to find references of them in a code object. - - Example: - ``` - import concurrent.futures - import cloudpickle - def func(): - x = concurrent.futures.ThreadPoolExecutor - if __name__ == '__main__': - cloudpickle.dumps(func) - ``` - The globals extracted by cloudpickle in the function's state include the - concurrent package, but not its submodule (here, concurrent.futures), which - is the module used by func. Find_imported_submodules will detect the usage - of concurrent.futures. Saving this module alongside with func will ensure - that calling func once depickled does not fail due to concurrent.futures - not being imported +def _extract_candidate_submodule_names(func): + """Extract strings that look like submodule accesses (e.g. 'concurrent.futures') from a function's + bytecode. """ + candidate_submodule_names = _extract_candidate_submodule_names_cache.get(func) + if candidate_submodule_names is None: + code = func.__code__ + f_globals = func.__globals__ + + # Create a mapping from closure variable names to their actual values. + # co_freevars stores the names of the non-local variables. + # func.__closure__ stores the "cell" objects containing their values. + closure_map = { + name: _get_cell_contents(cell) + for name, cell in zip(code.co_freevars, func.__closure__ or []) + } + instructions = list(dis.get_instructions(code)) + candidate_submodule_names = [] + + for i, instruction in enumerate(instructions): + candidate_module = None + # Check for a global module load, e.g., `import numpy`. + if instruction.opcode == LOAD_GLOBAL: + candidate_module = f_globals.get(instruction.argval) + + # Check for a non-local (closure) module load. + elif instruction.opcode == LOAD_DEREF: + candidate_module = closure_map.get(instruction.argval) + + # If we found a base module object, look ahead for attribute access. + if isinstance(candidate_module, types.ModuleType): + current_path = [candidate_module.__name__] + + for j in range(i + 1, len(instructions)): + next_instr = instructions[j] + if next_instr.opname in ('LOAD_ATTR', 'LOAD_METHOD'): + current_path.append(next_instr.argval) + candidate_submodule_names.append(".".join(current_path)) + else: + # The chain of attribute access was broken. + break + + _extract_candidate_submodule_names_cache[func] = candidate_submodule_names + return candidate_submodule_names + + +def _find_submodules_via_bytecode(func): + """ + Directly finds submodules used by func by analyzing bytecode. + """ + found_submodules = set() + for candidate_submodule_name in _extract_candidate_submodule_names(func): + candidate_submodule = sys.modules.get(candidate_submodule_name) + if candidate_submodule is not None: + found_submodules.add(candidate_submodule) - subimports = [] - # check if any known dependency is an imported package - for x in top_level_dependencies: - if ( - isinstance(x, types.ModuleType) - and hasattr(x, "__package__") - and x.__package__ - ): - # check if the package has any currently loaded sub-imports - prefix = x.__name__ + "." - # A concurrent thread could mutate sys.modules, - # make sure we iterate over a copy to avoid exceptions - for name in list(sys.modules): - # Older versions of pytest will add a "None" module to - # sys.modules. - if name is not None and name.startswith(prefix): - # check whether the function can address the sub-module - tokens = set(name[len(prefix) :].split(".")) - if not tokens - set(code.co_names): - subimports.append(sys.modules[name]) - return subimports + return list(found_submodules) # relevant opcodes @@ -391,6 +402,7 @@ def func(): DELETE_GLOBAL = opcode.opmap["DELETE_GLOBAL"] LOAD_GLOBAL = opcode.opmap["LOAD_GLOBAL"] GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL) +LOAD_DEREF = opcode.opmap["LOAD_DEREF"] HAVE_ARGUMENT = dis.HAVE_ARGUMENT EXTENDED_ARG = dis.EXTENDED_ARG @@ -738,9 +750,17 @@ def _function_getstate(func): # in a smoke _cloudpickle_subimports attribute of the object's state will # trigger the side effect of importing these modules at unpickling time # (which is necessary for func to work correctly once depickled) - slotstate["_cloudpickle_submodules"] = _find_imported_submodules( - func.__code__, itertools.chain(f_globals.values(), closure_values) - ) + top_level_dependencies = itertools.chain(f_globals.values(), closure_values) + if any(isinstance(x, types.ModuleType) and getattr(x, "__package__", None) + for x in top_level_dependencies): + # The func uses a package, so check if it uses any submodules of that + # package. (Sometimes there won't be any; for example function may + # use the package `numpy`, but not use any of its submodules such as + # numpy.random.) + slotstate["_cloudpickle_submodules"] = _find_submodules_via_bytecode(func) + else: + slotstate["_cloudpickle_submodules"] = [] + slotstate["__globals__"] = f_globals # Hack to circumvent non-predictable memoization caused by string interning. diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index e2097d1c..94e5337f 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -1103,6 +1103,50 @@ def example(): f = pickle.loads(s) f() # smoke test + def test_submodule_aliased(self): + # Same as test_submodule except the submodule is aliased + import xml.etree.ElementTree as et + + def example(): + _ = et.Comment # noqa: F821 + + example() # smoke test + + s = cloudpickle.dumps(example, protocol=self.protocol) + + # refresh the environment, i.e., unimport the dependency + del et + for item in list(sys.modules): + if item.split(".")[0] == "xml": + del sys.modules[item] + + # deserialise + f = pickle.loads(s) + f() # smoke test + + def test_submodule_aliased_package(self): + # Same as test_submodule except the xml package is aliased and + # the submodule is accessed through the alias + import xml.etree.ElementTree + _xml = xml + + def example(): + _ = _xml.etree.ElementTree.Comment # noqa: F821 + + example() # smoke test + + s = cloudpickle.dumps(example, protocol=self.protocol) + + # refresh the environment, i.e., unimport the dependency + del xml, _xml + for item in list(sys.modules): + if item.split(".")[0] == "xml": + del sys.modules[item] + + # deserialise + f = pickle.loads(s) + f() # smoke test + def test_submodule_closure(self): # Same as test_submodule except the xml package has not been imported def scope():