Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/error_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Error classification and error budget for fatal-error detection.

This module is intentionally dependency-free (no CUDA, no C++ extensions)
so that it can be imported and tested in any environment.
"""

import dataclasses
import time

# Patterns that corrupt the CUDA context beyond recovery.
# Matched case-insensitively against the error message.
IMMEDIATE_FATAL_PATTERNS: list[str] = [
"cudaerrorillegaladdress",
"cudaerrorlaunchfailure",
"illegal memory access",
"device-side assert",
"unrecoverable",
]

# Patterns that are serious but may be transient (e.g. a single OOM
# during a traffic spike). These drain the error budget 5x faster
# than transient errors.
SEVERE_ERROR_PATTERNS: list[str] = [
"cuda out of memory",
"cuda error",
"nccl error",
]


def classify_error(error_msg: str) -> str:
"""Classify an error message by severity.

Args:
error_msg: The error message string to classify.

Returns:
One of ``"immediate_fatal"``, ``"severe"``, or ``"transient"``.

- **immediate_fatal**: CUDA context is corrupted (device-side
Comment thread
chienchunhung marked this conversation as resolved.
assert, illegal address, launch failure). No future CUDA call
can succeed.
- **severe**: The operation failed but the CUDA context is
intact (e.g. OOM, NCCL timeout). Recovery is possible if
workload decreases.
- **transient**: All other errors (bad input, timeout, etc.).
"""
error_lower = error_msg.lower()
for p in IMMEDIATE_FATAL_PATTERNS:
if p in error_lower:
return "immediate_fatal"
for p in SEVERE_ERROR_PATTERNS:
if p in error_lower:
return "severe"
return "transient"


@dataclasses.dataclass
class ErrorBudget:
"""Token-bucket error budget for fatal-error promotion.

Each error deducts a cost from the budget (0.1 for transient, 0.5
for severe). The budget recovers at ``recovery_rate`` per second
of error-free wall time. When exhausted, the error is promoted to
fatal. Immediate-fatal errors bypass the budget entirely.

Attributes:
budget: Current budget level (starts at 1.0, capped at 1.0).
last_error_time: Monotonic timestamp of the last error.
recovery_rate: Budget recovered per second of error-free time.
cost: Cost per transient error (severe costs 5x this).
"""

budget: float = 1.0
last_error_time: float | None = None
recovery_rate: float = 0.1
cost: float = 0.1

def consume(self, error_msg: str) -> bool:
"""Deduct from the budget and return True if exhausted.

Args:
error_msg: The error message to classify and budget.

Returns:
True if the error should be treated as fatal.
"""
now = time.monotonic()
classification = classify_error(error_msg)

if classification == "immediate_fatal":
return True

if self.last_error_time is not None:
elapsed = now - self.last_error_time
self.budget = min(1.0, self.budget + elapsed * self.recovery_rate)
self.last_error_time = now

deduction = self.cost
if classification == "severe":
deduction *= 5

self.budget -= deduction
return self.budget < 1e-9
115 changes: 109 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from ..speculative.speculation_gate import SpeculationGate
from .connectors.kv_cache_connector import KvCacheConnectorManager
from .dwdp import DwdpManager
from .error_classification import ErrorBudget
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
from .guided_decoder import GuidedDecoder
from .handle_additional_outputs import HandleAdditionalOutputs
Expand Down Expand Up @@ -462,6 +463,8 @@ def __init__(
self.kv_cache_manager.snapshot_warmup_baseline()

self.is_shutdown = False
self._fatal_error: Optional[BaseException] = None
self._error_budget = ErrorBudget()
self.max_batch_size = max_batch_size
self.adp_ctx_waiting_iters_count = 0
self.adp_ctx_batching_wait_iters_count = 0
Expand Down Expand Up @@ -2784,7 +2787,9 @@ def _respond_if_invalid(request: LlmRequest) -> bool:
self._validate_request(request)
return False
except Exception as e:
self._handle_errors(str(e), requests=[request])
self._handle_errors(str(e),
requests=[request],
charge_budget=False)
return True

new_requests_cur_rank = self._fetch_new_requests(
Expand Down Expand Up @@ -3278,7 +3283,8 @@ def _check_cache_transfer_errors(self, error_msg_prefix: str):
if error_requests:
self._handle_errors(
f"Error in kv cache transfer for {error_msg_prefix}",
requests=error_requests)
requests=error_requests,
charge_budget=False)

@nvtx_range("_check_disagg_ctx_cache_transfer_status")
def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0):
Expand Down Expand Up @@ -3499,10 +3505,101 @@ def _update_requests(self,
def _handle_errors(self,
error_msg: Optional[str] = None,
*,
requests: Optional[List[LlmRequest]] = None):
requests: Optional[List[LlmRequest]] = None,
charge_budget: bool = True) -> None:
"""Fail requests and optionally initiate shutdown on fatal errors.

