Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.adk.JsonBaseModel;
import com.google.adk.models.BaseLlm;
import com.google.adk.models.BaseLlmConnection;
import com.google.adk.models.LlmRequest;
Expand Down Expand Up @@ -428,8 +429,26 @@ private List<ToolSpecification> toToolSpecifications(LlmRequest llmRequest) {
baseTool -> {
if (baseTool.declaration().isPresent()) {
FunctionDeclaration functionDeclaration = baseTool.declaration().get();
if (functionDeclaration.parameters().isPresent()) {
Schema schema = functionDeclaration.parameters().get();
Schema schema = null;
if (functionDeclaration.parametersJsonSchema().isPresent()) {
Object jsonSchemaObj = functionDeclaration.parametersJsonSchema().get();
try {
if (jsonSchemaObj instanceof Schema) {
schema = (Schema) jsonSchemaObj;
} else {
ObjectMapper adkMapper = JsonBaseModel.getMapper();
String jsonSchemaStr = adkMapper.writeValueAsString(jsonSchemaObj);
schema = adkMapper.readValue(jsonSchemaStr, Schema.class);
Comment on lines +439 to +441
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For efficiency and conciseness, you can directly convert the jsonSchemaObj to a Schema object using ObjectMapper.convertValue(). This avoids the intermediate step of serializing to a JSON string and then deserializing back.

Suggested change
ObjectMapper adkMapper = JsonBaseModel.getMapper();
String jsonSchemaStr = adkMapper.writeValueAsString(jsonSchemaObj);
schema = adkMapper.readValue(jsonSchemaStr, Schema.class);
schema = JsonBaseModel.getMapper().convertValue(jsonSchemaObj, Schema.class);

}
} catch (Exception e) {
throw new IllegalStateException(
"Failed to convert parametersJsonSchema to Schema: " + e.getMessage(), e);
}
} else if (functionDeclaration.parameters().isPresent()) {
schema = functionDeclaration.parameters().get();
}

if (schema != null) {
ToolSpecification toolSpecification =
ToolSpecification.builder()
.name(baseTool.name())
Expand All @@ -438,11 +457,9 @@ private List<ToolSpecification> toToolSpecifications(LlmRequest llmRequest) {
.build();
toolSpecifications.add(toolSpecification);
} else {
// TODO exception or something else?
throw new IllegalStateException("Tool lacking parameters: " + baseTool);
}
} else {
// TODO exception or something else?
throw new IllegalStateException("Tool lacking declaration: " + baseTool);
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -688,4 +688,131 @@ void testGenerateContentWithStructuredResponseJsonSchema() {
final UserMessage userMessage = (UserMessage) capturedRequest.messages().get(0);
assertThat(userMessage.singleText()).isEqualTo("Give me information about John Doe");
}

@Test
@DisplayName("Should handle MCP tools with parametersJsonSchema")
void testGenerateContentWithMcpToolParametersJsonSchema() {
// Given
// Create a mock BaseTool for MCP tool
final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class);
when(mcpTool.name()).thenReturn("mcpTool");
when(mcpTool.description()).thenReturn("An MCP tool");

// Create a mock FunctionDeclaration
final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class);
when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration));

// MCP tools use parametersJsonSchema() instead of parameters()
// Create a JSON schema object (Map representation)
final Map<String, Object> jsonSchemaMap =
Map.of(
"type",
"object",
"properties",
Map.of("city", Map.of("type", "string", "description", "City name")),
"required",
List.of("city"));

// Mock parametersJsonSchema() to return the JSON schema object
when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(jsonSchemaMap));
when(functionDeclaration.parameters()).thenReturn(Optional.empty());

// Create a LlmRequest with the MCP tool
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool"))))
.tools(Map.of("mcpTool", mcpTool))
.build();

// Mock the AI response
final AiMessage aiMessage = AiMessage.from("Tool executed successfully");

final ChatResponse chatResponse = mock(ChatResponse.class);
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst();

// Then
// Verify the response
assertThat(response).isNotNull();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("Tool executed successfully");

// Verify the request was built correctly with the tool specification
final ArgumentCaptor<ChatRequest> requestCaptor = ArgumentCaptor.forClass(ChatRequest.class);
verify(chatModel).chat(requestCaptor.capture());
final ChatRequest capturedRequest = requestCaptor.getValue();

// Verify tool specifications were created from parametersJsonSchema
assertThat(capturedRequest.toolSpecifications()).isNotEmpty();
assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool");
assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool");
}

@Test
@DisplayName("Should handle MCP tools with parametersJsonSchema when it's already a Schema")
void testGenerateContentWithMcpToolParametersJsonSchemaAsSchema() {
// Given
// Create a mock BaseTool for MCP tool
final com.google.adk.tools.BaseTool mcpTool = mock(com.google.adk.tools.BaseTool.class);
when(mcpTool.name()).thenReturn("mcpTool");
when(mcpTool.description()).thenReturn("An MCP tool");

// Create a mock FunctionDeclaration
final FunctionDeclaration functionDeclaration = mock(FunctionDeclaration.class);
when(mcpTool.declaration()).thenReturn(Optional.of(functionDeclaration));

// Create a Schema object directly (when parametersJsonSchema returns Schema)
final Schema cityPropertySchema =
Schema.builder()
.type(Type.builder().knownEnum(Type.Known.STRING).build())
.description("City name")
.build();

final Schema objectSchema =
Schema.builder()
.type(Type.builder().knownEnum(Type.Known.OBJECT).build())
.properties(Map.of("city", cityPropertySchema))
.required(List.of("city"))
.build();

// Mock parametersJsonSchema() to return Schema directly
when(functionDeclaration.parametersJsonSchema()).thenReturn(Optional.of(objectSchema));
when(functionDeclaration.parameters()).thenReturn(Optional.empty());

// Create a LlmRequest with the MCP tool
final LlmRequest llmRequest =
LlmRequest.builder()
.contents(List.of(Content.fromParts(Part.fromText("Use the MCP tool"))))
.tools(Map.of("mcpTool", mcpTool))
.build();

// Mock the AI response
final AiMessage aiMessage = AiMessage.from("Tool executed successfully");

final ChatResponse chatResponse = mock(ChatResponse.class);
when(chatResponse.aiMessage()).thenReturn(aiMessage);
when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse);

// When
final LlmResponse response = langChain4j.generateContent(llmRequest, false).blockingFirst();

// Then
// Verify the response
assertThat(response).isNotNull();
assertThat(response.content()).isPresent();
assertThat(response.content().get().text()).isEqualTo("Tool executed successfully");

// Verify the request was built correctly with the tool specification
final ArgumentCaptor<ChatRequest> requestCaptor = ArgumentCaptor.forClass(ChatRequest.class);
verify(chatModel).chat(requestCaptor.capture());
final ChatRequest capturedRequest = requestCaptor.getValue();

// Verify tool specifications were created from parametersJsonSchema
assertThat(capturedRequest.toolSpecifications()).isNotEmpty();
assertThat(capturedRequest.toolSpecifications().get(0).name()).isEqualTo("mcpTool");
assertThat(capturedRequest.toolSpecifications().get(0).description()).isEqualTo("An MCP tool");
}
}