Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.fx
from torch._dispatch.python import enable_python_dispatcher
import torch._inductor.fx_passes.reinplace
from torch._inductor.fx_utils import get_fake_args_kwargs, get_node_storage, get_storage
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import (
Expand Down
19 changes: 17 additions & 2 deletions tools/llm/static_cache_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,18 @@ def get_static_tensor(tensor: torch.Tensor):
min_max_opt = extract_var_range_info(seq_len)
max_seq_len = min_max_opt["max"]

from torch.fx.experimental.symbolic_shapes import ShapeEnv
# Get the ShapeEnv from the existing fake tensors in the graph rather than
# creating a new one. Using a fresh ShapeEnv causes a KeyError in
# FakeTensorUpdater because the unbacked symints (u0, u1) are unknown to
# the FakeTensorMode's ShapeEnv.
fake_tensors = [
node.meta["val"]
for node in gm.graph.nodes
if "val" in node.meta
and isinstance(node.meta["val"], torch._subclasses.fake_tensor.FakeTensor)
]
shape_env = fake_tensors[0].fake_mode.shape_env

shape_env = ShapeEnv()
# Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
start_idx_unbacked_symint = shape_env.create_unbacked_symint()
torch._check(start_idx_unbacked_symint >= 0)
Expand All @@ -120,6 +129,12 @@ def get_static_tensor(tensor: torch.Tensor):
start_idx_input.meta["val"] = start_idx_unbacked_symint
end_idx_input.meta["val"] = end_idx_unbacked_symint

# u0/u1 are scalar index values, not tensor shape dimensions, so they will
# never appear in any output tensor shape. Clear them from the pending list
# so FakeTensorUpdater doesn't raise PendingUnbackedSymbolNotFound when
# processing subsequent call_function nodes (placeholder nodes are skipped).
shape_env.pending_fresh_unbacked_symbols.clear()

return kv_inputs, start_idx_input, end_idx_input


Expand Down
19 changes: 17 additions & 2 deletions tools/llm/static_cache_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,18 @@ def get_static_tensor(tensor: torch.Tensor):
else:
max_seq_len = seq_len

from torch.fx.experimental.symbolic_shapes import ShapeEnv
# Get the ShapeEnv from the existing fake tensors in the graph rather than
# creating a new one. Using a fresh ShapeEnv causes a KeyError in
# FakeTensorUpdater because the unbacked symints (u0, u1) are unknown to
# the FakeTensorMode's ShapeEnv.
fake_tensors = [
node.meta["val"]
for node in gm.graph.nodes
if "val" in node.meta
and isinstance(node.meta["val"], torch._subclasses.fake_tensor.FakeTensor)
]
shape_env = fake_tensors[0].fake_mode.shape_env

shape_env = ShapeEnv()
# Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
start_idx_unbacked_symint = shape_env.create_unbacked_symint()
torch._check(start_idx_unbacked_symint >= 0)
Expand All @@ -125,6 +134,12 @@ def get_static_tensor(tensor: torch.Tensor):
start_idx_input.meta["val"] = start_idx_unbacked_symint
end_idx_input.meta["val"] = end_idx_unbacked_symint

# u0/u1 are scalar index values, not tensor shape dimensions, so they will
# never appear in any output tensor shape. Clear them from the pending list
# so FakeTensorUpdater doesn't raise PendingUnbackedSymbolNotFound when
# processing subsequent call_function nodes (placeholder nodes are skipped).
shape_env.pending_fresh_unbacked_symbols.clear()

return kv_inputs, start_idx_input, end_idx_input


Expand Down
Loading