Skip to content

Commit 43fc7f4

Browse files
authored
Merge pull request #99 from chenht2022/main
Adapt Windows
2 parents a81a7ff + 14869b5 commit 43fc7f4

File tree

2 files changed

+101
-37
lines changed

2 files changed

+101
-37
lines changed

ktransformers/ktransformers_ext/cpu_backend/task_queue.cpp

+16-12
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
/**
2-
* @Description :
3-
* @Author : chenht2022
4-
* @Date : 2024-07-17 12:25:51
5-
* @Version : 1.0.0
6-
* @LastEditors : chenht2022
7-
* @LastEditTime : 2024-07-25 10:33:44
2+
* @Description :
3+
* @Author : chenht2022
4+
* @Date : 2024-07-17 12:25:51
5+
* @Version : 1.0.0
6+
* @LastEditors : chenht2022
7+
* @LastEditTime : 2024-10-09 11:08:10
88
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
99
**/
1010
#include "task_queue.h"
@@ -17,8 +17,9 @@ TaskQueue::TaskQueue() {
1717

1818
TaskQueue::~TaskQueue() {
1919
{
20-
std::unique_lock<std::mutex> lock(mutex);
20+
mutex.lock();
2121
exit_flag.store(true, std::memory_order_seq_cst);
22+
mutex.unlock();
2223
}
2324
cv.notify_all();
2425
if (worker.joinable()) {
@@ -28,9 +29,10 @@ TaskQueue::~TaskQueue() {
2829

2930
void TaskQueue::enqueue(std::function<void()> task) {
3031
{
31-
std::unique_lock<std::mutex> lock(mutex);
32+
mutex.lock();
3233
tasks.push(task);
3334
sync_flag.store(false, std::memory_order_seq_cst);
35+
mutex.unlock();
3436
}
3537
cv.notify_one();
3638
}
@@ -44,20 +46,22 @@ void TaskQueue::processTasks() {
4446
while (true) {
4547
std::function<void()> task;
4648
{
47-
std::unique_lock<std::mutex> lock(mutex);
48-
cv.wait(lock, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); });
49+
mutex.lock();
50+
cv.wait(mutex, [this]() { return !tasks.empty() || exit_flag.load(std::memory_order_seq_cst); });
4951
if (exit_flag.load(std::memory_order_seq_cst) && tasks.empty()) {
5052
return;
5153
}
5254
task = tasks.front();
5355
tasks.pop();
56+
mutex.unlock();
5457
}
5558
task();
5659
{
57-
std::lock_guard<std::mutex> lock(mutex);
60+
mutex.lock();
5861
if (tasks.empty()) {
5962
sync_flag.store(true, std::memory_order_seq_cst);
6063
}
64+
mutex.unlock();
6165
}
6266
}
63-
}
67+
}
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
/**
2-
* @Description :
3-
* @Author : chenht2022
4-
* @Date : 2024-07-16 10:43:18
5-
* @Version : 1.0.0
6-
* @LastEditors : chenxl
7-
* @LastEditTime : 2024-08-08 04:23:51
2+
* @Description :
3+
* @Author : chenht2022
4+
* @Date : 2024-07-16 10:43:18
5+
* @Version : 1.0.0
6+
* @LastEditors : chenht
7+
* @LastEditTime : 2024-10-09 11:08:07
88
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
99
**/
1010
#ifndef CPUINFER_TASKQUEUE_H
@@ -22,36 +22,96 @@
2222
#endif
2323

2424
class custom_mutex {
25-
private:
25+
private:
2626
#ifdef _WIN32
27-
HANDLE global_mutex;
27+
CRITICAL_SECTION cs;
2828
#else
29-
std::mutex global_mutex;
29+
std::mutex mtx;
3030
#endif
31-
32-
public:
33-
custom_mutex()
34-
{
31+
32+
public:
33+
custom_mutex() {
3534
#ifdef _WIN32
36-
HANDLE global_mutex;
35+
InitializeCriticalSection(&cs);
36+
#else
37+
// No initialization required for std::mutex
3738
#endif
3839
}
3940

40-
void lock()
41-
{
41+
~custom_mutex() {
4242
#ifdef _WIN32
43-
WaitForSingleObject(global_mutex, INFINITE);
43+
DeleteCriticalSection(&cs);
44+
#endif
45+
}
46+
47+
void lock() {
48+
#ifdef _WIN32
49+
EnterCriticalSection(&cs);
4450
#else
45-
global_mutex.lock();
51+
mtx.lock();
4652
#endif
4753
}
4854

49-
void unlock()
50-
{
55+
void unlock() {
5156
#ifdef _WIN32
52-
ReleaseMutex(global_mutex);
57+
LeaveCriticalSection(&cs);
5358
#else
54-
global_mutex.unlock();
59+
mtx.unlock();
60+
#endif
61+
}
62+
63+
#ifdef _WIN32
64+
CRITICAL_SECTION* get_handle() {
65+
return &cs;
66+
}
67+
#else
68+
std::mutex* get_handle() {
69+
return &mtx;
70+
}
71+
#endif
72+
};
73+
74+
class custom_condition_variable {
75+
private:
76+
#ifdef _WIN32
77+
CONDITION_VARIABLE cond_var;
78+
#else
79+
std::condition_variable cond_var;
80+
#endif
81+
82+
public:
83+
custom_condition_variable() {
84+
#ifdef _WIN32
85+
InitializeConditionVariable(&cond_var);
86+
#endif
87+
}
88+
89+
template <typename Predicate>
90+
void wait(custom_mutex& mutex, Predicate pred) {
91+
#ifdef _WIN32
92+
while (!pred()) {
93+
SleepConditionVariableCS(&cond_var, mutex.get_handle(), INFINITE);
94+
}
95+
#else
96+
std::unique_lock<std::mutex> lock(*mutex.get_handle(), std::adopt_lock);
97+
cond_var.wait(lock, pred);
98+
lock.release();
99+
#endif
100+
}
101+
102+
void notify_one() {
103+
#ifdef _WIN32
104+
WakeConditionVariable(&cond_var);
105+
#else
106+
cond_var.notify_one();
107+
#endif
108+
}
109+
110+
void notify_all() {
111+
#ifdef _WIN32
112+
WakeAllConditionVariable(&cond_var);
113+
#else
114+
cond_var.notify_all();
55115
#endif
56116
}
57117
};
@@ -69,10 +129,10 @@ class TaskQueue {
69129
void processTasks();
70130

71131
std::queue<std::function<void()>> tasks;
72-
std::mutex mutex;
73-
std::condition_variable cv;
132+
custom_mutex mutex;
133+
custom_condition_variable cv;
74134
std::thread worker;
75135
std::atomic<bool> sync_flag;
76136
std::atomic<bool> exit_flag;
77137
};
78-
#endif
138+
#endif

0 commit comments

Comments
 (0)