Skip to content

Commit 18ad846

Browse files
committed
feat: add ToolTask implementation and update scheduled trigger functions
1 parent 00f9a89 commit 18ad846

3 files changed

Lines changed: 149 additions & 19 deletions

File tree

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎虎
5+
@file: application_task.py
6+
@date:2026/1/14 19:14
7+
@desc:
8+
"""
9+
10+
import uuid_utils.compat as uuid
11+
from django.db.models import QuerySet
12+
13+
from common.utils.tool_code import ToolExecutor
14+
from knowledge.models.knowledge_action import State
15+
from tools.models import Tool
16+
from trigger.handler.base_task import BaseTriggerTask
17+
from trigger.models import TaskRecord
18+
19+
20+
def get_reference(fields, obj):
21+
for field in fields:
22+
value = obj.get(field)
23+
if value is None:
24+
return None
25+
else:
26+
obj = value
27+
return obj
28+
29+
30+
def get_field_value(value, kwargs):
31+
source = value.get('source')
32+
if source == 'custom':
33+
return value.get('value')
34+
else:
35+
return get_reference(value.get('value'), kwargs)
36+
37+
38+
def get_application_execute_parameters(parameter_setting, kwargs):
39+
parameters = {'form_data': {}}
40+
question_setting = parameter_setting.get('question')
41+
if question_setting:
42+
parameters['message'] = get_field_value(question_setting, kwargs)
43+
filed_list = ['image_list', 'document_list', 'audio_list', 'video_list', 'other_list']
44+
for field in filed_list:
45+
field_setting = parameter_setting.get(field)
46+
if field_setting:
47+
parameters[field] = get_field_value(field_setting, kwargs)
48+
api_input_field_list = parameter_setting.get('api_input_field_list')
49+
if api_input_field_list:
50+
for key, value in api_input_field_list.items():
51+
parameters['form_data'][key] = get_field_value(value, kwargs)
52+
user_input_field_list = parameter_setting.get('user_input_field_list')
53+
if user_input_field_list:
54+
for key, value in user_input_field_list.items():
55+
parameters['form_data'][key] = get_field_value(value, kwargs)
56+
return parameters
57+
58+
59+
def get_loop_workflow_node(node_list):
60+
result = []
61+
for item in node_list:
62+
if item.get('type') == 'loop-node':
63+
for loop_item in item.get('loop_node_data') or []:
64+
for inner_item in loop_item.values():
65+
result.append(inner_item)
66+
return result
67+
68+
69+
def get_workflow_state(details):
70+
node_list = details.values()
71+
all_node = [*node_list, *get_loop_workflow_node(node_list)]
72+
err = any([True for value in all_node if value.get('status') == 500 and not value.get('enableException')])
73+
if err:
74+
return State.FAILURE
75+
return State.SUCCESS
76+
77+
78+
class ToolTask(BaseTriggerTask):
79+
def support(self, trigger_task, **kwargs):
80+
return trigger_task.get('source_type') == 'TOOL'
81+
82+
def execute(self, trigger_task, **kwargs):
83+
parameter_setting = trigger_task.get('parameter')
84+
parameters = get_application_execute_parameters(parameter_setting, kwargs)
85+
tool_id = trigger_task.get('source_id')
86+
task_record_id = uuid.uuid7()
87+
88+
TaskRecord(
89+
id=task_record_id,
90+
trigger_id=trigger_task.get('trigger_id'),
91+
trigger_task_id=trigger_task.get('id'),
92+
source_type="TOOL",
93+
source_id=tool_id,
94+
task_record_id=task_record_id,
95+
meta={},
96+
state=State.STARTED
97+
).save()
98+
99+
try:
100+
tool = QuerySet(Tool).filter(id=tool_id).first()
101+
executor = ToolExecutor()
102+
# executor.exec_code(tool.code, parameters)
103+
print(tool)
104+
print(parameters)
105+
106+
QuerySet(TaskRecord).filter(id=task_record_id).update(state=State.SUCCESS, run_time=0)
107+
except Exception as e:
108+
state = State.FAILURE
109+
QuerySet(TaskRecord).filter(id=task_record_id).update(state=state, run_time=0)

apps/trigger/handler/impl/trigger/scheduled_trigger.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# coding=utf-8
2-
from __future__ import annotations
32

43
import random
54

@@ -37,7 +36,7 @@ def _get_active_trigger_tasks(trigger_id: str) -> list[dict]:
3736
)
3837

3938

40-
def _deploy_daily(trigger: dict, trigger_tasks: list[dict], setting: dict, trigger_id: str, func) -> None:
39+
def _deploy_daily(trigger: dict, trigger_tasks: list[dict], setting: dict, trigger_id: str) -> None:
4140
from common.job import scheduler
4241

4342
times = setting.get("time") or []
@@ -51,17 +50,17 @@ def _deploy_daily(trigger: dict, trigger_tasks: list[dict], setting: dict, trigg
5150
for task in trigger_tasks:
5251
job_id = f"trigger:{trigger_id}:task:{task['id']}:daily:{hour:02d}{minute:02d}"
5352
scheduler.add_job(
54-
func,
53+
ScheduledTrigger.execute,
5554
trigger="cron",
5655
hour=str(hour),
5756
minute=str(minute),
5857
id=job_id,
59-
kwargs={"trigger": trigger, "trigger_tasks": trigger_tasks},
58+
kwargs={"trigger": trigger, "trigger_task": task},
6059
replace_existing=True,
6160
)
6261

6362

64-
def _deploy_weekly(trigger: dict, trigger_tasks: list[dict], setting: dict, trigger_id: str, func) -> None:
63+
def _deploy_weekly(trigger: dict, trigger_tasks: list[dict], setting: dict, trigger_id: str) -> None:
6564
from common.job import scheduler
6665

6766
times = setting.get("time") or []
@@ -87,18 +86,18 @@ def _deploy_weekly(trigger: dict, trigger_tasks: list[dict], setting: dict, trig
8786
for task in trigger_tasks:
8887
job_id = f"trigger:{trigger_id}:task:{task['id']}:weekly:{dow}:{hour:02d}{minute:02d}"
8988
scheduler.add_job(
90-
func,
89+
ScheduledTrigger.execute,
9190
trigger="cron",
9291
day_of_week=dow,
9392
hour=str(hour),
9493
minute=str(minute),
9594
id=job_id,
96-
kwargs={"trigger": trigger, "trigger_tasks": trigger_tasks},
95+
kwargs={"trigger": trigger, "trigger_task": task},
9796
replace_existing=True,
9897
)
9998

10099

101-
def _deploy_monthly(trigger: dict, trigger_tasks: list[dict], setting: dict, trigger_id: str, func) -> None:
100+
def _deploy_monthly(trigger: dict, trigger_tasks: list[dict], setting: dict, trigger_id: str) -> None:
102101
from common.job import scheduler
103102

104103
times = setting.get("time") or []
@@ -126,18 +125,18 @@ def _deploy_monthly(trigger: dict, trigger_tasks: list[dict], setting: dict, tri
126125
for task in trigger_tasks:
127126
job_id = f"trigger:{trigger_id}:task:{task['id']}:monthly:{dom:02d}:{hour:02d}{minute:02d}"
128127
scheduler.add_job(
129-
func,
128+
ScheduledTrigger.execute,
130129
trigger="cron",
131130
day=str(dom),
132131
hour=str(hour),
133132
minute=str(minute),
134133
id=job_id,
135-
kwargs={"trigger": trigger, "trigger_tasks": trigger_tasks},
134+
kwargs={"trigger": trigger, "trigger_task": task},
136135
replace_existing=True,
137136
)
138137

139138

140-
def _deploy_interval(trigger: dict, trigger_tasks: list[dict], setting: dict, trigger_id: str, func) -> None:
139+
def _deploy_interval(trigger: dict, trigger_tasks: list[dict], setting: dict, trigger_id: str) -> None:
141140
from common.job import scheduler
142141

143142
unit = (setting.get("interval_unit") or "").strip()
@@ -158,10 +157,10 @@ def _deploy_interval(trigger: dict, trigger_tasks: list[dict], setting: dict, tr
158157
for task in trigger_tasks:
159158
job_id = f"trigger:{trigger_id}:task:{task['id']}:interval:{unit}:{value_i}"
160159
scheduler.add_job(
161-
func,
160+
ScheduledTrigger.execute,
162161
trigger="interval",
163162
id=job_id,
164-
kwargs={"trigger": trigger, "trigger_tasks": trigger_tasks},
163+
kwargs={"trigger": trigger, "trigger_task": task},
165164
replace_existing=True,
166165
**{unit: value_i},
167166
)
@@ -179,9 +178,9 @@ def _remove_trigger_jobs(trigger_id: str) -> None:
179178
maxkb_logger.warning(f"remove job failed, job_id={job.id}, err={e}")
180179

181180

182-
@celery_app.task(name='celery:deploy_scheduled_trigger')
183-
def deploy_scheduled_trigger(trigger: dict, trigger_tasks: list[dict], setting: dict, schedule_type: str, func) -> None:
184-
_remove_trigger_jobs(trigger['id'])
181+
@celery_app.task(name="celery:deploy_scheduled_trigger")
182+
def deploy_scheduled_trigger(trigger: dict, trigger_tasks: list[dict], setting: dict, schedule_type: str) -> None:
183+
_remove_trigger_jobs(trigger["id"])
185184

186185
deployers = {
187186
"daily": _deploy_daily,
@@ -194,7 +193,7 @@ def deploy_scheduled_trigger(trigger: dict, trigger_tasks: list[dict], setting:
194193
maxkb_logger.warning(f"unsupported schedule_type={schedule_type}, trigger_id={trigger['id']}")
195194
return
196195

197-
fn(trigger, trigger_tasks, setting, trigger['id'], func)
196+
fn(trigger, trigger_tasks, setting, trigger["id"])
198197

199198

200199
class ScheduledTrigger(BaseTrigger):
@@ -205,6 +204,23 @@ class ScheduledTrigger(BaseTrigger):
205204
@staticmethod
206205
def execute(trigger, **kwargs):
207206
n = random.randint(1, 1_000_000_000)
207+
trigger_task = kwargs.get("trigger_task")
208+
if not trigger_task:
209+
maxkb_logger.warning(f"unsupported task={trigger_task}")
210+
return
211+
source_type = trigger_task["source_type"]
212+
213+
if source_type == "APPLICATION":
214+
from trigger.handler.impl.task.application_task import ApplicationTask
215+
216+
ApplicationTask.execute(trigger_task, **kwargs)
217+
elif source_type == "TOOL":
218+
from trigger.handler.impl.task.tool_task import ToolTask
219+
220+
ToolTask.execute(trigger_task, **kwargs)
221+
else:
222+
maxkb_logger.warning(f"unsupported source_type={source_type}, task_id={trigger_task['id']}")
223+
return
208224

209225
maxkb_logger.info(f"scheduled trigger execute, trigger={n}")
210226

@@ -233,8 +249,7 @@ def deploy(self, trigger, **kwargs):
233249

234250
try:
235251
maxkb_logger.debug(f"get lock {lock_key}")
236-
deploy_scheduled_trigger.delay(trigger, trigger_tasks, setting, schedule_type, self.execute)
237-
252+
deploy_scheduled_trigger.delay(trigger, trigger_tasks, setting, schedule_type)
238253
finally:
239254
rlock.un_lock(lock_key)
240255

apps/trigger/tasks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# apps/trigger/tasks.py
2+
# coding=utf-8
3+
from __future__ import annotations
4+
5+
# 作为 Celery autodiscover 的入口,确保任务模块被导入从而完成注册
6+
from trigger.handler.impl.trigger.scheduled_trigger import deploy_scheduled_trigger # noqa: F401

0 commit comments

Comments
 (0)