Skip to content

TransformerEngine conflicts with Triton #2579

@xiyu-Luoy

Description

@xiyu-Luoy

Describe the bug

When I used Meagtron to train Qwen3-Omni, I encountered an issue where the TransformerEngine and Triton versions were incompatible.

[rank1]: Traceback (most recent call last):
[rank1]:   File "/app/ms-swift/swift/cli/_megatron/sft.py", line 7, in <module>
[rank1]:     megatron_sft_main()
[rank1]:   File "/app/ms-swift/swift/megatron/train/sft.py", line 87, in megatron_sft_main
[rank1]:     return MegatronSft(args).main()
[rank1]:   File "/app/ms-swift/swift/llm/base.py", line 49, in main
[rank1]:     result = self.run()
[rank1]:   File "/app/ms-swift/swift/megatron/train/sft.py", line 77, in run
[rank1]:     self.trainer.train(train_dataset, val_dataset, data_collator)
[rank1]:   File "/app/ms-swift/swift/megatron/trainers/base.py", line 1098, in train
[rank1]:     pretrain(
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/training/training.py", line 737, in pretrain
[rank1]:     iteration, num_floating_point_operations_so_far = train(
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/training/training.py", line 2298, in train
[rank1]:     ) = train_step(
[rank1]:   File "/app/ms-swift/swift/megatron/trainers/base.py", line 565, in train_step
[rank1]:     return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/training/training.py", line 1268, in train_step
[rank1]:     losses_reduced = forward_backward_func(
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 615, in forward_backward_no_pipelining
[rank1]:     output_tensor, num_tokens = forward_step(
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 402, in forward_step
[rank1]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank1]:   File "/app/ms-swift/swift/megatron/trainers/trainer.py", line 150, in forward_step
[rank1]:     output_tensor = model(**data)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/distributed/data_parallel_base.py", line 22, in forward
[rank1]:     return self.module(*inputs, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/transformer/module.py", line 429, in forward
[rank1]:     outputs = self.module(*inputs, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/app/ms-swift/swift/megatron/model/mm_gpt_model.py", line 106, in forward
[rank1]:     return self.language_model(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/app/ms-swift/swift/megatron/model/gpt_model.py", line 299, in forward
[rank1]:     hidden_states = self.decoder(
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/transformer/transformer_block.py", line 553, in __call__
[rank1]:     return super().__call__(*args, **kwargs)
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/transformer/module.py", line 305, in __call__
[rank1]:     return super().__call__(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/app/ms-swift/swift/megatron/model/mm_gpt/qwen3_vl.py", line 376, in forward
[rank1]:     hidden_states = self._checkpointed_forward(
[rank1]:   File "/app/ms-swift/swift/megatron/model/mm_gpt/qwen3_vl.py", line 257, in _checkpointed_forward
[rank1]:     hidden_states, context = checkpoint_handler(
[rank1]:   File "/app/ms-swift/swift/megatron/model/mm_gpt/qwen3_vl.py", line 239, in checkpoint_handler
[rank1]:     return tensor_parallel.checkpoint(
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/tensor_parallel/random.py", line 480, in checkpoint
[rank1]:     return CheckpointFunction.apply(function, distribute_saved_activations, *args)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 581, in apply
[rank1]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/tensor_parallel/random.py", line 426, in forward
[rank1]:     outputs = run_function(*args)
[rank1]:   File "/app/ms-swift/swift/megatron/model/mm_gpt/qwen3_vl.py", line 200, in custom_forward
[rank1]:     hidden_states, context = layer(
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 852, in __call__
[rank1]:     return super().__call__(*args, **kwargs)
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/transformer/module.py", line 305, in __call__
[rank1]:     return super().__call__(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/app/ms-swift/swift/megatron/init.py", line 551, in forward
[rank1]:     output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None))
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 618, in _forward_mlp
[rank1]:     mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/transformer/moe/moe_layer.py", line 292, in forward
[rank1]:     output, mlp_bias = custom_forward(hidden_states)
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/transformer/moe/moe_layer.py", line 274, in custom_forward
[rank1]:     hidden_states, probs, residual = self.router_and_preprocess(hidden_states)
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/transformer/moe/moe_layer.py", line 179, in router_and_preprocess
[rank1]:     hidden_states, probs = self.token_dispatcher.dispatch_preprocess(
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/transformer/moe/token_dispatcher.py", line 597, in dispatch_preprocess
[rank1]:     ) = permute(
[rank1]:   File "/home/xiyu.luo/.cache/modelscope/hub/_github/Megatron-LM/megatron/core/transformer/moe/moe_utils.py", line 260, in permute
[rank1]:     return fused_permute_with_probs(tokens, probs, routing_map, num_out_tokens=num_out_tokens)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/permutation.py", line 572, in moe_permute_with_probs
[rank1]:     output, row_id_map, permuted_probs = _moe_permute_mask_map.apply(
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 581, in apply
[rank1]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/permutation.py", line 212, in forward
[rank1]:     row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/triton/permutation.py", line 280, in make_row_id_map
[rank1]:     _row_id_map_pass_3_kernel[grid](
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 419, in <lambda>
[rank1]:     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 733, in run
[rank1]:     kernel = self._do_compile(key, signature, device, constexprs, options, attrs, warmup)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 861, in _do_compile
[rank1]:     kernel = self.compile(src, target=target, options=options.__dict__)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 242, in compile
[rank1]:     key = get_cache_key(src, backend, options, env_vars=env_vars)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/cache.py", line 308, in get_cache_key
[rank1]:     key = f"{triton_key()}-{src.hash()}-{backend.hash()}-{backend_options.hash()}-{str(sorted(env_vars.items()))}"
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 75, in hash
[rank1]:     key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 548, in cache_key
[rank1]:     dependencies_finder.visit(self.parse())
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
[rank1]:     self.visit(item)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 194, in visit_FunctionDef
[rank1]:     self.generic_visit(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
[rank1]:     self.visit(item)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 248, in visit_Assign
[rank1]:     self.generic_visit(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 428, in generic_visit
[rank1]:     self.visit(value)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 428, in generic_visit
[rank1]:     self.visit(value)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 172, in visit_Name
[rank1]:     self.record_reference(val, var_dict, node.id)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 133, in record_reference
[rank1]:     self._update_hash(val)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 113, in _update_hash
[rank1]:     func_key = func.cache_key
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 548, in cache_key
[rank1]:     dependencies_finder.visit(self.parse())
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
[rank1]:     self.visit(item)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 194, in visit_FunctionDef
[rank1]:     self.generic_visit(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
[rank1]:     self.visit(item)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 260, in visit_For
[rank1]:     self.generic_visit(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
[rank1]:     self.visit(item)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 248, in visit_Assign
[rank1]:     self.generic_visit(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 428, in generic_visit
[rank1]:     self.visit(value)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 428, in generic_visit
[rank1]:     self.visit(value)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 172, in visit_Name
[rank1]:     self.record_reference(val, var_dict, node.id)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 133, in record_reference
[rank1]:     self._update_hash(val)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 113, in _update_hash
[rank1]:     func_key = func.cache_key
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 548, in cache_key
[rank1]:     dependencies_finder.visit(self.parse())
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
[rank1]:     self.visit(item)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 194, in visit_FunctionDef
[rank1]:     self.generic_visit(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
[rank1]:     self.visit(item)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 260, in visit_For
[rank1]:     self.generic_visit(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
[rank1]:     self.visit(item)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 248, in visit_Assign
[rank1]:     self.generic_visit(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 428, in generic_visit
[rank1]:     self.visit(value)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 428, in generic_visit
[rank1]:     self.visit(value)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 172, in visit_Name
[rank1]:     self.record_reference(val, var_dict, node.id)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 133, in record_reference
[rank1]:     self._update_hash(val)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 113, in _update_hash
[rank1]:     func_key = func.cache_key
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 548, in cache_key
[rank1]:     dependencies_finder.visit(self.parse())
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
[rank1]:     self.visit(item)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 194, in visit_FunctionDef
[rank1]:     self.generic_visit(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
[rank1]:     self.visit(item)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 248, in visit_Assign
[rank1]:     self.generic_visit(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 428, in generic_visit
[rank1]:     self.visit(value)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 428, in generic_visit
[rank1]:     self.visit(value)
[rank1]:   File "/usr/lib/python3.10/ast.py", line 418, in visit
[rank1]:     return visitor(node)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 188, in visit_Attribute
[rank1]:     self.record_reference(ret)
[rank1]:   File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 137, in record_reference
[rank1]:     raise RuntimeError(f"Unsupported function referenced: {val}")
[rank1]: RuntimeError: Unsupported function referenced: <function get_int_dtype at 0x7f6f7f073c70>
[rank6]:[W108 07:48:30.078154380 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank6]:[W108 07:48:33.795695654 AllocatorConfig.cpp:28] Warning: PYTORCH_CUDA_ALLOC_CONF is deprecated, use PYTORCH_ALLOC_CONF instead (function operator())
[rank4]:[W108 07:48:33.920804378 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
W0108 07:48:34.125000 60929 usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 61007 closing signal SIGTERM
W0108 07:48:34.126000 60929 usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 61008 closing signal SIGTERM
W0108 07:48:34.127000 60929 usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 61009 closing signal SIGTERM
W0108 07:48:34.128000 60929 usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 61010 closing signal SIGTERM
W0108 07:48:34.128000 60929 usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 61011 closing signal SIGTERM
W0108 07:48:34.129000 60929 usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 61012 closing signal SIGTERM
W0108 07:48:34.129000 60929 usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/api.py:908] Sending process 61014 closing signal SIGTERM
E0108 07:48:36.430000 60929 usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/api.py:882] failed (exitcode: 1) local_rank: 6 (pid: 61013) of binary: /usr/bin/python
Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 940, in <module>
    main()
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 936, in main
    run(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 927, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 156, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 293, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/app/ms-swift/swift/cli/_megatron/sft.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2026-01-08_07:48:34
  host      : mother-valid-clicks-iz9-xiyu-luo-master-0.mother-valid-clicks-iz9-xiyu-luo.wsgo-xiyu-luo.svc.cluster.local
  rank      : 6 (local_rank: 6)
  exitcode  : 1 (pid: 61013)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Steps/Code to reproduce bug
Using the ms-swift

PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
NVTE_DEBUG=1 \
NVTE_DEBUG_LEVEL=2 \
ENABLE_AUDIO_OUTPUT=0 \
NPROC_PER_NODE=8 \
MAX_PIXELS=1003520 \
VIDEO_MAX_PIXELS=50176 \
FPS_MAX_FRAMES=12 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
megatron sft \
    --model /Qwen3-Omni-30B-A3B-Instruct \
    --load_safetensors true \
    --save_safetensors true \
    --merge_lora false \
    --dataset train.jsonl \
    --val_dataset val.jsonl \
    --load_from_cache_file true \
    --train_type lora \
    --lora_rank 8 \
    --lora_alpha 32 \
    --target_modules all-linear \
    --sequence_parallel true \
    --packing true \
    --freeze_llm false \
    --freeze_vit true \
    --freeze_aligner true \
    --split_dataset_ratio 0.01 \
    --expert_model_parallel_size 2 \
    --moe_permute_fusion true \
    --moe_grouped_gemm true \
    --moe_shared_expert_overlap true \
    --moe_aux_loss_coeff 1e-3 \
    --micro_batch_size 1 \
    --global_batch_size 8 \
    --recompute_granularity full \
    --recompute_method uniform \
    --recompute_num_layers 1 \
    --finetune true \
    --cross_entropy_loss_fusion true \
    --lr 1e-4 \
    --lr_warmup_fraction 0.05 \
    --min_lr 1e-5 \
    --max_epochs 1 \
    --save /qwen3omni \
    --eval_interval 100 \
    --save_interval 100 \
    --vit_gradient_checkpointing true \
    --max_length 4096 \
    --num_workers 8 \
    --dataset_num_proc 8 \
    --no_save_optim true \
    --no_save_rng true \
    --attention_backend flash

Expected behavior

vllm==0.13.0 will install torch==2.9.0
torch==2.9.0 in turn depends on triton==3.5.0
triton==3.5.0 conflicts with transformerEngine==2.9.0+f7df8d8
How should I resolve this conflict?

Environment details

pytorch==2.9.0
transformers==4.57.3
vllm==0.13.0
ms-swift=3.13.0dev0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions