diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index 9f880c4103..2e9e553a23 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -1886,9 +1886,133 @@ def test_run_agent_internal_error_response(self, mock_run_agent): assert "response" in result_df.columns response_content = result_df["response"][0] - assert "Unexpected response type from agent run" in response_content + assert "agent run failed" in response_content assert not result_df["intermediate_events"][0] + @mock.patch.object(_evals_common, "_run_agent") + def test_run_agent_internal_multi_turn_success(self, mock_run_agent): + mock_run_agent.return_value = [ + [ + {"turn_index": 0, "turn_id": "t1", "events": []}, + {"turn_index": 1, "turn_id": "t2", "events": []}, + ] + ] + prompt_dataset = pd.DataFrame({"prompt": ["p1"], "conversation_plan": ["plan"]}) + mock_agent_engine = mock.Mock() + mock_api_client = mock.Mock() + result_df = _evals_common._run_agent_internal( + api_client=mock_api_client, + agent_engine=mock_agent_engine, + agent=None, + prompt_dataset=prompt_dataset, + ) + + assert "agent_data" in result_df.columns + agent_data = result_df["agent_data"][0] + assert agent_data["turns"] == [ + {"turn_index": 0, "turn_id": "t1", "events": []}, + {"turn_index": 1, "turn_id": "t2", "events": []}, + ] + + @mock.patch( + "vertexai._genai._evals_common.ADK_SessionInput" + ) + @mock.patch( + "vertexai._genai._evals_common.EvaluationGenerator" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulator" + ) + @mock.patch( + "vertexai._genai._evals_common.ConversationScenario" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulatorConfig" + ) + @pytest.mark.asyncio + async def test_run_adk_user_simulation_with_intermediate_events( + self, + mock_config, + mock_scenario, + mock_simulator, + mock_generator, + mock_session_input, + ): + """Tests that intermediate invocation events (e.g. tool calls) are parsed successfully.""" + row = pd.Series( + { + "starting_prompt": "I want a laptop.", + "conversation_plan": "Ask for a laptop", + "session_inputs": json.dumps({"user_id": "u1"}), + } + ) + mock_agent = mock.Mock() + + mock_invocation = mock.Mock() + mock_invocation.invocation_id = "turn_123" + mock_invocation.creation_timestamp = 1771811084.88 + mock_invocation.user_content.model_dump.return_value = { + "parts": [{"text": "I want a laptop."}], + "role": "user", + } + mock_event_1 = mock.Mock() + mock_event_1.author = "ecommerce_agent" + mock_event_1.content.model_dump.return_value = { + "parts": [ + { + "function_call": { + "name": "search_products", + "args": {"query": "laptop"}, + } + } + ] + } + mock_event_2 = mock.Mock() + mock_event_2.author = "ecommerce_agent" + mock_event_2.content.model_dump.return_value = { + "parts": [ + { + "function_response": { + "name": "search_products", + "response": {"products": []}, + } + } + ] + } + + mock_invocation.intermediate_data.invocation_events = [ + mock_event_1, + mock_event_2, + ] + mock_invocation.final_response.model_dump.return_value = { + "parts": [{"text": "There are no laptops matching your search."}], + "role": "model", + } + mock_generator._generate_inferences_from_root_agent = mock.AsyncMock( + return_value=[mock_invocation] + ) + turns = await _evals_common._run_adk_user_simulation(row, mock_agent) + + assert len(turns) == 1 + turn = turns[0] + assert turn["turn_index"] == 0 + assert turn["turn_id"] == "turn_123" + assert len(turn["events"]) == 4 + assert turn["events"][0]["author"] == "user" + assert turn["events"][0]["content"]["parts"][0]["text"] == "I want a laptop." + assert turn["events"][1]["author"] == "ecommerce_agent" + assert "function_call" in turn["events"][1]["content"]["parts"][0] + assert turn["events"][2]["author"] == "ecommerce_agent" + assert "function_response" in turn["events"][2]["content"]["parts"][0] + assert turn["events"][3]["author"] == "agent" + assert ( + turn["events"][3]["content"]["parts"][0]["text"] + == "There are no laptops matching your search." + ) + mock_invocation.user_content.model_dump.assert_called_with(mode="json") + mock_event_1.content.model_dump.assert_called_with(mode="json") + mock_invocation.final_response.model_dump.assert_called_with(mode="json") + @mock.patch.object(_evals_common, "_run_agent") def test_run_agent_internal_malformed_event(self, mock_run_agent): mock_run_agent.return_value = [ @@ -1916,6 +2040,28 @@ def test_run_agent_internal_malformed_event(self, mock_run_agent): assert not result_df["intermediate_events"][0] +class TestIsMultiTurnAgentRun: + """Unit tests for the _is_multi_turn_agent_run function.""" + + def test_is_multi_turn_agent_run_with_config(self): + config = vertexai_genai_types.evals.UserSimulatorConfig(model_name="gemini-pro") + assert _evals_common._is_multi_turn_agent_run( + user_simulator_config=config, prompt_dataset=pd.DataFrame() + ) + + def test_is_multi_turn_agent_run_with_conversation_plan(self): + prompt_dataset = pd.DataFrame({"conversation_plan": ["plan"]}) + assert _evals_common._is_multi_turn_agent_run( + user_simulator_config=None, prompt_dataset=prompt_dataset + ) + + def test_is_multi_turn_agent_run_false(self): + prompt_dataset = pd.DataFrame({"prompt": ["prompt"]}) + assert not _evals_common._is_multi_turn_agent_run( + user_simulator_config=None, prompt_dataset=prompt_dataset + ) + + class TestMetricPromptBuilder: """Unit tests for the MetricPromptBuilder class.""" @@ -4229,6 +4375,101 @@ def test_tool_use_quality_metric_no_tool_call_logs_warning( ) +@pytest.mark.usefixtures("google_auth_mock") +class TestRunAdkUserSimulation: + """Unit tests for the _run_adk_user_simulation function.""" + + @mock.patch( + "vertexai._genai._evals_common.ADK_SessionInput" + ) + @mock.patch( + "vertexai._genai._evals_common.EvaluationGenerator" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulator" + ) + @mock.patch( + "vertexai._genai._evals_common.ConversationScenario" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulatorConfig" + ) + @pytest.mark.asyncio + async def test_run_adk_user_simulation_success( + self, + mock_config_cls, + mock_scenario_cls, + mock_simulator_cls, + mock_generator_cls, + mock_session_input_cls, + ): + row = pd.Series( + { + "starting_prompt": "start", + "conversation_plan": "plan", + "session_inputs": json.dumps({"user_id": "u1"}), + } + ) + mock_agent = mock.Mock() + mock_invocation = mock.Mock() + mock_invocation.user_content.model_dump.return_value = {"text": "user msg"} + mock_invocation.final_response.model_dump.return_value = {"text": "agent msg"} + mock_invocation.intermediate_data = None + mock_invocation.creation_timestamp = 12345 + mock_invocation.invocation_id = "turn1" + + mock_generator_cls._generate_inferences_from_root_agent = mock.AsyncMock( + return_value=[mock_invocation] + ) + + turns = await _evals_common._run_adk_user_simulation(row, mock_agent) + + assert len(turns) == 1 + turn = turns[0] + assert turn["turn_index"] == 0 + assert turn["turn_id"] == "turn1" + assert len(turn["events"]) == 2 + assert turn["events"][0]["author"] == "user" + assert turn["events"][0]["content"] == {"text": "user msg"} + assert turn["events"][1]["author"] == "agent" + assert turn["events"][1]["content"] == {"text": "agent msg"} + + mock_scenario_cls.assert_called_once_with( + starting_prompt="start", conversation_plan="plan" + ) + mock_session_input_cls.assert_called_once() + + @mock.patch( + "vertexai._genai._evals_common.ADK_SessionInput" + ) + @mock.patch( + "vertexai._genai._evals_common.EvaluationGenerator" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulator" + ) + @mock.patch( + "vertexai._genai._evals_common.ConversationScenario" + ) + @mock.patch( + "vertexai._genai._evals_common.LlmBackedUserSimulatorConfig" + ) + @pytest.mark.asyncio + async def test_run_adk_user_simulation_missing_columns( + self, + mock_config_cls, + mock_scenario_cls, + mock_simulator_cls, + mock_generator_cls, + mock_session_input_cls, + ): + row = pd.Series({"conversation_plan": "plan"}) + mock_agent = mock.Mock() + + with pytest.raises(ValueError, match="User simulation requires"): + await _evals_common._run_adk_user_simulation(row, mock_agent) + + @pytest.mark.usefixtures("google_auth_mock") class TestLLMMetricHandlerPayload: def setup_method(self): diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index 0bc28994ed..f3b5757cd7 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -45,6 +45,8 @@ from . import evals from . import types +logger = logging.getLogger(__name__) + try: import litellm except ImportError: @@ -54,13 +56,29 @@ from google.adk.agents import LlmAgent from google.adk.runners import Runner from google.adk.sessions import InMemorySessionService + from google.adk.evaluation.simulation.llm_backed_user_simulator import ( + LlmBackedUserSimulator, + ) + from google.adk.evaluation.simulation.llm_backed_user_simulator import ( + LlmBackedUserSimulatorConfig, + ) + from google.adk.evaluation.conversation_scenarios import ConversationScenario + from google.adk.evaluation.evaluation_generator import EvaluationGenerator + from google.adk.evaluation.eval_case import SessionInput as ADK_SessionInput except ImportError: + logging.getLogger(__name__).warning( + "ADK is not installed. Please install it using" " 'pip install google-adk'" + ) LlmAgent = None Runner = None InMemorySessionService = None + LlmBackedUserSimulator = None + LlmBackedUserSimulatorConfig = None + ConversationScenario = None + EvaluationGenerator = None + ADK_SessionInput = None -logger = logging.getLogger(__name__) _thread_local_data = threading.local() MAX_WORKERS = 100 @@ -68,6 +86,7 @@ CONTENT = _evals_constant.CONTENT PARTS = _evals_constant.PARTS USER_AUTHOR = _evals_constant.USER_AUTHOR +AGENT_DATA = _evals_constant.AGENT_DATA @contextlib.contextmanager @@ -341,6 +360,7 @@ def _execute_inference_concurrently( inference_fn: Optional[Callable[..., Any]] = None, agent_engine: Optional[Union[str, types.AgentEngine]] = None, agent: Optional[LlmAgent] = None, + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, ) -> list[ Union[ genai_types.GenerateContentResponse, @@ -364,12 +384,15 @@ def _execute_inference_concurrently( ] = [None] * len(prompt_dataset) tasks = [] - primary_prompt_column = ( - "request" if "request" in prompt_dataset.columns else "prompt" - ) - if primary_prompt_column not in prompt_dataset.columns: + if "request" in prompt_dataset.columns: + primary_prompt_column = "request" + elif "prompt" in prompt_dataset.columns: + primary_prompt_column = "prompt" + elif "starting_prompt" in prompt_dataset.columns: + primary_prompt_column = "starting_prompt" + else: raise ValueError( - "Dataset must contain either 'prompt' or 'request'." + "Dataset must contain either 'prompt', 'request', or 'starting_prompt'." f" Found: {prompt_dataset.columns.tolist()}" ) @@ -399,6 +422,7 @@ def agent_run_wrapper( # type: ignore[no-untyped-def] agent_arg, inference_fn_arg, api_client_arg, + user_simulator_config_arg, ) -> Any: if agent_engine_arg: if isinstance(agent_engine_arg, str): @@ -417,6 +441,7 @@ def agent_run_wrapper( # type: ignore[no-untyped-def] return inference_fn_arg( row=row_arg, contents=contents_arg, + user_simulator_config=user_simulator_config_arg, agent=agent_arg, ) @@ -428,6 +453,7 @@ def agent_run_wrapper( # type: ignore[no-untyped-def] agent, inference_fn, api_client, + user_simulator_config, ) elif isinstance(model_or_fn, str): generation_content_config = _build_generate_content_config( @@ -619,10 +645,11 @@ def _run_inference_internal( if ( "prompt" not in prompt_dataset.columns and "request" not in prompt_dataset.columns + and "starting_prompt" not in prompt_dataset.columns ): raise ValueError( - "Prompt dataset for Gemini model must contain either 'prompt' or" - " 'request' column for inference. " + "Prompt dataset for Gemini model must contain either 'prompt'," + " 'request' or 'starting_prompt' column for inference. " f"Found columns: {prompt_dataset.columns.tolist()}" ) @@ -794,6 +821,113 @@ def _run_inference_internal( return results_df +async def _run_adk_user_simulation( + row: pd.Series, + agent: LlmAgent, + config: Optional[types.evals.UserSimulatorConfig] = None, +) -> list[dict[str, Any]]: + """Runs a multi-turn user simulation using ADK's EvaluationGenerator.""" + + starting_prompt = row.get("starting_prompt") + conversation_plan = row.get("conversation_plan") + + if not starting_prompt or not conversation_plan: + raise ValueError( + "User simulation requires 'starting_prompt' and 'conversation_plan'" + " columns." + ) + + scenario = ConversationScenario( + starting_prompt=starting_prompt, conversation_plan=conversation_plan + ) + + user_simulator_kwargs: dict[str, Any] = {} + if config: + if config.model_name: + user_simulator_kwargs["model"] = config.model_name + if config.model_configuration is not None: + user_simulator_kwargs["model_configuration"] = ( + config.model_configuration.model_dump(exclude_none=True) + ) + if config.max_turn is not None: + user_simulator_kwargs["max_allowed_invocations"] = config.max_turn + + user_simulator_config = LlmBackedUserSimulatorConfig(**user_simulator_kwargs) + user_simulator = LlmBackedUserSimulator( + conversation_scenario=scenario, config=user_simulator_config + ) + + initial_session = _get_session_inputs(row) + + invocations = await EvaluationGenerator._generate_inferences_from_root_agent( # pylint: disable=protected-access + root_agent=agent, + user_simulator=user_simulator, + reset_func=getattr(agent, "reset_data", None), + initial_session=ADK_SessionInput( + app_name=initial_session.app_name or "user_simulation_app", + user_id=initial_session.user_id or "user_simulation_default_user", + state=initial_session.state or {}, + ), + ) + + turns = [] + for i, invocation in enumerate(invocations): + events = [] + if invocation.user_content: + events.append( + { + "author": "user", + "content": invocation.user_content.model_dump(mode="json"), + "event_time": invocation.creation_timestamp, + } + ) + if invocation.intermediate_data: + if ( + hasattr(invocation.intermediate_data, "invocation_events") + and invocation.intermediate_data.invocation_events + ): + for ie in invocation.intermediate_data.invocation_events: + events.append( + { + "author": ie.author, + "content": ( + ie.content.model_dump(mode="json") + if ie.content + else None + ), + "event_time": invocation.creation_timestamp, + } + ) + elif hasattr(invocation.intermediate_data, "tool_uses"): + for tool_call in invocation.intermediate_data.tool_uses: + events.append( + { + "author": "tool_call", + "content": tool_call.model_dump(mode="json"), + "event_time": invocation.creation_timestamp, + } + ) + + if invocation.final_response: + events.append( + { + "author": "agent", + "content": invocation.final_response.model_dump(mode="json"), + "event_time": invocation.creation_timestamp, + } + ) + + turns.append( + { + "turn_index": i, + "turn_id": invocation.invocation_id or str(uuid.uuid4()), + "events": events, + } + ) + + return turns + + def _apply_prompt_template( df: pd.DataFrame, prompt_template: types.PromptTemplate ) -> None: @@ -860,6 +994,7 @@ def _execute_inference( config: Optional[genai_types.GenerateContentConfig] = None, prompt_template: Optional[Union[str, types.PromptTemplateOrDict]] = None, location: Optional[str] = None, + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, ) -> pd.DataFrame: """Executes inference on a given dataset using the specified model. @@ -878,6 +1013,8 @@ def _execute_inference( prompt_template: The prompt template to use for inference. location: The location to use for the inference. If not specified, the location configured in the client will be used. + user_simulator_config: The configuration for the user simulator in + multi-turn agent scraping. Returns: A pandas DataFrame containing the inference results. @@ -956,6 +1093,7 @@ def _execute_inference( agent_engine=agent_engine, agent=agent, prompt_dataset=prompt_dataset, + user_simulator_config=user_simulator_config, ) end_time = time.time() logger.info("Agent Run completed in %.2f seconds.", end_time - start_time) @@ -1360,11 +1498,23 @@ def _get_session_inputs(row: pd.Series) -> types.evals.SessionInput: ) +def _is_multi_turn_agent_run( + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, + prompt_dataset: pd.DataFrame = None, +) -> bool: + """Checks if the agent run is multi-turn.""" + return ( + user_simulator_config is not None + or "conversation_plan" in prompt_dataset.columns + ) + + def _run_agent_internal( api_client: BaseApiClient, agent_engine: Optional[Union[str, types.AgentEngine]], agent: Optional[LlmAgent], prompt_dataset: pd.DataFrame, + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, ) -> pd.DataFrame: """Runs an agent.""" raw_responses = _run_agent( @@ -1372,66 +1522,94 @@ def _run_agent_internal( agent_engine=agent_engine, agent=agent, prompt_dataset=prompt_dataset, + user_simulator_config=user_simulator_config, ) processed_intermediate_events = [] processed_responses = [] + processed_agent_data = [] + for resp_item in raw_responses: intermediate_events_row: list[dict[str, Any]] = [] response_row = None - if isinstance(resp_item, list): - try: - response_row = resp_item[-1]["content"]["parts"][0]["text"] - for intermediate_event in resp_item[:-1]: - intermediate_events_row.append( - { - "event_id": intermediate_event["id"], - "content": intermediate_event["content"], - "creation_timestamp": intermediate_event["timestamp"], - "author": intermediate_event["author"], - } - ) - except Exception as e: # pylint: disable=broad-exception-caught + agent_data_row = None + + if _is_multi_turn_agent_run(user_simulator_config, prompt_dataset): + if isinstance(resp_item, dict) and "error" in resp_item: + response_row = json.dumps(resp_item) + else: + # TODO: Migrate single turn agent run result to AgentData. + agent_data_row = types.evals.AgentData(turns=resp_item).model_dump() + else: + if isinstance(resp_item, list): + try: + response_row = resp_item[-1]["content"]["parts"][0]["text"] + for intermediate_event in resp_item[:-1]: + intermediate_events_row.append( + { + "event_id": intermediate_event.get("id"), + "content": intermediate_event.get("content"), + "creation_timestamp": intermediate_event.get( + "timestamp" + ), + "author": intermediate_event.get("author"), + } + ) + except Exception as e: # pylint: disable=broad-exception-caught + error_payload = { + "error": ( + f"Failed to parse agent run response {str(resp_item)} to " + f"agent data: {e}" + ), + } + response_row = json.dumps(error_payload) + elif isinstance(resp_item, dict) and "error" in resp_item: + response_row = json.dumps(resp_item) + else: error_payload = { - "error": ( - f"Failed to parse agent run response {str(resp_item)} to " - f"intermediate events and final response: {e}" - ), + "error": "Unexpected response type from agent run", + "response_type": str(type(resp_item)), + "details": str(resp_item), } response_row = json.dumps(error_payload) - else: - error_payload = { - "error": "Unexpected response type from agent run", - "response_type": str(type(resp_item)), - "details": str(resp_item), - } - response_row = json.dumps(error_payload) processed_intermediate_events.append(intermediate_events_row) processed_responses.append(response_row) - - if len(processed_responses) != len(prompt_dataset) or len( - processed_responses - ) != len(processed_intermediate_events): - raise RuntimeError( - "Critical prompt/response/intermediate_events count mismatch: %d" - " prompts vs %d vs %d responses. This indicates an issue in response" - " collection." - % ( - len(prompt_dataset), - len(processed_responses), - len(processed_intermediate_events), + processed_agent_data.append(agent_data_row) + + df_dict: dict[str, Any] = {} + if _is_multi_turn_agent_run(user_simulator_config, prompt_dataset): + df_dict[AGENT_DATA] = processed_agent_data + if len(processed_agent_data) != len(prompt_dataset): + raise RuntimeError( + "Critical prompt/agent_data count mismatch: %d" + " prompts vs %d agent_data. This indicates an issue in response" + " collection." + % ( + len(prompt_dataset), + len(processed_agent_data), + ) + ) + else: + df_dict[_evals_constant.INTERMEDIATE_EVENTS] = processed_intermediate_events + df_dict[_evals_constant.RESPONSE] = processed_responses + if len(processed_responses) != len(prompt_dataset) or len( + processed_responses + ) != len(processed_intermediate_events): + raise RuntimeError( + "Critical prompt/response/intermediate_events count mismatch: %d" + " prompts vs %d vs %d responses. This indicates an issue in response" + " collection." + % ( + len(prompt_dataset), + len(processed_responses), + len(processed_intermediate_events), + ) ) - ) - results_df_responses_only = pd.DataFrame( - { - _evals_constant.INTERMEDIATE_EVENTS: processed_intermediate_events, - _evals_constant.RESPONSE: processed_responses, - } - ) + results_df_raw = pd.DataFrame(df_dict) prompt_dataset_indexed = prompt_dataset.reset_index(drop=True) - results_df_responses_only_indexed = results_df_responses_only.reset_index(drop=True) + results_df_responses_only_indexed = results_df_raw.reset_index(drop=True) results_df = pd.concat( [prompt_dataset_indexed, results_df_responses_only_indexed], axis=1 @@ -1444,6 +1622,7 @@ def _run_agent( agent_engine: Optional[Union[str, types.AgentEngine]], agent: Optional[LlmAgent], prompt_dataset: pd.DataFrame, + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, ) -> list[ Union[ list[dict[str, Any]], @@ -1459,6 +1638,7 @@ def _run_agent( prompt_dataset=prompt_dataset, progress_desc="Agent Run", gemini_config=None, + user_simulator_config=None, inference_fn=_execute_agent_run_with_retry, ) elif agent: @@ -1468,6 +1648,7 @@ def _run_agent( prompt_dataset=prompt_dataset, progress_desc="Local Agent Run", gemini_config=None, + user_simulator_config=user_simulator_config, inference_fn=_execute_local_agent_run_with_retry, ) else: @@ -1535,10 +1716,13 @@ def _execute_local_agent_run_with_retry( contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict], agent: LlmAgent, max_retries: int = 3, + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, ) -> Union[list[dict[str, Any]], dict[str, Any]]: """Executes agent run locally for a single prompt synchronously.""" return asyncio.run( - _execute_local_agent_run_with_retry_async(row, contents, agent, max_retries) + _execute_local_agent_run_with_retry_async( + row, contents, agent, max_retries, user_simulator_config + ) ) @@ -1547,8 +1731,18 @@ async def _execute_local_agent_run_with_retry_async( contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict], agent: LlmAgent, max_retries: int = 3, + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, ) -> Union[list[dict[str, Any]], dict[str, Any]]: """Executes agent run locally for a single prompt asynchronously.""" + + # Multi-turn agent scraping with user simulation. + if user_simulator_config or "conversation_plan" in row: + try: + return await _run_adk_user_simulation(row, agent, user_simulator_config) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Multi-turn agent run with user simulation failed: %s", e) + return {"error": f"Multi-turn agent run with user simulation failed: {e}"} + session_inputs = _get_session_inputs(row) user_id = session_inputs.user_id session_id = str(uuid.uuid4()) diff --git a/vertexai/_genai/_evals_constant.py b/vertexai/_genai/_evals_constant.py index 6fc27d94e0..aa4b188e5f 100644 --- a/vertexai/_genai/_evals_constant.py +++ b/vertexai/_genai/_evals_constant.py @@ -53,6 +53,7 @@ CONTENT = "content" PARTS = "parts" USER_AUTHOR = "user" +AGENT_DATA = "agent_data" COMMON_DATASET_COLUMNS = frozenset( { diff --git a/vertexai/_genai/evals.py b/vertexai/_genai/evals.py index 03363dd1c6..5a4eab1de5 100644 --- a/vertexai/_genai/evals.py +++ b/vertexai/_genai/evals.py @@ -1309,6 +1309,7 @@ def run_inference( prompt_template=config.prompt_template, location=location, config=config.generate_content_config, + user_simulator_config=getattr(config, "user_simulator_config", None), ) def evaluate(