-
Notifications
You must be signed in to change notification settings - Fork 184
Description
This is an investigation into a performance issue and a proposal and diff for a fix. I would appreciate the dev team's consideration. If it looks OK, I am happy to submit a PR with the fix (including cleaning it up and adding more tests, if desired).
I'm investigating a performance issue with function-pickling behaviour. I narrowed it down to the function _find_imported_submodules. This function repeatedly copies and iterates through sys.modules, doing a string operation on each module name.
With just the imports below, pickling a function that references 3 packages is 6x slower than pickling a function that references 3 non-package modules.
import cloudpickle, timeit
import re, math, os, sys, collections, asyncio
def func1():
[math, os, sys] # non-packages
def func2():
[collections, re, asyncio] # packages
print(timeit.timeit(lambda: cloudpickle.dumps(func1), number=1000))
print(timeit.timeit(lambda: cloudpickle.dumps(func2), number=1000))0.02538153436034918
0.1656142445281148
The slowness comes from the code doing a full copy and linear scan through sys.modules for each package used by the pickled function. Its duration is proportional to the size of sys.modules, which can be very large: in one of my projects, sys.modules has 3800 elements, and the same benchmark shows a speed difference of >50x:
0.021970179863274097
1.147934407927096
I propose an alternative implementation based on scanning the instructions of the function code and checking whether submodules are being accessed. This is a much more efficient access pattern for sys.modules. It is (almost) a drop-in replacement _find_imported_submodules that is much faster in most of the cases I've tried, and no slower in any case I've tried.
The diff is shown below.
Correctness
I verified that all unit tests pass with my change.
I also checked that the proposed change produces the same output as the existing code in these edge cases:
import concurrent.futures
import concurrent.futures as cf
c = concurrent
def func1():
x = concurrent.futures.ThreadPoolExecutor
def func2():
x = cf.ThreadPoolExecutor
def func3():
x = c.futures.ThreadPoolExecutor
def func4():
x = getattr(concurrent, 'futures').ThreadPoolExecutor
def func5():
import concurrent.futures
def inner(): # <-- testing pickling the inner function, where concurrent.futures is nonlocal but not global
x = concurrent.futures.ThreadPoolExecutor
open('/tmp/pkl5', 'wb').write(cloudpickle.dumps(inner))NOTE: both the existing and the proposed algorithms for _find_imported_submodules fail to unpickle the function below, because they both act only on names found in the bytecode. But the proposed algorithm is no worse than the existing one in this regard.
def func():
x = getattr(concurrent, 'futures').ThreadPoolExecutorPerformance
If the function-to-be-pickled does not use packages, we skip this codepath entirely so there is no impact to performance.
Running the benchmark shown in the first code snippet above, which features small functions that use packages, the proposal does not exhibit the observed performance degradation anymore. Moreover, its duration no longer sensitive to the size of sys.modules.
0.025113264098763466
0.028617795556783676
On a 100-line dummy large function generated by ChatGPT that uses a single package, the proposed change is also faster. The result of print(timeit.timeit(lambda: cloudpickle.dumps(func), number=1000)) with the proposed change is 0.0399, compared to 0.1123 with the existing code.
On a real-world scenario from a project that involves pickling 10K+ tasks and task parameters for remote execution, we see a 3.2x speedup of overall cloudpickle.dump, which includes a 8.4x speedup of _function_reduce.
Diff
diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py
index 4f4d857..5ae91e5 100644
--- a/cloudpickle/cloudpickle.py
+++ b/cloudpickle/cloudpickle.py
@@ -103,6 +103,7 @@ if PYPY:
builtin_code_type = type(float.__new__.__code__)
_extract_code_globals_cache = weakref.WeakKeyDictionary()
+_extract_candidate_submodule_names_cache = weakref.WeakKeyDictionary()
def _get_or_create_tracker_id(class_def):
@@ -335,60 +336,77 @@ def _extract_code_globals(co):
return out_names
-def _find_imported_submodules(code, top_level_dependencies):
- """Find currently imported submodules used by a function.
-
- Submodules used by a function need to be detected and referenced for the
- function to work correctly at depickling time. Because submodules can be
- referenced as attribute of their parent package (``package.submodule``), we
- need a special introspection technique that does not rely on GLOBAL-related
- opcodes to find references of them in a code object.
-
- Example:
- ```
- import concurrent.futures
- import cloudpickle
- def func():
- x = concurrent.futures.ThreadPoolExecutor
- if __name__ == '__main__':
- cloudpickle.dumps(func)
- ```
- The globals extracted by cloudpickle in the function's state include the
- concurrent package, but not its submodule (here, concurrent.futures), which
- is the module used by func. Find_imported_submodules will detect the usage
- of concurrent.futures. Saving this module alongside with func will ensure
- that calling func once depickled does not fail due to concurrent.futures
- not being imported
+def _extract_candidate_submodule_names(func):
+ """Extract strings that look like submodule accesses (e.g. 'concurrent.futures') from a function's
+ bytecode.
"""
+ out_names = _extract_candidate_submodule_names_cache.get(func)
+ if out_names is not None:
+ return out_names
+
+ code = func.__code__
+ f_globals = func.__globals__
+
+ # Create a mapping from closure variable names to their actual values.
+ # co_freevars stores the names of the non-local variables.
+ # func.__closure__ stores the "cell" objects containing their values.
+ closure_map = {
+ name: _get_cell_contents(cell)
+ for name, cell in zip(code.co_freevars, func.__closure__ or [])
+ }
- subimports = []
- # check if any known dependency is an imported package
- for x in top_level_dependencies:
- if (
- isinstance(x, types.ModuleType)
- and hasattr(x, "__package__")
- and x.__package__
- ):
- # check if the package has any currently loaded sub-imports
- prefix = x.__name__ + "."
- # A concurrent thread could mutate sys.modules,
- # make sure we iterate over a copy to avoid exceptions
- for name in list(sys.modules):
- # Older versions of pytest will add a "None" module to
- # sys.modules.
- if name is not None and name.startswith(prefix):
- # check whether the function can address the sub-module
- tokens = set(name[len(prefix) :].split("."))
- if not tokens - set(code.co_names):
- subimports.append(sys.modules[name])
- return subimports
+ instructions = list(dis.get_instructions(code))
+ candidate_submodule_names = []
+
+ for i, instruction in enumerate(instructions):
+ candidate_module = None
+ # Check for a global module load, e.g., `import numpy`.
+ if instruction.opcode == LOAD_GLOBAL:
+ candidate_module = f_globals.get(instruction.argval)
+
+ # Check for a non-local (closure) module load.
+ elif instruction.opcode == LOAD_DEREF:
+ candidate_module = closure_map.get(instruction.argval)
+
+ # If we found a base module object, look ahead for attribute access.
+ if isinstance(candidate_module, types.ModuleType):
+ current_path = [candidate_module.__name__]
+
+ for j in range(i + 1, len(instructions)):
+ next_instr = instructions[j]
+ if next_instr.opname in ('LOAD_ATTR', 'LOAD_METHOD'):
+ current_path.append(next_instr.argval)
+ candidate_submodule_names.append(".".join(current_path))
+ else:
+ # The chain of attribute access was broken.
+ break
+
+ _extract_candidate_submodule_names_cache[func] = candidate_submodule_names
+ return candidate_submodule_names
+
+def _find_submodules_via_bytecode(func, top_level_dependencies):
+ """
+ Directly finds submodules used by func by analyzing bytecode.
+ """
+ # Fail fast if the function doesn't use packages.
+ if not any(x for x in top_level_dependencies
+ if isinstance(x, types.ModuleType) and getattr(x, "__package__", None)):
+ return []
+
+ found_submodules = set()
+ for candidate_submodule_name in _extract_candidate_submodule_names(func):
+ candidate_submodule = sys.modules.get(candidate_submodule_name)
+ if candidate_submodule is not None:
+ found_submodules.add(candidate_submodule)
+ return list(found_submodules)
# relevant opcodes
STORE_GLOBAL = opcode.opmap["STORE_GLOBAL"]
DELETE_GLOBAL = opcode.opmap["DELETE_GLOBAL"]
LOAD_GLOBAL = opcode.opmap["LOAD_GLOBAL"]
GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
+LOAD_DEREF = opcode.opmap["LOAD_DEREF"]
HAVE_ARGUMENT = dis.HAVE_ARGUMENT
EXTENDED_ARG = dis.EXTENDED_ARG
@@ -736,8 +754,8 @@ def _function_getstate(func):
# in a smoke _cloudpickle_subimports attribute of the object's state will
# trigger the side effect of importing these modules at unpickling time
# (which is necessary for func to work correctly once depickled)
- slotstate["_cloudpickle_submodules"] = _find_imported_submodules(
- func.__code__, itertools.chain(f_globals.values(), closure_values)
+ slotstate["_cloudpickle_submodules"] = _find_submodules_via_bytecode(
+ func, itertools.chain(f_globals.values(), closure_values)
)
slotstate["__globals__"] = f_globals