diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index e522ce5453..4d0e4055c9 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -751,7 +751,7 @@ async def _process_quote_message( img_cap_prov_id: str, plugin_context: Context, quoted_message_settings: QuotedMessageParserSettings = DEFAULT_QUOTED_MESSAGE_SETTINGS, - config: MainAgentBuildConfig | None = None, + cfg: dict | None = None, ) -> None: quote = None for comp in event.message_obj.message: @@ -795,12 +795,19 @@ async def _process_quote_message( path = await image_seg.convert_to_file_path() compress_path = await _compress_image_for_provider( path, - config.provider_settings if config else None, + cfg, ) if path and _is_generated_compressed_image_path(path, compress_path): event.track_temporary_local_file(compress_path) + if cfg is None: + cfg = plugin_context.get_config(umo=event.unified_msg_origin).get( + "provider_settings", {} + ) + img_cap_prompt = ( + cfg.get("image_caption_prompt") or "Please describe the image." + ) llm_resp = await prov.text_chat( - prompt="Please describe the image content.", + prompt=img_cap_prompt, image_urls=[compress_path], ) if llm_resp.completion_text: @@ -904,7 +911,7 @@ async def _decorate_llm_request( img_cap_prov_id, plugin_context, quoted_message_settings, - config, + cfg, ) tz = config.timezone @@ -1269,9 +1276,11 @@ async def build_main_agent( reply_comps = [ comp for comp in event.message_obj.message if isinstance(comp, Reply) ] - quoted_message_settings = _get_quoted_message_parser_settings( - config.provider_settings - ) + cfg = config.provider_settings or plugin_context.get_config( + umo=event.unified_msg_origin + ).get("provider_settings", {}) + quoted_message_settings = _get_quoted_message_parser_settings(cfg) + img_cap_prov_id = cfg.get("default_image_caption_provider_id") or "" fallback_quoted_image_count = 0 for comp in reply_comps: has_embedded_image = False @@ -1286,7 +1295,8 @@ async def build_main_agent( ) if _is_generated_compressed_image_path(path, image_path): event.track_temporary_local_file(image_path) - req.image_urls.append(image_path) + if not img_cap_prov_id: + req.image_urls.append(image_path) _append_quoted_image_attachment(req, image_path) elif isinstance(reply_comp, Record): audio_path = await reply_comp.convert_to_file_path() @@ -1341,7 +1351,8 @@ async def build_main_agent( for image_ref in fallback_images: if image_ref in req.image_urls: continue - req.image_urls.append(image_ref) + if not img_cap_prov_id: + req.image_urls.append(image_ref) fallback_quoted_image_count += 1 _append_quoted_image_attachment(req, image_ref) except Exception as exc: # noqa: BLE001 diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index fdd06d34a0..305a04f82b 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -1296,6 +1296,101 @@ async def test_build_main_agent_with_existing_request( assert result is not None assert result.provider_request == existing_req + @pytest.mark.asyncio + async def test_build_main_agent_with_quoted_image_no_caption_provider( + self, mock_event, mock_context, mock_provider + ): + """Test building main agent with quoted image when no image caption provider is set.""" + module = ama + mock_image = Image(file="file:///path/to/quoted.jpg") + mock_reply = Reply( + id="reply-1", + chain=[mock_image], + sender_nickname="", + message_str="quoted message", + ) + mock_event.message_obj.message = [Plain(text="Hello"), mock_reply] + + mock_context.get_provider_by_id.return_value = None + mock_context.get_using_provider.return_value = mock_provider + mock_context.get_config.return_value = { + "provider_settings": { + "default_image_caption_provider_id": "" + } + } + + conv_mgr = mock_context.conversation_manager + _setup_conversation_for_build(conv_mgr) + + with ( + patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, + patch("astrbot.core.astr_main_agent.AstrAgentContext"), + patch.object(Image, "convert_to_file_path", AsyncMock(return_value="/path/to/quoted.jpg")), + ): + mock_runner = MagicMock() + mock_runner.reset = AsyncMock() + mock_runner_cls.return_value = mock_runner + + result = await module.build_main_agent( + event=mock_event, + plugin_context=mock_context, + config=module.MainAgentBuildConfig( + tool_call_timeout=60, + provider_settings={"default_image_caption_provider_id": ""} + ), + ) + + assert result is not None + assert len(result.provider_request.image_urls) > 0 + assert result.provider_request.image_urls[0] == "/path/to/quoted.jpg" + + @pytest.mark.asyncio + async def test_build_main_agent_with_quoted_image_with_caption_provider( + self, mock_event, mock_context, mock_provider + ): + """Test building main agent with quoted image when image caption provider is set.""" + module = ama + mock_image = Image(file="file:///path/to/quoted.jpg") + mock_reply = Reply( + id="reply-1", + chain=[mock_image], + sender_nickname="", + message_str="quoted message", + ) + mock_event.message_obj.message = [Plain(text="Hello"), mock_reply] + + mock_context.get_provider_by_id.return_value = None + mock_context.get_using_provider.return_value = mock_provider + mock_context.get_config.return_value = { + "provider_settings": { + "default_image_caption_provider_id": "some_captioner" + } + } + + conv_mgr = mock_context.conversation_manager + _setup_conversation_for_build(conv_mgr) + + with ( + patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, + patch("astrbot.core.astr_main_agent.AstrAgentContext"), + patch.object(Image, "convert_to_file_path", AsyncMock(return_value="/path/to/quoted.jpg")), + ): + mock_runner = MagicMock() + mock_runner.reset = AsyncMock() + mock_runner_cls.return_value = mock_runner + + result = await module.build_main_agent( + event=mock_event, + plugin_context=mock_context, + config=module.MainAgentBuildConfig( + tool_call_timeout=60, + provider_settings={"default_image_caption_provider_id": "some_captioner"} + ), + ) + + assert result is not None + assert len(result.provider_request.image_urls) == 0 + class TestHandleWebchat: """Tests for _handle_webchat function."""