diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index d160d5f244..abed4cd3dc 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -27,6 +27,7 @@ from typing import Literal from typing import Optional from typing import Tuple +from typing import TypedDict from typing import Union from google.genai import types @@ -35,15 +36,9 @@ from litellm import ChatCompletionAssistantMessage from litellm import ChatCompletionAssistantToolCall from litellm import ChatCompletionDeveloperMessage -from litellm import ChatCompletionFileObject -from litellm import ChatCompletionImageObject -from litellm import ChatCompletionImageUrlObject from litellm import ChatCompletionMessageToolCall -from litellm import ChatCompletionTextObject from litellm import ChatCompletionToolMessage from litellm import ChatCompletionUserMessage -from litellm import ChatCompletionVideoObject -from litellm import ChatCompletionVideoUrlObject from litellm import completion from litellm import CustomStreamWrapper from litellm import Function @@ -67,6 +62,11 @@ _EXCLUDED_PART_FIELD = {"inline_data": {"data"}} +class ChatCompletionFileUrlObject(TypedDict): + file_data: str + format: str + + class FunctionChunk(BaseModel): id: Optional[str] name: Optional[str] @@ -237,12 +237,10 @@ def _get_content( if part.text: if len(parts) == 1: return part.text - content_objects.append( - ChatCompletionTextObject( - type="text", - text=part.text, - ) - ) + content_objects.append({ + "type": "text", + "text": part.text, + }) elif ( part.inline_data and part.inline_data.data @@ -252,33 +250,32 @@ def _get_content( data_uri = f"data:{part.inline_data.mime_type};base64,{base64_string}" if part.inline_data.mime_type.startswith("image"): - # Extract format from mime type (e.g., "image/png" -> "png") - format_type = part.inline_data.mime_type.split("/")[-1] - content_objects.append( - ChatCompletionImageObject( - type="image_url", - image_url=ChatCompletionImageUrlObject( - url=data_uri, format=format_type - ), - ) - ) + # Use full MIME type (e.g., "image/png") for providers that validate it + format_type = part.inline_data.mime_type + content_objects.append({ + "type": "image_url", + "image_url": {"url": data_uri, "format": format_type}, + }) elif part.inline_data.mime_type.startswith("video"): - # Extract format from mime type (e.g., "video/mp4" -> "mp4") - format_type = part.inline_data.mime_type.split("/")[-1] - content_objects.append( - ChatCompletionVideoObject( - type="video_url", - video_url=ChatCompletionVideoUrlObject( - url=data_uri, format=format_type - ), - ) - ) + # Use full MIME type (e.g., "video/mp4") for providers that validate it + format_type = part.inline_data.mime_type + content_objects.append({ + "type": "video_url", + "video_url": {"url": data_uri, "format": format_type}, + }) + elif part.inline_data.mime_type.startswith("audio"): + # Use full MIME type (e.g., "audio/mpeg") for providers that validate it + format_type = part.inline_data.mime_type + content_objects.append({ + "type": "audio_url", + "audio_url": {"url": data_uri, "format": format_type}, + }) elif part.inline_data.mime_type == "application/pdf": - content_objects.append( - ChatCompletionFileObject( - type="file", file={"file_data": data_uri, "format": "pdf"} - ) - ) + format_type = part.inline_data.mime_type + content_objects.append({ + "type": "file", + "file": {"file_data": data_uri, "format": format_type}, + }) else: raise ValueError("LiteLlm(BaseLlm) does not support this content part.") diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 4bd6424236..ad72d3c375 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1036,7 +1036,7 @@ def test_get_content_image(): content[0]["image_url"]["url"] == "" ) - assert content[0]["image_url"]["format"] == "png" + assert content[0]["image_url"]["format"] == "image/png" def test_get_content_video(): @@ -1049,7 +1049,33 @@ def test_get_content_video(): content[0]["video_url"]["url"] == "data:video/mp4;base64,dGVzdF92aWRlb19kYXRh" ) - assert content[0]["video_url"]["format"] == "mp4" + assert content[0]["video_url"]["format"] == "video/mp4" + + +def test_get_content_pdf(): + parts = [ + types.Part.from_bytes(data=b"test_pdf_data", mime_type="application/pdf") + ] + content = _get_content(parts) + assert content[0]["type"] == "file" + assert ( + content[0]["file"]["file_data"] + == "data:application/pdf;base64,dGVzdF9wZGZfZGF0YQ==" + ) + assert content[0]["file"]["format"] == "application/pdf" + + +def test_get_content_audio(): + parts = [ + types.Part.from_bytes(data=b"test_audio_data", mime_type="audio/mpeg") + ] + content = _get_content(parts) + assert content[0]["type"] == "audio_url" + assert ( + content[0]["audio_url"]["url"] + == "data:audio/mpeg;base64,dGVzdF9hdWRpb19kYXRh" + ) + assert content[0]["audio_url"]["format"] == "audio/mpeg" def test_to_litellm_role():