Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion cli/decompose/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class DecompVersion(StrEnum):
latest = "latest"
v1 = "v1"
v2 = "v2"
# v3 = "v3"
v3 = "v3"


this_file_dir = Path(__file__).resolve().parent
Expand Down Expand Up @@ -226,6 +226,15 @@ def run(
case_sensitive=False,
),
] = LogMode.demo,
enable_script_run: Annotated[
bool,
typer.Option(
help=(
"When true, generated scripts expose argparse runtime options "
"for backend, model, endpoint, and API key overrides."
)
),
] = False,
) -> None:
"""Runs the ``m decompose`` CLI workflow and writes generated outputs.

Expand Down Expand Up @@ -253,6 +262,8 @@ def run(
prompts and programs. Each name must be a valid non-keyword Python
identifier.
log_mode: Logging verbosity for CLI and pipeline execution.
enable_script_run: Whether generated scripts should expose argparse
runtime options. Defaults to ``False``.

Raises:
AssertionError: If ``out_name`` is invalid, ``out_dir`` does not name an
Expand All @@ -277,6 +288,7 @@ def run(
logger.info("model_id : %s", model_id)
logger.info("version : %s", version.value)
logger.info("log_mode : %s", log_mode.value)
logger.info("script options : %s", enable_script_run)
logger.info("input_vars : %s", input_var or "[]")

environment = Environment(
Expand Down Expand Up @@ -393,6 +405,11 @@ def run(
subtasks=decomp_data["subtasks"],
user_inputs=input_var,
identified_constraints=decomp_data["identified_constraints"],
model_id=model_id,
backend=backend.value,
backend_endpoint=backend_endpoint,
backend_api_key=backend_api_key,
enable_script_run=enable_script_run,
)
+ "\n"
)
Expand Down
100 changes: 100 additions & 0 deletions cli/decompose/m_decomp_result_v3.py.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
{% if user_inputs -%}
import os
{% endif -%}
import textwrap

import mellea

{%- set ns = namespace(need_req=false) -%}
{%- for item in subtasks -%}
{%- for c in item.constraints or [] -%}
{%- if c.val_fn -%}
{%- set ns.need_req = true -%}
{%- endif -%}
{%- endfor -%}
{%- endfor %}

{%- if ns.need_req %}
from mellea.stdlib.requirements import req
{%- for c in identified_constraints %}
{%- if c.val_fn and c.val_fn_name %}
from validations.{{ c.val_fn_name }} import validate_input as {{ c.val_fn_name }}
{%- endif %}
{%- endfor %}
{%- endif %}

m = mellea.start_session(model_id="mistral-small3.2:latest")
{%- if user_inputs %}


# User Input Variables
try:
{%- for var in user_inputs %}
{{ var | lower }} = os.environ["{{ var | upper }}"]
{%- endfor %}
except KeyError as e:
raise SystemExit(f"ERROR: One or more required environment variables are not set: {e}")
{%- endif %}
{%- for item in subtasks %}


{{ item.tag | lower }}_gnrl = textwrap.dedent(
R"""
{{ item.general_instructions | trim | indent(width=4, first=False) }}
""".strip()
)
{{ item.tag | lower }} = m.instruct(
{%- if not (item.input_vars_required or []) %}
{{ item.subtask[3:] | trim | tojson }},
{%- else %}
textwrap.dedent(
R"""
{{ item.subtask[3:] | trim }}

Here are the input variables and their content:
{%- for var in item.input_vars_required or [] %}

