Skip to content

Commit f3d117b

Browse files
author
songyuqin0686
committed
XSched integrated with priority scheduling for high-priority task acceleration, auto-detects GPU architecture, adapts preemption level, and runs stably without CUDA errors
1 parent cdeff47 commit f3d117b

3 files changed

Lines changed: 80 additions & 17 deletions

File tree

ggml/src/ggml-cuda/common.cuh

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <cstdio>
2626
#include <string>
2727
#include <vector>
28+
#include <mutex>
2829

2930
#if defined(GGML_USE_HIP)
3031
#include "vendors/hip.h"
@@ -751,7 +752,7 @@ struct ggml_tensor_extra_gpu {
751752

752753

753754
#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
754-
#define USE_CUDA_GRAPH
755+
// #define USE_CUDA_GRAPH
755756
#endif
756757

757758
struct ggml_graph_node_properties {
@@ -799,8 +800,11 @@ struct ggml_backend_cuda_context {
799800
cudaEvent_t copy_event = nullptr;
800801

801802
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
803+
HwQueueHandle hwqueues[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { 0 } };
804+
XQueueHandle xqueues[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { 0 } };
802805
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
803806

807+
mutable std::mutex streams_mutex;
804808
std::unique_ptr<ggml_cuda_graph> cuda_graph;
805809

806810
int priority = 0;
@@ -812,16 +816,63 @@ struct ggml_backend_cuda_context {
812816

813817
~ggml_backend_cuda_context();
814818

819+
// Disable copying and moving to prevent resource management issues
820+
ggml_backend_cuda_context(const ggml_backend_cuda_context&) = delete;
821+
ggml_backend_cuda_context& operator=(const ggml_backend_cuda_context&) = delete;
822+
ggml_backend_cuda_context(ggml_backend_cuda_context&&) = delete;
823+
ggml_backend_cuda_context& operator=(ggml_backend_cuda_context&&) = delete;
824+
825+
int get_max_supported_preempt_level(int device_id) {
826+
cudaDeviceProp prop;
827+
CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id));
828+
int arch = prop.major * 10 + prop.minor;
829+
switch (arch) {
830+
case 35: // Kepler: K20, K40, GTX TITAN
831+
return 2; // kPreemptLevelDeactivate
832+
case 70: // Volta: V100
833+
case 86: // Ampere: RTX 30系列
834+
return 3; // kPreemptLevelInterrupt
835+
default: // 其他架构(包括 A100 arch=80)
836+
return 1; // kPreemptLevelBlock
837+
}
838+
}
839+
815840
cudaStream_t stream(int device, int stream) {
841+
std::lock_guard<std::mutex> lock(streams_mutex);
842+
843+
// 如果流不存在,创建流和XSched队列
816844
if (streams[device][stream] == nullptr) {
817845
ggml_cuda_set_device(device);
818846
CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
819-
HwQueueHandle hwqueue;
820-
CudaQueueCreate(&hwqueue,streams[device][stream]);
821-
XQueueHandle xqueue;
822-
XQueueCreate(&xqueue, hwqueue, kPreemptLevelDeactivate, kQueueCreateFlagNone);
823-
XHintPriority(xqueue, priority); // In XSched, lower number means lower priority
847+
848+
HwQueueHandle hwqueue = 0;
849+
XResult res = CudaQueueCreate(&hwqueue, streams[device][stream]);
850+
if (res != kXSchedSuccess) {
851+
CUDA_CHECK(cudaStreamDestroy(streams[device][stream]));
852+
streams[device][stream] = nullptr;
853+
GGML_ABORT("CudaQueueCreate failed: %d", res);
854+
}
855+
856+
XQueueHandle xqueue = 0;
857+
res = XQueueCreate(&xqueue, hwqueue, get_max_supported_preempt_level(device), kQueueCreateFlagNone);
858+
if (res != kXSchedSuccess) {
859+
HwQueueDestroy(hwqueue);
860+
CUDA_CHECK(cudaStreamDestroy(streams[device][stream]));
861+
streams[device][stream] = nullptr;
862+
GGML_ABORT("XQueueCreate failed: %d", res);
863+
}
864+
865+
hwqueues[device][stream] = hwqueue;
866+
xqueues[device][stream] = xqueue;
867+
868+
// 设置初始优先级(总是设置,包括优先级0)
869+
XHintPriority(xqueue, priority);
870+
}
871+
// 如果流已存在但XSched队列未绑定(不应该发生,但确保健壮性)
872+
else if (xqueues[device][stream] == 0) {
873+
GGML_ABORT("Stream exists but XQueue not bound - internal error");
824874
}
875+
825876
return streams[device][stream];
826877
}
827878

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -531,14 +531,26 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() {
531531
if (copy_event != nullptr) {
532532
CUDA_CHECK(cudaEventDestroy(copy_event));
533533
}
534+
535+
// 销毁所有XSched队列和硬件队列
534536
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
535537
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
538+
if (xqueues[i][j] != 0) {
539+
XQueueDestroy(xqueues[i][j]);
540+
xqueues[i][j] = 0;
541+
}
542+
if (hwqueues[i][j] != 0) {
543+
HwQueueDestroy(hwqueues[i][j]);
544+
hwqueues[i][j] = 0;
545+
}
536546
if (streams[i][j] != nullptr) {
537547
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
548+
streams[i][j] = nullptr;
538549
}
539550
}
540551
if (cublas_handles[i] != nullptr) {
541552
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
553+
cublas_handles[i] = nullptr;
542554
}
543555
}
544556
}
@@ -2855,19 +2867,19 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
28552867

28562868
static void ggml_backend_cuda_set_priority(ggml_backend_t backend, int prio) {
28572869
ggml_backend_cuda_context *cuda_ctx = (ggml_backend_cuda_context *)backend->context;
2870+
2871+
std::lock_guard<std::mutex> lock(cuda_ctx->streams_mutex);
2872+
2873+
// 更新所有已存在XQueue的优先级
28582874
for (int device = 0; device < GGML_CUDA_MAX_DEVICES; device++) {
28592875
for (int idx = 0; idx < GGML_CUDA_MAX_STREAMS; idx++) {
2860-
auto stream = cuda_ctx->streams[device][idx];
2861-
if(stream == nullptr) {
2862-
continue;
2876+
if (cuda_ctx->xqueues[device][idx] != 0) {
2877+
XHintPriority(cuda_ctx->xqueues[device][idx], prio);
28632878
}
2864-
HwQueueHandle hwqueue;
2865-
CudaQueueCreate(&hwqueue,stream);
2866-
XQueueHandle xqueue;
2867-
XQueueCreate(&xqueue, hwqueue, kPreemptLevelDeactivate, kQueueCreateFlagNone);
2868-
XHintPriority(xqueue, prio); // In XSched, lower number means lower priority
28692879
}
28702880
}
2881+
2882+
// 存储新优先级,后续创建的XQueue将使用此优先级
28712883
cuda_ctx->priority = prio;
28722884
}
28732885

tools/server/server.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5010,9 +5010,9 @@ int main(int argc, char ** argv) {
50105010
std::vector<std::thread> threads;
50115011
// this call blocks the main thread until queue_tasks.terminate() is called
50125012
for(int i = 0; i < SERVER_TASK_PRIO_COUNT; i++) {
5013-
threads.emplace_back([&ctx_server, &i]() {
5014-
ctx_server[i].queue_tasks.start_loop();
5015-
});
5013+
threads.emplace_back([&ctx_server](int ind) {
5014+
ctx_server[ind].queue_tasks.start_loop();
5015+
},i);
50165016
}
50175017

50185018
for(auto &thread: threads) {

0 commit comments

Comments
 (0)