@@ -102,7 +102,7 @@ def _test():
102102import sys
103103import traceback
104104import unittest
105- from io import StringIO
105+ from io import StringIO # XXX: RUSTPYTHON; , IncrementalNewlineDecoder
106106from collections import namedtuple
107107
108108TestResults = namedtuple ('TestResults' , 'failed attempted' )
@@ -211,17 +211,28 @@ def _normalize_module(module, depth=2):
211211 else :
212212 raise TypeError ("Expected a module, string, or None" )
213213
214+ def _newline_convert (data ):
215+ # The IO module provides a handy decoder for universal newline conversion
216+ return IncrementalNewlineDecoder (None , True ).decode (data , True )
217+
214218def _load_testfile (filename , package , module_relative , encoding ):
215219 if module_relative :
216220 package = _normalize_module (package , 3 )
217221 filename = _module_relative_path (package , filename )
218- if getattr (package , '__loader__' , None ) is not None :
219- if hasattr (package .__loader__ , 'get_data' ):
220- file_contents = package .__loader__ .get_data (filename )
221- file_contents = file_contents .decode (encoding )
222- # get_data() opens files as 'rb', so one must do the equivalent
223- # conversion as universal newlines would do.
224- return file_contents .replace (os .linesep , '\n ' ), filename
222+ if (loader := getattr (package , '__loader__' , None )) is None :
223+ try :
224+ loader = package .__spec__ .loader
225+ except AttributeError :
226+ pass
227+ if hasattr (loader , 'get_data' ):
228+ file_contents = loader .get_data (filename )
229+ file_contents = file_contents .decode (encoding )
230+ # get_data() opens files as 'rb', so one must do the equivalent
231+ # conversion as universal newlines would do.
232+
233+ # TODO: RUSTPYTHON; use _newline_convert once io.IncrementalNewlineDecoder is implemented
234+ return file_contents .replace (os .linesep , '\n ' ), filename
235+ # return _newline_convert(file_contents), filename
225236 with open (filename , encoding = encoding ) as f :
226237 return f .read (), filename
227238
@@ -965,6 +976,17 @@ def _from_module(self, module, object):
965976 else :
966977 raise ValueError ("object must be a class or function" )
967978
979+ def _is_routine (self , obj ):
980+ """
981+ Safely unwrap objects and determine if they are functions.
982+ """
983+ maybe_routine = obj
984+ try :
985+ maybe_routine = inspect .unwrap (maybe_routine )
986+ except ValueError :
987+ pass
988+ return inspect .isroutine (maybe_routine )
989+
968990 def _find (self , tests , obj , name , module , source_lines , globs , seen ):
969991 """
970992 Find tests for the given object and any contained objects, and
@@ -987,9 +1009,9 @@ def _find(self, tests, obj, name, module, source_lines, globs, seen):
9871009 if inspect .ismodule (obj ) and self ._recurse :
9881010 for valname , val in obj .__dict__ .items ():
9891011 valname = '%s.%s' % (name , valname )
1012+
9901013 # Recurse to functions & classes.
991- if ((inspect .isroutine (inspect .unwrap (val ))
992- or inspect .isclass (val )) and
1014+ if ((self ._is_routine (val ) or inspect .isclass (val )) and
9931015 self ._from_module (module , val )):
9941016 self ._find (tests , val , valname , module , source_lines ,
9951017 globs , seen )
@@ -1015,10 +1037,8 @@ def _find(self, tests, obj, name, module, source_lines, globs, seen):
10151037 if inspect .isclass (obj ) and self ._recurse :
10161038 for valname , val in obj .__dict__ .items ():
10171039 # Special handling for staticmethod/classmethod.
1018- if isinstance (val , staticmethod ):
1019- val = getattr (obj , valname )
1020- if isinstance (val , classmethod ):
1021- val = getattr (obj , valname ).__func__
1040+ if isinstance (val , (staticmethod , classmethod )):
1041+ val = val .__func__
10221042
10231043 # Recurse to methods, properties, and nested classes.
10241044 if ((inspect .isroutine (val ) or inspect .isclass (val ) or
@@ -1068,19 +1088,21 @@ def _get_test(self, obj, name, module, globs, source_lines):
10681088
10691089 def _find_lineno (self , obj , source_lines ):
10701090 """
1071- Return a line number of the given object's docstring. Note:
1072- this method assumes that the object has a docstring.
1091+ Return a line number of the given object's docstring.
1092+
1093+ Returns `None` if the given object does not have a docstring.
10731094 """
10741095 lineno = None
1096+ docstring = getattr (obj , '__doc__' , None )
10751097
10761098 # Find the line number for modules.
1077- if inspect .ismodule (obj ):
1099+ if inspect .ismodule (obj ) and docstring is not None :
10781100 lineno = 0
10791101
10801102 # Find the line number for classes.
10811103 # Note: this could be fooled if a class is defined multiple
10821104 # times in a single file.
1083- if inspect .isclass (obj ):
1105+ if inspect .isclass (obj ) and docstring is not None :
10841106 if source_lines is None :
10851107 return None
10861108 pat = re .compile (r'^\s*class\s*%s\b' %
@@ -1092,7 +1114,9 @@ def _find_lineno(self, obj, source_lines):
10921114
10931115 # Find the line number for functions & methods.
10941116 if inspect .ismethod (obj ): obj = obj .__func__
1095- if inspect .isfunction (obj ): obj = obj .__code__
1117+ if inspect .isfunction (obj ) and getattr (obj , '__doc__' , None ):
1118+ # We don't use `docstring` var here, because `obj` can be changed.
1119+ obj = obj .__code__
10961120 if inspect .istraceback (obj ): obj = obj .tb_frame
10971121 if inspect .isframe (obj ): obj = obj .f_code
10981122 if inspect .iscode (obj ):
@@ -1327,7 +1351,7 @@ def __run(self, test, compileflags, out):
13271351 try :
13281352 # Don't blink! This is where the user's code gets run.
13291353 exec (compile (example .source , filename , "single" ,
1330- compileflags , 1 ), test .globs )
1354+ compileflags , True ), test .globs )
13311355 self .debugger .set_continue () # ==== Example Finished ====
13321356 exception = None
13331357 except KeyboardInterrupt :
@@ -2154,6 +2178,7 @@ def __init__(self, test, optionflags=0, setUp=None, tearDown=None,
21542178 unittest .TestCase .__init__ (self )
21552179 self ._dt_optionflags = optionflags
21562180 self ._dt_checker = checker
2181+ self ._dt_globs = test .globs .copy ()
21572182 self ._dt_test = test
21582183 self ._dt_setUp = setUp
21592184 self ._dt_tearDown = tearDown
@@ -2170,7 +2195,9 @@ def tearDown(self):
21702195 if self ._dt_tearDown is not None :
21712196 self ._dt_tearDown (test )
21722197
2198+ # restore the original globs
21732199 test .globs .clear ()
2200+ test .globs .update (self ._dt_globs )
21742201
21752202 def runTest (self ):
21762203 test = self ._dt_test
0 commit comments