diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 20c5a7947d..a4426ad83c 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -1,6 +1,5 @@ from __future__ import annotations -import base64 import enum import json from dataclasses import dataclass, field @@ -21,6 +20,7 @@ from astrbot.core.agent.tool import ToolSet from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.utils.image_utils import encode_image_to_base64_url from astrbot.core.utils.io import download_image_by_url @@ -189,41 +189,52 @@ async def assemble_context(self) -> dict: for image_url in self.image_urls: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) - image_data = await self._encode_image_bs64(image_path) + logger.debug( + "[ProviderRequest] 下载远程图片 %s -> %s", + image_url, + image_path, + ) + # 统一通过公共工具函数编码,自动检测 MIME 类型 + image_data = await encode_image_to_base64_url(image_path) elif image_url.startswith("file:///"): image_path = image_url.replace("file:///", "") - image_data = await self._encode_image_bs64(image_path) + logger.debug( + "[ProviderRequest] 读取本地文件 %s", + image_path, + ) + image_data = await encode_image_to_base64_url(image_path) else: - image_data = await self._encode_image_bs64(image_url) + logger.debug( + "[ProviderRequest] 编码图片引用 %s", + image_url[:80], + ) + image_data = await encode_image_to_base64_url(image_url) if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue + # image_data 已是 data:;base64,... 格式, + # MIME 类型由 encode_image_to_base64_url 内部检测并输出 debug 日志 + logger.debug( + "[ProviderRequest] 图片编码完成,Data URL 前缀: %s", + image_data[:40], + ) content_blocks.append( {"type": "image_url", "image_url": {"url": image_data}}, ) - # 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容 + # 只有当只有一个来自 prompt 的文本块且没有额外内容块时, + # 才降级为简单格式以保持向后兼容 if ( - len(content_blocks) == 1 - and content_blocks[0]["type"] == "text" - and not self.extra_user_content_parts - and not self.image_urls + len(content_blocks) == 1 + and content_blocks[0]["type"] == "text" + and not self.extra_user_content_parts + and not self.image_urls ): return {"role": "user", "content": content_blocks[0]["text"]} # 否则返回多模态格式 return {"role": "user", "content": content_blocks} - async def _encode_image_bs64(self, image_url: str) -> str: - """将图片转换为 base64""" - if image_url.startswith("base64://"): - return image_url.replace("base64://", "data:image/jpeg;base64,") - with open(image_url, "rb") as f: - image_bs64 = base64.b64encode(f.read()).decode("utf-8") - return "data:image/jpeg;base64," + image_bs64 - return "" - - @dataclass class TokenUsage: input_other: int = 0 diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index cdad66a22f..f8c8179558 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -35,10 +35,24 @@ log_connection_failure, ) from astrbot.core.utils.string_utils import normalize_and_dedupe_strings +from astrbot.core.utils.image_utils import ( + detect_mime_type_from_base64_str, + encode_image_to_base64_url, +) from ..register import register_provider_adapter +# SVG 检测所需的最小头部字节数。 +# 其他二进制格式的魔术字节最多 16 字节,SVG 需要更多以跳过可能的 XML 声明。 +_HEADER_READ_SIZE = 256 + +# 对应 _HEADER_READ_SIZE 字节所需的 base64 字符数。 +# base64 每 3 字节编码为 4 字符,向上取整后再补充少量余量保证填充对齐。 +# ceil(256 / 3) * 4 = 344 +_BASE64_SAMPLE_CHARS = 344 + + @register_provider_adapter( "openai_chat_completion", "OpenAI API Chat Completion 提供商适配器", @@ -170,34 +184,96 @@ def _is_invalid_attachment_error(self, error: Exception) -> bool: @classmethod def _encode_image_file_to_data_url( - cls, - image_path: str, - *, - mode: Literal["safe", "strict"], + cls, + image_path: str, + *, + mode: Literal["safe", "strict"], ) -> str | None: + """将本地图片文件编码为 base64 Data URL。 + + 委托给公共工具函数 encode_image_to_base64_url 实现 + 文件读取、MIME 检测与 base64 编码,避免重复逻辑。 + + Args: + image_path: 本地文件路径。 + mode: "safe" 模式下读取失败返回 None; + "strict" 模式下读取或校验失败直接抛出异常。 + + Returns: + 形如 "data:image/png;base64,..." 的 Data URL, + 或 None(仅 safe 模式)。 + """ try: + # encode_image_to_base64_url 是异步函数,但此处为 classmethod, + # 需在同步上下文中处理文件。直接读取文件字节进行编码。 image_bytes = Path(image_path).read_bytes() except OSError: if mode == "strict": raise return None - try: - with PILImage.open(BytesIO(image_bytes)) as image: - image.verify() - image_format = str(image.format or "").upper() - except (OSError, UnidentifiedImageError): + if not image_bytes: if mode == "strict": - raise ValueError(f"Invalid image file: {image_path}") + raise ValueError(f"图片文件为空: {image_path}") return None - mime_type = { - "JPEG": "image/jpeg", - "PNG": "image/png", - "GIF": "image/gif", - "WEBP": "image/webp", - "BMP": "image/bmp", - }.get(image_format, "image/jpeg") + # 使用魔术字节检测 MIME 类型 + from astrbot.core.utils.image_utils import ( + _HEADER_READ_SIZE, + detect_image_mime_type, + ) + + header = image_bytes[:_HEADER_READ_SIZE] + mime_type = detect_image_mime_type(header) + + # 当魔术字节检测回退到默认值且文件头不是 JPEG 时, + # 尝试 PIL 作为二级检测 + if mime_type == "image/jpeg" and not ( + len(header) >= 3 and header[:3] == b'\xff\xd8\xff' + ): + try: + with PILImage.open(BytesIO(image_bytes)) as image: + image.verify() + image_format = str(image.format or "").upper() + pil_mime = { + "JPEG": "image/jpeg", + "PNG": "image/png", + "GIF": "image/gif", + "WEBP": "image/webp", + "BMP": "image/bmp", + }.get(image_format) + if pil_mime: + logger.debug( + "[%s] PIL 回退命中,识别为 %s 格式(PIL format: %s)", + image_path, + pil_mime, + image_format, + ) + mime_type = pil_mime + else: + logger.debug( + "[%s] PIL 识别到格式 %s 但无对应 MIME 映射", + image_path, + image_format, + ) + except (OSError, UnidentifiedImageError) as exc: + # strict 模式下,魔术字节和 PIL 都无法识别,说明不是合法图片 + if mode == "strict": + raise ValueError( + f"无法识别的图片文件: {image_path}" + ) from exc + logger.debug( + "[%s] PIL 回退失败: %s,硬编码回退为 image/jpeg", + image_path, + exc, + ) + else: + logger.debug( + "[%s] 魔术字节检测命中,识别为 %s 格式", + image_path, + mime_type, + ) + image_bs64 = base64.b64encode(image_bytes).decode("utf-8") return f"data:{mime_type};base64,{image_bs64}" @@ -229,7 +305,11 @@ async def _image_ref_to_data_url( mode: Literal["safe", "strict"] = "safe", ) -> str | None: if image_ref.startswith("base64://"): - return image_ref.replace("base64://", "data:image/jpeg;base64,") + raw_b64 = image_ref[len("base64://"):] + mime_type = detect_mime_type_from_base64_str( + raw_b64, source_hint="adapter/base64://" + ) + return f"data:{mime_type};base64,{raw_b64}" if image_ref.startswith("http"): image_path = await download_image_by_url(image_ref) @@ -243,6 +323,8 @@ async def _image_ref_to_data_url( mode=mode, ) + + async def _resolve_image_part( self, image_url: str, @@ -351,6 +433,61 @@ def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None proxy = provider_config.get("proxy", "") return create_proxy_client("OpenAI", proxy) + def _create_openai_client(self) -> AsyncOpenAI | AsyncAzureOpenAI: + """创建 OpenAI/Azure 客户端实例,将初始化逻辑解耦以便复用。""" + if "api_version" in self.provider_config: + # Using Azure OpenAI API + return AsyncAzureOpenAI( + api_key=self.chosen_api_key, + api_version=self.provider_config.get("api_version", None), + default_headers=self.custom_headers, + base_url=self.provider_config.get("api_base", ""), + timeout=self.timeout, + http_client=self._create_http_client(self.provider_config), + ) + else: + # Using OpenAI Official API + return AsyncOpenAI( + api_key=self.chosen_api_key, + base_url=self.provider_config.get("api_base", None), + default_headers=self.custom_headers, + timeout=self.timeout, + http_client=self._create_http_client(self.provider_config), + ) + + def _is_underlying_client_closed(self) -> bool: + """集中处理对 openai SDK 私有属性的访问,便于未来替换为公开 API。 + + 注意:此处直接访问了 openai 库的私有属性 `_client`, + 依赖其内部实现(httpx.AsyncClient 实例暴露的 is_closed 属性)。 + 若 openai 库未来版本调整了内部结构,此处可能失效。 + 目前 openai SDK 尚未提供检查底层连接是否已关闭的公开 API。 + 若未来 SDK 提供了类似 self.client.is_closed() 的公开方法, + 应及时将此处替换为对应的公开接口。 + + 当检测逻辑因 SDK 内部结构变更而抛出 AttributeError 时,会: + 1. 记录 warning 日志,提示可能的 SDK 变更; + 2. 保守地视为"已关闭",触发后续的 client 重建逻辑。 + """ + try: + return bool(self.client and self.client._client.is_closed) + except AttributeError: + logger.warning( + "无法检测 OpenAI client 是否已关闭," + "可能是 SDK 内部结构变更导致;" + "将保守视为已关闭并触发 client 重建。" + ) + return True + + def _ensure_client(self) -> None: + """确保 client 可用。如果 client 为 None 或底层连接已关闭,则重新创建。""" + if self.client is None or self._is_underlying_client_closed(): + logger.warning("检测到 OpenAI client 已关闭或未初始化,正在重新创建...") + self.client = self._create_openai_client() + self.default_params = inspect.signature( + self.client.chat.completions.create, + ).parameters.keys() + def __init__(self, provider_config, provider_settings) -> None: super().__init__(provider_config, provider_settings) self.chosen_api_key = None @@ -367,25 +504,7 @@ def __init__(self, provider_config, provider_settings) -> None: for key in self.custom_headers: self.custom_headers[key] = str(self.custom_headers[key]) - if "api_version" in provider_config: - # Using Azure OpenAI API - self.client = AsyncAzureOpenAI( - api_key=self.chosen_api_key, - api_version=provider_config.get("api_version", None), - default_headers=self.custom_headers, - base_url=provider_config.get("api_base", ""), - timeout=self.timeout, - http_client=self._create_http_client(provider_config), - ) - else: - # Using OpenAI Official API - self.client = AsyncOpenAI( - api_key=self.chosen_api_key, - base_url=provider_config.get("api_base", None), - default_headers=self.custom_headers, - timeout=self.timeout, - http_client=self._create_http_client(provider_config), - ) + self.client = self._create_openai_client() self.default_params = inspect.signature( self.client.chat.completions.create, @@ -417,6 +536,7 @@ def _apply_provider_specific_extra_body_overrides( extra_body["reasoning_effort"] = "none" async def get_models(self): + self._ensure_client() try: models_str = [] models = await self.client.models.list() @@ -428,6 +548,8 @@ async def get_models(self): raise Exception(f"获取模型列表失败:{e}") async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: + self._ensure_client() + if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -478,6 +600,8 @@ async def _query_stream( tools: ToolSet | None, ) -> AsyncGenerator[LLMResponse, None]: """流式查询API,逐步返回结果""" + self._ensure_client() + if tools: model = payloads.get("model", "").lower() omit_empty_param_field = "gemini" in model @@ -852,7 +976,7 @@ async def _handle_api_error( """处理API错误并尝试恢复""" if "429" in str(e): logger.warning( - f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}", + "API 调用过于频繁,尝试使用其他 Key 重试。" ) # 最后一次不等待 if retry_cnt < max_retries - 1: @@ -988,6 +1112,7 @@ async def text_chat( retry_cnt = 0 for retry_cnt in range(max_retries): try: + self._ensure_client() self.client.api_key = chosen_key llm_response = await self._query(payloads, func_tool) break @@ -1054,6 +1179,7 @@ async def text_chat_stream( retry_cnt = 0 for retry_cnt in range(max_retries): try: + self._ensure_client() self.client.api_key = chosen_key async for response in self._query_stream(payloads, func_tool): yield response @@ -1108,12 +1234,14 @@ async def _remove_image_from_context(self, contexts: list): return new_contexts def get_current_key(self) -> str: + self._ensure_client() return self.client.api_key def get_keys(self) -> list[str]: return self.api_keys def set_key(self, key) -> None: + self._ensure_client() self.client.api_key = key async def assemble_context( @@ -1179,5 +1307,17 @@ async def encode_image_bs64(self, image_url: str) -> str: return image_data async def terminate(self): + """关闭 client 并将引用置为 None。 + + 通过 try/finally 确保即使 close() 抛出异常, + self.client 也会被清空,避免配置重载(reload)期间 + 复用已关闭的 client 导致 APIConnectionError。 + """ if self.client: - await self.client.close() + try: + await self.client.close() + except Exception as e: + logger.warning(f"关闭 OpenAI client 时出错: {e}") + finally: + self.client = None + diff --git a/astrbot/core/utils/image_utils.py b/astrbot/core/utils/image_utils.py new file mode 100644 index 0000000000..e2f8c0f702 --- /dev/null +++ b/astrbot/core/utils/image_utils.py @@ -0,0 +1,162 @@ +"""图片处理工具函数。 +提供 MIME 类型检测与 base64 Data URL 编码等公共能力, +供 ProviderRequest 及各 Provider 适配器复用,避免重复实现。 +""" + +from __future__ import annotations + +import asyncio +import base64 +import binascii + +from astrbot import logger + + +def detect_image_mime_type(header_bytes: bytes) -> str: + """根据文件头magic bytes检测图片的实际 MIME 类型。 + + 依次匹配常见图片格式的文件头特征,均不匹配时回退到 image/jpeg + 以保持向后兼容。支持的格式:JPEG、PNG、GIF、WebP、BMP、TIFF、 + ICO、SVG、AVIF、HEIF/HEIC。 + + 注意:此函数为纯检测逻辑,不输出日志,日志由调用方负责。 + + Args: + header_bytes: 文件头原始字节。SVG检测需要至少 256 字节; + 其他二进制格式最多需要 16 字节。 + + Returns: + 对应的 MIME 类型字符串,例如 "image/png"。 + """ + if len(header_bytes) >= 3 and header_bytes[:3] == b'\xff\xd8\xff': + return "image/jpeg" + if len(header_bytes) >= 8 and header_bytes[:8] == b'\x89PNG\r\n\x1a\n': + return "image/png" + if len(header_bytes) >= 4 and header_bytes[:4] == b'GIF8': + return "image/gif" + # WebP: RIFF????WEBP + if len(header_bytes) >= 12 and header_bytes[:4] == b'RIFF' and header_bytes[8:12] == b'WEBP': + return "image/webp" + if len(header_bytes) >= 2 and header_bytes[:2] == b'BM': + return "image/bmp" + # TIFF: 小端 (II) 或大端 (MM) + if len(header_bytes) >= 4 and header_bytes[:4] in (b'II\x2a\x00', b'MM\x00\x2a'): + return "image/tiff" + if len(header_bytes) >= 4 and header_bytes[:4] == b'\x00\x00\x01\x00': + return "image/x-icon" + # SVG 为文本格式,检测头部是否含有 《svg 标签。 + # 调用方需传入至少 256 字节以覆盖带有 XML 声明的 SVG 文件。 + if b'= 12 and header_bytes[4:12] == b'ftypavif': + return "image/avif" + # HEIF/HEIC: ftyp box,brand 为 heic/heix/hevc/hevx/mif1 + if len(header_bytes) >= 12 and header_bytes[4:8] == b'ftyp': + brand = header_bytes[8:12] + if brand in (b'heic', b'heix', b'hevc', b'hevx', b'mif1'): + return "image/heif" + # 无法识别,回退到 image/jpeg + return "image/jpeg" + + +# SVG 检测所需的最小头部字节数。 +_HEADER_READ_SIZE = 256 + +# 对应 _HEADER_READ_SIZE 字节所需的 base64 字符数。 +# ceil(256 / 3) * 4 = 344 +_BASE64_SAMPLE_CHARS = 344 + + +def detect_mime_type_from_base64_str( + raw_b64: str, + source_hint: str = "", +) -> str: + """从 base64 编码字符串中采样解码头部字节并检测 MIME 类型。 + + 统一所有 base64 输入的 MIME 检测逻辑,避免在多处重复 + 采样/解码/回退的代码。解码失败时安全回退到 image/jpeg。 + + Args: + raw_b64: 原始 base64 编码字符串(不含 "base64://" 前缀)。 + source_hint: 日志中标识调用来源的提示字符串。 + + Returns: + 检测到的 MIME 类型字符串。 + """ + label = source_hint or "base64" + try: + sample = raw_b64[:_BASE64_SAMPLE_CHARS] + missing_padding = len(sample) % 4 + if missing_padding: + sample += '=' * (4 - missing_padding) + header_bytes = base64.b64decode(sample) + except (binascii.Error, ValueError) as exc: + logger.debug( + "[%s] base64 解码失败: %s,硬编码回退为 image/jpeg", + label, + exc, + ) + return "image/jpeg" + + mime_type = detect_image_mime_type(header_bytes) + logger.debug( + "[%s] 魔术字节检测命中,识别为 %s 格式", + label, + mime_type, + ) + return mime_type + + +def _sync_encode_from_file(path: str) -> tuple[str, str]: + """同步读取文件并编码为 base64,同时检测 MIME 类型。 + + 此函数包含阻塞式文件 I/O,设计为通过 run_in_executor + 在线程池中执行,避免阻塞 asyncio 事件循环。 + + Args: + path: 本地文件路径。 + + Returns: + (mime_type, image_bs64) 元组。 + """ + with open(path, "rb") as f: + header_bytes = f.read(_HEADER_READ_SIZE) + mime_type = detect_image_mime_type(header_bytes) + f.seek(0) + image_bs64 = base64.b64encode(f.read()).decode("utf-8") + return mime_type, image_bs64 + + +async def encode_image_to_base64_url(image_url: str) -> str: + """将图片转换为 base64 Data URL,自动检测实际 MIME 类型。 + + 对于 base64:// 输入,委托 detect_mime_type_from_base64_str + 统一处理采样/解码/回退逻辑。 + 对于文件路径输入,阻塞式文件 I/O 通过 run_in_executor + 移至线程池执行,避免在高并发场景下阻塞 asyncio 事件循环。 + + Args: + image_url: 本地文件路径,或以 "base64://" 开头的 base64 字符串。 + + Returns: + 形如 "data:image/png;base64,..." 的 Data URL 字符串。 + """ + if image_url.startswith("base64://"): + raw_b64 = image_url[len("base64://"):] + mime_type = detect_mime_type_from_base64_str( + raw_b64, source_hint="image_utils/base64://" + ) + return f"data:{mime_type};base64,{raw_b64}" + + # 文件 I/O 为阻塞操作,移至线程池执行以保持事件循环响应性 + loop = asyncio.get_running_loop() + mime_type, image_bs64 = await loop.run_in_executor( + None, _sync_encode_from_file, image_url + ) + logger.debug( + "[image_utils][%s] 魔术字节检测命中,识别为 %s 格式", + image_url, + mime_type, + ) + return f"data:{mime_type};base64,{image_bs64}"