Skip to content

Commit 8a4e7e0

Browse files
committed
[update] update vad & whisper, vad support kws, whisper suport kws.
1 parent ee276a3 commit 8a4e7e0

File tree

2 files changed

+178
-61
lines changed

2 files changed

+178
-61
lines changed

projects/llm_framework/main_vad/src/main.cpp

Lines changed: 169 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
using namespace StackFlows;
2424

2525
int main_exit_flage = 0;
26+
2627
static void __sigint(int iSigNo)
2728
{
2829
SLOGW("llm_vad will be exit!");
@@ -51,13 +52,20 @@ class llm_task {
5152
std::vector<std::string> inputs_;
5253
bool enoutput_;
5354
bool enstream_;
55+
bool ensleep_;
5456
bool printed = false;
55-
task_callback_t out_callback_;
57+
std::atomic_bool superior_flage_;
5658
std::atomic_bool audio_flage_;
59+
std::atomic_bool awake_flage_;
60+
std::string superior_id_;
61+
task_callback_t out_callback_;
62+
int awake_delay_ = 50;
5763
int delay_audio_frame_ = 100;
5864
buffer_t *pcmdata;
5965
std::string wake_wav_file_;
6066

67+
std::function<void(void)> pause;
68+
6169
bool parse_config(const nlohmann::json &config_body)
6270
{
6371
try {
@@ -184,9 +192,17 @@ class llm_task {
184192
if (out_callback_) {
185193
out_callback_(false);
186194
}
195+
if (ensleep_) {
196+
if (pause) pause();
197+
}
187198
}
188199
}
189200

201+
void kws_awake()
202+
{
203+
awake_flage_ = true;
204+
}
205+
190206
bool delete_model()
191207
{
192208
vad_.reset();
@@ -195,7 +211,9 @@ class llm_task {
195211

196212
llm_task(const std::string &workid) : audio_flage_(false)
197213
{
198-
pcmdata = buffer_create();
214+
ensleep_ = false;
215+
awake_flage_ = false;
216+
pcmdata = buffer_create();
199217
}
200218

201219
~llm_task()
@@ -215,9 +233,12 @@ class llm_vad : public StackFlow {
215233
std::unordered_map<int, std::shared_ptr<llm_task>> llm_task_;
216234

217235
public:
236+
enum { EVENT_LOAD_CONFIG = EVENT_EXPORT + 1, EVENT_TASK_PAUSE };
218237
llm_vad() : StackFlow("vad")
219238
{
220239
task_count_ = 1;
240+
event_queue_.appendListener(
241+
EVENT_TASK_PAUSE, std::bind(&llm_vad::_task_pause, this, std::placeholders::_1, std::placeholders::_2));
221242
}
222243

223244
void task_output(const std::weak_ptr<llm_task> llm_task_obj_weak,
@@ -233,20 +254,70 @@ class llm_vad : public StackFlow {
233254
llm_channel->send(llm_task_obj->response_format_, (*next_data), LLM_NO_ERROR);
234255
}
235256

236-
void task_pause(const std::weak_ptr<llm_task> llm_task_obj_weak,
237-
const std::weak_ptr<llm_channel_obj> llm_channel_weak)
257+
void task_user_data(const std::weak_ptr<llm_task> llm_task_obj_weak,
258+
const std::weak_ptr<llm_channel_obj> llm_channel_weak, const std::string &object,
259+
const std::string &data)
238260
{
261+
nlohmann::json error_body;
239262
auto llm_task_obj = llm_task_obj_weak.lock();
240263
auto llm_channel = llm_channel_weak.lock();
241264
if (!(llm_task_obj && llm_channel)) {
265+
error_body["code"] = -11;
266+
error_body["message"] = "Model run failed.";
267+
send("None", "None", error_body, unit_name_);
242268
return;
243269
}
270+
std::string tmp_msg1;
271+
const std::string *next_data = &data;
272+
int ret;
273+
if (object.find("stream") != std::string::npos) {
274+
static std::unordered_map<int, std::string> stream_buff;
275+
try {
276+
if (decode_stream(data, tmp_msg1, stream_buff)) {
277+
return;
278+
};
279+
} catch (...) {
280+
stream_buff.clear();
281+
error_body["code"] = -25;
282+
error_body["message"] = "Stream data index error.";
283+
send("None", "None", error_body, unit_name_);
284+
return;
285+
}
286+
next_data = &tmp_msg1;
287+
}
288+
std::string tmp_msg2;
289+
if (object.find("base64") != std::string::npos) {
290+
ret = decode_base64((*next_data), tmp_msg2);
291+
if (ret == -1) {
292+
error_body["code"] = -23;
293+
error_body["message"] = "Base64 decoding error.";
294+
send("None", "None", error_body, unit_name_);
295+
return;
296+
}
297+
next_data = &tmp_msg2;
298+
}
299+
llm_task_obj->sys_pcm_on_data((*next_data));
300+
}
301+
302+
void _task_pause(const std::string &work_id, const std::string &data)
303+
{
304+
int work_id_num = sample_get_work_id_num(work_id);
305+
if (llm_task_.find(work_id_num) == llm_task_.end()) {
306+
return;
307+
}
308+
auto llm_task_obj = llm_task_[work_id_num];
309+
auto llm_channel = get_channel(work_id_num);
244310
if (llm_task_obj->audio_flage_) {
245311
if (!audio_url_.empty()) llm_channel->stop_subscriber(audio_url_);
246312
llm_task_obj->audio_flage_ = false;
247313
}
248314
}
249315

316+
void task_pause(const std::string &work_id, const std::string &data)
317+
{
318+
event_queue_.enqueue(EVENT_TASK_PAUSE, work_id, "");
319+
}
320+
250321
void task_work(const std::weak_ptr<llm_task> llm_task_obj_weak,
251322
const std::weak_ptr<llm_channel_obj> llm_channel_weak)
252323
{
@@ -264,9 +335,22 @@ class llm_vad : public StackFlow {
264335
}
265336
}
266337

338+
void kws_awake(const std::weak_ptr<llm_task> llm_task_obj_weak,
339+
const std::weak_ptr<llm_channel_obj> llm_channel_weak, const std::string &object,
340+
const std::string &data)
341+
{
342+
auto llm_task_obj = llm_task_obj_weak.lock();
343+
auto llm_channel = llm_channel_weak.lock();
344+
if (!(llm_task_obj && llm_channel)) {
345+
return;
346+
}
347+
std::this_thread::sleep_for(std::chrono::milliseconds(llm_task_obj->awake_delay_));
348+
task_work(llm_task_obj, llm_channel);
349+
}
350+
267351
void work(const std::string &work_id, const std::string &object, const std::string &data) override
268352
{
269-
SLOGI("llm_asr::work:%s", data.c_str());
353+
SLOGI("llm_vad::work:%s", data.c_str());
270354

271355
nlohmann::json error_body;
272356
int work_id_num = sample_get_work_id_num(work_id);
@@ -282,7 +366,7 @@ class llm_vad : public StackFlow {
282366

283367
void pause(const std::string &work_id, const std::string &object, const std::string &data) override
284368
{
285-
SLOGI("llm_asr::work:%s", data.c_str());
369+
SLOGI("llm_vad::work:%s", data.c_str());
286370

287371
nlohmann::json error_body;
288372
int work_id_num = sample_get_work_id_num(work_id);
@@ -292,55 +376,10 @@ class llm_vad : public StackFlow {
292376
send("None", "None", error_body, work_id);
293377
return;
294378
}
295-
task_pause(llm_task_[work_id_num], get_channel(work_id_num));
379+
task_pause(work_id, "");
296380
send("None", "None", LLM_NO_ERROR, work_id);
297381
}
298382

299-
void task_user_data(const std::weak_ptr<llm_task> llm_task_obj_weak,
300-
const std::weak_ptr<llm_channel_obj> llm_channel_weak, const std::string &object,
301-
const std::string &data)
302-
{
303-
nlohmann::json error_body;
304-
auto llm_task_obj = llm_task_obj_weak.lock();
305-
auto llm_channel = llm_channel_weak.lock();
306-
if (!(llm_task_obj && llm_channel)) {
307-
error_body["code"] = -11;
308-
error_body["message"] = "Model run failed.";
309-
send("None", "None", error_body, unit_name_);
310-
return;
311-
}
312-
std::string tmp_msg1;
313-
const std::string *next_data = &data;
314-
int ret;
315-
if (object.find("stream") != std::string::npos) {
316-
static std::unordered_map<int, std::string> stream_buff;
317-
try {
318-
if (decode_stream(data, tmp_msg1, stream_buff)) {
319-
return;
320-
};
321-
} catch (...) {
322-
stream_buff.clear();
323-
error_body["code"] = -25;
324-
error_body["message"] = "Stream data index error.";
325-
send("None", "None", error_body, unit_name_);
326-
return;
327-
}
328-
next_data = &tmp_msg1;
329-
}
330-
std::string tmp_msg2;
331-
if (object.find("base64") != std::string::npos) {
332-
ret = decode_base64((*next_data), tmp_msg2);
333-
if (ret == -1) {
334-
error_body["code"] = -23;
335-
error_body["message"] = "Base64 decoding error.";
336-
send("None", "None", error_body, unit_name_);
337-
return;
338-
}
339-
next_data = &tmp_msg2;
340-
}
341-
llm_task_obj->sys_pcm_on_data((*next_data));
342-
}
343-
344383
int setup(const std::string &work_id, const std::string &object, const std::string &data) override
345384
{
346385
nlohmann::json error_body;
@@ -354,6 +393,7 @@ class llm_vad : public StackFlow {
354393
int work_id_num = sample_get_work_id_num(work_id);
355394
auto llm_channel = get_channel(work_id);
356395
auto llm_task_obj = std::make_shared<llm_task>(work_id);
396+
357397
nlohmann::json config_body;
358398
try {
359399
config_body = nlohmann::json::parse(data);
@@ -368,6 +408,7 @@ class llm_vad : public StackFlow {
368408
if (ret == 0) {
369409
llm_channel->set_output(llm_task_obj->enoutput_);
370410
llm_channel->set_stream(llm_task_obj->enstream_);
411+
llm_task_obj->pause = std::bind(&llm_vad::task_pause, this, work_id, "");
371412
llm_task_obj->set_output(std::bind(&llm_vad::task_output, this, std::weak_ptr<llm_task>(llm_task_obj),
372413
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1));
373414

@@ -384,6 +425,13 @@ class llm_vad : public StackFlow {
384425
"", std::bind(&llm_vad::task_user_data, this, std::weak_ptr<llm_task>(llm_task_obj),
385426
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1,
386427
std::placeholders::_2));
428+
} else if (input.find("kws") != std::string::npos) {
429+
llm_task_obj->ensleep_ = true;
430+
task_pause(work_id, "");
431+
llm_channel->subscriber_work_id(
432+
input, std::bind(&llm_vad::kws_awake, this, std::weak_ptr<llm_task>(llm_task_obj),
433+
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1,
434+
std::placeholders::_2));
387435
}
388436
}
389437
llm_task_[work_id_num] = llm_task_obj;
@@ -399,6 +447,74 @@ class llm_vad : public StackFlow {
399447
}
400448
}
401449

450+
void link(const std::string &work_id, const std::string &object, const std::string &data) override
451+
{
452+
SLOGI("llm_vad::link:%s", data.c_str());
453+
int ret = 1;
454+
nlohmann::json error_body;
455+
int work_id_num = sample_get_work_id_num(work_id);
456+
if (llm_task_.find(work_id_num) == llm_task_.end()) {
457+
error_body["code"] = -6;
458+
error_body["message"] = "Unit Does Not Exist";
459+
send("None", "None", error_body, work_id);
460+
return;
461+
}
462+
auto llm_channel = get_channel(work_id);
463+
auto llm_task_obj = llm_task_[work_id_num];
464+
if (data.find("sys") != std::string::npos) {
465+
if (audio_url_.empty()) audio_url_ = unit_call("audio", "cap", data);
466+
std::weak_ptr<llm_task> _llm_task_obj = llm_task_obj;
467+
llm_channel->subscriber(audio_url_, [_llm_task_obj](pzmq *_pzmq, const std::string &raw) {
468+
_llm_task_obj.lock()->sys_pcm_on_data(raw);
469+
});
470+
llm_task_obj->audio_flage_ = true;
471+
llm_task_obj->inputs_.push_back(data);
472+
} else if (data.find("kws") != std::string::npos) {
473+
llm_task_obj->ensleep_ = true;
474+
ret = llm_channel->subscriber_work_id(
475+
data,
476+
std::bind(&llm_vad::kws_awake, this, std::weak_ptr<llm_task>(llm_task_obj),
477+
std::weak_ptr<llm_channel_obj>(llm_channel), std::placeholders::_1, std::placeholders::_2));
478+
llm_task_obj->inputs_.push_back(data);
479+
}
480+
if (ret) {
481+
error_body["code"] = -20;
482+
error_body["message"] = "link false";
483+
send("None", "None", error_body, work_id);
484+
return;
485+
} else {
486+
send("None", "None", LLM_NO_ERROR, work_id);
487+
}
488+
}
489+
490+
void unlink(const std::string &work_id, const std::string &object, const std::string &data) override
491+
{
492+
SLOGI("llm_vad::unlink:%s", data.c_str());
493+
int ret = 0;
494+
nlohmann::json error_body;
495+
int work_id_num = sample_get_work_id_num(work_id);
496+
if (llm_task_.find(work_id_num) == llm_task_.end()) {
497+
error_body["code"] = -6;
498+
error_body["message"] = "Unit Does Not Exist";
499+
send("None", "None", error_body, work_id);
500+
return;
501+
}
502+
auto llm_channel = get_channel(work_id);
503+
auto llm_task_obj = llm_task_[work_id_num];
504+
if (llm_task_obj->superior_id_ == work_id) {
505+
llm_task_obj->superior_flage_ = false;
506+
}
507+
llm_channel->stop_subscriber_work_id(data);
508+
for (auto it = llm_task_obj->inputs_.begin(); it != llm_task_obj->inputs_.end();) {
509+
if (*it == data) {
510+
it = llm_task_obj->inputs_.erase(it);
511+
} else {
512+
++it;
513+
}
514+
}
515+
send("None", "None", LLM_NO_ERROR, work_id);
516+
}
517+
402518
void taskinfo(const std::string &work_id, const std::string &object, const std::string &data) override
403519
{
404520
SLOGI("llm_vad::taskinfo:%s", data.c_str());
@@ -428,7 +544,7 @@ class llm_vad : public StackFlow {
428544

429545
int exit(const std::string &work_id, const std::string &object, const std::string &data) override
430546
{
431-
SLOGI("llm_kws::exit:%s", data.c_str());
547+
SLOGI("llm_vad::exit:%s", data.c_str());
432548
nlohmann::json error_body;
433549
int work_id_num = sample_get_work_id_num(work_id);
434550
if (llm_task_.find(work_id_num) == llm_task_.end()) {

0 commit comments

Comments
 (0)