-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTaskScheduler.h
More file actions
87 lines (72 loc) · 2.36 KB
/
TaskScheduler.h
File metadata and controls
87 lines (72 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#pragma once
#include <functional>
#include <atomic>
#include <vector>
#include <memory>
#include <future>
#include <mutex>
#include "ThreadPool.h"
class TaskGraphScheduler {
public:
struct ITaskNode {
std::vector<ITaskNode*> dependents;
std::atomic<size_t> remainingDeps{0};
virtual void run() = 0;
virtual ~ITaskNode() = default;
};
template<typename T>
struct TaskNode : ITaskNode {
std::packaged_task<T()> task;
std::future<T> fut;
explicit TaskNode(std::function<T()> t)
: task(std::move(t)), fut(task.get_future()) {}
void run() override { task(); }
};
explicit TaskGraphScheduler(ThreadPool& pool) : pool(pool) {}
// TaskHandle wrapper
template<typename T>
struct TaskHandle {
std::shared_ptr<TaskNode<T>> node;
explicit TaskHandle(std::shared_ptr<TaskNode<T>> n) : node(std::move(n)) {}
std::future<T>& get_future() { return node->fut; }
};
// Submit independent task
template<typename T>
TaskHandle<T> submit(std::function<T()> func) {
auto node = std::make_shared<TaskNode<T>>(std::move(func));
nodes.push_back(node);
submitIfReady(node.get());
return TaskHandle<T>(node);
}
// Submit task with dependencies
template<typename T>
TaskHandle<T> submit_with_deps(std::function<T()> func, const std::vector<std::shared_ptr<ITaskNode>>& deps) {
auto node = std::make_shared<TaskNode<T>>(std::move(func));
node->remainingDeps.store(deps.size(), std::memory_order_relaxed);
for (auto& dep : deps) {
dep->dependents.push_back(node.get());
}
nodes.push_back(node);
submitIfReady(node.get());
return TaskHandle<T>(node);
}
private:
ThreadPool& pool;
std::vector<std::shared_ptr<ITaskNode>> nodes;
void submitIfReady(ITaskNode* node) {
if (node->remainingDeps.load(std::memory_order_acquire) == 0) {
submitNode(node);
}
}
void submitNode(ITaskNode* node) {
pool.submit([this, node]() {
node->run();
// Notify dependents
for (auto* dep : node->dependents) {
if (dep->remainingDeps.fetch_sub(1, std::memory_order_acq_rel) == 1) {
submitNode(dep);
}
}
});
}
};