Skip to content

Pickling functions that use packages can be extremely slow. Proposal and code for an alternative implementation #576

@27359794

Description

@27359794

This is an investigation into a performance issue and a proposal and diff for a fix. I would appreciate the dev team's consideration. If it looks OK, I am happy to submit a PR with the fix (including cleaning it up and adding more tests, if desired).


I'm investigating a performance issue with function-pickling behaviour. I narrowed it down to the function _find_imported_submodules. This function repeatedly copies and iterates through sys.modules, doing a string operation on each module name.

With just the imports below, pickling a function that references 3 packages is 6x slower than pickling a function that references 3 non-package modules.

import cloudpickle, timeit                                         
import re, math, os, sys, collections, asyncio                         
                                                                   
def func1():                                                       
    [math, os, sys]  # non-packages                            
                                                                   
def func2():                                                       
    [collections, re, asyncio]  # packages                     
                                                                   
print(timeit.timeit(lambda: cloudpickle.dumps(func1), number=1000))
print(timeit.timeit(lambda: cloudpickle.dumps(func2), number=1000))
0.02538153436034918
0.1656142445281148

The slowness comes from the code doing a full copy and linear scan through sys.modules for each package used by the pickled function. Its duration is proportional to the size of sys.modules, which can be very large: in one of my projects, sys.modules has 3800 elements, and the same benchmark shows a speed difference of >50x:

0.021970179863274097
1.147934407927096

I propose an alternative implementation based on scanning the instructions of the function code and checking whether submodules are being accessed. This is a much more efficient access pattern for sys.modules. It is (almost) a drop-in replacement _find_imported_submodules that is much faster in most of the cases I've tried, and no slower in any case I've tried.

The diff is shown below.

Correctness

I verified that all unit tests pass with my change.

I also checked that the proposed change produces the same output as the existing code in these edge cases:

import concurrent.futures       
import concurrent.futures as cf 
c = concurrent                  

def func1():                                               
    x = concurrent.futures.ThreadPoolExecutor              
                                                           
def func2():                                               
    x = cf.ThreadPoolExecutor                              
                                                           
def func3():                                               
    x = c.futures.ThreadPoolExecutor                       
                                                           
def func4():                                               
    x = getattr(concurrent, 'futures').ThreadPoolExecutor  
                                                           
def func5():                                               
    import concurrent.futures                              
    def inner():    # <-- testing pickling the inner function, where concurrent.futures is nonlocal but not global
        x = concurrent.futures.ThreadPoolExecutor          
    open('/tmp/pkl5', 'wb').write(cloudpickle.dumps(inner))

NOTE: both the existing and the proposed algorithms for _find_imported_submodules fail to unpickle the function below, because they both act only on names found in the bytecode. But the proposed algorithm is no worse than the existing one in this regard.

def func():                                              
    x = getattr(concurrent, 'futures').ThreadPoolExecutor

Performance

If the function-to-be-pickled does not use packages, we skip this codepath entirely so there is no impact to performance.

Running the benchmark shown in the first code snippet above, which features small functions that use packages, the proposal does not exhibit the observed performance degradation anymore. Moreover, its duration no longer sensitive to the size of sys.modules.

0.025113264098763466
0.028617795556783676

On a 100-line dummy large function generated by ChatGPT that uses a single package, the proposed change is also faster. The result of print(timeit.timeit(lambda: cloudpickle.dumps(func), number=1000)) with the proposed change is 0.0399, compared to 0.1123 with the existing code.

On a real-world scenario from a project that involves pickling 10K+ tasks and task parameters for remote execution, we see a 3.2x speedup of overall cloudpickle.dump, which includes a 8.4x speedup of _function_reduce.

Diff

diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py
index 4f4d857..5ae91e5 100644
--- a/cloudpickle/cloudpickle.py
+++ b/cloudpickle/cloudpickle.py
@@ -103,6 +103,7 @@ if PYPY:
     builtin_code_type = type(float.__new__.__code__)
 
 _extract_code_globals_cache = weakref.WeakKeyDictionary()
+_extract_candidate_submodule_names_cache = weakref.WeakKeyDictionary()
 
 
 def _get_or_create_tracker_id(class_def):
@@ -335,60 +336,77 @@ def _extract_code_globals(co):
     return out_names
 
 
