diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index be471a44eb..0d45950196 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -686,6 +686,29 @@ async def _add_instructions_to_user_content( llm_request: The LLM request to modify instruction_contents: List of instruction-related contents to insert """ + + def is_valid_instruction_position( + llm_request: LlmRequest, index: int + ) -> bool: + """Checks if instructions can be inserted after a given index. + + A valid insertion point is after a model response that is not a tool call. + This prevents injecting instructions in the middle of a user's turn or a + tool-use sequence. + + Args: + llm_request: The LLM request containing the conversation contents. + index: The index of the content to check. + + Returns: + True if the position after this index is a valid insertion point. + """ + content_at_index = llm_request.contents[index] + is_user_message = content_at_index.role == 'user' + is_tool_request = any(part.function_call for part in content_at_index.parts) + + return not is_user_message and not is_tool_request + if not instruction_contents: return @@ -695,7 +718,7 @@ async def _add_instructions_to_user_content( if llm_request.contents: for i in range(len(llm_request.contents) - 1, -1, -1): - if llm_request.contents[i].role != 'user': + if is_valid_instruction_position(llm_request, i): insert_index = i + 1 break elif i == 0: diff --git a/tests/unittests/flows/llm_flows/test_contents.py b/tests/unittests/flows/llm_flows/test_contents.py index 9e77407b27..549c09440d 100644 --- a/tests/unittests/flows/llm_flows/test_contents.py +++ b/tests/unittests/flows/llm_flows/test_contents.py @@ -427,3 +427,89 @@ async def test_events_with_empty_content_are_skipped(): types.UserContent("Hello"), types.UserContent("How are you?"), ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "initial_contents, expected_insertion_index", + [ + ( + [ + types.UserContent("First user message"), + types.ModelContent("Model response"), + types.ModelContent([ + types.Part( + function_call=types.FunctionCall( + name="test_tool", args={} + ) + ) + ]), + types.Content( + parts=[ + types.Part( + function_response=types.FunctionResponse( + name="test_tool", response={} + ) + ) + ], + role="user", + ), + types.UserContent("Final user message"), + ], + 2, + ), + ( + [ + types.UserContent("First user message"), + types.UserContent("Second user message"), + types.ModelContent("Model response"), + types.UserContent("Third user message"), + types.UserContent("Fourth user message"), + ], + 3, + ), + ( + [ + types.UserContent("First user message"), + types.UserContent("Second user message"), + ], + 0, + ), + ([], 0), + ( + [ + types.UserContent("User message"), + types.ModelContent("Model response"), + ], + 2, + ), + ], + ids=[ + "skips_function_call_and_user_content", + "skips_trailing_user_content", + "inserts_at_start_when_all_user_content", + "inserts_at_start_for_empty_content", + "inserts_at_end_when_last_is_model_content", + ], +) +async def test_add_instructions_to_user_content( + initial_contents, expected_insertion_index +): + """Tests that instructions are correctly inserted into the content list.""" + agent = Agent(model="gemini-2.5-flash", name="test_agent") + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + instruction_contents = [ + types.Content(parts=[types.Part(text="System instruction")], role="user") + ] + llm_request = LlmRequest(model="gemini-2.5-flash", contents=initial_contents) + + await contents._add_instructions_to_user_content( + invocation_context, llm_request, instruction_contents + ) + + assert len(llm_request.contents) == len(initial_contents) + 1 + assert ( + llm_request.contents[expected_insertion_index] == instruction_contents[0] + )