When ``charge_budget`` is True (the default), classifies the error
via the error budget. If deemed fatal (immediate-fatal pattern or
budget exhausted), **all** active requests are failed and a shutdown
is enqueued. Otherwise only the requests in *requests* are failed.

When ``charge_budget`` is False, the error is treated as a
per-request failure: only the specified requests are failed, the
error budget is not consumed, and shutdown is never triggered.
Use this for request-scoped errors (validation, KV-transfer
timeout, guided-decoder) that should not affect server health.

.. note::
The ``charge_budget=False`` path reuses the full
``_handle_errors`` machinery (queue drain, response
enqueue, terminate) even though it only needs to fail a
single request. A future improvement would be to extract
a lightweight ``_fail_request(request, error_msg)`` helper
for request-scoped failures, keeping ``_handle_errors``
focused on system-level errors that may crash the engine.

Args:
error_msg: Human-readable error description. Defaults to
``"error"`` when ``None``.
requests: Subset of active requests to fail. When ``None``
(or when the error is fatal), all ``active_requests`` are
failed.
charge_budget: Whether to consume the error budget. Set to
False for request-scoped errors that should not affect
server health.
"""
error_responses: Dict[int, LlmResponse] = {}
error_msg = error_msg or "error"
failed_requests = requests if requests is not None else self.active_requests

is_fatal = (self._error_budget.consume(error_msg)
if charge_budget else False)
if is_fatal and self._error_budget.budget < 1e-9:
logger.error(f"Error budget exhausted "
f"(budget={self._error_budget.budget:.3f}), "
"treating as fatal")

Comment thread
coderabbitai[bot] marked this conversation as resolved.
if is_fatal:
self._fatal_error = RuntimeError(f"Fatal error: {error_msg}")
self.is_shutdown = True
logger.error(
f"Fatal error detected, initiating shutdown: {error_msg}")
requests = None

# Drain waiting_queue so that queued-but-not-yet-activated
# requests don't get picked up on the next iteration.
# These are RequestQueueItems (not yet LlmRequests), so we
# fail them via error responses. Buffer all responses and
# call _enqueue_responses once after the loop so every rank
# enters the same number of collectives (attention-DP /
# gather-all modes use collective gathers internally).
waiting_responses: List[Tuple[int, LlmResponse]] = []
while self.waiting_queue:
item = self.waiting_queue.pop_request()
if (self.gather_all_responses
or self.dist.rank == 0) and item.request is not None:
waiting_responses.append(
(item.id,
LlmResponse(request_id=item.id,
error_msg=error_msg,
client_id=getattr(item.request,
'client_id', None))))
# Also drain executor_request_queue so items already queued
# but not yet fetched by the main loop are not scheduled
# after the CUDA context is corrupted. Safe to use empty()
# here because is_shutdown is True and the queue's active
# flag is about to be set False, so no new items arrive.
raw_queue = self.executor_request_queue.get_request_queue()
while not raw_queue.empty():
item = raw_queue.get_nowait()
if item.is_shutdown_request:
continue
if ((self.gather_all_responses or self.dist.rank == 0)
and item.request is not None):
waiting_responses.append(
(item.id,
LlmResponse(request_id=item.id,
error_msg=error_msg,
client_id=getattr(item.request,
'client_id', None))))

if waiting_responses:
self._enqueue_responses(waiting_responses)
logger.info(f"Drained {len(waiting_responses)} queued requests "
"on fatal error")

failed_requests = (list(self.active_requests)
if requests is None else requests)
for request in failed_requests:
req_id = request.py_request_id
request.state = LlmRequestState.GENERATION_COMPLETE
Expand All @@ -3521,6 +3618,9 @@ def _handle_errors(self,
for request in failed_requests:
self._terminate_request(request)

if self._fatal_error is not None:
self.executor_request_queue.enqueue_shutdown_request()

def _terminate_request(self, request: LlmRequest):
# Dummy requests don't participate in disagg KV cache transfers,
# so they must bypass the PP termination handler to avoid stale
Expand Down Expand Up @@ -3686,7 +3786,8 @@ def _handle_responses(self):
if is_cancelled:
self._handle_errors(
error_msg=f"Request {request.py_request_id} timed out",
requests=[request])
requests=[request],
charge_budget=False)
continue

if request.is_generation_only_request() and not request.is_finished:
Expand Down Expand Up @@ -3894,7 +3995,9 @@ def _handle_guided_decoder_errors(
if request.py_request_id not in failed_req_id_to_err:
continue
error_msg = failed_req_id_to_err[request.py_request_id]
self._handle_errors(error_msg, requests=[request])
self._handle_errors(error_msg,
requests=[request],
charge_budget=False)


class DisaggPPTerminationHandler:
Expand Down
Loading
Loading