Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,8 +979,11 @@ def forward_chunk(
enable_alltoall=False)
return x

def load_weights(self, weights: Dict[str, torch.Tensor]):
super().load_weights(weights)
def load_weights(self,
weights: List[Dict],
allow_partial_loading: bool = False):
super().load_weights(weights,
allow_partial_loading=allow_partial_loading)
dwdp_handle_collector = getattr(self, "dwdp_handle_collector", None)
if dwdp_handle_collector is not None:
dwdp_handle_collector.register_weights(self)
19 changes: 19 additions & 0 deletions tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,3 +962,22 @@ def test_nvfp4_gather_grouped_gemm_swiglu_blackwell(
c_sf_multi_valid.append(c_sf_multi_unswizzled[i])
c_sf_multi_valid = torch.cat(c_sf_multi_valid)
check_accuracy(c_sf_multi_valid, c_sf_ref_valid, atol=1e-4, rtol=1e-4, percent=0.95)


def test_cutedsl_load_weights_signature_matches_base():
"""Ensure CuteDslFusedMoE.load_weights accepts allow_partial_loading.

Regression guard for the fix where CuteDslFusedMoE.load_weights was
missing the allow_partial_loading parameter, causing TypeError when
called from modeling_utils._load_weights_impl via the params_map path.
"""
import inspect

from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import CuteDslFusedMoE

params = inspect.signature(CuteDslFusedMoE.load_weights).parameters
assert "weights" in params, "CuteDslFusedMoE.load_weights must accept 'weights' parameter"
assert "allow_partial_loading" in params, (
"CuteDslFusedMoE.load_weights must accept 'allow_partial_loading' "
"to match the MoE base class interface"
)
Loading