@@ -259,7 +259,7 @@ class TestMyModel(MyModelTestConfig, ModelTesterMixin):
259259 pass
260260 """
261261
262- def test_from_save_pretrained (self , tmp_path , atol = 5e-5 , rtol = 0 ):
262+ def test_from_save_pretrained (self , tmp_path , atol = 5e-5 , rtol = 5e-5 ):
263263 torch .manual_seed (0 )
264264 model = self .model_class (** self .get_init_dict ())
265265 model .to (torch_device )
@@ -278,15 +278,8 @@ def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=0):
278278 )
279279
280280 with torch .no_grad ():
281- image = model (** self .get_dummy_inputs ())
282-
283- if isinstance (image , dict ):
284- image = image .to_tuple ()[0 ]
285-
286- new_image = new_model (** self .get_dummy_inputs ())
287-
288- if isinstance (new_image , dict ):
289- new_image = new_image .to_tuple ()[0 ]
281+ image = model (** self .get_dummy_inputs (), return_dict = False )[0 ]
282+ new_image = new_model (** self .get_dummy_inputs (), return_dict = False )[0 ]
290283
291284 assert_tensors_close (image , new_image , atol = atol , rtol = rtol , msg = "Models give different forward passes." )
292285
@@ -308,14 +301,8 @@ def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
308301 new_model .to (torch_device )
309302
310303 with torch .no_grad ():
311- image = model (** self .get_dummy_inputs ())
312- if isinstance (image , dict ):
313- image = image .to_tuple ()[0 ]
314-
315- new_image = new_model (** self .get_dummy_inputs ())
316-
317- if isinstance (new_image , dict ):
318- new_image = new_image .to_tuple ()[0 ]
304+ image = model (** self .get_dummy_inputs (), return_dict = False )[0 ]
305+ new_image = new_model (** self .get_dummy_inputs (), return_dict = False )[0 ]
319306
320307 assert_tensors_close (image , new_image , atol = atol , rtol = rtol , msg = "Models give different forward passes." )
321308
@@ -343,13 +330,8 @@ def test_determinism(self, atol=1e-5, rtol=0):
343330 model .eval ()
344331
345332 with torch .no_grad ():
346- first = model (** self .get_dummy_inputs ())
347- if isinstance (first , dict ):
348- first = first .to_tuple ()[0 ]
349-
350- second = model (** self .get_dummy_inputs ())
351- if isinstance (second , dict ):
352- second = second .to_tuple ()[0 ]
333+ first = model (** self .get_dummy_inputs (), return_dict = False )[0 ]
334+ second = model (** self .get_dummy_inputs (), return_dict = False )[0 ]
353335
354336 # Filter out NaN values before comparison
355337 first_flat = first .flatten ()
@@ -369,10 +351,7 @@ def test_output(self, expected_output_shape=None):
369351
370352 inputs_dict = self .get_dummy_inputs ()
371353 with torch .no_grad ():
372- output = model (** inputs_dict )
373-
374- if isinstance (output , dict ):
375- output = output .to_tuple ()[0 ]
354+ output = model (** inputs_dict , return_dict = False )[0 ]
376355
377356 assert output is not None , "Model output is None"
378357 assert output [0 ].shape == expected_output_shape or self .output_shape , (
@@ -501,13 +480,8 @@ def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
501480 assert param .data .dtype == dtype
502481
503482 with torch .no_grad ():
504- output = model (** self .get_dummy_inputs ())
505- if isinstance (output , dict ):
506- output = output .to_tuple ()[0 ]
507-
508- output_loaded = model_loaded (** self .get_dummy_inputs ())
509- if isinstance (output_loaded , dict ):
510- output_loaded = output_loaded .to_tuple ()[0 ]
483+ output = model (** self .get_dummy_inputs (), return_dict = False )[0 ]
484+ output_loaded = model_loaded (** self .get_dummy_inputs (), return_dict = False )[0 ]
511485
512486 assert_tensors_close (output , output_loaded , atol = 1e-4 , rtol = 0 , msg = f"Loaded model output differs for { dtype } " )
513487
@@ -519,7 +493,7 @@ def test_sharded_checkpoints(self, tmp_path):
519493 model = self .model_class (** config ).eval ()
520494 model = model .to (torch_device )
521495
522- base_output = model (** inputs_dict )
496+ base_output = model (** inputs_dict , return_dict = False )[ 0 ]
523497
524498 model_size = compute_module_persistent_sizes (model )["" ]
525499 max_shard_size = int ((model_size * 0.75 ) / (2 ** 10 )) # Convert to KB as these test models are small
@@ -539,10 +513,10 @@ def test_sharded_checkpoints(self, tmp_path):
539513
540514 torch .manual_seed (0 )
541515 inputs_dict_new = self .get_dummy_inputs ()
542- new_output = new_model (** inputs_dict_new )
516+ new_output = new_model (** inputs_dict_new , return_dict = False )[ 0 ]
543517
544518 assert_tensors_close (
545- base_output [ 0 ] , new_output [ 0 ] , atol = 1e-5 , rtol = 0 , msg = "Output should match after sharded save/load"
519+ base_output , new_output , atol = 1e-5 , rtol = 0 , msg = "Output should match after sharded save/load"
546520 )
547521
548522 @require_accelerator
@@ -553,7 +527,7 @@ def test_sharded_checkpoints_with_variant(self, tmp_path):
553527 model = self .model_class (** config ).eval ()
554528 model = model .to (torch_device )
555529
556- base_output = model (** inputs_dict )
530+ base_output = model (** inputs_dict , return_dict = False )[ 0 ]
557531
558532 model_size = compute_module_persistent_sizes (model )["" ]
559533 max_shard_size = int ((model_size * 0.75 ) / (2 ** 10 )) # Convert to KB as these test models are small
@@ -578,10 +552,10 @@ def test_sharded_checkpoints_with_variant(self, tmp_path):
578552
579553 torch .manual_seed (0 )
580554 inputs_dict_new = self .get_dummy_inputs ()
581- new_output = new_model (** inputs_dict_new )
555+ new_output = new_model (** inputs_dict_new , return_dict = False )[ 0 ]
582556
583557 assert_tensors_close (
584- base_output [ 0 ] , new_output [ 0 ] , atol = 1e-5 , rtol = 0 , msg = "Output should match after variant sharded save/load"
558+ base_output , new_output , atol = 1e-5 , rtol = 0 , msg = "Output should match after variant sharded save/load"
585559 )
586560
587561 def test_sharded_checkpoints_with_parallel_loading (self , tmp_path ):
@@ -593,7 +567,7 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path):
593567 model = self .model_class (** config ).eval ()
594568 model = model .to (torch_device )
595569
596- base_output = model (** inputs_dict )
570+ base_output = model (** inputs_dict , return_dict = False )[ 0 ]
597571
598572 model_size = compute_module_persistent_sizes (model )["" ]
599573 max_shard_size = int ((model_size * 0.75 ) / (2 ** 10 )) # Convert to KB as these test models are small
@@ -628,10 +602,10 @@ def test_sharded_checkpoints_with_parallel_loading(self, tmp_path):
628602
629603 torch .manual_seed (0 )
630604 inputs_dict_parallel = self .get_dummy_inputs ()
631- output_parallel = model_parallel (** inputs_dict_parallel )
605+ output_parallel = model_parallel (** inputs_dict_parallel , return_dict = False )[ 0 ]
632606
633607 assert_tensors_close (
634- base_output [ 0 ] , output_parallel [ 0 ] , atol = 1e-5 , rtol = 0 , msg = "Output should match with parallel loading"
608+ base_output , output_parallel , atol = 1e-5 , rtol = 0 , msg = "Output should match with parallel loading"
635609 )
636610
637611 finally :
@@ -652,7 +626,7 @@ def test_model_parallelism(self, tmp_path):
652626 model = model .to (torch_device )
653627
654628 torch .manual_seed (0 )
655- base_output = model (** inputs_dict )
629+ base_output = model (** inputs_dict , return_dict = False )[ 0 ]
656630
657631 model_size = compute_module_sizes (model )["" ]
658632 max_gpu_sizes = [int (p * model_size ) for p in self .model_split_percents ]
@@ -668,8 +642,8 @@ def test_model_parallelism(self, tmp_path):
668642 check_device_map_is_respected (new_model , new_model .hf_device_map )
669643
670644 torch .manual_seed (0 )
671- new_output = new_model (** inputs_dict )
645+ new_output = new_model (** inputs_dict , return_dict = False )[ 0 ]
672646
673647 assert_tensors_close (
674- base_output [ 0 ] , new_output [ 0 ] , atol = 1e-5 , rtol = 0 , msg = "Output should match with model parallelism"
648+ base_output , new_output , atol = 1e-5 , rtol = 0 , msg = "Output should match with model parallelism"
675649 )
0 commit comments