Skip to content

Commit e12b3a6

Browse files
CPython developersyouknowone
authored andcommitted
Update doctest from CPython 3.10.5
1 parent c87c8dc commit e12b3a6

File tree

1 file changed

+43
-19
lines changed

1 file changed

+43
-19
lines changed

Lib/doctest.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -211,17 +211,25 @@ def _normalize_module(module, depth=2):
211211
else:
212212
raise TypeError("Expected a module, string, or None")
213213

214+
def _newline_convert(data):
215+
# The IO module provides a handy decoder for universal newline conversion
216+
return IncrementalNewlineDecoder(None, True).decode(data, True)
217+
214218
def _load_testfile(filename, package, module_relative, encoding):
215219
if module_relative:
216220
package = _normalize_module(package, 3)
217221
filename = _module_relative_path(package, filename)
218-
if getattr(package, '__loader__', None) is not None:
219-
if hasattr(package.__loader__, 'get_data'):
220-
file_contents = package.__loader__.get_data(filename)
221-
file_contents = file_contents.decode(encoding)
222-
# get_data() opens files as 'rb', so one must do the equivalent
223-
# conversion as universal newlines would do.
224-
return file_contents.replace(os.linesep, '\n'), filename
222+
if (loader := getattr(package, '__loader__', None)) is None:
223+
try:
224+
loader = package.__spec__.loader
225+
except AttributeError:
226+
pass
227+
if hasattr(loader, 'get_data'):
228+
file_contents = loader.get_data(filename)
229+
file_contents = file_contents.decode(encoding)
230+
# get_data() opens files as 'rb', so one must do the equivalent
231+
# conversion as universal newlines would do.
232+
return _newline_convert(file_contents), filename
225233
with open(filename, encoding=encoding) as f:
226234
return f.read(), filename
227235

@@ -965,6 +973,17 @@ def _from_module(self, module, object):
965973
else:
966974
raise ValueError("object must be a class or function")
967975

976+
def _is_routine(self, obj):
977+
"""
978+
Safely unwrap objects and determine if they are functions.
979+
"""
980+
maybe_routine = obj
981+
try:
982+
maybe_routine = inspect.unwrap(maybe_routine)
983+
except ValueError:
984+
pass
985+
return inspect.isroutine(maybe_routine)
986+
968987
def _find(self, tests, obj, name, module, source_lines, globs, seen):
969988
"""
970989
Find tests for the given object and any contained objects, and
@@ -987,9 +1006,9 @@ def _find(self, tests, obj, name, module, source_lines, globs, seen):
9871006
if inspect.ismodule(obj) and self._recurse:
9881007
for valname, val in obj.__dict__.items():
9891008
valname = '%s.%s' % (name, valname)
1009+
9901010
# Recurse to functions & classes.
991-
if ((inspect.isroutine(inspect.unwrap(val))
992-
or inspect.isclass(val)) and
1011+
if ((self._is_routine(val) or inspect.isclass(val)) and
9931012
self._from_module(module, val)):
9941013
self._find(tests, val, valname, module, source_lines,
9951014
globs, seen)
@@ -1015,10 +1034,8 @@ def _find(self, tests, obj, name, module, source_lines, globs, seen):
10151034
if inspect.isclass(obj) and self._recurse:
10161035
for valname, val in obj.__dict__.items():
10171036
# Special handling for staticmethod/classmethod.
1018-
if isinstance(val, staticmethod):
1019-
val = getattr(obj, valname)
1020-
if isinstance(val, classmethod):
1021-
val = getattr(obj, valname).__func__
1037+
if isinstance(val, (staticmethod, classmethod)):
1038+
val = val.__func__
10221039

10231040
# Recurse to methods, properties, and nested classes.
10241041
if ((inspect.isroutine(val) or inspect.isclass(val) or
@@ -1068,19 +1085,21 @@ def _get_test(self, obj, name, module, globs, source_lines):
10681085

10691086
def _find_lineno(self, obj, source_lines):
10701087
"""
1071-
Return a line number of the given object's docstring. Note:
1072-
this method assumes that the object has a docstring.
1088+
Return a line number of the given object's docstring.
1089+
1090+
Returns `None` if the given object does not have a docstring.
10731091
"""
10741092
lineno = None
1093+
docstring = getattr(obj, '__doc__', None)
10751094

10761095
# Find the line number for modules.
1077-
if inspect.ismodule(obj):
1096+
if inspect.ismodule(obj) and docstring is not None:
10781097
lineno = 0
10791098

10801099
# Find the line number for classes.
10811100
# Note: this could be fooled if a class is defined multiple
10821101
# times in a single file.
1083-
if inspect.isclass(obj):
1102+
if inspect.isclass(obj) and docstring is not None:
10841103
if source_lines is None:
10851104
return None
10861105
pat = re.compile(r'^\s*class\s*%s\b' %
@@ -1092,7 +1111,9 @@ def _find_lineno(self, obj, source_lines):
10921111

10931112
# Find the line number for functions & methods.
10941113
if inspect.ismethod(obj): obj = obj.__func__
1095-
if inspect.isfunction(obj): obj = obj.__code__
1114+
if inspect.isfunction(obj) and getattr(obj, '__doc__', None):
1115+
# We don't use `docstring` var here, because `obj` can be changed.
1116+
obj = obj.__code__
10961117
if inspect.istraceback(obj): obj = obj.tb_frame
10971118
if inspect.isframe(obj): obj = obj.f_code
10981119
if inspect.iscode(obj):
@@ -1327,7 +1348,7 @@ def __run(self, test, compileflags, out):
13271348
try:
13281349
# Don't blink! This is where the user's code gets run.
13291350
exec(compile(example.source, filename, "single",
1330-
compileflags, 1), test.globs)
1351+
compileflags, True), test.globs)
13311352
self.debugger.set_continue() # ==== Example Finished ====
13321353
exception = None
13331354
except KeyboardInterrupt:
@@ -2154,6 +2175,7 @@ def __init__(self, test, optionflags=0, setUp=None, tearDown=None,
21542175
unittest.TestCase.__init__(self)
21552176
self._dt_optionflags = optionflags
21562177
self._dt_checker = checker
2178+
self._dt_globs = test.globs.copy()
21572179
self._dt_test = test
21582180
self._dt_setUp = setUp
21592181
self._dt_tearDown = tearDown
@@ -2170,7 +2192,9 @@ def tearDown(self):
21702192
if self._dt_tearDown is not None:
21712193
self._dt_tearDown(test)
21722194

2195+
# restore the original globs
21732196
test.globs.clear()
2197+
test.globs.update(self._dt_globs)
21742198

21752199
def runTest(self):
21762200
test = self._dt_test

0 commit comments

Comments
 (0)