- {{ var | upper }} = {{ "{{" }}{{ var | upper }}{{ "}}" }}
{%- endfor %}
""".strip()
),
{%- endif %}
{%- if item.constraints %}
requirements=[
{%- for c in item.constraints %}
{%- if c.val_fn and c.val_fn_name %}
req(
{{ c.constraint | tojson}},
validation_fn={{ c.val_fn_name }},
),
{%- else %}
{{ c.constraint | tojson}},
{%- endif %}
{%- endfor %}
],
{%- else %}
requirements=None,
{%- endif %}
{%- if item.input_vars_required %}
user_variables={
{%- for var in item.input_vars_required or [] %}
{{ var | upper | tojson }}: {{ var | lower }},
{%- endfor %}
},
{%- endif %}
grounding_context={
"GENERAL_INSTRUCTIONS": {{ item.tag | lower }}_gnrl,
{%- for var in item.depends_on or [] %}
{{ var | upper | tojson }}: {{ var | lower }}.value,
{%- endfor %}
},
)
assert {{ item.tag | lower }}.value is not None, 'ERROR: task "{{ item.tag | lower }}" execution failed'
{%- if loop.last %}


final_answer = {{ item.tag | lower }}.value

print(final_answer)
{%- endif -%}
{%- endfor -%}
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,41 @@
T = TypeVar("T")

RE_GENERAL_INSTRUCTIONS = re.compile(
r"<general_instructions>(.+?)</general_instructions>",
r"<general_instructions>(.*?)</general_instructions>",
flags=re.IGNORECASE | re.DOTALL,
)

RE_GENERAL_INSTRUCTIONS_OPEN = re.compile(
r"<general_instructions>(.*)", flags=re.IGNORECASE | re.DOTALL
)

RE_FINAL_SENTENCE = re.compile(
r"\n*All tags are closed and my assignment is finished\.\s*$", flags=re.IGNORECASE
)


@final
class _GeneralInstructions(PromptModule):
@staticmethod
def _default_parser(generated_str: str) -> str:
general_instructions_match = re.search(RE_GENERAL_INSTRUCTIONS, generated_str)

general_instructions_str: str | None = (
general_instructions_match.group(1).strip()
if general_instructions_match
else None
)

if general_instructions_str is None:
raise TagExtractionError(
'LLM failed to generate correct tags for extraction: "<general_instructions>"'
if general_instructions_match:
general_instructions_str = general_instructions_match.group(1).strip()
else:
# fallback: opening tag only (in case the closing tag is missing)
general_instructions_match = re.search(
RE_GENERAL_INSTRUCTIONS_OPEN, generated_str
)
if not general_instructions_match:
raise TagExtractionError(
'LLM failed to generate correct tags for extraction: "<general_instructions>"'
)
general_instructions_str = general_instructions_match.group(1).strip()

general_instructions_str = re.sub(
RE_FINAL_SENTENCE, "", general_instructions_str
).strip()

return general_instructions_str

Expand All @@ -50,20 +64,19 @@ def generate(

system_prompt = get_system_prompt()
user_prompt = get_user_prompt(task_prompt=input_str)

action = Message("user", user_prompt)

model_options = {
ModelOption.SYSTEM_PROMPT: system_prompt,
ModelOption.TEMPERATURE: 0,
ModelOption.MAX_NEW_TOKENS: max_new_tokens,
}

try:
gen_result = mellea_session.act(
action=action,
model_options={
ModelOption.SYSTEM_PROMPT: system_prompt,
ModelOption.TEMPERATURE: 0,
ModelOption.MAX_NEW_TOKENS: max_new_tokens,
},
).value
response = mellea_session.act(action=action, model_options=model_options)
gen_result = response.value
except Exception as e:
raise BackendGenerationError(f"LLM generation failed: {e}")
raise BackendGenerationError(f"LLM generation failed: {e}") from e

if gen_result is None:
raise BackendGenerationError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Do not write anything between </general_instructions> and the final sentence exc
Here are some complete examples to guide you on how to complete your assignment:

{% for item in icl_examples -%}
<example>
<example_{{ loop.index }}>
<task_prompt>
{{ item["task_prompt"] }}
</task_prompt>
Expand All @@ -22,7 +22,7 @@ Here are some complete examples to guide you on how to complete your assignment:
</general_instructions>

All tags are closed and my assignment is finished.
</example>
</example_{{ loop.index }}>

{% endfor -%}
That concludes the complete examples of your assignment.
Expand Down
Loading
Loading