11"""Unit tests for contextlib.py, and other context managers."""
22
33import io
4+ import os
45import sys
56import tempfile
67import threading
8+ import traceback
79import unittest
810from contextlib import * # Tests __all__
911from test import support
@@ -86,6 +88,56 @@ def woohoo():
8688 raise ZeroDivisionError ()
8789 self .assertEqual (state , [1 , 42 , 999 ])
8890
91+ def test_contextmanager_traceback (self ):
92+ @contextmanager
93+ def f ():
94+ yield
95+
96+ try :
97+ with f ():
98+ 1 / 0
99+ except ZeroDivisionError as e :
100+ frames = traceback .extract_tb (e .__traceback__ )
101+
102+ self .assertEqual (len (frames ), 1 )
103+ self .assertEqual (frames [0 ].name , 'test_contextmanager_traceback' )
104+ self .assertEqual (frames [0 ].line , '1/0' )
105+
106+ # Repeat with RuntimeError (which goes through a different code path)
107+ class RuntimeErrorSubclass (RuntimeError ):
108+ pass
109+
110+ try :
111+ with f ():
112+ raise RuntimeErrorSubclass (42 )
113+ except RuntimeErrorSubclass as e :
114+ frames = traceback .extract_tb (e .__traceback__ )
115+
116+ self .assertEqual (len (frames ), 1 )
117+ self .assertEqual (frames [0 ].name , 'test_contextmanager_traceback' )
118+ self .assertEqual (frames [0 ].line , 'raise RuntimeErrorSubclass(42)' )
119+
120+ class StopIterationSubclass (StopIteration ):
121+ pass
122+
123+ for stop_exc in (
124+ StopIteration ('spam' ),
125+ StopIterationSubclass ('spam' ),
126+ ):
127+ with self .subTest (type = type (stop_exc )):
128+ try :
129+ with f ():
130+ raise stop_exc
131+ except type (stop_exc ) as e :
132+ self .assertIs (e , stop_exc )
133+ frames = traceback .extract_tb (e .__traceback__ )
134+ else :
135+ self .fail (f'{ stop_exc } was suppressed' )
136+
137+ self .assertEqual (len (frames ), 1 )
138+ self .assertEqual (frames [0 ].name , 'test_contextmanager_traceback' )
139+ self .assertEqual (frames [0 ].line , 'raise stop_exc' )
140+
89141 def test_contextmanager_no_reraise (self ):
90142 @contextmanager
91143 def whee ():
@@ -126,19 +178,22 @@ def woohoo():
126178 self .assertEqual (state , [1 , 42 , 999 ])
127179
128180 def test_contextmanager_except_stopiter (self ):
129- stop_exc = StopIteration ('spam' )
130181 @contextmanager
131182 def woohoo ():
132183 yield
133- try :
134- with self .assertWarnsRegex (DeprecationWarning ,
135- "StopIteration" ):
136- with woohoo ():
137- raise stop_exc
138- except Exception as ex :
139- self .assertIs (ex , stop_exc )
140- else :
141- self .fail ('StopIteration was suppressed' )
184+
185+ class StopIterationSubclass (StopIteration ):
186+ pass
187+
188+ for stop_exc in (StopIteration ('spam' ), StopIterationSubclass ('spam' )):
189+ with self .subTest (type = type (stop_exc )):
190+ try :
191+ with woohoo ():
192+ raise stop_exc
193+ except Exception as ex :
194+ self .assertIs (ex , stop_exc )
195+ else :
196+ self .fail (f'{ stop_exc } was suppressed' )
142197
143198 # TODO: RUSTPYTHON
144199 @unittest .expectedFailure
@@ -230,6 +285,8 @@ class A:
230285 def woohoo (a , b ):
231286 a = weakref .ref (a )
232287 b = weakref .ref (b )
288+ # Allow test to work with a non-refcounted GC
289+ support .gc_collect ()
233290 self .assertIsNone (a ())
234291 self .assertIsNone (b ())
235292 yield
@@ -318,13 +375,13 @@ def testWithOpen(self):
318375 tfn = tempfile .mktemp ()
319376 try :
320377 f = None
321- with open (tfn , "w" ) as f :
378+ with open (tfn , "w" , encoding = "utf-8" ) as f :
322379 self .assertFalse (f .closed )
323380 f .write ("Booh\n " )
324381 self .assertTrue (f .closed )
325382 f = None
326383 with self .assertRaises (ZeroDivisionError ):
327- with open (tfn , "r" ) as f :
384+ with open (tfn , "r" , encoding = "utf-8" ) as f :
328385 self .assertFalse (f .closed )
329386 self .assertEqual (f .read (), "Booh\n " )
330387 1 / 0
@@ -493,7 +550,7 @@ def __unter__(self):
493550 def __exit__ (self , * exc ):
494551 pass
495552
496- with self .assertRaises ( AttributeError ):
553+ with self .assertRaisesRegex ( TypeError , 'the context manager' ):
497554 with mycontext ():
498555 pass
499556
@@ -505,7 +562,7 @@ def __enter__(self):
505562 def __uxit__ (self , * exc ):
506563 pass
507564
508- with self .assertRaises ( AttributeError ):
565+ with self .assertRaisesRegex ( TypeError , 'the context manager.*__exit__' ):
509566 with mycontext ():
510567 pass
511568
@@ -608,9 +665,9 @@ def _exit(*args, **kwds):
608665 stack .callback (arg = 1 )
609666 with self .assertRaises (TypeError ):
610667 self .exit_stack .callback (arg = 2 )
611- with self .assertWarns ( DeprecationWarning ):
668+ with self .assertRaises ( TypeError ):
612669 stack .callback (callback = _exit , arg = 3 )
613- self .assertEqual (result , [((), { 'arg' : 3 }) ])
670+ self .assertEqual (result , [])
614671
615672 def test_push (self ):
616673 exc_raised = ZeroDivisionError
@@ -665,6 +722,25 @@ def _exit():
665722 result .append (2 )
666723 self .assertEqual (result , [1 , 2 , 3 , 4 ])
667724
725+ def test_enter_context_errors (self ):
726+ class LacksEnterAndExit :
727+ pass
728+ class LacksEnter :
729+ def __exit__ (self , * exc_info ):
730+ pass
731+ class LacksExit :
732+ def __enter__ (self ):
733+ pass
734+
735+ with self .exit_stack () as stack :
736+ with self .assertRaisesRegex (TypeError , 'the context manager' ):
737+ stack .enter_context (LacksEnterAndExit ())
738+ with self .assertRaisesRegex (TypeError , 'the context manager' ):
739+ stack .enter_context (LacksEnter ())
740+ with self .assertRaisesRegex (TypeError , 'the context manager' ):
741+ stack .enter_context (LacksExit ())
742+ self .assertFalse (stack ._exit_callbacks )
743+
668744 def test_close (self ):
669745 result = []
670746 with self .exit_stack () as stack :
@@ -700,6 +776,38 @@ def test_exit_suppress(self):
700776 stack .push (lambda * exc : True )
701777 1 / 0
702778
779+ def test_exit_exception_traceback (self ):
780+ # This test captures the current behavior of ExitStack so that we know
781+ # if we ever unintendedly change it. It is not a statement of what the
782+ # desired behavior is (for instance, we may want to remove some of the
783+ # internal contextlib frames).
784+
785+ def raise_exc (exc ):
786+ raise exc
787+
788+ try :
789+ with self .exit_stack () as stack :
790+ stack .callback (raise_exc , ValueError )
791+ 1 / 0
792+ except ValueError as e :
793+ exc = e
794+
795+ self .assertIsInstance (exc , ValueError )
796+ ve_frames = traceback .extract_tb (exc .__traceback__ )
797+ expected = \
798+ [('test_exit_exception_traceback' , 'with self.exit_stack() as stack:' )] + \
799+ self .callback_error_internal_frames + \
800+ [('_exit_wrapper' , 'callback(*args, **kwds)' ),
801+ ('raise_exc' , 'raise exc' )]
802+
803+ self .assertEqual (
804+ [(f .name , f .line ) for f in ve_frames ], expected )
805+
806+ self .assertIsInstance (exc .__context__ , ZeroDivisionError )
807+ zde_frames = traceback .extract_tb (exc .__context__ .__traceback__ )
808+ self .assertEqual ([(f .name , f .line ) for f in zde_frames ],
809+ [('test_exit_exception_traceback' , '1/0' )])
810+
703811 def test_exit_exception_chaining_reference (self ):
704812 # Sanity check to make sure that ExitStack chaining matches
705813 # actual nested with statements
@@ -781,6 +889,40 @@ def suppress_exc(*exc_details):
781889 self .assertIsInstance (inner_exc , ValueError )
782890 self .assertIsInstance (inner_exc .__context__ , ZeroDivisionError )
783891
892+ def test_exit_exception_explicit_none_context (self ):
893+ # Ensure ExitStack chaining matches actual nested `with` statements
894+ # regarding explicit __context__ = None.
895+
896+ class MyException (Exception ):
897+ pass
898+
899+ @contextmanager
900+ def my_cm ():
901+ try :
902+ yield
903+ except BaseException :
904+ exc = MyException ()
905+ try :
906+ raise exc
907+ finally :
908+ exc .__context__ = None
909+
910+ @contextmanager
911+ def my_cm_with_exit_stack ():
912+ with self .exit_stack () as stack :
913+ stack .enter_context (my_cm ())
914+ yield stack
915+
916+ for cm in (my_cm , my_cm_with_exit_stack ):
917+ with self .subTest ():
918+ try :
919+ with cm ():
920+ raise IndexError ()
921+ except MyException as exc :
922+ self .assertIsNone (exc .__context__ )
923+ else :
924+ self .fail ("Expected IndexError, but no exception was raised" )
925+
784926 def test_exit_exception_non_suppressing (self ):
785927 # http://bugs.python.org/issue19092
786928 def raise_exc (exc ):
@@ -896,9 +1038,11 @@ def test_excessive_nesting(self):
8961038 def test_instance_bypass (self ):
8971039 class Example (object ): pass
8981040 cm = Example ()
1041+ cm .__enter__ = object ()
8991042 cm .__exit__ = object ()
9001043 stack = self .exit_stack ()
901- self .assertRaises (AttributeError , stack .enter_context , cm )
1044+ with self .assertRaisesRegex (TypeError , 'the context manager' ):
1045+ stack .enter_context (cm )
9021046 stack .push (cm )
9031047 self .assertIs (stack ._exit_callbacks [- 1 ][1 ], cm )
9041048
@@ -939,6 +1083,10 @@ def first():
9391083
9401084class TestExitStack (TestBaseExitStack , unittest .TestCase ):
9411085 exit_stack = ExitStack
1086+ callback_error_internal_frames = [
1087+ ('__exit__' , 'raise exc_details[1]' ),
1088+ ('__exit__' , 'if cb(*exc_details):' ),
1089+ ]
9421090
9431091
9441092class TestRedirectStream :
@@ -1064,5 +1212,53 @@ def test_cm_is_reentrant(self):
10641212 1 / 0
10651213 self .assertTrue (outer_continued )
10661214
1215+
1216+ class TestChdir (unittest .TestCase ):
1217+ def make_relative_path (self , * parts ):
1218+ return os .path .join (
1219+ os .path .dirname (os .path .realpath (__file__ )),
1220+ * parts ,
1221+ )
1222+
1223+ def test_simple (self ):
1224+ old_cwd = os .getcwd ()
1225+ target = self .make_relative_path ('data' )
1226+ self .assertNotEqual (old_cwd , target )
1227+
1228+ with chdir (target ):
1229+ self .assertEqual (os .getcwd (), target )
1230+ self .assertEqual (os .getcwd (), old_cwd )
1231+
1232+ def test_reentrant (self ):
1233+ old_cwd = os .getcwd ()
1234+ target1 = self .make_relative_path ('data' )
1235+ target2 = self .make_relative_path ('ziptestdata' )
1236+ self .assertNotIn (old_cwd , (target1 , target2 ))
1237+ chdir1 , chdir2 = chdir (target1 ), chdir (target2 )
1238+
1239+ with chdir1 :
1240+ self .assertEqual (os .getcwd (), target1 )
1241+ with chdir2 :
1242+ self .assertEqual (os .getcwd (), target2 )
1243+ with chdir1 :
1244+ self .assertEqual (os .getcwd (), target1 )
1245+ self .assertEqual (os .getcwd (), target2 )
1246+ self .assertEqual (os .getcwd (), target1 )
1247+ self .assertEqual (os .getcwd (), old_cwd )
1248+
1249+ def test_exception (self ):
1250+ old_cwd = os .getcwd ()
1251+ target = self .make_relative_path ('data' )
1252+ self .assertNotEqual (old_cwd , target )
1253+
1254+ try :
1255+ with chdir (target ):
1256+ self .assertEqual (os .getcwd (), target )
1257+ raise RuntimeError ("boom" )
1258+ except RuntimeError as re :
1259+ self .assertEqual (str (re ), "boom" )
1260+ self .assertEqual (os .getcwd (), old_cwd )
1261+
1262+
10671263if __name__ == "__main__" :
10681264 unittest .main ()
0 commit comments