diff --git a/core/src/main/java/com/google/adk/agents/RunConfig.java b/core/src/main/java/com/google/adk/agents/RunConfig.java index 4fe3365fe..174066073 100644 --- a/core/src/main/java/com/google/adk/agents/RunConfig.java +++ b/core/src/main/java/com/google/adk/agents/RunConfig.java @@ -48,8 +48,12 @@ public enum StreamingMode { public abstract @Nullable AudioTranscriptionConfig outputAudioTranscription(); + public abstract @Nullable AudioTranscriptionConfig inputAudioTranscription(); + public abstract int maxLlmCalls(); + public abstract Builder toBuilder(); + public static Builder builder() { return new AutoValue_RunConfig.Builder() .setSaveInputBlobsAsArtifacts(false) @@ -65,7 +69,8 @@ public static Builder builder(RunConfig runConfig) { .setMaxLlmCalls(runConfig.maxLlmCalls()) .setResponseModalities(runConfig.responseModalities()) .setSpeechConfig(runConfig.speechConfig()) - .setOutputAudioTranscription(runConfig.outputAudioTranscription()); + .setOutputAudioTranscription(runConfig.outputAudioTranscription()) + .setInputAudioTranscription(runConfig.inputAudioTranscription()); } /** Builder for {@link RunConfig}. */ @@ -88,6 +93,10 @@ public abstract static class Builder { public abstract Builder setOutputAudioTranscription( AudioTranscriptionConfig outputAudioTranscription); + @CanIgnoreReturnValue + public abstract Builder setInputAudioTranscription( + AudioTranscriptionConfig inputAudioTranscription); + @CanIgnoreReturnValue public abstract Builder setMaxLlmCalls(int maxLlmCalls); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Basic.java b/core/src/main/java/com/google/adk/flows/llmflows/Basic.java index fa447625a..68cd4e1a5 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Basic.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Basic.java @@ -48,6 +48,8 @@ public Single processRequest( .ifPresent(liveConnectConfigBuilder::speechConfig); Optional.ofNullable(context.runConfig().outputAudioTranscription()) .ifPresent(liveConnectConfigBuilder::outputAudioTranscription); + Optional.ofNullable(context.runConfig().inputAudioTranscription()) + .ifPresent(liveConnectConfigBuilder::inputAudioTranscription); LlmRequest.Builder builder = request.toBuilder() diff --git a/core/src/main/java/com/google/adk/runner/Runner.java b/core/src/main/java/com/google/adk/runner/Runner.java index ec19c404f..19e6a06f9 100644 --- a/core/src/main/java/com/google/adk/runner/Runner.java +++ b/core/src/main/java/com/google/adk/runner/Runner.java @@ -397,8 +397,9 @@ private void copySessionStates(Session source, Session target) { private InvocationContext newInvocationContextForLive( Session session, Optional liveRequestQueue, RunConfig runConfig) { RunConfig.Builder runConfigBuilder = RunConfig.builder(runConfig); - if (!CollectionUtils.isNullOrEmpty(runConfig.responseModalities()) - && liveRequestQueue.isPresent()) { + if (liveRequestQueue.isPresent() && !this.agent.subAgents().isEmpty()) { + // Parity with Python: apply modality defaults and transcription settings + // only for multi-agent live scenarios. // Default to AUDIO modality if not specified. if (CollectionUtils.isNullOrEmpty(runConfig.responseModalities())) { runConfigBuilder.setResponseModalities( @@ -411,6 +412,10 @@ private InvocationContext newInvocationContextForLive( runConfigBuilder.setOutputAudioTranscription(AudioTranscriptionConfig.builder().build()); } } + // Need input transcription for agent transferring in live mode. + if (runConfig.inputAudioTranscription() == null) { + runConfigBuilder.setInputAudioTranscription(AudioTranscriptionConfig.builder().build()); + } } return newInvocationContext( session, /* newMessage= */ Optional.empty(), liveRequestQueue, runConfigBuilder.build()); diff --git a/core/src/test/java/com/google/adk/agents/RunConfigTest.java b/core/src/test/java/com/google/adk/agents/RunConfigTest.java index ac2a6840e..83bc82c93 100644 --- a/core/src/test/java/com/google/adk/agents/RunConfigTest.java +++ b/core/src/test/java/com/google/adk/agents/RunConfigTest.java @@ -25,6 +25,7 @@ public void testBuilderWithVariousValues() { .setSaveInputBlobsAsArtifacts(true) .setStreamingMode(RunConfig.StreamingMode.SSE) .setOutputAudioTranscription(audioTranscriptionConfig) + .setInputAudioTranscription(audioTranscriptionConfig) .setMaxLlmCalls(10) .build(); @@ -33,6 +34,7 @@ public void testBuilderWithVariousValues() { assertThat(runConfig.saveInputBlobsAsArtifacts()).isTrue(); assertThat(runConfig.streamingMode()).isEqualTo(RunConfig.StreamingMode.SSE); assertThat(runConfig.outputAudioTranscription()).isEqualTo(audioTranscriptionConfig); + assertThat(runConfig.inputAudioTranscription()).isEqualTo(audioTranscriptionConfig); assertThat(runConfig.maxLlmCalls()).isEqualTo(10); } @@ -45,6 +47,7 @@ public void testBuilderDefaults() { assertThat(runConfig.saveInputBlobsAsArtifacts()).isFalse(); assertThat(runConfig.streamingMode()).isEqualTo(RunConfig.StreamingMode.NONE); assertThat(runConfig.outputAudioTranscription()).isNull(); + assertThat(runConfig.inputAudioTranscription()).isNull(); assertThat(runConfig.maxLlmCalls()).isEqualTo(500); } @@ -66,6 +69,7 @@ public void testBuilderWithDifferentValues() { .setSaveInputBlobsAsArtifacts(true) .setStreamingMode(RunConfig.StreamingMode.BIDI) .setOutputAudioTranscription(audioTranscriptionConfig) + .setInputAudioTranscription(audioTranscriptionConfig) .setMaxLlmCalls(20) .build(); @@ -74,6 +78,24 @@ public void testBuilderWithDifferentValues() { assertThat(runConfig.saveInputBlobsAsArtifacts()).isTrue(); assertThat(runConfig.streamingMode()).isEqualTo(RunConfig.StreamingMode.BIDI); assertThat(runConfig.outputAudioTranscription()).isEqualTo(audioTranscriptionConfig); + assertThat(runConfig.inputAudioTranscription()).isEqualTo(audioTranscriptionConfig); assertThat(runConfig.maxLlmCalls()).isEqualTo(20); } + + @Test + public void testInputAudioTranscriptionOnly() { + AudioTranscriptionConfig inputTranscriptionConfig = AudioTranscriptionConfig.builder().build(); + + RunConfig runConfig = + RunConfig.builder() + .setStreamingMode(RunConfig.StreamingMode.BIDI) + .setResponseModalities(ImmutableList.of(new Modality(Modality.Known.AUDIO))) + .setInputAudioTranscription(inputTranscriptionConfig) + .build(); + + assertThat(runConfig.inputAudioTranscription()).isEqualTo(inputTranscriptionConfig); + assertThat(runConfig.outputAudioTranscription()).isNull(); + assertThat(runConfig.streamingMode()).isEqualTo(RunConfig.StreamingMode.BIDI); + assertThat(runConfig.responseModalities()).containsExactly(new Modality(Modality.Known.AUDIO)); + } } diff --git a/core/src/test/java/com/google/adk/flows/llmflows/BasicTest.java b/core/src/test/java/com/google/adk/flows/llmflows/BasicTest.java index 07a231985..f9b6614d1 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/BasicTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/BasicTest.java @@ -220,6 +220,25 @@ public void processRequest_buildsLiveConnectConfigFromRunConfig_outputAudioTrans assertThat(result.events()).isEmpty(); } + @Test + public void processRequest_buildsLiveConnectConfigFromRunConfig_inputAudioTranscription() { + RunConfig runConfig = + RunConfig.builder().setInputAudioTranscription(TEST_AUDIO_TRANSCRIPTION_CONFIG).build(); + LlmAgent agentWithConfig = LlmAgent.builder().name("agentWithConfig").model(testLlm).build(); + InvocationContext contextWithRunConfig = createInvocationContext(agentWithConfig, runConfig); + + RequestProcessingResult result = + basicProcessor.processRequest(contextWithRunConfig, initialRequest).blockingGet(); + + LlmRequest updatedRequest = result.updatedRequest(); + assertThat(updatedRequest.liveConnectConfig()).isNotNull(); + assertThat(updatedRequest.liveConnectConfig().responseModalities().get()).isEmpty(); + assertThat(updatedRequest.liveConnectConfig().speechConfig()).isEmpty(); + assertThat(updatedRequest.liveConnectConfig().inputAudioTranscription()) + .hasValue(TEST_AUDIO_TRANSCRIPTION_CONFIG); + assertThat(result.events()).isEmpty(); + } + @Test public void processRequest_buildsLiveConnectConfigFromRunConfig_allFields() { RunConfig runConfig = @@ -227,6 +246,7 @@ public void processRequest_buildsLiveConnectConfigFromRunConfig_allFields() { .setResponseModalities(ImmutableList.of(new Modality(Modality.Known.AUDIO))) .setSpeechConfig(TEST_SPEECH_CONFIG) .setOutputAudioTranscription(TEST_AUDIO_TRANSCRIPTION_CONFIG) + .setInputAudioTranscription(TEST_AUDIO_TRANSCRIPTION_CONFIG) .build(); LlmAgent agentWithConfig = LlmAgent.builder().name("agentWithConfig").model(testLlm).build(); InvocationContext contextWithRunConfig = createInvocationContext(agentWithConfig, runConfig); @@ -241,6 +261,8 @@ public void processRequest_buildsLiveConnectConfigFromRunConfig_allFields() { assertThat(updatedRequest.liveConnectConfig().speechConfig()).hasValue(TEST_SPEECH_CONFIG); assertThat(updatedRequest.liveConnectConfig().outputAudioTranscription()) .hasValue(TEST_AUDIO_TRANSCRIPTION_CONFIG); + assertThat(updatedRequest.liveConnectConfig().inputAudioTranscription()) + .hasValue(TEST_AUDIO_TRANSCRIPTION_CONFIG); assertThat(result.events()).isEmpty(); } } diff --git a/core/src/test/java/com/google/adk/runner/InputAudioTranscriptionTest.java b/core/src/test/java/com/google/adk/runner/InputAudioTranscriptionTest.java new file mode 100644 index 000000000..5719355f1 --- /dev/null +++ b/core/src/test/java/com/google/adk/runner/InputAudioTranscriptionTest.java @@ -0,0 +1,116 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.runner; + +import static com.google.adk.testing.TestUtils.createLlmResponse; +import static com.google.adk.testing.TestUtils.createTestAgentBuilder; +import static com.google.adk.testing.TestUtils.createTestLlm; +import static com.google.common.truth.Truth.assertThat; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LiveRequestQueue; +import com.google.adk.agents.LlmAgent; +import com.google.adk.agents.RunConfig; +import com.google.adk.sessions.Session; +import com.google.adk.testing.TestLlm; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.AudioTranscriptionConfig; +import com.google.genai.types.Content; +import com.google.genai.types.Modality; +import com.google.genai.types.Part; +import java.lang.reflect.Method; +import java.util.Optional; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class InputAudioTranscriptionTest { + + private Content createContent(String text) { + return Content.builder().parts(Part.builder().text(text).build()).build(); + } + + private InvocationContext invokeNewInvocationContextForLive( + Runner runner, Session session, LiveRequestQueue liveRequestQueue, RunConfig runConfig) + throws Exception { + Method method = + Runner.class.getDeclaredMethod( + "newInvocationContextForLive", Session.class, Optional.class, RunConfig.class); + method.setAccessible(true); + return (InvocationContext) + method.invoke(runner, session, Optional.of(liveRequestQueue), runConfig); + } + + @Test + public void newInvocationContextForLive_multiAgent_autoConfiguresInputAudioTranscription() + throws Exception { + TestLlm testLlm = createTestLlm(createLlmResponse(createContent("response"))); + LlmAgent subAgent = createTestAgentBuilder(testLlm).name("sub_agent").build(); + LlmAgent rootAgent = + createTestAgentBuilder(testLlm) + .name("root_agent") + .subAgents(ImmutableList.of(subAgent)) + .build(); + + Runner runner = new InMemoryRunner(rootAgent, "test", ImmutableList.of()); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + + RunConfig initialConfig = + RunConfig.builder() + .setResponseModalities(ImmutableList.of(new Modality(Modality.Known.AUDIO))) + .setStreamingMode(RunConfig.StreamingMode.BIDI) + .build(); + + assertThat(initialConfig.inputAudioTranscription()).isNull(); + + LiveRequestQueue liveQueue = new LiveRequestQueue(); + InvocationContext context = + invokeNewInvocationContextForLive(runner, session, liveQueue, initialConfig); + + assertThat(context.runConfig().inputAudioTranscription()).isNotNull(); + } + + @Test + public void newInvocationContextForLive_explicitConfig_preservesUserInputAudioTranscription() + throws Exception { + TestLlm testLlm = createTestLlm(createLlmResponse(createContent("response"))); + LlmAgent subAgent = createTestAgentBuilder(testLlm).name("sub_agent").build(); + LlmAgent rootAgent = + createTestAgentBuilder(testLlm) + .name("root_agent") + .subAgents(ImmutableList.of(subAgent)) + .build(); + + Runner runner = new InMemoryRunner(rootAgent, "test", ImmutableList.of()); + Session session = runner.sessionService().createSession("test", "user").blockingGet(); + + AudioTranscriptionConfig userConfig = AudioTranscriptionConfig.builder().build(); + RunConfig configWithUserSetting = + RunConfig.builder() + .setResponseModalities(ImmutableList.of(new Modality(Modality.Known.AUDIO))) + .setStreamingMode(RunConfig.StreamingMode.BIDI) + .setInputAudioTranscription(userConfig) + .build(); + + LiveRequestQueue liveQueue = new LiveRequestQueue(); + InvocationContext context = + invokeNewInvocationContextForLive(runner, session, liveQueue, configWithUserSetting); + + assertThat(context.runConfig().inputAudioTranscription()).isSameInstanceAs(userConfig); + } +}