-def _find_imported_submodules(code, top_level_dependencies):
-    """Find currently imported submodules used by a function.
-
-    Submodules used by a function need to be detected and referenced for the
-    function to work correctly at depickling time. Because submodules can be
-    referenced as attribute of their parent package (``package.submodule``), we
-    need a special introspection technique that does not rely on GLOBAL-related
-    opcodes to find references of them in a code object.
-
-    Example:
-    ```
-    import concurrent.futures
-    import cloudpickle
-    def func():
-        x = concurrent.futures.ThreadPoolExecutor
-    if __name__ == '__main__':
-        cloudpickle.dumps(func)
-    ```
-    The globals extracted by cloudpickle in the function's state include the
-    concurrent package, but not its submodule (here, concurrent.futures), which
-    is the module used by func. Find_imported_submodules will detect the usage
-    of concurrent.futures. Saving this module alongside with func will ensure
-    that calling func once depickled does not fail due to concurrent.futures
-    not being imported
+def _extract_candidate_submodule_names(func):
+    """Extract strings that look like submodule accesses (e.g. 'concurrent.futures') from a function's
+    bytecode.
     """
+    out_names = _extract_candidate_submodule_names_cache.get(func)
+    if out_names is not None:
+        return out_names
+
+    code = func.__code__
+    f_globals = func.__globals__
+
+    # Create a mapping from closure variable names to their actual values.
+    # co_freevars stores the names of the non-local variables.
+    # func.__closure__ stores the "cell" objects containing their values.
+    closure_map = {
+        name: _get_cell_contents(cell)
+        for name, cell in zip(code.co_freevars, func.__closure__ or [])
+    }
 
-    subimports = []
-    # check if any known dependency is an imported package
-    for x in top_level_dependencies:
-        if (
-            isinstance(x, types.ModuleType)
-            and hasattr(x, "__package__")
-            and x.__package__
-        ):
-            # check if the package has any currently loaded sub-imports
-            prefix = x.__name__ + "."
-            # A concurrent thread could mutate sys.modules,
-            # make sure we iterate over a copy to avoid exceptions
-            for name in list(sys.modules):
-                # Older versions of pytest will add a "None" module to
-                # sys.modules.
-                if name is not None and name.startswith(prefix):
-                    # check whether the function can address the sub-module
-                    tokens = set(name[len(prefix) :].split("."))
-                    if not tokens - set(code.co_names):
-                        subimports.append(sys.modules[name])
-    return subimports
+    instructions = list(dis.get_instructions(code))
+    candidate_submodule_names = []
+
+    for i, instruction in enumerate(instructions):
+        candidate_module = None
+        # Check for a global module load, e.g., `import numpy`.
+        if instruction.opcode == LOAD_GLOBAL:
+            candidate_module = f_globals.get(instruction.argval)
+
+        # Check for a non-local (closure) module load.
+        elif instruction.opcode == LOAD_DEREF:
+            candidate_module = closure_map.get(instruction.argval)
+
+        # If we found a base module object, look ahead for attribute access.
+        if isinstance(candidate_module, types.ModuleType):
+            current_path = [candidate_module.__name__]
+
+            for j in range(i + 1, len(instructions)):
+                next_instr = instructions[j]
+                if next_instr.opname in ('LOAD_ATTR', 'LOAD_METHOD'):
+                    current_path.append(next_instr.argval)
+                    candidate_submodule_names.append(".".join(current_path))
+                else:
+                    # The chain of attribute access was broken.
+                    break
+
+    _extract_candidate_submodule_names_cache[func] = candidate_submodule_names
+    return candidate_submodule_names
+
+def _find_submodules_via_bytecode(func, top_level_dependencies):
+    """
+    Directly finds submodules used by func by analyzing bytecode.
+    """
+    # Fail fast if the function doesn't use packages.
+    if not any(x for x in top_level_dependencies 
+               if isinstance(x, types.ModuleType) and getattr(x, "__package__", None)):
+        return []
+
+    found_submodules = set()
+    for candidate_submodule_name in _extract_candidate_submodule_names(func):
+        candidate_submodule = sys.modules.get(candidate_submodule_name)
+        if candidate_submodule is not None:
+            found_submodules.add(candidate_submodule)
 
+    return list(found_submodules)
 
 # relevant opcodes
 STORE_GLOBAL = opcode.opmap["STORE_GLOBAL"]
 DELETE_GLOBAL = opcode.opmap["DELETE_GLOBAL"]
 LOAD_GLOBAL = opcode.opmap["LOAD_GLOBAL"]
 GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
+LOAD_DEREF = opcode.opmap["LOAD_DEREF"]
 HAVE_ARGUMENT = dis.HAVE_ARGUMENT
 EXTENDED_ARG = dis.EXTENDED_ARG
 
@@ -736,8 +754,8 @@ def _function_getstate(func):
     # in a smoke _cloudpickle_subimports attribute of the object's state will
     # trigger the side effect of importing these modules at unpickling time
     # (which is necessary for func to work correctly once depickled)
-    slotstate["_cloudpickle_submodules"] = _find_imported_submodules(
-        func.__code__, itertools.chain(f_globals.values(), closure_values)
+    slotstate["_cloudpickle_submodules"] = _find_submodules_via_bytecode(
+        func, itertools.chain(f_globals.values(), closure_values)
     )
     slotstate["__globals__"] = f_globals

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions