Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 68 additions & 48 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
builtin_code_type = type(float.__new__.__code__)

_extract_code_globals_cache = weakref.WeakKeyDictionary()
_extract_candidate_submodule_names_cache = weakref.WeakKeyDictionary()


def _get_or_create_tracker_id(class_def):
Expand Down Expand Up @@ -337,60 +338,71 @@ def _extract_code_globals(co):
return out_names


def _find_imported_submodules(code, top_level_dependencies):
"""Find currently imported submodules used by a function.

Submodules used by a function need to be detected and referenced for the
function to work correctly at depickling time. Because submodules can be
referenced as attribute of their parent package (``package.submodule``), we
need a special introspection technique that does not rely on GLOBAL-related
opcodes to find references of them in a code object.

Example:
```
import concurrent.futures
import cloudpickle
def func():
x = concurrent.futures.ThreadPoolExecutor
if __name__ == '__main__':
cloudpickle.dumps(func)
```
The globals extracted by cloudpickle in the function's state include the
concurrent package, but not its submodule (here, concurrent.futures), which
is the module used by func. Find_imported_submodules will detect the usage
of concurrent.futures. Saving this module alongside with func will ensure
that calling func once depickled does not fail due to concurrent.futures
not being imported
def _extract_candidate_submodule_names(func):
"""Extract strings that look like submodule accesses (e.g. 'concurrent.futures') from a function's
bytecode.
"""
candidate_submodule_names = _extract_candidate_submodule_names_cache.get(func)
if candidate_submodule_names is None:
code = func.__code__
f_globals = func.__globals__

# Create a mapping from closure variable names to their actual values.
# co_freevars stores the names of the non-local variables.
# func.__closure__ stores the "cell" objects containing their values.
closure_map = {
name: _get_cell_contents(cell)
for name, cell in zip(code.co_freevars, func.__closure__ or [])
}
instructions = list(dis.get_instructions(code))
candidate_submodule_names = []

for i, instruction in enumerate(instructions):
candidate_module = None
# Check for a global module load, e.g., `import numpy`.
if instruction.opcode == LOAD_GLOBAL:
candidate_module = f_globals.get(instruction.argval)

# Check for a non-local (closure) module load.
elif instruction.opcode == LOAD_DEREF:
candidate_module = closure_map.get(instruction.argval)

# If we found a base module object, look ahead for attribute access.
if isinstance(candidate_module, types.ModuleType):
current_path = [candidate_module.__name__]

for j in range(i + 1, len(instructions)):
next_instr = instructions[j]
if next_instr.opname in ('LOAD_ATTR', 'LOAD_METHOD'):
current_path.append(next_instr.argval)
candidate_submodule_names.append(".".join(current_path))
else:
# The chain of attribute access was broken.
break

_extract_candidate_submodule_names_cache[func] = candidate_submodule_names
return candidate_submodule_names


def _find_submodules_via_bytecode(func):
"""
Directly finds submodules used by func by analyzing bytecode.
"""
found_submodules = set()
for candidate_submodule_name in _extract_candidate_submodule_names(func):
candidate_submodule = sys.modules.get(candidate_submodule_name)
if candidate_submodule is not None:
found_submodules.add(candidate_submodule)

subimports = []
# check if any known dependency is an imported package
for x in top_level_dependencies:
if (
isinstance(x, types.ModuleType)
and hasattr(x, "__package__")
and x.__package__
):
# check if the package has any currently loaded sub-imports
prefix = x.__name__ + "."
# A concurrent thread could mutate sys.modules,
# make sure we iterate over a copy to avoid exceptions
for name in list(sys.modules):
# Older versions of pytest will add a "None" module to
# sys.modules.
if name is not None and name.startswith(prefix):
# check whether the function can address the sub-module
tokens = set(name[len(prefix) :].split("."))
if not tokens - set(code.co_names):
subimports.append(sys.modules[name])
return subimports
return list(found_submodules)


# relevant opcodes
STORE_GLOBAL = opcode.opmap["STORE_GLOBAL"]
DELETE_GLOBAL = opcode.opmap["DELETE_GLOBAL"]
LOAD_GLOBAL = opcode.opmap["LOAD_GLOBAL"]
GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
LOAD_DEREF = opcode.opmap["LOAD_DEREF"]
HAVE_ARGUMENT = dis.HAVE_ARGUMENT
EXTENDED_ARG = dis.EXTENDED_ARG

Expand Down Expand Up @@ -738,9 +750,17 @@ def _function_getstate(func):
# in a smoke _cloudpickle_subimports attribute of the object's state will
# trigger the side effect of importing these modules at unpickling time
# (which is necessary for func to work correctly once depickled)
slotstate["_cloudpickle_submodules"] = _find_imported_submodules(
func.__code__, itertools.chain(f_globals.values(), closure_values)
)
top_level_dependencies = itertools.chain(f_globals.values(), closure_values)
if any(isinstance(x, types.ModuleType) and getattr(x, "__package__", None)
for x in top_level_dependencies):
# The func uses a package, so check if it uses any submodules of that
# package. (Sometimes there won't be any; for example function may
# use the package `numpy`, but not use any of its submodules such as
# numpy.random.)
slotstate["_cloudpickle_submodules"] = _find_submodules_via_bytecode(func)
else:
slotstate["_cloudpickle_submodules"] = []

slotstate["__globals__"] = f_globals

# Hack to circumvent non-predictable memoization caused by string interning.
Expand Down
44 changes: 44 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,50 @@ def example():
f = pickle.loads(s)
f() # smoke test

def test_submodule_aliased(self):
# Same as test_submodule except the submodule is aliased
import xml.etree.ElementTree as et

def example():
_ = et.Comment # noqa: F821

example() # smoke test

s = cloudpickle.dumps(example, protocol=self.protocol)

# refresh the environment, i.e., unimport the dependency
del et
for item in list(sys.modules):
if item.split(".")[0] == "xml":
del sys.modules[item]

# deserialise
f = pickle.loads(s)
f() # smoke test

def test_submodule_aliased_package(self):
# Same as test_submodule except the xml package is aliased and
# the submodule is accessed through the alias
import xml.etree.ElementTree
_xml = xml

def example():
_ = _xml.etree.ElementTree.Comment # noqa: F821

example() # smoke test

s = cloudpickle.dumps(example, protocol=self.protocol)

# refresh the environment, i.e., unimport the dependency
del xml, _xml
for item in list(sys.modules):
if item.split(".")[0] == "xml":
del sys.modules[item]

# deserialise
f = pickle.loads(s)
f() # smoke test

def test_submodule_closure(self):
# Same as test_submodule except the xml package has not been imported
def scope():
Expand Down