From a105fa235fb4273295f7ac242237c7ee558eccdd Mon Sep 17 00:00:00 2001 From: Tz-WIND <154044203+Tz-WIND@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:44:14 +0800 Subject: [PATCH 1/8] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E5=90=8E=E9=87=8D?= =?UTF-8?q?=E6=96=B0=E4=B8=8A=E4=BC=A0commit?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/utils/image_utils.py | 106 ++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 astrbot/core/utils/image_utils.py diff --git a/astrbot/core/utils/image_utils.py b/astrbot/core/utils/image_utils.py new file mode 100644 index 0000000000..1d10ec65e3 --- /dev/null +++ b/astrbot/core/utils/image_utils.py @@ -0,0 +1,106 @@ +"""图片处理工具函数。 +提供 MIME 类型检测与 base64 Data URL 编码等公共能力, +供 ProviderRequest 及各 Provider 适配器复用,避免重复实现。 +""" + +from __future__ import annotations + +import base64 + + +def detect_image_mime_type(header_bytes: bytes) -> str: + """根据文件头magic bytes检测图片的实际 MIME 类型。 + + 依次匹配常见图片格式的文件头特征,均不匹配时回退到 image/jpeg + 以保持向后兼容。支持的格式:JPEG、PNG、GIF、WebP、BMP、TIFF、 + ICO、SVG、AVIF、HEIF/HEIC。 + + Args: + header_bytes: 文件头原始字节。SVG检测需要至少 256 字节; + 其他二进制格式最多需要 16 字节。 + + Returns: + 对应的 MIME 类型字符串,例如 "image/png"。 + """ + if len(header_bytes) >= 3 and header_bytes[:3] == b'\xff\xd8\xff': + return "image/jpeg" + if len(header_bytes) >= 8 and header_bytes[:8] == b'\x89PNG\r\n\x1a\n': + return "image/png" + if len(header_bytes) >= 4 and header_bytes[:4] == b'GIF8': + return "image/gif" + # WebP: RIFF????WEBP + if len(header_bytes) >= 12 and header_bytes[:4] == b'RIFF' and header_bytes[8:12] == b'WEBP': + return "image/webp" + if len(header_bytes) >= 2 and header_bytes[:2] == b'BM': + return "image/bmp" + # TIFF: 小端 (II) 或大端 (MM) + if len(header_bytes) >= 4 and header_bytes[:4] in (b'II\x2a\x00', b'MM\x00\x2a'): + return "image/tiff" + if len(header_bytes) >= 4 and header_bytes[:4] == b'\x00\x00\x01\x00': + return "image/x-icon" + # SVG 为文本格式,检测头部是否含有 = 12 and header_bytes[4:12] == b'ftypavif': + return "image/avif" + # HEIF/HEIC: ftyp box,brand 为 heic/heix/hevc/hevx/mif1 + if len(header_bytes) >= 12 and header_bytes[4:8] == b'ftyp': + brand = header_bytes[8:12] + if brand in (b'heic', b'heix', b'hevc', b'hevx', b'mif1'): + return "image/heif" + # 无法识别,回退到 image/jpeg + return "image/jpeg" + + +# SVG 检测所需的最小头部字节数。 +# 其他二进制格式的魔术字节最多 16 字节,SVG 需要更多以跳过可能的 XML 声明。 +_HEADER_READ_SIZE = 256 + +# 对应 _HEADER_READ_SIZE 字节所需的 base64 字符数。 +# base64 每 3 字节编码为 4 字符,向上取整后再补充少量余量保证填充对齐。 +# ceil(256 / 3) * 4 = 344 +_BASE64_SAMPLE_CHARS = 344 + + +async def encode_image_to_base64_url(image_url: str) -> str: + """将图片转换为 base64 Data URL,自动检测实际 MIME 类型。 + + 原实现硬编码 image/jpeg,会导致 PNG 等格式在严格校验的接口上报错。 + 现通过读取文件头魔术字节来推断正确的 MIME 类型。 + + 对于 SVG 文件,需要读取至少 256 字节的头部才能可靠检测, + 因此统一将读取/解码量提升到 256 字节,消除之前字节数不足的问题。 + + Args: + image_url: 本地文件路径,或以 "base64://" 开头的 base64 字符串。 + + Returns: + 形如 "data:image/png;base64,..." 的 Data URL 字符串。 + """ + if image_url.startswith("base64://"): + raw_b64 = image_url[len("base64://"):] + # 从 base64 数据中解码足量字节以检测实际格式。 + # 取前 344 个 base64 字符可解码出约 258 字节,足以覆盖 SVG 检测所需的 256 字节。 + try: + sample = raw_b64[:_BASE64_SAMPLE_CHARS] + # 确保 base64 填充正确,避免解码报错 + missing_padding = len(sample) % 4 + if missing_padding: + sample += '=' * (4 - missing_padding) + header_bytes = base64.b64decode(sample) + mime_type = detect_image_mime_type(header_bytes) + except Exception: + # 解码失败时安全回退 + mime_type = "image/jpeg" + return f"data:{mime_type};base64,{raw_b64}" + + with open(image_url, "rb") as f: + # 读取 256 字节用于格式检测,以支持需要较多头部数据的 SVG 等格式, + # 再 seek 回起点读取完整内容进行 base64 编码。 + header_bytes = f.read(_HEADER_READ_SIZE) + mime_type = detect_image_mime_type(header_bytes) + f.seek(0) + image_bs64 = base64.b64encode(f.read()).decode("utf-8") + return f"data:{mime_type};base64,{image_bs64}" \ No newline at end of file From dc5b408f0db9a93aab46ef3400fd6b58aaf0ae9b Mon Sep 17 00:00:00 2001 From: Tz-WIND <154044203+Tz-WIND@users.noreply.github.com> Date: Fri, 27 Mar 2026 11:45:37 +0800 Subject: [PATCH 2/8] Remove unused import and add image encoding utility --- astrbot/core/provider/entities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 20c5a7947d..62c6b82ee3 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -1,6 +1,5 @@ from __future__ import annotations -import base64 import enum import json from dataclasses import dataclass, field @@ -21,6 +20,7 @@ from astrbot.core.agent.tool import ToolSet from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.utils.image_utils import encode_image_to_base64_url from astrbot.core.utils.io import download_image_by_url From 678c0089f6787b25bc274d89daa0ed4e910a870f Mon Sep 17 00:00:00 2001 From: Tz-WIND <154044203+Tz-WIND@users.noreply.github.com> Date: Fri, 27 Mar 2026 13:01:32 +0800 Subject: [PATCH 3/8] Refactor OpenAI source and improve client handling Refactor OpenAI source code by removing unused methods and improving image handling logic. Ensure client initialization is more robust and consistent. --- .../core/provider/sources/openai_source.py | 320 +++++------------- 1 file changed, 89 insertions(+), 231 deletions(-) diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index cdad66a22f..2e61a47553 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -1,15 +1,11 @@ import asyncio import base64 -import copy import inspect import json import random import re from collections.abc import AsyncGenerator -from io import BytesIO -from pathlib import Path -from typing import Any, Literal -from urllib.parse import unquote, urlparse +from typing import Any import httpx from openai import AsyncAzureOpenAI, AsyncOpenAI @@ -18,8 +14,6 @@ from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.completion_usage import CompletionUsage -from PIL import Image as PILImage -from PIL import UnidentifiedImageError import astrbot.core.message.components as Comp from astrbot import logger @@ -35,6 +29,7 @@ log_connection_failure, ) from astrbot.core.utils.string_utils import normalize_and_dedupe_strings +from astrbot.core.utils.image_utils import encode_image_to_base64_url from ..register import register_provider_adapter @@ -139,186 +134,6 @@ def _context_contains_image(contexts: list[dict]) -> bool: return True return False - def _is_invalid_attachment_error(self, error: Exception) -> bool: - body = getattr(error, "body", None) - code: str | None = None - message: str | None = None - if isinstance(body, dict): - err_obj = body.get("error") - if isinstance(err_obj, dict): - raw_code = err_obj.get("code") - raw_message = err_obj.get("message") - code = raw_code.lower() if isinstance(raw_code, str) else None - message = raw_message.lower() if isinstance(raw_message, str) else None - - if code == "invalid_attachment": - return True - - text_sources: list[str] = [] - if message: - text_sources.append(message) - if code: - text_sources.append(code) - text_sources.extend(map(str, self._extract_error_text_candidates(error))) - - error_text = " ".join(text.lower() for text in text_sources if text) - if "invalid_attachment" in error_text: - return True - if "download attachment" in error_text and "404" in error_text: - return True - return False - - @classmethod - def _encode_image_file_to_data_url( - cls, - image_path: str, - *, - mode: Literal["safe", "strict"], - ) -> str | None: - try: - image_bytes = Path(image_path).read_bytes() - except OSError: - if mode == "strict": - raise - return None - - try: - with PILImage.open(BytesIO(image_bytes)) as image: - image.verify() - image_format = str(image.format or "").upper() - except (OSError, UnidentifiedImageError): - if mode == "strict": - raise ValueError(f"Invalid image file: {image_path}") - return None - - mime_type = { - "JPEG": "image/jpeg", - "PNG": "image/png", - "GIF": "image/gif", - "WEBP": "image/webp", - "BMP": "image/bmp", - }.get(image_format, "image/jpeg") - image_bs64 = base64.b64encode(image_bytes).decode("utf-8") - return f"data:{mime_type};base64,{image_bs64}" - - @staticmethod - def _file_uri_to_path(file_uri: str) -> str: - """Normalize file URIs to paths. - - `file://localhost/...` and drive-letter forms are treated as local paths. - Other non-empty hosts are preserved as UNC-style paths. - """ - parsed = urlparse(file_uri) - if parsed.scheme != "file": - return file_uri - - netloc = unquote(parsed.netloc or "") - path = unquote(parsed.path or "") - if re.fullmatch(r"[A-Za-z]:", netloc): - return str(Path(f"{netloc}{path}")) - if re.match(r"^/[A-Za-z]:/", path): - path = path[1:] - if netloc and netloc != "localhost": - path = f"//{netloc}{path}" - return str(Path(path)) - - async def _image_ref_to_data_url( - self, - image_ref: str, - *, - mode: Literal["safe", "strict"] = "safe", - ) -> str | None: - if image_ref.startswith("base64://"): - return image_ref.replace("base64://", "data:image/jpeg;base64,") - - if image_ref.startswith("http"): - image_path = await download_image_by_url(image_ref) - elif image_ref.startswith("file://"): - image_path = self._file_uri_to_path(image_ref) - else: - image_path = image_ref - - return self._encode_image_file_to_data_url( - image_path, - mode=mode, - ) - - async def _resolve_image_part( - self, - image_url: str, - *, - image_detail: str | None = None, - ) -> dict | None: - if image_url.startswith("data:"): - image_payload = {"url": image_url} - else: - image_data = await self._image_ref_to_data_url(image_url, mode="safe") - if not image_data: - logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") - return None - image_payload = {"url": image_data} - - if image_detail: - image_payload["detail"] = image_detail - return { - "type": "image_url", - "image_url": image_payload, - } - - def _extract_image_part_info(self, part: dict) -> tuple[str | None, str | None]: - if not isinstance(part, dict) or part.get("type") != "image_url": - return None, None - - image_url_data = part.get("image_url") - if not isinstance(image_url_data, dict): - logger.warning("图片内容块格式无效,将保留原始内容。") - return None, None - - url = image_url_data.get("url") - if not isinstance(url, str) or not url: - logger.warning("图片内容块缺少有效 URL,将保留原始内容。") - return None, None - - image_detail = image_url_data.get("detail") - if not isinstance(image_detail, str): - image_detail = None - return url, image_detail - - async def _transform_content_part(self, part: dict) -> dict: - url, image_detail = self._extract_image_part_info(part) - if not url: - return part - - try: - resolved_part = await self._resolve_image_part( - url, image_detail=image_detail - ) - except Exception as exc: - logger.warning( - "图片 %s 预处理失败,将保留原始内容。错误: %s", - url, - exc, - ) - return part - - return resolved_part or part - - async def _materialize_message_image_parts(self, message: dict) -> dict: - content = message.get("content") - if not isinstance(content, list): - return {**message} - - new_content = [await self._transform_content_part(part) for part in content] - return {**message, "content": new_content} - - async def _materialize_context_image_parts( - self, context_query: list[dict] - ) -> list[dict]: - return [ - await self._materialize_message_image_parts(message) - for message in context_query - ] - async def _fallback_to_text_only_and_retry( self, payloads: dict, @@ -367,34 +182,62 @@ def __init__(self, provider_config, provider_settings) -> None: for key in self.custom_headers: self.custom_headers[key] = str(self.custom_headers[key]) - if "api_version" in provider_config: - # Using Azure OpenAI API - self.client = AsyncAzureOpenAI( + # 创建 client + self.client = self._create_openai_client() + + self.default_params = inspect.signature( + self.client.chat.completions.create, + ).parameters.keys() + + model = provider_config.get("model", "unknown") + self.set_model(model) + + self.reasoning_key = "reasoning_content" + + def _create_openai_client(self) -> AsyncOpenAI | AsyncAzureOpenAI: + """创建 OpenAI client 实例(统一初始化逻辑)。 + + 解耦 client 初始化,便于 reload 时重建。 + """ + if "api_version" in self.provider_config: + # Azure OpenAI + return AsyncAzureOpenAI( api_key=self.chosen_api_key, - api_version=provider_config.get("api_version", None), + api_version=self.provider_config.get("api_version", None), default_headers=self.custom_headers, - base_url=provider_config.get("api_base", ""), + base_url=self.provider_config.get("api_base", ""), timeout=self.timeout, - http_client=self._create_http_client(provider_config), + http_client=self._create_http_client(self.provider_config), ) else: - # Using OpenAI Official API - self.client = AsyncOpenAI( + # OpenAI Official + return AsyncOpenAI( api_key=self.chosen_api_key, - base_url=provider_config.get("api_base", None), + base_url=self.provider_config.get("api_base", None), default_headers=self.custom_headers, timeout=self.timeout, - http_client=self._create_http_client(provider_config), + http_client=self._create_http_client(self.provider_config), ) - self.default_params = inspect.signature( - self.client.chat.completions.create, - ).parameters.keys() - - model = provider_config.get("model", "unknown") - self.set_model(model) - - self.reasoning_key = "reasoning_content" + def _ensure_client(self) -> None: + """确保 client 可用,若已关闭则重建。 + + 多层防御:None 检查 + 底层连接检查。 + 注:访问 self.client._client.is_closed 依赖 openai 库内部实现, + 若 SDK 提供公开 API 应及时替换。 + """ + if self.client is None: + self.client = self._create_openai_client() + return + + # 检查底层 httpx.AsyncClient 是否已关闭 + try: + if hasattr(self.client, '_client') and hasattr(self.client._client, 'is_closed'): + if self.client._client.is_closed: + self.client = self._create_openai_client() + except Exception: + # 若检查失败,安全起见重建 client + self.client = self._create_openai_client() def _ollama_disable_thinking_enabled(self) -> bool: value = self.provider_config.get("ollama_disable_thinking", False) @@ -417,6 +260,7 @@ def _apply_provider_specific_extra_body_overrides( extra_body["reasoning_effort"] = "none" async def get_models(self): + self._ensure_client() try: models_str = [] models = await self.client.models.list() @@ -428,6 +272,7 @@ async def get_models(self): raise Exception(f"获取模型列表失败:{e}") async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: + self._ensure_client() if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -478,6 +323,7 @@ async def _query_stream( tools: ToolSet | None, ) -> AsyncGenerator[LLMResponse, None]: """流式查询API,逐步返回结果""" + self._ensure_client() if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -790,7 +636,7 @@ async def _prepare_chat_payload( new_record = await self.assemble_context( prompt, image_urls, extra_user_content_parts ) - context_query = copy.deepcopy(self._ensure_message_to_dicts(contexts)) + context_query = self._ensure_message_to_dicts(contexts) if new_record: context_query.append(new_record) if system_prompt: @@ -808,9 +654,6 @@ async def _prepare_chat_payload( for tcr in tool_calls_result: context_query.extend(tcr.to_openai_messages()) - if self._context_contains_image(context_query): - context_query = await self._materialize_context_image_parts(context_query) - model = model or self.get_model() payloads = {"messages": context_query, "model": model} @@ -911,18 +754,6 @@ async def _handle_api_error( "image_content_moderated", image_fallback_used=True, ) - if self._is_invalid_attachment_error(e): - if image_fallback_used or not self._context_contains_image(context_query): - raise e - return await self._fallback_to_text_only_and_retry( - payloads, - context_query, - chosen_key, - available_api_keys, - func_tool, - "invalid_attachment", - image_fallback_used=True, - ) if ( "Function calling is not enabled" in str(e) @@ -1124,6 +955,23 @@ async def assemble_context( ) -> dict: """组装成符合 OpenAI 格式的 role 为 user 的消息段""" + async def resolve_image_part(image_url: str) -> dict | None: + if image_url.startswith("http"): + image_path = await download_image_by_url(image_url) + image_data = await self.encode_image_bs64(image_path) + elif image_url.startswith("file:///"): + image_path = image_url.replace("file:///", "") + image_data = await self.encode_image_bs64(image_path) + else: + image_data = await self.encode_image_bs64(image_url) + if not image_data: + logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") + return None + return { + "type": "image_url", + "image_url": {"url": image_data}, + } + # 构建内容块列表 content_blocks = [] @@ -1143,9 +991,7 @@ async def assemble_context( if isinstance(part, TextPart): content_blocks.append({"type": "text", "text": part.text}) elif isinstance(part, ImageURLPart): - image_part = await self._resolve_image_part( - part.image_url.url, - ) + image_part = await resolve_image_part(part.image_url.url) if image_part: content_blocks.append(image_part) else: @@ -1154,7 +1000,7 @@ async def assemble_context( # 3. 图片内容 if image_urls: for image_url in image_urls: - image_part = await self._resolve_image_part(image_url) + image_part = await resolve_image_part(image_url) if image_part: content_blocks.append(image_part) @@ -1172,12 +1018,24 @@ async def assemble_context( return {"role": "user", "content": content_blocks} async def encode_image_bs64(self, image_url: str) -> str: - """将图片转换为 base64""" - image_data = await self._image_ref_to_data_url(image_url, mode="strict") - if image_data is None: - raise RuntimeError(f"Failed to encode image data: {image_url}") - return image_data + """将图片转换为 base64 Data URL,自动检测实际 MIME 类型。 + + 委托给公共工具函数 encode_image_to_base64_url 实现, + 消除与 ProviderRequest 之间的重复逻辑。原实现硬编码 + image/jpeg,会导致 PNG 等格式在严格校验的接口上报错。 + + Args: + image_url: 本地文件路径,或以 "base64://" 开头的 base64 字符串。 + + Returns: + 形如 "data:image/png;base64,..." 的 Data URL 字符串。 + """ + return await encode_image_to_base64_url(image_url) async def terminate(self): + """关闭 client 并清空引用,防止 reload 时复用已关闭连接。""" if self.client: - await self.client.close() + try: + await self.client.close() + finally: + self.client = None From 6eb3295b976d466f2c10742c9f87d306fdbd67d8 Mon Sep 17 00:00:00 2001 From: Tz-WIND <154044203+Tz-WIND@users.noreply.github.com> Date: Fri, 27 Mar 2026 15:17:53 +0800 Subject: [PATCH 4/8] Refactor image handling and client initialization --- .../core/provider/sources/openai_source.py | 461 ++++++++++++++---- 1 file changed, 379 insertions(+), 82 deletions(-) diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 2e61a47553..c4f5f565fd 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -1,11 +1,15 @@ import asyncio import base64 +import copy import inspect import json import random import re from collections.abc import AsyncGenerator -from typing import Any +from io import BytesIO +from pathlib import Path +from typing import Any, Literal +from urllib.parse import unquote, urlparse import httpx from openai import AsyncAzureOpenAI, AsyncOpenAI @@ -14,6 +18,8 @@ from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.completion_usage import CompletionUsage +from PIL import Image as PILImage +from PIL import UnidentifiedImageError import astrbot.core.message.components as Comp from astrbot import logger @@ -29,11 +35,21 @@ log_connection_failure, ) from astrbot.core.utils.string_utils import normalize_and_dedupe_strings -from astrbot.core.utils.image_utils import encode_image_to_base64_url +from astrbot.core.utils.image_utils import detect_image_mime_type from ..register import register_provider_adapter +# SVG 检测所需的最小头部字节数。 +# 其他二进制格式的魔术字节最多 16 字节,SVG 需要更多以跳过可能的 XML 声明。 +_HEADER_READ_SIZE = 256 + +# 对应 _HEADER_READ_SIZE 字节所需的 base64 字符数。 +# base64 每 3 字节编码为 4 字符,向上取整后再补充少量余量保证填充对齐。 +# ceil(256 / 3) * 4 = 344 +_BASE64_SAMPLE_CHARS = 344 + + @register_provider_adapter( "openai_chat_completion", "OpenAI API Chat Completion 提供商适配器", @@ -134,6 +150,275 @@ def _context_contains_image(contexts: list[dict]) -> bool: return True return False + def _is_invalid_attachment_error(self, error: Exception) -> bool: + body = getattr(error, "body", None) + code: str | None = None + message: str | None = None + if isinstance(body, dict): + err_obj = body.get("error") + if isinstance(err_obj, dict): + raw_code = err_obj.get("code") + raw_message = err_obj.get("message") + code = raw_code.lower() if isinstance(raw_code, str) else None + message = raw_message.lower() if isinstance(raw_message, str) else None + + if code == "invalid_attachment": + return True + + text_sources: list[str] = [] + if message: + text_sources.append(message) + if code: + text_sources.append(code) + text_sources.extend(map(str, self._extract_error_text_candidates(error))) + + error_text = " ".join(text.lower() for text in text_sources if text) + if "invalid_attachment" in error_text: + return True + if "download attachment" in error_text and "404" in error_text: + return True + return False + + @classmethod + def _detect_mime_type_with_fallback( + cls, + image_bytes: bytes, + source_hint: str = "", + ) -> str: + """检测图片的实际 MIME 类型,使用三级回退策略。 + + 检测优先级: + 1. 主检测:基于文件头魔术字节的 detect_image_mime_type(来自 image_utils) + 2. 回退:PIL 库识别 + 3. 回退的回退:硬编码 image/jpeg + + Args: + image_bytes: 图片的原始字节数据。 + source_hint: 可选的来源描述,用于 debug 日志输出。 + + Returns: + 对应的 MIME 类型字符串,例如 "image/png"。 + """ + log_prefix = f"[{source_hint}] " if source_hint else "" + + # --- 第一级:魔术字节检测(主检测) --- + header = image_bytes[:_HEADER_READ_SIZE] + mime_type = detect_image_mime_type(header) + # detect_image_mime_type 在无法识别时回退到 "image/jpeg", + # 因此只有当返回值不是默认回退值时才认为检测成功。 + # 但如果文件头确实是 JPEG(\xff\xd8\xff),那也是真正检测到了, + # 所以额外检查文件头是否是 JPEG 的魔术字节。 + if mime_type != "image/jpeg" or ( + len(header) >= 3 and header[:3] == b'\xff\xd8\xff' + ): + logger.debug( + "%s魔术字节检测命中,识别为 %s 格式", + log_prefix, + mime_type, + ) + return mime_type + + logger.debug( + "%s魔术字节检测未命中,尝试 PIL 回退...", + log_prefix, + ) + + # --- 第二级:PIL 回退 --- + try: + with PILImage.open(BytesIO(image_bytes)) as image: + image.verify() + image_format = str(image.format or "").upper() + pil_mime = { + "JPEG": "image/jpeg", + "PNG": "image/png", + "GIF": "image/gif", + "WEBP": "image/webp", + "BMP": "image/bmp", + }.get(image_format) + if pil_mime: + logger.debug( + "%sPIL 回退命中,识别为 %s 格式(PIL format: %s)", + log_prefix, + pil_mime, + image_format, + ) + return pil_mime + logger.debug( + "%sPIL 识别到格式 %s 但无对应 MIME 映射,继续回退...", + log_prefix, + image_format, + ) + except (OSError, UnidentifiedImageError) as exc: + logger.debug( + "%sPIL 回退失败: %s,继续回退...", + log_prefix, + exc, + ) + + # --- 第三级:硬编码回退 --- + logger.debug( + "%s所有检测均未命中,硬编码回退为 image/jpeg", + log_prefix, + ) + return "image/jpeg" + + @classmethod + def _encode_image_file_to_data_url( + cls, + image_path: str, + *, + mode: Literal["safe", "strict"], + ) -> str | None: + try: + image_bytes = Path(image_path).read_bytes() + except OSError: + if mode == "strict": + raise + return None + + # 使用三级回退策略检测 MIME 类型 + mime_type = cls._detect_mime_type_with_fallback( + image_bytes, source_hint=image_path + ) + + image_bs64 = base64.b64encode(image_bytes).decode("utf-8") + return f"data:{mime_type};base64,{image_bs64}" + + @staticmethod + def _file_uri_to_path(file_uri: str) -> str: + """Normalize file URIs to paths. + + `file://localhost/...` and drive-letter forms are treated as local paths. + Other non-empty hosts are preserved as UNC-style paths. + """ + parsed = urlparse(file_uri) + if parsed.scheme != "file": + return file_uri + + netloc = unquote(parsed.netloc or "") + path = unquote(parsed.path or "") + if re.fullmatch(r"[A-Za-z]:", netloc): + return str(Path(f"{netloc}{path}")) + if re.match(r"^/[A-Za-z]:/", path): + path = path[1:] + if netloc and netloc != "localhost": + path = f"//{netloc}{path}" + return str(Path(path)) + + async def _image_ref_to_data_url( + self, + image_ref: str, + *, + mode: Literal["safe", "strict"] = "safe", + ) -> str | None: + if image_ref.startswith("base64://"): + raw_b64 = image_ref[len("base64://"):] + try: + sample = raw_b64[:_BASE64_SAMPLE_CHARS] + missing_padding = len(sample) % 4 + if missing_padding: + sample += '=' * (4 - missing_padding) + header_bytes = base64.b64decode(sample) + mime_type = detect_image_mime_type(header_bytes) + logger.debug( + "[base64://] 魔术字节检测命中,识别为 %s 格式", + mime_type, + ) + except Exception: + mime_type = "image/jpeg" + logger.debug( + "[base64://] base64 解码失败,硬编码回退为 image/jpeg" + ) + return f"data:{mime_type};base64,{raw_b64}" + + if image_ref.startswith("http"): + image_path = await download_image_by_url(image_ref) + elif image_ref.startswith("file://"): + image_path = self._file_uri_to_path(image_ref) + else: + image_path = image_ref + + return self._encode_image_file_to_data_url( + image_path, + mode=mode, + ) + + async def _resolve_image_part( + self, + image_url: str, + *, + image_detail: str | None = None, + ) -> dict | None: + if image_url.startswith("data:"): + image_payload = {"url": image_url} + else: + image_data = await self._image_ref_to_data_url(image_url, mode="safe") + if not image_data: + logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") + return None + image_payload = {"url": image_data} + + if image_detail: + image_payload["detail"] = image_detail + return { + "type": "image_url", + "image_url": image_payload, + } + + def _extract_image_part_info(self, part: dict) -> tuple[str | None, str | None]: + if not isinstance(part, dict) or part.get("type") != "image_url": + return None, None + + image_url_data = part.get("image_url") + if not isinstance(image_url_data, dict): + logger.warning("图片内容块格式无效,将保留原始内容。") + return None, None + + url = image_url_data.get("url") + if not isinstance(url, str) or not url: + logger.warning("图片内容块缺少有效 URL,将保留原始内容。") + return None, None + + image_detail = image_url_data.get("detail") + if not isinstance(image_detail, str): + image_detail = None + return url, image_detail + + async def _transform_content_part(self, part: dict) -> dict: + url, image_detail = self._extract_image_part_info(part) + if not url: + return part + + try: + resolved_part = await self._resolve_image_part( + url, image_detail=image_detail + ) + except Exception as exc: + logger.warning( + "图片 %s 预处理失败,将保留原始内容。错误: %s", + url, + exc, + ) + return part + + return resolved_part or part + + async def _materialize_message_image_parts(self, message: dict) -> dict: + content = message.get("content") + if not isinstance(content, list): + return {**message} + + new_content = [await self._transform_content_part(part) for part in content] + return {**message, "content": new_content} + + async def _materialize_context_image_parts( + self, context_query: list[dict] + ) -> list[dict]: + return [ + await self._materialize_message_image_parts(message) + for message in context_query + ] + async def _fallback_to_text_only_and_retry( self, payloads: dict, @@ -166,6 +451,59 @@ def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None proxy = provider_config.get("proxy", "") return create_proxy_client("OpenAI", proxy) + def _create_openai_client(self) -> AsyncOpenAI | AsyncAzureOpenAI: + """创建 OpenAI/Azure 客户端实例,将初始化逻辑解耦以便复用。""" + if "api_version" in self.provider_config: + # Using Azure OpenAI API + return AsyncAzureOpenAI( + api_key=self.chosen_api_key, + api_version=self.provider_config.get("api_version", None), + default_headers=self.custom_headers, + base_url=self.provider_config.get("api_base", ""), + timeout=self.timeout, + http_client=self._create_http_client(self.provider_config), + ) + else: + # Using OpenAI Official API + return AsyncOpenAI( + api_key=self.chosen_api_key, + base_url=self.provider_config.get("api_base", None), + default_headers=self.custom_headers, + timeout=self.timeout, + http_client=self._create_http_client(self.provider_config), + ) + + def _ensure_client(self) -> None: + """确保 client 可用。如果 client 为 None 或底层连接已关闭,则重新创建。 + + 这提供了多层防御:terminate() 会将 client 置为 None, + 配置重载(reload)期间若复用已关闭的 client 会导致 APIConnectionError, + 此方法在每次使用前检查并自动重建。 + """ + need_reinit = False + if self.client is None: + need_reinit = True + else: + try: + # 注意:此处直接访问了 openai 库的私有属性 `_client`, + # 依赖其内部实现(httpx.AsyncClient 实例暴露的 is_closed 属性)。 + # 这一做法存在脆弱性——若 openai 库未来版本调整了内部结构, + # 此处可能在没有任何报错的情况下静默失效。 + # 目前 openai SDK 尚未提供检查底层连接是否已关闭的公开 API。 + # 若未来 SDK 提供了类似 self.client.is_closed() 的公开方法, + # 应及时将此处替换为对应的公开接口。 + if self.client._client.is_closed: + need_reinit = True + except AttributeError: + pass + + if need_reinit: + logger.warning("检测到 OpenAI client 已关闭或未初始化,正在重新创建...") + self.client = self._create_openai_client() + self.default_params = inspect.signature( + self.client.chat.completions.create, + ).parameters.keys() + def __init__(self, provider_config, provider_settings) -> None: super().__init__(provider_config, provider_settings) self.chosen_api_key = None @@ -182,7 +520,6 @@ def __init__(self, provider_config, provider_settings) -> None: for key in self.custom_headers: self.custom_headers[key] = str(self.custom_headers[key]) - # 创建 client self.client = self._create_openai_client() self.default_params = inspect.signature( @@ -194,51 +531,6 @@ def __init__(self, provider_config, provider_settings) -> None: self.reasoning_key = "reasoning_content" - def _create_openai_client(self) -> AsyncOpenAI | AsyncAzureOpenAI: - """创建 OpenAI client 实例(统一初始化逻辑)。 - - 解耦 client 初始化,便于 reload 时重建。 - """ - if "api_version" in self.provider_config: - # Azure OpenAI - return AsyncAzureOpenAI( - api_key=self.chosen_api_key, - api_version=self.provider_config.get("api_version", None), - default_headers=self.custom_headers, - base_url=self.provider_config.get("api_base", ""), - timeout=self.timeout, - http_client=self._create_http_client(self.provider_config), - ) - else: - # OpenAI Official - return AsyncOpenAI( - api_key=self.chosen_api_key, - base_url=self.provider_config.get("api_base", None), - default_headers=self.custom_headers, - timeout=self.timeout, - http_client=self._create_http_client(self.provider_config), - ) - - def _ensure_client(self) -> None: - """确保 client 可用,若已关闭则重建。 - - 多层防御:None 检查 + 底层连接检查。 - 注:访问 self.client._client.is_closed 依赖 openai 库内部实现, - 若 SDK 提供公开 API 应及时替换。 - """ - if self.client is None: - self.client = self._create_openai_client() - return - - # 检查底层 httpx.AsyncClient 是否已关闭 - try: - if hasattr(self.client, '_client') and hasattr(self.client._client, 'is_closed'): - if self.client._client.is_closed: - self.client = self._create_openai_client() - except Exception: - # 若检查失败,安全起见重建 client - self.client = self._create_openai_client() - def _ollama_disable_thinking_enabled(self) -> bool: value = self.provider_config.get("ollama_disable_thinking", False) if isinstance(value, str): @@ -273,6 +565,7 @@ async def get_models(self): async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: self._ensure_client() + if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -324,6 +617,7 @@ async def _query_stream( ) -> AsyncGenerator[LLMResponse, None]: """流式查询API,逐步返回结果""" self._ensure_client() + if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -636,7 +930,7 @@ async def _prepare_chat_payload( new_record = await self.assemble_context( prompt, image_urls, extra_user_content_parts ) - context_query = self._ensure_message_to_dicts(contexts) + context_query = copy.deepcopy(self._ensure_message_to_dicts(contexts)) if new_record: context_query.append(new_record) if system_prompt: @@ -654,6 +948,9 @@ async def _prepare_chat_payload( for tcr in tool_calls_result: context_query.extend(tcr.to_openai_messages()) + if self._context_contains_image(context_query): + context_query = await self._materialize_context_image_parts(context_query) + model = model or self.get_model() payloads = {"messages": context_query, "model": model} @@ -695,7 +992,7 @@ async def _handle_api_error( """处理API错误并尝试恢复""" if "429" in str(e): logger.warning( - f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}", + "API 调用过于频繁,尝试使用其他 Key 重试。" ) # 最后一次不等待 if retry_cnt < max_retries - 1: @@ -754,6 +1051,18 @@ async def _handle_api_error( "image_content_moderated", image_fallback_used=True, ) + if self._is_invalid_attachment_error(e): + if image_fallback_used or not self._context_contains_image(context_query): + raise e + return await self._fallback_to_text_only_and_retry( + payloads, + context_query, + chosen_key, + available_api_keys, + func_tool, + "invalid_attachment", + image_fallback_used=True, + ) if ( "Function calling is not enabled" in str(e) @@ -819,6 +1128,7 @@ async def text_chat( retry_cnt = 0 for retry_cnt in range(max_retries): try: + self._ensure_client() self.client.api_key = chosen_key llm_response = await self._query(payloads, func_tool) break @@ -885,6 +1195,7 @@ async def text_chat_stream( retry_cnt = 0 for retry_cnt in range(max_retries): try: + self._ensure_client() self.client.api_key = chosen_key async for response in self._query_stream(payloads, func_tool): yield response @@ -939,12 +1250,14 @@ async def _remove_image_from_context(self, contexts: list): return new_contexts def get_current_key(self) -> str: + self._ensure_client() return self.client.api_key def get_keys(self) -> list[str]: return self.api_keys def set_key(self, key) -> None: + self._ensure_client() self.client.api_key = key async def assemble_context( @@ -955,23 +1268,6 @@ async def assemble_context( ) -> dict: """组装成符合 OpenAI 格式的 role 为 user 的消息段""" - async def resolve_image_part(image_url: str) -> dict | None: - if image_url.startswith("http"): - image_path = await download_image_by_url(image_url) - image_data = await self.encode_image_bs64(image_path) - elif image_url.startswith("file:///"): - image_path = image_url.replace("file:///", "") - image_data = await self.encode_image_bs64(image_path) - else: - image_data = await self.encode_image_bs64(image_url) - if not image_data: - logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") - return None - return { - "type": "image_url", - "image_url": {"url": image_data}, - } - # 构建内容块列表 content_blocks = [] @@ -991,7 +1287,9 @@ async def resolve_image_part(image_url: str) -> dict | None: if isinstance(part, TextPart): content_blocks.append({"type": "text", "text": part.text}) elif isinstance(part, ImageURLPart): - image_part = await resolve_image_part(part.image_url.url) + image_part = await self._resolve_image_part( + part.image_url.url, + ) if image_part: content_blocks.append(image_part) else: @@ -1000,7 +1298,7 @@ async def resolve_image_part(image_url: str) -> dict | None: # 3. 图片内容 if image_urls: for image_url in image_urls: - image_part = await resolve_image_part(image_url) + image_part = await self._resolve_image_part(image_url) if image_part: content_blocks.append(image_part) @@ -1018,24 +1316,23 @@ async def resolve_image_part(image_url: str) -> dict | None: return {"role": "user", "content": content_blocks} async def encode_image_bs64(self, image_url: str) -> str: - """将图片转换为 base64 Data URL,自动检测实际 MIME 类型。 - - 委托给公共工具函数 encode_image_to_base64_url 实现, - 消除与 ProviderRequest 之间的重复逻辑。原实现硬编码 - image/jpeg,会导致 PNG 等格式在严格校验的接口上报错。 + """将图片转换为 base64""" + image_data = await self._image_ref_to_data_url(image_url, mode="strict") + if image_data is None: + raise RuntimeError(f"Failed to encode image data: {image_url}") + return image_data - Args: - image_url: 本地文件路径,或以 "base64://" 开头的 base64 字符串。 + async def terminate(self): + """关闭 client 并将引用置为 None。 - Returns: - 形如 "data:image/png;base64,..." 的 Data URL 字符串。 + 通过 try/finally 确保即使 close() 抛出异常, + self.client 也会被清空,避免配置重载(reload)期间 + 复用已关闭的 client 导致 APIConnectionError。 """ - return await encode_image_to_base64_url(image_url) - - async def terminate(self): - """关闭 client 并清空引用,防止 reload 时复用已关闭连接。""" if self.client: try: await self.client.close() + except Exception as e: + logger.warning(f"关闭 OpenAI client 时出错: {e}") finally: self.client = None From 1bb72f4df5382994a51e491c3a656b7c31ff2535 Mon Sep 17 00:00:00 2001 From: Tz-WIND <154044203+Tz-WIND@users.noreply.github.com> Date: Fri, 27 Mar 2026 15:19:07 +0800 Subject: [PATCH 5/8] Refactor image encoding to use common utility function --- astrbot/core/provider/entities.py | 47 +++++++++++++++++++------------ 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 62c6b82ee3..a4426ad83c 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -189,41 +189,52 @@ async def assemble_context(self) -> dict: for image_url in self.image_urls: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) - image_data = await self._encode_image_bs64(image_path) + logger.debug( + "[ProviderRequest] 下载远程图片 %s -> %s", + image_url, + image_path, + ) + # 统一通过公共工具函数编码,自动检测 MIME 类型 + image_data = await encode_image_to_base64_url(image_path) elif image_url.startswith("file:///"): image_path = image_url.replace("file:///", "") - image_data = await self._encode_image_bs64(image_path) + logger.debug( + "[ProviderRequest] 读取本地文件 %s", + image_path, + ) + image_data = await encode_image_to_base64_url(image_path) else: - image_data = await self._encode_image_bs64(image_url) + logger.debug( + "[ProviderRequest] 编码图片引用 %s", + image_url[:80], + ) + image_data = await encode_image_to_base64_url(image_url) if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue + # image_data 已是 data:;base64,... 格式, + # MIME 类型由 encode_image_to_base64_url 内部检测并输出 debug 日志 + logger.debug( + "[ProviderRequest] 图片编码完成,Data URL 前缀: %s", + image_data[:40], + ) content_blocks.append( {"type": "image_url", "image_url": {"url": image_data}}, ) - # 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容 + # 只有当只有一个来自 prompt 的文本块且没有额外内容块时, + # 才降级为简单格式以保持向后兼容 if ( - len(content_blocks) == 1 - and content_blocks[0]["type"] == "text" - and not self.extra_user_content_parts - and not self.image_urls + len(content_blocks) == 1 + and content_blocks[0]["type"] == "text" + and not self.extra_user_content_parts + and not self.image_urls ): return {"role": "user", "content": content_blocks[0]["text"]} # 否则返回多模态格式 return {"role": "user", "content": content_blocks} - async def _encode_image_bs64(self, image_url: str) -> str: - """将图片转换为 base64""" - if image_url.startswith("base64://"): - return image_url.replace("base64://", "data:image/jpeg;base64,") - with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 - return "" - - @dataclass class TokenUsage: input_other: int = 0 From 095c7a5439559f63f93b365327f3273c0d4db3a5 Mon Sep 17 00:00:00 2001 From: Tz-WIND <154044203+Tz-WIND@users.noreply.github.com> Date: Fri, 27 Mar 2026 15:44:57 +0800 Subject: [PATCH 6/8] =?UTF-8?q?=E6=B6=88=E9=99=A4=20base64=20MIME=20?= =?UTF-8?q?=E6=A3=80=E6=B5=8B=E7=9A=84=E9=87=8D=E5=A4=8D=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/provider/sources/openai_source.py | 74 ++++++++----------- 1 file changed, 30 insertions(+), 44 deletions(-) diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index c4f5f565fd..3171308296 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -35,7 +35,10 @@ log_connection_failure, ) from astrbot.core.utils.string_utils import normalize_and_dedupe_strings -from astrbot.core.utils.image_utils import detect_image_mime_type +from astrbot.core.utils.image_utils import ( + detect_image_mime_type, + detect_mime_type_from_base64_str, +) from ..register import register_provider_adapter @@ -306,29 +309,16 @@ def _file_uri_to_path(file_uri: str) -> str: return str(Path(path)) async def _image_ref_to_data_url( - self, - image_ref: str, - *, - mode: Literal["safe", "strict"] = "safe", + self, + image_ref: str, + *, + mode: Literal["safe", "strict"] = "safe", ) -> str | None: if image_ref.startswith("base64://"): raw_b64 = image_ref[len("base64://"):] - try: - sample = raw_b64[:_BASE64_SAMPLE_CHARS] - missing_padding = len(sample) % 4 - if missing_padding: - sample += '=' * (4 - missing_padding) - header_bytes = base64.b64decode(sample) - mime_type = detect_image_mime_type(header_bytes) - logger.debug( - "[base64://] 魔术字节检测命中,识别为 %s 格式", - mime_type, - ) - except Exception: - mime_type = "image/jpeg" - logger.debug( - "[base64://] base64 解码失败,硬编码回退为 image/jpeg" - ) + mime_type = detect_mime_type_from_base64_str( + raw_b64, source_hint="adapter/base64://" + ) return f"data:{mime_type};base64,{raw_b64}" if image_ref.startswith("http"): @@ -343,6 +333,7 @@ async def _image_ref_to_data_url( mode=mode, ) + async def _resolve_image_part( self, image_url: str, @@ -473,31 +464,25 @@ def _create_openai_client(self) -> AsyncOpenAI | AsyncAzureOpenAI: http_client=self._create_http_client(self.provider_config), ) - def _ensure_client(self) -> None: - """确保 client 可用。如果 client 为 None 或底层连接已关闭,则重新创建。 + def _is_underlying_client_closed(self) -> bool: + """集中处理对 openai SDK 私有属性的访问,便于未来替换为公开 API。 - 这提供了多层防御:terminate() 会将 client 置为 None, - 配置重载(reload)期间若复用已关闭的 client 会导致 APIConnectionError, - 此方法在每次使用前检查并自动重建。 + 注意:此处直接访问了 openai 库的私有属性 `_client`, + 依赖其内部实现(httpx.AsyncClient 实例暴露的 is_closed 属性)。 + 这一做法存在脆弱性——若 openai 库未来版本调整了内部结构, + 此处可能在没有任何报错的情况下静默失效。 + 目前 openai SDK 尚未提供检查底层连接是否已关闭的公开 API。 + 若未来 SDK 提供了类似 self.client.is_closed() 的公开方法, + 应及时将此处替换为对应的公开接口。 """ - need_reinit = False - if self.client is None: - need_reinit = True - else: - try: - # 注意:此处直接访问了 openai 库的私有属性 `_client`, - # 依赖其内部实现(httpx.AsyncClient 实例暴露的 is_closed 属性)。 - # 这一做法存在脆弱性——若 openai 库未来版本调整了内部结构, - # 此处可能在没有任何报错的情况下静默失效。 - # 目前 openai SDK 尚未提供检查底层连接是否已关闭的公开 API。 - # 若未来 SDK 提供了类似 self.client.is_closed() 的公开方法, - # 应及时将此处替换为对应的公开接口。 - if self.client._client.is_closed: - need_reinit = True - except AttributeError: - pass - - if need_reinit: + try: + return bool(self.client and self.client._client.is_closed) + except AttributeError: + return False + + def _ensure_client(self) -> None: + """确保 client 可用。如果 client 为 None 或底层连接已关闭,则重新创建。""" + if self.client is None or self._is_underlying_client_closed(): logger.warning("检测到 OpenAI client 已关闭或未初始化,正在重新创建...") self.client = self._create_openai_client() self.default_params = inspect.signature( @@ -1336,3 +1321,4 @@ async def terminate(self): logger.warning(f"关闭 OpenAI client 时出错: {e}") finally: self.client = None + From 6a46b140110b4562db5e11123ce2e1c5b968b09d Mon Sep 17 00:00:00 2001 From: Tz-WIND <154044203+Tz-WIND@users.noreply.github.com> Date: Fri, 27 Mar 2026 15:45:57 +0800 Subject: [PATCH 7/8] Refactor image_utils for async and improved MIME detection Refactor image utility functions to support async operations and improve MIME type detection. Enhance base64 encoding logic for better compatibility with various image formats. --- astrbot/core/utils/image_utils.py | 114 ++++++++++++++++++++++-------- 1 file changed, 85 insertions(+), 29 deletions(-) diff --git a/astrbot/core/utils/image_utils.py b/astrbot/core/utils/image_utils.py index 1d10ec65e3..e2f8c0f702 100644 --- a/astrbot/core/utils/image_utils.py +++ b/astrbot/core/utils/image_utils.py @@ -5,7 +5,11 @@ from __future__ import annotations +import asyncio import base64 +import binascii + +from astrbot import logger def detect_image_mime_type(header_bytes: bytes) -> str: @@ -15,6 +19,8 @@ def detect_image_mime_type(header_bytes: bytes) -> str: 以保持向后兼容。支持的格式:JPEG、PNG、GIF、WebP、BMP、TIFF、 ICO、SVG、AVIF、HEIF/HEIC。 + 注意:此函数为纯检测逻辑,不输出日志,日志由调用方负责。 + Args: header_bytes: 文件头原始字节。SVG检测需要至少 256 字节; 其他二进制格式最多需要 16 字节。 @@ -38,7 +44,7 @@ def detect_image_mime_type(header_bytes: bytes) -> str: return "image/tiff" if len(header_bytes) >= 4 and header_bytes[:4] == b'\x00\x00\x01\x00': return "image/x-icon" - # SVG 为文本格式,检测头部是否含有 str: # SVG 检测所需的最小头部字节数。 -# 其他二进制格式的魔术字节最多 16 字节,SVG 需要更多以跳过可能的 XML 声明。 _HEADER_READ_SIZE = 256 # 对应 _HEADER_READ_SIZE 字节所需的 base64 字符数。 -# base64 每 3 字节编码为 4 字符,向上取整后再补充少量余量保证填充对齐。 # ceil(256 / 3) * 4 = 344 _BASE64_SAMPLE_CHARS = 344 +def detect_mime_type_from_base64_str( + raw_b64: str, + source_hint: str = "", +) -> str: + """从 base64 编码字符串中采样解码头部字节并检测 MIME 类型。 + + 统一所有 base64 输入的 MIME 检测逻辑,避免在多处重复 + 采样/解码/回退的代码。解码失败时安全回退到 image/jpeg。 + + Args: + raw_b64: 原始 base64 编码字符串(不含 "base64://" 前缀)。 + source_hint: 日志中标识调用来源的提示字符串。 + + Returns: + 检测到的 MIME 类型字符串。 + """ + label = source_hint or "base64" + try: + sample = raw_b64[:_BASE64_SAMPLE_CHARS] + missing_padding = len(sample) % 4 + if missing_padding: + sample += '=' * (4 - missing_padding) + header_bytes = base64.b64decode(sample) + except (binascii.Error, ValueError) as exc: + logger.debug( + "[%s] base64 解码失败: %s,硬编码回退为 image/jpeg", + label, + exc, + ) + return "image/jpeg" + + mime_type = detect_image_mime_type(header_bytes) + logger.debug( + "[%s] 魔术字节检测命中,识别为 %s 格式", + label, + mime_type, + ) + return mime_type + + +def _sync_encode_from_file(path: str) -> tuple[str, str]: + """同步读取文件并编码为 base64,同时检测 MIME 类型。 + + 此函数包含阻塞式文件 I/O,设计为通过 run_in_executor + 在线程池中执行,避免阻塞 asyncio 事件循环。 + + Args: + path: 本地文件路径。 + + Returns: + (mime_type, image_bs64) 元组。 + """ + with open(path, "rb") as f: + header_bytes = f.read(_HEADER_READ_SIZE) + mime_type = detect_image_mime_type(header_bytes) + f.seek(0) + image_bs64 = base64.b64encode(f.read()).decode("utf-8") + return mime_type, image_bs64 + + async def encode_image_to_base64_url(image_url: str) -> str: """将图片转换为 base64 Data URL,自动检测实际 MIME 类型。 - 原实现硬编码 image/jpeg,会导致 PNG 等格式在严格校验的接口上报错。 - 现通过读取文件头魔术字节来推断正确的 MIME 类型。 - - 对于 SVG 文件,需要读取至少 256 字节的头部才能可靠检测, - 因此统一将读取/解码量提升到 256 字节,消除之前字节数不足的问题。 + 对于 base64:// 输入,委托 detect_mime_type_from_base64_str + 统一处理采样/解码/回退逻辑。 + 对于文件路径输入,阻塞式文件 I/O 通过 run_in_executor + 移至线程池执行,避免在高并发场景下阻塞 asyncio 事件循环。 Args: image_url: 本地文件路径,或以 "base64://" 开头的 base64 字符串。 @@ -81,26 +144,19 @@ async def encode_image_to_base64_url(image_url: str) -> str: """ if image_url.startswith("base64://"): raw_b64 = image_url[len("base64://"):] - # 从 base64 数据中解码足量字节以检测实际格式。 - # 取前 344 个 base64 字符可解码出约 258 字节,足以覆盖 SVG 检测所需的 256 字节。 - try: - sample = raw_b64[:_BASE64_SAMPLE_CHARS] - # 确保 base64 填充正确,避免解码报错 - missing_padding = len(sample) % 4 - if missing_padding: - sample += '=' * (4 - missing_padding) - header_bytes = base64.b64decode(sample) - mime_type = detect_image_mime_type(header_bytes) - except Exception: - # 解码失败时安全回退 - mime_type = "image/jpeg" + mime_type = detect_mime_type_from_base64_str( + raw_b64, source_hint="image_utils/base64://" + ) return f"data:{mime_type};base64,{raw_b64}" - with open(image_url, "rb") as f: - # 读取 256 字节用于格式检测,以支持需要较多头部数据的 SVG 等格式, - # 再 seek 回起点读取完整内容进行 base64 编码。 - header_bytes = f.read(_HEADER_READ_SIZE) - mime_type = detect_image_mime_type(header_bytes) - f.seek(0) - image_bs64 = base64.b64encode(f.read()).decode("utf-8") - return f"data:{mime_type};base64,{image_bs64}" \ No newline at end of file + # 文件 I/O 为阻塞操作,移至线程池执行以保持事件循环响应性 + loop = asyncio.get_running_loop() + mime_type, image_bs64 = await loop.run_in_executor( + None, _sync_encode_from_file, image_url + ) + logger.debug( + "[image_utils][%s] 魔术字节检测命中,识别为 %s 格式", + image_url, + mime_type, + ) + return f"data:{mime_type};base64,{image_bs64}" From 59ec75a66c55ae7d9b47dc5a2460f9b617ca57ec Mon Sep 17 00:00:00 2001 From: Tz-WIND <154044203+Tz-WIND@users.noreply.github.com> Date: Fri, 27 Mar 2026 16:15:30 +0800 Subject: [PATCH 8/8] Refactor image MIME type detection and encoding Refactor image MIME type detection and encoding logic. Introduce new method to encode image files to base64 Data URLs with improved error handling and fallback mechanisms. --- .../core/provider/sources/openai_source.py | 177 +++++++++--------- 1 file changed, 88 insertions(+), 89 deletions(-) diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 3171308296..f8c8179558 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -36,8 +36,8 @@ ) from astrbot.core.utils.string_utils import normalize_and_dedupe_strings from astrbot.core.utils.image_utils import ( - detect_image_mime_type, detect_mime_type_from_base64_str, + encode_image_to_base64_url, ) from ..register import register_provider_adapter @@ -183,107 +183,97 @@ def _is_invalid_attachment_error(self, error: Exception) -> bool: return False @classmethod - def _detect_mime_type_with_fallback( + def _encode_image_file_to_data_url( cls, - image_bytes: bytes, - source_hint: str = "", - ) -> str: - """检测图片的实际 MIME 类型,使用三级回退策略。 + image_path: str, + *, + mode: Literal["safe", "strict"], + ) -> str | None: + """将本地图片文件编码为 base64 Data URL。 - 检测优先级: - 1. 主检测:基于文件头魔术字节的 detect_image_mime_type(来自 image_utils) - 2. 回退:PIL 库识别 - 3. 回退的回退:硬编码 image/jpeg + 委托给公共工具函数 encode_image_to_base64_url 实现 + 文件读取、MIME 检测与 base64 编码,避免重复逻辑。 Args: - image_bytes: 图片的原始字节数据。 - source_hint: 可选的来源描述,用于 debug 日志输出。 + image_path: 本地文件路径。 + mode: "safe" 模式下读取失败返回 None; + "strict" 模式下读取或校验失败直接抛出异常。 Returns: - 对应的 MIME 类型字符串,例如 "image/png"。 + 形如 "data:image/png;base64,..." 的 Data URL, + 或 None(仅 safe 模式)。 """ - log_prefix = f"[{source_hint}] " if source_hint else "" + try: + # encode_image_to_base64_url 是异步函数,但此处为 classmethod, + # 需在同步上下文中处理文件。直接读取文件字节进行编码。 + image_bytes = Path(image_path).read_bytes() + except OSError: + if mode == "strict": + raise + return None + + if not image_bytes: + if mode == "strict": + raise ValueError(f"图片文件为空: {image_path}") + return None + + # 使用魔术字节检测 MIME 类型 + from astrbot.core.utils.image_utils import ( + _HEADER_READ_SIZE, + detect_image_mime_type, + ) - # --- 第一级:魔术字节检测(主检测) --- header = image_bytes[:_HEADER_READ_SIZE] mime_type = detect_image_mime_type(header) - # detect_image_mime_type 在无法识别时回退到 "image/jpeg", - # 因此只有当返回值不是默认回退值时才认为检测成功。 - # 但如果文件头确实是 JPEG(\xff\xd8\xff),那也是真正检测到了, - # 所以额外检查文件头是否是 JPEG 的魔术字节。 - if mime_type != "image/jpeg" or ( + + # 当魔术字节检测回退到默认值且文件头不是 JPEG 时, + # 尝试 PIL 作为二级检测 + if mime_type == "image/jpeg" and not ( len(header) >= 3 and header[:3] == b'\xff\xd8\xff' ): - logger.debug( - "%s魔术字节检测命中,识别为 %s 格式", - log_prefix, - mime_type, - ) - return mime_type - - logger.debug( - "%s魔术字节检测未命中,尝试 PIL 回退...", - log_prefix, - ) - - # --- 第二级:PIL 回退 --- - try: - with PILImage.open(BytesIO(image_bytes)) as image: - image.verify() - image_format = str(image.format or "").upper() - pil_mime = { - "JPEG": "image/jpeg", - "PNG": "image/png", - "GIF": "image/gif", - "WEBP": "image/webp", - "BMP": "image/bmp", - }.get(image_format) - if pil_mime: + try: + with PILImage.open(BytesIO(image_bytes)) as image: + image.verify() + image_format = str(image.format or "").upper() + pil_mime = { + "JPEG": "image/jpeg", + "PNG": "image/png", + "GIF": "image/gif", + "WEBP": "image/webp", + "BMP": "image/bmp", + }.get(image_format) + if pil_mime: + logger.debug( + "[%s] PIL 回退命中,识别为 %s 格式(PIL format: %s)", + image_path, + pil_mime, + image_format, + ) + mime_type = pil_mime + else: + logger.debug( + "[%s] PIL 识别到格式 %s 但无对应 MIME 映射", + image_path, + image_format, + ) + except (OSError, UnidentifiedImageError) as exc: + # strict 模式下,魔术字节和 PIL 都无法识别,说明不是合法图片 + if mode == "strict": + raise ValueError( + f"无法识别的图片文件: {image_path}" + ) from exc logger.debug( - "%sPIL 回退命中,识别为 %s 格式(PIL format: %s)", - log_prefix, - pil_mime, - image_format, + "[%s] PIL 回退失败: %s,硬编码回退为 image/jpeg", + image_path, + exc, ) - return pil_mime - logger.debug( - "%sPIL 识别到格式 %s 但无对应 MIME 映射,继续回退...", - log_prefix, - image_format, - ) - except (OSError, UnidentifiedImageError) as exc: + else: logger.debug( - "%sPIL 回退失败: %s,继续回退...", - log_prefix, - exc, + "[%s] 魔术字节检测命中,识别为 %s 格式", + image_path, + mime_type, ) - # --- 第三级:硬编码回退 --- - logger.debug( - "%s所有检测均未命中,硬编码回退为 image/jpeg", - log_prefix, - ) - return "image/jpeg" - - @classmethod - def _encode_image_file_to_data_url( - cls, - image_path: str, - *, - mode: Literal["safe", "strict"], - ) -> str | None: - try: - image_bytes = Path(image_path).read_bytes() - except OSError: - if mode == "strict": - raise - return None - - # 使用三级回退策略检测 MIME 类型 - mime_type = cls._detect_mime_type_with_fallback( - image_bytes, source_hint=image_path - ) - image_bs64 = base64.b64encode(image_bytes).decode("utf-8") return f"data:{mime_type};base64,{image_bs64}" @@ -334,6 +324,7 @@ async def _image_ref_to_data_url( ) + async def _resolve_image_part( self, image_url: str, @@ -469,16 +460,24 @@ def _is_underlying_client_closed(self) -> bool: 注意:此处直接访问了 openai 库的私有属性 `_client`, 依赖其内部实现(httpx.AsyncClient 实例暴露的 is_closed 属性)。 - 这一做法存在脆弱性——若 openai 库未来版本调整了内部结构, - 此处可能在没有任何报错的情况下静默失效。 + 若 openai 库未来版本调整了内部结构,此处可能失效。 目前 openai SDK 尚未提供检查底层连接是否已关闭的公开 API。 若未来 SDK 提供了类似 self.client.is_closed() 的公开方法, 应及时将此处替换为对应的公开接口。 + + 当检测逻辑因 SDK 内部结构变更而抛出 AttributeError 时,会: + 1. 记录 warning 日志,提示可能的 SDK 变更; + 2. 保守地视为"已关闭",触发后续的 client 重建逻辑。 """ try: return bool(self.client and self.client._client.is_closed) except AttributeError: - return False + logger.warning( + "无法检测 OpenAI client 是否已关闭," + "可能是 SDK 内部结构变更导致;" + "将保守视为已关闭并触发 client 重建。" + ) + return True def _ensure_client(self) -> None: """确保 client 可用。如果 client 为 None 或底层连接已关闭,则重新创建。""" @@ -1321,4 +1320,4 @@ async def terminate(self): logger.warning(f"关闭 OpenAI client 时出错: {e}") finally: self.client = None - +