@@ -51,6 +51,12 @@ class ReplayVerificationError(Exception):
5151 pass
5252
5353
54+ class ReplayConfigError (Exception ):
55+ """Exception raised when replay configuration is invalid or missing."""
56+
57+ pass
58+
59+
5460class _InvocationReplayState (BaseModel ):
5561 """Per-invocation replay state to isolate concurrent runs."""
5662
@@ -93,7 +99,7 @@ async def before_model_callback(
9399 return None
94100
95101 if (state := self ._get_invocation_state (callback_context )) is None :
96- raise ValueError (
102+ raise ReplayConfigError (
97103 "Replay state not initialized. Ensure before_run created it."
98104 )
99105
@@ -122,7 +128,7 @@ async def before_tool_callback(
122128 return None
123129
124130 if (state := self ._get_invocation_state (tool_context )) is None :
125- raise ValueError (
131+ raise ReplayConfigError (
126132 "Replay state not initialized. Ensure before_run created it."
127133 )
128134
@@ -188,20 +194,24 @@ def _load_invocation_state(
188194 msg_index = config .get ("user_message_index" )
189195
190196 if not case_dir or msg_index is None :
191- raise ValueError ("Replay parameters are missing from session state" )
197+ raise ReplayConfigError (
198+ "Replay parameters are missing from session state"
199+ )
192200
193201 # Load recordings
194202 recordings_file = Path (case_dir ) / "generated-recordings.yaml"
195203
196204 if not recordings_file .exists ():
197- raise ValueError (f"Recordings file not found: { recordings_file } " )
205+ raise ReplayConfigError (f"Recordings file not found: { recordings_file } " )
198206
199207 try :
200208 with recordings_file .open ("r" , encoding = "utf-8" ) as f :
201209 recordings_data = yaml .safe_load (f )
202210 recordings = Recordings .model_validate (recordings_data )
203211 except Exception as e :
204- raise ValueError (f"Failed to load recordings from { recordings_file } : { e } " )
212+ raise ReplayConfigError (
213+ f"Failed to load recordings from { recordings_file } : { e } "
214+ ) from e
205215
206216 # Load and store invocation state
207217 state = _InvocationReplayState (
@@ -320,62 +330,28 @@ def _verify_llm_request_match(
320330 agent_index : int ,
321331 ) -> None :
322332 """Verify that the current LLM request exactly matches the recorded one."""
323- self ._verify_config_match (
324- recorded_request , current_request , agent_name , agent_index
325- )
326- handled_fields : set [str ] = {"config" }
327- ignored_fields : set [str ] = {"live_connect_config" }
328- exclude_fields = handled_fields | ignored_fields
329- if not self ._compare_fields (
330- recorded_request , current_request , exclude_fields = exclude_fields
331- ):
332- raise ValueError (
333- f"LLM request mismatch for agent '{ agent_name } ' (index"
334- f" { agent_index } ): "
335- "recorded:"
336- f" { recorded_request .model_dump (exclude_none = True , exclude = exclude_fields )} ,"
337- " current:"
338- f" { current_request .model_dump (exclude_none = True , exclude = exclude_fields )} "
339- )
340-
341- def _compare_fields (
342- self ,
343- obj1 : BaseModel ,
344- obj2 : BaseModel ,
345- * ,
346- exclude_fields : Optional [set [str ]] = None ,
347- ) -> bool :
348- """Compare two Pydantic models excluding specified fields."""
349- exclude_fields = exclude_fields or set ()
350- dict1 = obj1 .model_dump (exclude_none = True , exclude = exclude_fields )
351- dict2 = obj2 .model_dump (exclude_none = True , exclude = exclude_fields )
352- return dict1 == dict2
353-
354- def _verify_config_match (
355- self ,
356- recorded_request : LlmRequest ,
357- current_request : LlmRequest ,
358- agent_name : str ,
359- agent_index : int ,
360- ) -> None :
361- """Verify that the config matches between recorded and current requests."""
362- # Fields to ignore when comparing GenerateContentConfig (denylist approach)
363- ignored_fields : set [str ] = {
364- "http_options" ,
365- "labels" ,
333+ # Comprehensive exclude dict for all fields that can differ between runs
334+ excluded_fields = {
335+ "live_connect_config" : True ,
336+ "config" : { # some config fields can vary per run
337+ "http_options" : True ,
338+ "labels" : True ,
339+ },
366340 }
367341
368- if not self ._compare_fields (
369- recorded_request .config ,
370- current_request .config ,
371- exclude_fields = ignored_fields ,
372- ):
373- raise ValueError (
374- f"Config mismatch for agent '{ agent_name } ' (index { agent_index } ): "
375- "recorded:"
376- f" { recorded_request .config .model_dump (exclude_none = True , exclude = ignored_fields )} ,"
377- " current:"
378- f" { current_request .config .model_dump (exclude_none = True , exclude = ignored_fields )} "
342+ # Compare using model dumps with nested exclude dict
343+ recorded_dict = recorded_request .model_dump (
344+ exclude_none = True , exclude = excluded_fields , exclude_defaults = True
345+ )
346+ current_dict = current_request .model_dump (
347+ exclude_none = True , exclude = excluded_fields , exclude_defaults = True
348+ )
349+
350+ if recorded_dict != current_dict :
351+ raise ReplayVerificationError (
352+ f"""LLM request mismatch for agent '{ agent_name } ' (index { agent_index } ):
353+ recorded: { recorded_dict }
354+ current: { current_dict } """
379355 )
380356
381357 def _verify_tool_call_match (
@@ -389,12 +365,14 @@ def _verify_tool_call_match(
389365 """Verify that the current tool call exactly matches the recorded one."""
390366 if recorded_call .name != tool_name :
391367 raise ReplayVerificationError (
392- f"Tool name mismatch for agent '{ agent_name } ' at index { agent_index } :"
393- f" recorded='{ recorded_call .name } ', current='{ tool_name } '"
368+ f"""Tool name mismatch for agent '{ agent_name } ' at index { agent_index } :
369+ recorded: '{ recorded_call .name } '
370+ current: '{ tool_name } '"""
394371 )
395372
396373 if recorded_call .args != tool_args :
397374 raise ReplayVerificationError (
398- f"Tool args mismatch for agent '{ agent_name } ' at index { agent_index } :"
399- f" recorded={ recorded_call .args } , current={ tool_args } "
375+ f"""Tool args mismatch for agent '{ agent_name } ' at index { agent_index } :
376+ recorded: { recorded_call .args }
377+ current: { tool_args } """
400378 )
0 commit comments