From 416ecf63705b3ffedc18f3aff7111b4401ee114b Mon Sep 17 00:00:00 2001 From: idiotsj Date: Thu, 26 Mar 2026 18:16:28 +0800 Subject: [PATCH 1/9] fix(openai): preserve image mime in data uri encoding (#6991) --- astrbot/core/provider/entities.py | 11 +-- .../core/provider/sources/openai_source.py | 10 +-- astrbot/core/utils/io.py | 43 +++++++++ tests/test_io_image_data_uri.py | 49 +++++++++++ tests/test_openai_source.py | 88 +++++++++++++++++++ tests/unit/test_astr_main_agent.py | 32 +++++++ 6 files changed, 218 insertions(+), 15 deletions(-) create mode 100644 tests/test_io_image_data_uri.py diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 20c5a7947d..a5ec6b5ea9 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -1,6 +1,5 @@ from __future__ import annotations -import base64 import enum import json from dataclasses import dataclass, field @@ -21,7 +20,7 @@ from astrbot.core.agent.tool import ToolSet from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.utils.io import download_image_by_url +from astrbot.core.utils.io import download_image_by_url, image_source_to_data_uri class ProviderType(enum.Enum): @@ -216,12 +215,8 @@ async def assemble_context(self) -> dict: async def _encode_image_bs64(self, image_url: str) -> str: """将图片转换为 base64""" - if image_url.startswith("base64://"): - return image_url.replace("base64://", "data:image/jpeg;base64,") - with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 - return "" + data_uri, _ = image_source_to_data_uri(image_url) + return data_uri @dataclass diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 68fad067b0..bc9707f918 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -1,5 +1,4 @@ import asyncio -import base64 import inspect import json import random @@ -22,7 +21,7 @@ from astrbot.core.agent.tool import ToolSet from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult -from astrbot.core.utils.io import download_image_by_url +from astrbot.core.utils.io import download_image_by_url, image_source_to_data_uri from astrbot.core.utils.network_utils import ( create_proxy_client, is_connection_error, @@ -987,11 +986,8 @@ async def resolve_image_part(image_url: str) -> dict | None: async def encode_image_bs64(self, image_url: str) -> str: """将图片转换为 base64""" - if image_url.startswith("base64://"): - return image_url.replace("base64://", "data:image/jpeg;base64,") - with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 + data_uri, _ = image_source_to_data_uri(image_url) + return data_uri async def terminate(self): if self.client: diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index b565926749..f0f13a5066 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -206,6 +206,49 @@ def file_to_base64(file_path: str) -> str: return "base64://" + base64_str +def detect_image_mime_type(data: bytes) -> str: + """根据图片二进制数据的 magic bytes 检测 MIME 类型。""" + if data[:8] == b"\x89PNG\r\n\x1a\n": + return "image/png" + if data[:2] == b"\xff\xd8": + return "image/jpeg" + if data[:6] in (b"GIF87a", b"GIF89a"): + return "image/gif" + if data[:4] == b"RIFF" and data[8:12] == b"WEBP": + return "image/webp" + return "image/jpeg" + + +def image_source_to_data_uri(image_source: str) -> tuple[str, str]: + """将图片来源统一转换为 data URI,并尽量保留真实 MIME 类型。""" + lower_source = image_source.lower() + + if lower_source.startswith("data:image/"): + mime_type = "image/jpeg" + prefix = image_source.split(",", 1)[0] + if prefix.startswith("data:"): + mime_type = prefix.split(";", 1)[0].removeprefix("data:") + if not mime_type.startswith("image/"): + mime_type = "image/jpeg" + return image_source, mime_type + + if image_source.startswith("base64://"): + raw_base64 = image_source.removeprefix("base64://") + mime_type = "image/jpeg" + try: + image_bytes = base64.b64decode(raw_base64) + mime_type = detect_image_mime_type(image_bytes) + except Exception: + mime_type = "image/jpeg" + return f"data:{mime_type};base64,{raw_base64}", mime_type + + with open(image_source, "rb") as f: + image_bytes = f.read() + mime_type = detect_image_mime_type(image_bytes) + image_bs64 = base64.b64encode(image_bytes).decode("utf-8") + return f"data:{mime_type};base64,{image_bs64}", mime_type + + def get_local_ip_addresses(): net_interfaces = psutil.net_if_addrs() network_ips = [] diff --git a/tests/test_io_image_data_uri.py b/tests/test_io_image_data_uri.py new file mode 100644 index 0000000000..762d421930 --- /dev/null +++ b/tests/test_io_image_data_uri.py @@ -0,0 +1,49 @@ +import base64 +from pathlib import Path + +from astrbot.core.utils.io import detect_image_mime_type, image_source_to_data_uri + +PNG_BYTES = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR" +GIF_BYTES = b"GIF89a\x01\x00\x01\x00\x80\x00\x00" +WEBP_BYTES = b"RIFF\x0c\x00\x00\x00WEBPVP8 " +JPEG_BYTES = b"\xff\xd8\xff\xe0\x00\x10JFIF" + + +def test_detect_image_mime_type_known_formats(): + assert detect_image_mime_type(PNG_BYTES) == "image/png" + assert detect_image_mime_type(JPEG_BYTES) == "image/jpeg" + assert detect_image_mime_type(GIF_BYTES) == "image/gif" + assert detect_image_mime_type(WEBP_BYTES) == "image/webp" + + +def test_detect_image_mime_type_unknown_fallback_jpeg(): + assert detect_image_mime_type(b"not-an-image") == "image/jpeg" + + +def test_image_source_to_data_uri_passthrough_data_uri(): + data_uri = f"data:image/png;base64,{base64.b64encode(PNG_BYTES).decode('utf-8')}" + encoded, mime_type = image_source_to_data_uri(data_uri) + assert encoded == data_uri + assert mime_type == "image/png" + + +def test_image_source_to_data_uri_detects_base64_mime(): + raw = base64.b64encode(GIF_BYTES).decode("utf-8") + encoded, mime_type = image_source_to_data_uri(f"base64://{raw}") + assert encoded.startswith("data:image/gif;base64,") + assert mime_type == "image/gif" + + +def test_image_source_to_data_uri_invalid_base64_fallback_jpeg(): + encoded, mime_type = image_source_to_data_uri("base64://not-valid-base64") + assert encoded == "data:image/jpeg;base64,not-valid-base64" + assert mime_type == "image/jpeg" + + +def test_image_source_to_data_uri_detects_local_file_mime(tmp_path: Path): + webp_path = tmp_path / "image.webp" + webp_path.write_bytes(WEBP_BYTES) + + encoded, mime_type = image_source_to_data_uri(str(webp_path)) + assert encoded.startswith("data:image/webp;base64,") + assert mime_type == "image/webp" diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 0040f0be62..60ebfe32fc 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -1,8 +1,11 @@ +import base64 +from pathlib import Path from types import SimpleNamespace import pytest from openai.types.chat.chat_completion import ChatCompletion +from astrbot.core.agent.message import ImageURLPart from astrbot.core.provider.sources.groq_source import ProviderGroq from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial @@ -19,6 +22,14 @@ def __init__(self, message: str, response_text: str): self.response = SimpleNamespace(text=response_text) +PNG_BYTES = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR" +GIF_BYTES = b"GIF89a\x01\x00\x01\x00\x80\x00\x00" +WEBP_BYTES = b"RIFF\x0c\x00\x00\x00WEBPVP8 " +PNG_BASE64 = base64.b64encode(PNG_BYTES).decode("utf-8") +GIF_BASE64 = base64.b64encode(GIF_BYTES).decode("utf-8") +WEBP_BASE64 = base64.b64encode(WEBP_BYTES).decode("utf-8") + + def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial: provider_config = { "id": "test-openai", @@ -533,3 +544,80 @@ async def fake_create(**kwargs): assert extra_body["temperature"] == 0.1 finally: await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_encode_image_bs64_detects_base64_mime(): + provider = _make_provider() + try: + png_data = await provider.encode_image_bs64(f"base64://{PNG_BASE64}") + gif_data = await provider.encode_image_bs64(f"base64://{GIF_BASE64}") + webp_data = await provider.encode_image_bs64(f"base64://{WEBP_BASE64}") + + assert png_data.startswith("data:image/png;base64,") + assert gif_data.startswith("data:image/gif;base64,") + assert webp_data.startswith("data:image/webp;base64,") + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_encode_image_bs64_detects_local_file_mime(tmp_path: Path): + provider = _make_provider() + png_path = tmp_path / "pixel.png" + webp_path = tmp_path / "pixel.webp" + png_path.write_bytes(PNG_BYTES) + webp_path.write_bytes(WEBP_BYTES) + try: + png_data = await provider.encode_image_bs64(str(png_path)) + webp_data = await provider.encode_image_bs64(str(webp_path)) + + assert png_data.startswith("data:image/png;base64,") + assert webp_data.startswith("data:image/webp;base64,") + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_encode_image_bs64_keeps_data_uri(): + provider = _make_provider() + data_uri = f"data:image/png;base64,{PNG_BASE64}" + try: + assert await provider.encode_image_bs64(data_uri) == data_uri + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_encode_image_bs64_invalid_base64_fallback_to_jpeg(): + provider = _make_provider() + try: + image_data = await provider.encode_image_bs64("base64://not-valid-base64") + assert image_data == "data:image/jpeg;base64,not-valid-base64" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_assemble_context_extra_image_file_uri_mime(tmp_path: Path): + provider = _make_provider() + png_path = tmp_path / "agent-request.png" + png_path.write_bytes(PNG_BYTES) + try: + assembled = await provider.assemble_context( + text="hello", + extra_user_content_parts=[ + ImageURLPart( + image_url=ImageURLPart.ImageURL( + url=f"file:///{png_path.as_posix()}", + ) + ) + ], + ) + + assert isinstance(assembled["content"], list) + image_part = assembled["content"][1] + assert image_part["type"] == "image_url" + assert image_part["image_url"]["url"].startswith("data:image/png;base64,") + finally: + await provider.terminate() diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 9a42abd733..fb8cdbaf51 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -1,5 +1,6 @@ """Tests for astr_main_agent module.""" +import base64 import os from unittest.mock import AsyncMock, MagicMock, patch @@ -721,6 +722,37 @@ def test_modalities_fix_all_supported(self, mock_provider): assert req.func_tool is not None +class TestProviderRequestAssembleContextImage: + @pytest.mark.asyncio + async def test_provider_request_assemble_context_image_mime_local_and_file_uri( + self, tmp_path + ): + png_path = tmp_path / "request.png" + webp_path = tmp_path / "request.webp" + png_path.write_bytes(base64.b64decode("iVBORw0KGgo=")) + webp_path.write_bytes(b"RIFF\x0c\x00\x00\x00WEBPVP8 ") + + req = ProviderRequest( + prompt="Hello", + image_urls=[ + str(png_path), + f"file:///{webp_path.as_posix()}", + ], + ) + + assembled = await req.assemble_context() + assert isinstance(assembled["content"], list) + + image_urls = [ + part["image_url"]["url"] + for part in assembled["content"] + if part.get("type") == "image_url" + ] + assert len(image_urls) == 2 + assert image_urls[0].startswith("data:image/png;base64,") + assert image_urls[1].startswith("data:image/webp;base64,") + + class TestSanitizeContextByModalities: """Tests for _sanitize_context_by_modalities function.""" From 9536cab714a3086b85822996cecada8b5ed9794e Mon Sep 17 00:00:00 2001 From: idiotsj Date: Thu, 26 Mar 2026 18:38:37 +0800 Subject: [PATCH 2/9] refactor(image): normalize file uri handling and test fixtures --- astrbot/core/provider/entities.py | 3 -- .../core/provider/sources/openai_source.py | 3 -- astrbot/core/utils/io.py | 10 +++++++ tests/fixtures/image_samples.py | 10 +++++++ tests/test_io_image_data_uri.py | 28 +++++++++++++------ tests/test_openai_source.py | 16 +++++------ tests/unit/test_astr_main_agent.py | 11 +++++--- 7 files changed, 53 insertions(+), 28 deletions(-) create mode 100644 tests/fixtures/image_samples.py diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index a5ec6b5ea9..281abf3e97 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -189,9 +189,6 @@ async def assemble_context(self) -> dict: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) image_data = await self._encode_image_bs64(image_path) - elif image_url.startswith("file:///"): - image_path = image_url.replace("file:///", "") - image_data = await self._encode_image_bs64(image_path) else: image_data = await self._encode_image_bs64(image_url) if not image_data: diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index bc9707f918..ffbf4d8a30 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -926,9 +926,6 @@ async def resolve_image_part(image_url: str) -> dict | None: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) image_data = await self.encode_image_bs64(image_path) - elif image_url.startswith("file:///"): - image_path = image_url.replace("file:///", "") - image_data = await self.encode_image_bs64(image_path) else: image_data = await self.encode_image_bs64(image_url) if not image_data: diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index f0f13a5066..3dddbc78a7 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -8,6 +8,8 @@ import uuid import zipfile from pathlib import Path +from urllib.parse import unquote, urlsplit +from urllib.request import url2pathname import aiohttp import certifi @@ -242,6 +244,14 @@ def image_source_to_data_uri(image_source: str) -> tuple[str, str]: mime_type = "image/jpeg" return f"data:{mime_type};base64,{raw_base64}", mime_type + if lower_source.startswith("file://"): + parsed = urlsplit(image_source) + if parsed.netloc and parsed.netloc != "localhost": + raw_path = f"//{parsed.netloc}{parsed.path}" + else: + raw_path = parsed.path + image_source = url2pathname(unquote(raw_path)) + with open(image_source, "rb") as f: image_bytes = f.read() mime_type = detect_image_mime_type(image_bytes) diff --git a/tests/fixtures/image_samples.py b/tests/fixtures/image_samples.py new file mode 100644 index 0000000000..4453c28717 --- /dev/null +++ b/tests/fixtures/image_samples.py @@ -0,0 +1,10 @@ +import base64 + +PNG_BYTES = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR" +GIF_BYTES = b"GIF89a\x01\x00\x01\x00\x80\x00\x00" +WEBP_BYTES = b"RIFF\x0c\x00\x00\x00WEBPVP8 " +JPEG_BYTES = b"\xff\xd8\xff\xe0\x00\x10JFIF" + +PNG_BASE64 = base64.b64encode(PNG_BYTES).decode("ascii") +GIF_BASE64 = base64.b64encode(GIF_BYTES).decode("ascii") +WEBP_BASE64 = base64.b64encode(WEBP_BYTES).decode("ascii") diff --git a/tests/test_io_image_data_uri.py b/tests/test_io_image_data_uri.py index 762d421930..e2c0360048 100644 --- a/tests/test_io_image_data_uri.py +++ b/tests/test_io_image_data_uri.py @@ -1,12 +1,14 @@ -import base64 from pathlib import Path from astrbot.core.utils.io import detect_image_mime_type, image_source_to_data_uri - -PNG_BYTES = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR" -GIF_BYTES = b"GIF89a\x01\x00\x01\x00\x80\x00\x00" -WEBP_BYTES = b"RIFF\x0c\x00\x00\x00WEBPVP8 " -JPEG_BYTES = b"\xff\xd8\xff\xe0\x00\x10JFIF" +from tests.fixtures.image_samples import ( + GIF_BASE64, + GIF_BYTES, + JPEG_BYTES, + PNG_BASE64, + PNG_BYTES, + WEBP_BYTES, +) def test_detect_image_mime_type_known_formats(): @@ -21,15 +23,14 @@ def test_detect_image_mime_type_unknown_fallback_jpeg(): def test_image_source_to_data_uri_passthrough_data_uri(): - data_uri = f"data:image/png;base64,{base64.b64encode(PNG_BYTES).decode('utf-8')}" + data_uri = f"data:image/png;base64,{PNG_BASE64}" encoded, mime_type = image_source_to_data_uri(data_uri) assert encoded == data_uri assert mime_type == "image/png" def test_image_source_to_data_uri_detects_base64_mime(): - raw = base64.b64encode(GIF_BYTES).decode("utf-8") - encoded, mime_type = image_source_to_data_uri(f"base64://{raw}") + encoded, mime_type = image_source_to_data_uri(f"base64://{GIF_BASE64}") assert encoded.startswith("data:image/gif;base64,") assert mime_type == "image/gif" @@ -47,3 +48,12 @@ def test_image_source_to_data_uri_detects_local_file_mime(tmp_path: Path): encoded, mime_type = image_source_to_data_uri(str(webp_path)) assert encoded.startswith("data:image/webp;base64,") assert mime_type == "image/webp" + + +def test_image_source_to_data_uri_detects_file_uri_mime(tmp_path: Path): + png_path = tmp_path / "uri-image.png" + png_path.write_bytes(PNG_BYTES) + + encoded, mime_type = image_source_to_data_uri(f"file:///{png_path.as_posix()}") + assert encoded == f"data:image/png;base64,{PNG_BASE64}" + assert mime_type == "image/png" diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 60ebfe32fc..23693a7acb 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -1,4 +1,3 @@ -import base64 from pathlib import Path from types import SimpleNamespace @@ -8,6 +7,13 @@ from astrbot.core.agent.message import ImageURLPart from astrbot.core.provider.sources.groq_source import ProviderGroq from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial +from tests.fixtures.image_samples import ( + GIF_BASE64, + PNG_BASE64, + PNG_BYTES, + WEBP_BASE64, + WEBP_BYTES, +) class _ErrorWithBody(Exception): @@ -22,14 +28,6 @@ def __init__(self, message: str, response_text: str): self.response = SimpleNamespace(text=response_text) -PNG_BYTES = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR" -GIF_BYTES = b"GIF89a\x01\x00\x01\x00\x80\x00\x00" -WEBP_BYTES = b"RIFF\x0c\x00\x00\x00WEBPVP8 " -PNG_BASE64 = base64.b64encode(PNG_BYTES).decode("utf-8") -GIF_BASE64 = base64.b64encode(GIF_BYTES).decode("utf-8") -WEBP_BASE64 = base64.b64encode(WEBP_BYTES).decode("utf-8") - - def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial: provider_config = { "id": "test-openai", diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index fb8cdbaf51..5eeb2f8536 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -1,6 +1,5 @@ """Tests for astr_main_agent module.""" -import base64 import os from unittest.mock import AsyncMock, MagicMock, patch @@ -15,6 +14,7 @@ from astrbot.core.platform.platform_metadata import PlatformMetadata from astrbot.core.provider import Provider from astrbot.core.provider.entities import ProviderRequest +from tests.fixtures.image_samples import PNG_BASE64, PNG_BYTES, WEBP_BYTES @pytest.fixture @@ -729,14 +729,16 @@ async def test_provider_request_assemble_context_image_mime_local_and_file_uri( ): png_path = tmp_path / "request.png" webp_path = tmp_path / "request.webp" - png_path.write_bytes(base64.b64decode("iVBORw0KGgo=")) - webp_path.write_bytes(b"RIFF\x0c\x00\x00\x00WEBPVP8 ") + png_path.write_bytes(PNG_BYTES) + webp_path.write_bytes(WEBP_BYTES) + base64_url = f"base64://{PNG_BASE64}" req = ProviderRequest( prompt="Hello", image_urls=[ str(png_path), f"file:///{webp_path.as_posix()}", + base64_url, ], ) @@ -748,9 +750,10 @@ async def test_provider_request_assemble_context_image_mime_local_and_file_uri( for part in assembled["content"] if part.get("type") == "image_url" ] - assert len(image_urls) == 2 + assert len(image_urls) == 3 assert image_urls[0].startswith("data:image/png;base64,") assert image_urls[1].startswith("data:image/webp;base64,") + assert image_urls[2].startswith("data:image/png;base64,") class TestSanitizeContextByModalities: From bbb0bf92f47cedd66fb0a61ec3e39ab2b6a5016f Mon Sep 17 00:00:00 2001 From: idiotsj Date: Thu, 26 Mar 2026 19:48:48 +0800 Subject: [PATCH 3/9] test: remove duplicated io image data uri test --- tests/test_io_image_data_uri.py | 59 --------------------------------- 1 file changed, 59 deletions(-) delete mode 100644 tests/test_io_image_data_uri.py diff --git a/tests/test_io_image_data_uri.py b/tests/test_io_image_data_uri.py deleted file mode 100644 index e2c0360048..0000000000 --- a/tests/test_io_image_data_uri.py +++ /dev/null @@ -1,59 +0,0 @@ -from pathlib import Path - -from astrbot.core.utils.io import detect_image_mime_type, image_source_to_data_uri -from tests.fixtures.image_samples import ( - GIF_BASE64, - GIF_BYTES, - JPEG_BYTES, - PNG_BASE64, - PNG_BYTES, - WEBP_BYTES, -) - - -def test_detect_image_mime_type_known_formats(): - assert detect_image_mime_type(PNG_BYTES) == "image/png" - assert detect_image_mime_type(JPEG_BYTES) == "image/jpeg" - assert detect_image_mime_type(GIF_BYTES) == "image/gif" - assert detect_image_mime_type(WEBP_BYTES) == "image/webp" - - -def test_detect_image_mime_type_unknown_fallback_jpeg(): - assert detect_image_mime_type(b"not-an-image") == "image/jpeg" - - -def test_image_source_to_data_uri_passthrough_data_uri(): - data_uri = f"data:image/png;base64,{PNG_BASE64}" - encoded, mime_type = image_source_to_data_uri(data_uri) - assert encoded == data_uri - assert mime_type == "image/png" - - -def test_image_source_to_data_uri_detects_base64_mime(): - encoded, mime_type = image_source_to_data_uri(f"base64://{GIF_BASE64}") - assert encoded.startswith("data:image/gif;base64,") - assert mime_type == "image/gif" - - -def test_image_source_to_data_uri_invalid_base64_fallback_jpeg(): - encoded, mime_type = image_source_to_data_uri("base64://not-valid-base64") - assert encoded == "data:image/jpeg;base64,not-valid-base64" - assert mime_type == "image/jpeg" - - -def test_image_source_to_data_uri_detects_local_file_mime(tmp_path: Path): - webp_path = tmp_path / "image.webp" - webp_path.write_bytes(WEBP_BYTES) - - encoded, mime_type = image_source_to_data_uri(str(webp_path)) - assert encoded.startswith("data:image/webp;base64,") - assert mime_type == "image/webp" - - -def test_image_source_to_data_uri_detects_file_uri_mime(tmp_path: Path): - png_path = tmp_path / "uri-image.png" - png_path.write_bytes(PNG_BYTES) - - encoded, mime_type = image_source_to_data_uri(f"file:///{png_path.as_posix()}") - assert encoded == f"data:image/png;base64,{PNG_BASE64}" - assert mime_type == "image/png" From 9c9280fc37a25f373db23b21806cf5630d8eafb3 Mon Sep 17 00:00:00 2001 From: idiotsj Date: Thu, 26 Mar 2026 20:25:19 +0800 Subject: [PATCH 4/9] fix(image): validate data uri and scheme; add jpeg gif coverage --- astrbot/core/utils/io.py | 22 ++++++++++++++----- tests/fixtures/image_samples.py | 1 + tests/test_openai_source.py | 39 +++++++++++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 3dddbc78a7..f34e615d52 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -225,15 +225,20 @@ def image_source_to_data_uri(image_source: str) -> tuple[str, str]: """将图片来源统一转换为 data URI,并尽量保留真实 MIME 类型。""" lower_source = image_source.lower() - if lower_source.startswith("data:image/"): - mime_type = "image/jpeg" + if lower_source.startswith("data:"): prefix = image_source.split(",", 1)[0] - if prefix.startswith("data:"): - mime_type = prefix.split(";", 1)[0].removeprefix("data:") - if not mime_type.startswith("image/"): - mime_type = "image/jpeg" + mime_type = prefix.split(";", 1)[0].removeprefix("data:").lower() + if not mime_type.startswith("image/"): + raise ValueError( + f"Only image data URI is supported, got MIME type: {mime_type or 'unknown'}", + ) return image_source, mime_type + if lower_source.startswith(("http://", "https://")): + raise ValueError( + "Remote image URL is not supported in image_source_to_data_uri; download the file before calling this helper.", + ) + if image_source.startswith("base64://"): raw_base64 = image_source.removeprefix("base64://") mime_type = "image/jpeg" @@ -251,6 +256,11 @@ def image_source_to_data_uri(image_source: str) -> tuple[str, str]: else: raw_path = parsed.path image_source = url2pathname(unquote(raw_path)) + elif "://" in image_source: + scheme = image_source.split("://", 1)[0].lower() + raise ValueError( + f"Unsupported image source scheme: {scheme}://", + ) with open(image_source, "rb") as f: image_bytes = f.read() diff --git a/tests/fixtures/image_samples.py b/tests/fixtures/image_samples.py index 4453c28717..2301916b2b 100644 --- a/tests/fixtures/image_samples.py +++ b/tests/fixtures/image_samples.py @@ -8,3 +8,4 @@ PNG_BASE64 = base64.b64encode(PNG_BYTES).decode("ascii") GIF_BASE64 = base64.b64encode(GIF_BYTES).decode("ascii") WEBP_BASE64 = base64.b64encode(WEBP_BYTES).decode("ascii") +JPEG_BASE64 = base64.b64encode(JPEG_BYTES).decode("ascii") diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 23693a7acb..1acf884c0a 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -9,6 +9,9 @@ from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial from tests.fixtures.image_samples import ( GIF_BASE64, + GIF_BYTES, + JPEG_BASE64, + JPEG_BYTES, PNG_BASE64, PNG_BYTES, WEBP_BASE64, @@ -563,14 +566,22 @@ async def test_openai_encode_image_bs64_detects_base64_mime(): async def test_openai_encode_image_bs64_detects_local_file_mime(tmp_path: Path): provider = _make_provider() png_path = tmp_path / "pixel.png" + gif_path = tmp_path / "pixel.gif" + jpeg_path = tmp_path / "pixel.jpg" webp_path = tmp_path / "pixel.webp" png_path.write_bytes(PNG_BYTES) + gif_path.write_bytes(GIF_BYTES) + jpeg_path.write_bytes(JPEG_BYTES) webp_path.write_bytes(WEBP_BYTES) try: png_data = await provider.encode_image_bs64(str(png_path)) + gif_data = await provider.encode_image_bs64(str(gif_path)) + jpeg_data = await provider.encode_image_bs64(str(jpeg_path)) webp_data = await provider.encode_image_bs64(str(webp_path)) assert png_data.startswith("data:image/png;base64,") + assert gif_data.startswith("data:image/gif;base64,") + assert jpeg_data.startswith("data:image/jpeg;base64,") assert webp_data.startswith("data:image/webp;base64,") finally: await provider.terminate() @@ -579,9 +590,13 @@ async def test_openai_encode_image_bs64_detects_local_file_mime(tmp_path: Path): @pytest.mark.asyncio async def test_openai_encode_image_bs64_keeps_data_uri(): provider = _make_provider() - data_uri = f"data:image/png;base64,{PNG_BASE64}" + png_data_uri = f"data:image/png;base64,{PNG_BASE64}" + gif_data_uri = f"data:image/gif;base64,{GIF_BASE64}" + jpeg_data_uri = f"data:image/jpeg;base64,{JPEG_BASE64}" try: - assert await provider.encode_image_bs64(data_uri) == data_uri + assert await provider.encode_image_bs64(png_data_uri) == png_data_uri + assert await provider.encode_image_bs64(gif_data_uri) == gif_data_uri + assert await provider.encode_image_bs64(jpeg_data_uri) == jpeg_data_uri finally: await provider.terminate() @@ -596,6 +611,26 @@ async def test_openai_encode_image_bs64_invalid_base64_fallback_to_jpeg(): await provider.terminate() +@pytest.mark.asyncio +async def test_openai_encode_image_bs64_rejects_non_image_data_uri(): + provider = _make_provider() + try: + with pytest.raises(ValueError, match="Only image data URI is supported"): + await provider.encode_image_bs64("data:text/plain;base64,SGVsbG8=") + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_encode_image_bs64_rejects_unsupported_uri_scheme(): + provider = _make_provider() + try: + with pytest.raises(ValueError, match="Unsupported image source scheme"): + await provider.encode_image_bs64("s3://bucket/path/image.png") + finally: + await provider.terminate() + + @pytest.mark.asyncio async def test_openai_assemble_context_extra_image_file_uri_mime(tmp_path: Path): provider = _make_provider() From a0512caf4bfddec1c098759757e9c59a0412e4f8 Mon Sep 17 00:00:00 2001 From: idiotsj Date: Thu, 26 Mar 2026 21:08:03 +0800 Subject: [PATCH 5/9] refactor(image): centralize default mime and clarify helper constraints --- astrbot/core/utils/io.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index f34e615d52..a8d93e4263 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -208,21 +208,29 @@ def file_to_base64(file_path: str) -> str: return "base64://" + base64_str +DEFAULT_IMAGE_MIME_TYPE = "image/jpeg" + + def detect_image_mime_type(data: bytes) -> str: """根据图片二进制数据的 magic bytes 检测 MIME 类型。""" if data[:8] == b"\x89PNG\r\n\x1a\n": return "image/png" if data[:2] == b"\xff\xd8": - return "image/jpeg" + return DEFAULT_IMAGE_MIME_TYPE if data[:6] in (b"GIF87a", b"GIF89a"): return "image/gif" if data[:4] == b"RIFF" and data[8:12] == b"WEBP": return "image/webp" - return "image/jpeg" + return DEFAULT_IMAGE_MIME_TYPE def image_source_to_data_uri(image_source: str) -> tuple[str, str]: - """将图片来源统一转换为 data URI,并尽量保留真实 MIME 类型。""" + """将本地/内联图片来源统一转换为 data URI,并尽量保留真实 MIME 类型。 + + 说明: + - 支持 `data:image/...`、`base64://...`、本地路径和 `file://...`。 + - 不支持远程 URL(`http://`、`https://`),调用方应先下载到本地文件。 + """ lower_source = image_source.lower() if lower_source.startswith("data:"): @@ -241,12 +249,12 @@ def image_source_to_data_uri(image_source: str) -> tuple[str, str]: if image_source.startswith("base64://"): raw_base64 = image_source.removeprefix("base64://") - mime_type = "image/jpeg" + mime_type = DEFAULT_IMAGE_MIME_TYPE try: image_bytes = base64.b64decode(raw_base64) mime_type = detect_image_mime_type(image_bytes) except Exception: - mime_type = "image/jpeg" + mime_type = DEFAULT_IMAGE_MIME_TYPE return f"data:{mime_type};base64,{raw_base64}", mime_type if lower_source.startswith("file://"): From 75888d462f8ec1f450ea105e4354eaab8510e1ed Mon Sep 17 00:00:00 2001 From: idiotsj Date: Thu, 26 Mar 2026 21:42:22 +0800 Subject: [PATCH 6/9] refactor(image): simplify image source handling branches --- astrbot/core/provider/entities.py | 11 ++++++----- astrbot/core/provider/sources/openai_source.py | 11 ++++++----- astrbot/core/utils/io.py | 2 +- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 281abf3e97..1adeaac09c 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -186,11 +186,12 @@ async def assemble_context(self) -> dict: # 3. 图片内容 if self.image_urls: for image_url in self.image_urls: - if image_url.startswith("http"): - image_path = await download_image_by_url(image_url) - image_data = await self._encode_image_bs64(image_path) - else: - image_data = await self._encode_image_bs64(image_url) + image_source = ( + await download_image_by_url(image_url) + if image_url.startswith("http") + else image_url + ) + image_data = await self._encode_image_bs64(image_source) if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index ffbf4d8a30..07fbaa0432 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -923,11 +923,12 @@ async def assemble_context( """组装成符合 OpenAI 格式的 role 为 user 的消息段""" async def resolve_image_part(image_url: str) -> dict | None: - if image_url.startswith("http"): - image_path = await download_image_by_url(image_url) - image_data = await self.encode_image_bs64(image_path) - else: - image_data = await self.encode_image_bs64(image_url) + image_source = ( + await download_image_by_url(image_url) + if image_url.startswith("http") + else image_url + ) + image_data = await self.encode_image_bs64(image_source) if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") return None diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index a8d93e4263..18217eeada 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -254,7 +254,7 @@ def image_source_to_data_uri(image_source: str) -> tuple[str, str]: image_bytes = base64.b64decode(raw_base64) mime_type = detect_image_mime_type(image_bytes) except Exception: - mime_type = DEFAULT_IMAGE_MIME_TYPE + pass return f"data:{mime_type};base64,{raw_base64}", mime_type if lower_source.startswith("file://"): From 2e9506f60a5f4742559e0bc30585c719e5be45ee Mon Sep 17 00:00:00 2001 From: idiotsj Date: Thu, 26 Mar 2026 23:28:14 +0800 Subject: [PATCH 7/9] refactor(image): align http scheme checks and remove dead branches --- astrbot/core/provider/entities.py | 12 +++---- .../core/provider/sources/openai_source.py | 11 ++++--- astrbot/core/utils/io.py | 7 ++++- tests/test_openai_source.py | 30 ++++++++++++++++++ tests/unit/test_astr_main_agent.py | 31 +++++++++++++++++++ 5 files changed, 79 insertions(+), 12 deletions(-) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 1adeaac09c..ee29b6c6bc 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -10,7 +10,6 @@ from openai.types.chat.chat_completion import ChatCompletion import astrbot.core.message.components as Comp -from astrbot import logger from astrbot.core.agent.message import ( AssistantMessageSegment, ContentPart, @@ -20,7 +19,11 @@ from astrbot.core.agent.tool import ToolSet from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.utils.io import download_image_by_url, image_source_to_data_uri +from astrbot.core.utils.io import ( + download_image_by_url, + image_source_to_data_uri, + is_http_or_https_url, +) class ProviderType(enum.Enum): @@ -188,13 +191,10 @@ async def assemble_context(self) -> dict: for image_url in self.image_urls: image_source = ( await download_image_by_url(image_url) - if image_url.startswith("http") + if is_http_or_https_url(image_url) else image_url ) image_data = await self._encode_image_bs64(image_source) - if not image_data: - logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") - continue content_blocks.append( {"type": "image_url", "image_url": {"url": image_data}}, ) diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 07fbaa0432..8b1639e071 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -21,7 +21,11 @@ from astrbot.core.agent.tool import ToolSet from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult -from astrbot.core.utils.io import download_image_by_url, image_source_to_data_uri +from astrbot.core.utils.io import ( + download_image_by_url, + image_source_to_data_uri, + is_http_or_https_url, +) from astrbot.core.utils.network_utils import ( create_proxy_client, is_connection_error, @@ -925,13 +929,10 @@ async def assemble_context( async def resolve_image_part(image_url: str) -> dict | None: image_source = ( await download_image_by_url(image_url) - if image_url.startswith("http") + if is_http_or_https_url(image_url) else image_url ) image_data = await self.encode_image_bs64(image_source) - if not image_data: - logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") - return None return { "type": "image_url", "image_url": {"url": image_data}, diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 18217eeada..23186a7a1f 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -211,6 +211,11 @@ def file_to_base64(file_path: str) -> str: DEFAULT_IMAGE_MIME_TYPE = "image/jpeg" +def is_http_or_https_url(source: str) -> bool: + """Return whether source is a HTTP(S) URL (case-insensitive).""" + return urlsplit(source).scheme.lower() in ("http", "https") + + def detect_image_mime_type(data: bytes) -> str: """根据图片二进制数据的 magic bytes 检测 MIME 类型。""" if data[:8] == b"\x89PNG\r\n\x1a\n": @@ -242,7 +247,7 @@ def image_source_to_data_uri(image_source: str) -> tuple[str, str]: ) return image_source, mime_type - if lower_source.startswith(("http://", "https://")): + if is_http_or_https_url(image_source): raise ValueError( "Remote image URL is not supported in image_source_to_data_uri; download the file before calling this helper.", ) diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 1acf884c0a..4f6e760fda 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -654,3 +654,33 @@ async def test_openai_assemble_context_extra_image_file_uri_mime(tmp_path: Path) assert image_part["image_url"]["url"].startswith("data:image/png;base64,") finally: await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_assemble_context_uppercase_https_image_url( + tmp_path: Path, monkeypatch +): + provider = _make_provider() + png_path = tmp_path / "remote.png" + png_path.write_bytes(PNG_BYTES) + + async def fake_download(url: str) -> str: + assert url == "HTTPS://example.com/asset.png" + return str(png_path) + + monkeypatch.setattr( + "astrbot.core.provider.sources.openai_source.download_image_by_url", + fake_download, + ) + try: + assembled = await provider.assemble_context( + text="hello", + image_urls=["HTTPS://example.com/asset.png"], + ) + + image_part = next( + part for part in assembled["content"] if part.get("type") == "image_url" + ) + assert image_part["image_url"]["url"].startswith("data:image/png;base64,") + finally: + await provider.terminate() diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 5eeb2f8536..0917d65707 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -755,6 +755,37 @@ async def test_provider_request_assemble_context_image_mime_local_and_file_uri( assert image_urls[1].startswith("data:image/webp;base64,") assert image_urls[2].startswith("data:image/png;base64,") + @pytest.mark.asyncio + async def test_provider_request_assemble_context_uppercase_https_image_url( + self, tmp_path, monkeypatch + ): + png_path = tmp_path / "request-upper.png" + png_path.write_bytes(PNG_BYTES) + + async def fake_download(url: str) -> str: + assert url == "HTTPS://example.com/request.png" + return str(png_path) + + monkeypatch.setattr( + "astrbot.core.provider.entities.download_image_by_url", + fake_download, + ) + + req = ProviderRequest( + prompt="Hello", + image_urls=["HTTPS://example.com/request.png"], + ) + + assembled = await req.assemble_context() + assert isinstance(assembled["content"], list) + image_urls = [ + part["image_url"]["url"] + for part in assembled["content"] + if part.get("type") == "image_url" + ] + assert len(image_urls) == 1 + assert image_urls[0].startswith("data:image/png;base64,") + class TestSanitizeContextByModalities: """Tests for _sanitize_context_by_modalities function.""" From 63332139f8b0402db08d65928485bfab568d7410 Mon Sep 17 00:00:00 2001 From: idiotsj Date: Thu, 26 Mar 2026 23:43:09 +0800 Subject: [PATCH 8/9] fix(image): narrow base64 decode error handling --- astrbot/core/utils/io.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 23186a7a1f..f9e115d956 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -1,4 +1,5 @@ import base64 +import binascii import logging import os import shutil @@ -257,9 +258,15 @@ def image_source_to_data_uri(image_source: str) -> tuple[str, str]: mime_type = DEFAULT_IMAGE_MIME_TYPE try: image_bytes = base64.b64decode(raw_base64) - mime_type = detect_image_mime_type(image_bytes) - except Exception: - pass + except (binascii.Error, ValueError) as exc: + logger.debug( + "Failed to decode base64 image source, fallback to %s: %s", + DEFAULT_IMAGE_MIME_TYPE, + exc, + ) + return f"data:{mime_type};base64,{raw_base64}", mime_type + + mime_type = detect_image_mime_type(image_bytes) return f"data:{mime_type};base64,{raw_base64}", mime_type if lower_source.startswith("file://"): From 7899320e9bfe9c3187dd97bdfd5848ad4967741f Mon Sep 17 00:00:00 2001 From: idiotsj Date: Thu, 26 Mar 2026 23:49:35 +0800 Subject: [PATCH 9/9] fix(image): make base64 scheme matching case-insensitive --- astrbot/core/utils/io.py | 6 +++--- tests/test_openai_source.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index f9e115d956..282d07d406 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -253,11 +253,11 @@ def image_source_to_data_uri(image_source: str) -> tuple[str, str]: "Remote image URL is not supported in image_source_to_data_uri; download the file before calling this helper.", ) - if image_source.startswith("base64://"): - raw_base64 = image_source.removeprefix("base64://") + if lower_source.startswith("base64://"): + raw_base64 = image_source[len("base64://") :] mime_type = DEFAULT_IMAGE_MIME_TYPE try: - image_bytes = base64.b64decode(raw_base64) + image_bytes = base64.b64decode(raw_base64, validate=True) except (binascii.Error, ValueError) as exc: logger.debug( "Failed to decode base64 image source, fallback to %s: %s", diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index 4f6e760fda..4350f72cd1 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -552,10 +552,12 @@ async def test_openai_encode_image_bs64_detects_base64_mime(): provider = _make_provider() try: png_data = await provider.encode_image_bs64(f"base64://{PNG_BASE64}") + png_data_upper = await provider.encode_image_bs64(f"BASE64://{PNG_BASE64}") gif_data = await provider.encode_image_bs64(f"base64://{GIF_BASE64}") webp_data = await provider.encode_image_bs64(f"base64://{WEBP_BASE64}") assert png_data.startswith("data:image/png;base64,") + assert png_data_upper.startswith("data:image/png;base64,") assert gif_data.startswith("data:image/gif;base64,") assert webp_data.startswith("data:image/webp;base64,") finally: