Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ async def _process_quote_message(
img_cap_prov_id: str,
plugin_context: Context,
quoted_message_settings: QuotedMessageParserSettings = DEFAULT_QUOTED_MESSAGE_SETTINGS,
config: MainAgentBuildConfig | None = None,
cfg: dict | None = None,
) -> None:
quote = None
for comp in event.message_obj.message:
Expand Down Expand Up @@ -795,12 +795,19 @@ async def _process_quote_message(
path = await image_seg.convert_to_file_path()
compress_path = await _compress_image_for_provider(
path,
config.provider_settings if config else None,
cfg,
)
if path and _is_generated_compressed_image_path(path, compress_path):
event.track_temporary_local_file(compress_path)
if cfg is None:
cfg = plugin_context.get_config(umo=event.unified_msg_origin).get(
"provider_settings", {}
)
img_cap_prompt = (
cfg.get("image_caption_prompt") or "Please describe the image."
)
llm_resp = await prov.text_chat(
prompt="Please describe the image content.",
prompt=img_cap_prompt,
image_urls=[compress_path],
)
if llm_resp.completion_text:
Expand Down Expand Up @@ -904,7 +911,7 @@ async def _decorate_llm_request(
img_cap_prov_id,
plugin_context,
quoted_message_settings,
config,
cfg,
)

tz = config.timezone
Expand Down Expand Up @@ -1269,9 +1276,11 @@ async def build_main_agent(
reply_comps = [
comp for comp in event.message_obj.message if isinstance(comp, Reply)
]
quoted_message_settings = _get_quoted_message_parser_settings(
config.provider_settings
)
cfg = config.provider_settings or plugin_context.get_config(
umo=event.unified_msg_origin
).get("provider_settings", {})
quoted_message_settings = _get_quoted_message_parser_settings(cfg)
img_cap_prov_id = cfg.get("default_image_caption_provider_id") or ""
fallback_quoted_image_count = 0
for comp in reply_comps:
has_embedded_image = False
Expand All @@ -1286,7 +1295,8 @@ async def build_main_agent(
)
if _is_generated_compressed_image_path(path, image_path):
event.track_temporary_local_file(image_path)
req.image_urls.append(image_path)
if not img_cap_prov_id:
req.image_urls.append(image_path)
_append_quoted_image_attachment(req, image_path)
elif isinstance(reply_comp, Record):
audio_path = await reply_comp.convert_to_file_path()
Expand Down Expand Up @@ -1341,7 +1351,8 @@ async def build_main_agent(
for image_ref in fallback_images:
if image_ref in req.image_urls:
continue
req.image_urls.append(image_ref)
if not img_cap_prov_id:
req.image_urls.append(image_ref)
fallback_quoted_image_count += 1
_append_quoted_image_attachment(req, image_ref)
except Exception as exc: # noqa: BLE001
Expand Down
95 changes: 95 additions & 0 deletions tests/unit/test_astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,101 @@ async def test_build_main_agent_with_existing_request(
assert result is not None
assert result.provider_request == existing_req

@pytest.mark.asyncio
async def test_build_main_agent_with_quoted_image_no_caption_provider(
self, mock_event, mock_context, mock_provider
):
"""Test building main agent with quoted image when no image caption provider is set."""
module = ama
mock_image = Image(file="file:///path/to/quoted.jpg")
mock_reply = Reply(
id="reply-1",
chain=[mock_image],
sender_nickname="",
message_str="quoted message",
)
mock_event.message_obj.message = [Plain(text="Hello"), mock_reply]

mock_context.get_provider_by_id.return_value = None
mock_context.get_using_provider.return_value = mock_provider
mock_context.get_config.return_value = {
"provider_settings": {
"default_image_caption_provider_id": ""
}
}

conv_mgr = mock_context.conversation_manager
_setup_conversation_for_build(conv_mgr)

with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
patch.object(Image, "convert_to_file_path", AsyncMock(return_value="/path/to/quoted.jpg")),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner

result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(
tool_call_timeout=60,
provider_settings={"default_image_caption_provider_id": ""}
),
)

assert result is not None
assert len(result.provider_request.image_urls) > 0
assert result.provider_request.image_urls[0] == "/path/to/quoted.jpg"

@pytest.mark.asyncio
async def test_build_main_agent_with_quoted_image_with_caption_provider(
self, mock_event, mock_context, mock_provider
):
"""Test building main agent with quoted image when image caption provider is set."""
module = ama
mock_image = Image(file="file:///path/to/quoted.jpg")
mock_reply = Reply(
id="reply-1",
chain=[mock_image],
sender_nickname="",
message_str="quoted message",
)
mock_event.message_obj.message = [Plain(text="Hello"), mock_reply]

mock_context.get_provider_by_id.return_value = None
mock_context.get_using_provider.return_value = mock_provider
mock_context.get_config.return_value = {
"provider_settings": {
"default_image_caption_provider_id": "some_captioner"
}
}

conv_mgr = mock_context.conversation_manager
_setup_conversation_for_build(conv_mgr)

with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
patch.object(Image, "convert_to_file_path", AsyncMock(return_value="/path/to/quoted.jpg")),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner

result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(
tool_call_timeout=60,
provider_settings={"default_image_caption_provider_id": "some_captioner"}
),
)

assert result is not None
assert len(result.provider_request.image_urls) == 0


class TestHandleWebchat:
"""Tests for _handle_webchat function."""
Expand Down