diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_FakeTensorUpdater.py b/py/torch_tensorrt/dynamo/lowering/passes/_FakeTensorUpdater.py index ee668c720c..438a1b8ffb 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_FakeTensorUpdater.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_FakeTensorUpdater.py @@ -7,6 +7,7 @@ import torch import torch.fx from torch._dispatch.python import enable_python_dispatcher +import torch._inductor.fx_passes.reinplace from torch._inductor.fx_utils import get_fake_args_kwargs, get_node_storage, get_storage from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.symbolic_shapes import ( diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py index fa2a90b129..1d9a3cca19 100644 --- a/tools/llm/static_cache_v1.py +++ b/tools/llm/static_cache_v1.py @@ -105,9 +105,18 @@ def get_static_tensor(tensor: torch.Tensor): min_max_opt = extract_var_range_info(seq_len) max_seq_len = min_max_opt["max"] - from torch.fx.experimental.symbolic_shapes import ShapeEnv + # Get the ShapeEnv from the existing fake tensors in the graph rather than + # creating a new one. Using a fresh ShapeEnv causes a KeyError in + # FakeTensorUpdater because the unbacked symints (u0, u1) are unknown to + # the FakeTensorMode's ShapeEnv. + fake_tensors = [ + node.meta["val"] + for node in gm.graph.nodes + if "val" in node.meta + and isinstance(node.meta["val"], torch._subclasses.fake_tensor.FakeTensor) + ] + shape_env = fake_tensors[0].fake_mode.shape_env - shape_env = ShapeEnv() # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive start_idx_unbacked_symint = shape_env.create_unbacked_symint() torch._check(start_idx_unbacked_symint >= 0) @@ -120,6 +129,12 @@ def get_static_tensor(tensor: torch.Tensor): start_idx_input.meta["val"] = start_idx_unbacked_symint end_idx_input.meta["val"] = end_idx_unbacked_symint + # u0/u1 are scalar index values, not tensor shape dimensions, so they will + # never appear in any output tensor shape. Clear them from the pending list + # so FakeTensorUpdater doesn't raise PendingUnbackedSymbolNotFound when + # processing subsequent call_function nodes (placeholder nodes are skipped). + shape_env.pending_fresh_unbacked_symbols.clear() + return kv_inputs, start_idx_input, end_idx_input diff --git a/tools/llm/static_cache_v2.py b/tools/llm/static_cache_v2.py index 4634b79a52..b1bc37c552 100644 --- a/tools/llm/static_cache_v2.py +++ b/tools/llm/static_cache_v2.py @@ -110,9 +110,18 @@ def get_static_tensor(tensor: torch.Tensor): else: max_seq_len = seq_len - from torch.fx.experimental.symbolic_shapes import ShapeEnv + # Get the ShapeEnv from the existing fake tensors in the graph rather than + # creating a new one. Using a fresh ShapeEnv causes a KeyError in + # FakeTensorUpdater because the unbacked symints (u0, u1) are unknown to + # the FakeTensorMode's ShapeEnv. + fake_tensors = [ + node.meta["val"] + for node in gm.graph.nodes + if "val" in node.meta + and isinstance(node.meta["val"], torch._subclasses.fake_tensor.FakeTensor) + ] + shape_env = fake_tensors[0].fake_mode.shape_env - shape_env = ShapeEnv() # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive start_idx_unbacked_symint = shape_env.create_unbacked_symint() torch._check(start_idx_unbacked_symint >= 0) @@ -125,6 +134,12 @@ def get_static_tensor(tensor: torch.Tensor): start_idx_input.meta["val"] = start_idx_unbacked_symint end_idx_input.meta["val"] = end_idx_unbacked_symint + # u0/u1 are scalar index values, not tensor shape dimensions, so they will + # never appear in any output tensor shape. Clear them from the pending list + # so FakeTensorUpdater doesn't raise PendingUnbackedSymbolNotFound when + # processing subsequent call_function nodes (placeholder nodes are skipped). + shape_env.pending_fresh_unbacked_symbols.clear() + return kv_inputs, start_idx_input, end_idx_input