Skip to content

Commit 0897467

Browse files
committed
feat(handle_batch): bypass ramp_up/ramp_down for batches
1 parent 679c243 commit 0897467

2 files changed

Lines changed: 57 additions & 42 deletions

File tree

dags/google_api_helper.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,21 +223,27 @@ def ramp_up_cluster(key, initial_size, total_size):
223223
run_metadata = Variable.get("run_metadata", deserialize_json=True, default_var={})
224224
if not run_metadata.get("manage_clusters", True):
225225
return
226+
already_at_size = False
226227
try:
227228
target_sizes = Variable.get("cluster_target_size", deserialize_json=True)
229+
already_at_size = target_sizes.get(key, 0) >= total_size
228230
target_sizes[key] = total_size
229231
Variable.set("cluster_target_size", target_sizes, serialize_json=True)
230232
slack_message(":information_source: ramping up cluster {} to {} instances, starting from {} instances".format(key, total_size, min(initial_size, total_size)))
231233
increase_instance_group_size(key, min(initial_size, total_size))
232234
except:
233235
increase_instance_group_size(key, total_size)
234-
sleep(60)
236+
if not already_at_size:
237+
sleep(60)
235238
Variable.set("cluster_target_size", target_sizes, serialize_json=True)
236239

237240
def ramp_down_cluster(key, total_size):
238241
run_metadata = Variable.get("run_metadata", deserialize_json=True, default_var={})
239242
if not run_metadata.get("manage_clusters", True):
240243
return
244+
if Variable.get("batch_keep_cluster", default_var="false") == "true":
245+
slack_message(f":recycle: Batch mode: keeping {key} cluster alive for next job")
246+
return
241247
try:
242248
target_sizes = Variable.get("cluster_target_size", deserialize_json=True)
243249
target_sizes[key] = total_size

slackbot/pipeline_commands.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -207,53 +207,62 @@ def handle_batch(task, msg):
207207
replyto(msg, "Batch jobs will reuse on the parameters from the first job unless new parameters are specified, *including those with default values*")
208208

209209
default_param = json_obj[0]
210-
for i, p in enumerate(json_obj):
211-
if visible_messages(broker_url, "seuronbot_cmd") != 0:
212-
cmd = get_message(broker_url, "seuronbot_cmd")
213-
if cmd == "cancel":
214-
replyto(msg, "Cancel batch process")
215-
break
216-
217-
if p.get("INHERIT_PARAMETERS", True):
218-
param = deepcopy(default_param)
219-
else:
220-
param = {}
221-
222-
if i > 0:
223-
if 'NAME' in param:
224-
del param['NAME']
225-
for k in p:
226-
param[k] = p[k]
227-
supply_default_param(param)
228-
replyto(msg, "*Sanity check: batch job {} out of {}*".format(i+1, len(json_obj)))
210+
is_batch = len(json_obj) > 1
211+
try:
212+
for i, p in enumerate(json_obj):
213+
if visible_messages(broker_url, "seuronbot_cmd") != 0:
214+
cmd = get_message(broker_url, "seuronbot_cmd")
215+
if cmd == "cancel":
216+
replyto(msg, "Cancel batch process")
217+
break
218+
219+
if p.get("INHERIT_PARAMETERS", True):
220+
param = deepcopy(default_param)
221+
else:
222+
param = {}
223+
224+
if i > 0:
225+
if 'NAME' in param:
226+
del param['NAME']
227+
for k in p:
228+
param[k] = p[k]
229+
supply_default_param(param)
230+
replyto(msg, "*Sanity check: batch job {} out of {}*".format(i+1, len(json_obj)))
231+
state = "unknown"
232+
current_task = guess_run_type(param)
233+
if current_task == "seg_run":
234+
set_variable('param', param, serialize_json=True)
235+
state = run_dag("sanity_check", wait_for_completion=True).state
236+
elif current_task == "inf_run":
237+
set_variable('inference_param', param, serialize_json=True)
238+
state = run_dag("chunkflow_generator", wait_for_completion=True).state
239+
elif current_task == "syn_run":
240+
set_variable("synaptor_param.json", param, serialize_json=True)
241+
state = run_dag("synaptor_sanity_check", wait_for_completion=True).state
242+
243+
if state != "success":
244+
replyto(msg, "*Sanity check failed, abort!*")
245+
break
246+
247+
is_last_job = (i == len(json_obj) - 1)
229248
state = "unknown"
230-
current_task = guess_run_type(param)
249+
replyto(msg, "*Starting batch job {} out of {}*".format(i+1, len(json_obj)), broadcast=True)
250+
231251
if current_task == "seg_run":
232-
set_variable('param', param, serialize_json=True)
233-
state = run_dag("sanity_check", wait_for_completion=True).state
252+
state = run_dag('segmentation', wait_for_completion=True).state
234253
elif current_task == "inf_run":
235-
set_variable('inference_param', param, serialize_json=True)
236-
state = run_dag("chunkflow_generator", wait_for_completion=True).state
254+
if is_batch and not is_last_job:
255+
set_variable("batch_keep_cluster", "true")
256+
else:
257+
set_variable("batch_keep_cluster", "false")
258+
state = run_dag("chunkflow_worker", wait_for_completion=True).state
237259
elif current_task == "syn_run":
238-
set_variable("synaptor_param.json", param, serialize_json=True)
239-
state = run_dag("synaptor_sanity_check", wait_for_completion=True).state
260+
state = run_dag("synaptor", wait_for_completion=True).state
240261

241262
if state != "success":
242-
replyto(msg, "*Sanity check failed, abort!*")
263+
replyto(msg, f"*Bach job failed, abort!* ({state})")
243264
break
244-
245-
state = "unknown"
246-
replyto(msg, "*Starting batch job {} out of {}*".format(i+1, len(json_obj)), broadcast=True)
247-
248-
if current_task == "seg_run":
249-
state = run_dag('segmentation', wait_for_completion=True).state
250-
elif current_task == "inf_run":
251-
state = run_dag("chunkflow_worker", wait_for_completion=True).state
252-
elif current_task == "syn_run":
253-
state = run_dag("synaptor", wait_for_completion=True).state
254-
255-
if state != "success":
256-
replyto(msg, f"*Bach job failed, abort!* ({state})")
257-
break
265+
finally:
266+
set_variable("batch_keep_cluster", "false")
258267

259268
replyto(msg, "*Batch process finished*")

0 commit comments

Comments
 (0)