2525#include < cstdio>
2626#include < string>
2727#include < vector>
28+ #include < mutex>
2829
2930#if defined(GGML_USE_HIP)
3031#include " vendors/hip.h"
@@ -751,7 +752,7 @@ struct ggml_tensor_extra_gpu {
751752
752753
753754#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
754- #define USE_CUDA_GRAPH
755+ // #define USE_CUDA_GRAPH
755756#endif
756757
757758struct ggml_graph_node_properties {
@@ -799,8 +800,11 @@ struct ggml_backend_cuda_context {
799800 cudaEvent_t copy_event = nullptr ;
800801
801802 cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
803+ HwQueueHandle hwqueues[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { 0 } };
804+ XQueueHandle xqueues[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { 0 } };
802805 cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr };
803806
807+ mutable std::mutex streams_mutex;
804808 std::unique_ptr<ggml_cuda_graph> cuda_graph;
805809
806810 int priority = 0 ;
@@ -812,16 +816,63 @@ struct ggml_backend_cuda_context {
812816
813817 ~ggml_backend_cuda_context ();
814818
819+ // Disable copying and moving to prevent resource management issues
820+ ggml_backend_cuda_context (const ggml_backend_cuda_context&) = delete ;
821+ ggml_backend_cuda_context& operator =(const ggml_backend_cuda_context&) = delete ;
822+ ggml_backend_cuda_context (ggml_backend_cuda_context&&) = delete ;
823+ ggml_backend_cuda_context& operator =(ggml_backend_cuda_context&&) = delete ;
824+
825+ int get_max_supported_preempt_level (int device_id) {
826+ cudaDeviceProp prop;
827+ CUDA_CHECK (cudaGetDeviceProperties (&prop, device_id));
828+ int arch = prop.major * 10 + prop.minor ;
829+ switch (arch) {
830+ case 35 : // Kepler: K20, K40, GTX TITAN
831+ return 2 ; // kPreemptLevelDeactivate
832+ case 70 : // Volta: V100
833+ case 86 : // Ampere: RTX 30系列
834+ return 3 ; // kPreemptLevelInterrupt
835+ default : // 其他架构(包括 A100 arch=80)
836+ return 1 ; // kPreemptLevelBlock
837+ }
838+ }
839+
815840 cudaStream_t stream (int device, int stream) {
841+ std::lock_guard<std::mutex> lock (streams_mutex);
842+
843+ // 如果流不存在,创建流和XSched队列
816844 if (streams[device][stream] == nullptr ) {
817845 ggml_cuda_set_device (device);
818846 CUDA_CHECK (cudaStreamCreateWithFlags (&streams[device][stream], cudaStreamNonBlocking));
819- HwQueueHandle hwqueue;
820- CudaQueueCreate (&hwqueue,streams[device][stream]);
821- XQueueHandle xqueue;
822- XQueueCreate (&xqueue, hwqueue, kPreemptLevelDeactivate , kQueueCreateFlagNone );
823- XHintPriority (xqueue, priority); // In XSched, lower number means lower priority
847+
848+ HwQueueHandle hwqueue = 0 ;
849+ XResult res = CudaQueueCreate (&hwqueue, streams[device][stream]);
850+ if (res != kXSchedSuccess ) {
851+ CUDA_CHECK (cudaStreamDestroy (streams[device][stream]));
852+ streams[device][stream] = nullptr ;
853+ GGML_ABORT (" CudaQueueCreate failed: %d" , res);
854+ }
855+
856+ XQueueHandle xqueue = 0 ;
857+ res = XQueueCreate (&xqueue, hwqueue, get_max_supported_preempt_level (device), kQueueCreateFlagNone );
858+ if (res != kXSchedSuccess ) {
859+ HwQueueDestroy (hwqueue);
860+ CUDA_CHECK (cudaStreamDestroy (streams[device][stream]));
861+ streams[device][stream] = nullptr ;
862+ GGML_ABORT (" XQueueCreate failed: %d" , res);
863+ }
864+
865+ hwqueues[device][stream] = hwqueue;
866+ xqueues[device][stream] = xqueue;
867+
868+ // 设置初始优先级(总是设置,包括优先级0)
869+ XHintPriority (xqueue, priority);
870+ }
871+ // 如果流已存在但XSched队列未绑定(不应该发生,但确保健壮性)
872+ else if (xqueues[device][stream] == 0 ) {
873+ GGML_ABORT (" Stream exists but XQueue not bound - internal error" );
824874 }
875+
825876 return streams[device][stream];
826877 }
827878
0 commit comments