diff --git a/pyproject.toml b/pyproject.toml index c5eb301e..b374a91a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,21 +75,23 @@ module = [ ] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "fixtures.*" +ignore_missing_imports = true +follow_imports = "skip" + [[tool.mypy.overrides]] module = [ # FIXME(stephenfin): We would like to remove all modules from this list # except tests (we're not sadists) "testtools.assertions", "testtools.compat", - "testtools.content", - "testtools.content_type", "testtools.matchers.*", "testtools.monkey", "testtools.run", "testtools.runtest", "testtools.testcase", "testtools.testresult.*", - "testtools.testsuite", "testtools.twistedsupport.*", "tests.*", ] diff --git a/tests/matchers/test_basic.py b/tests/matchers/test_basic.py index 07aa83c9..8d5c3ad7 100644 --- a/tests/matchers/test_basic.py +++ b/tests/matchers/test_basic.py @@ -4,10 +4,7 @@ from typing import ClassVar from testtools import TestCase -from testtools.compat import ( - _b, - text_repr, -) +from testtools.compat import text_repr from testtools.matchers._basic import ( Contains, DoesNotEndWith, @@ -36,7 +33,7 @@ class Test_BinaryMismatch(TestCase): """Mismatches from binary comparisons need useful describe output""" _long_string = "This is a longish multiline non-ascii string\n\xa7" - _long_b = _b(_long_string) + _long_b = _long_string.encode("utf-8") _long_u = _long_string class CustomRepr: @@ -52,12 +49,12 @@ def test_short_objects(self): self.assertEqual(mismatch.describe(), f"{o1!r} !~ {o2!r}") def test_short_mixed_strings(self): - b, u = _b("\xa7"), "\xa7" + b, u = b"\xa7", "\xa7" mismatch = _BinaryMismatch(b, "!~", u) self.assertEqual(mismatch.describe(), f"{b!r} !~ {u!r}") def test_long_bytes(self): - one_line_b = self._long_b.replace(_b("\n"), _b(" ")) + one_line_b = self._long_b.replace(b"\n", b" ") mismatch = _BinaryMismatch(one_line_b, "!~", self._long_b) self.assertEqual( mismatch.describe(), @@ -249,8 +246,8 @@ def test_describe_non_ascii_unicode(self): ) def test_describe_non_ascii_bytes(self): - string = _b("A\xa7") - suffix = _b("B\xa7") + string = b"A\xa7" + suffix = b"B\xa7" mismatch = DoesNotStartWith(string, suffix) self.assertEqual( f"{string!r} does not start with {suffix!r}.", mismatch.describe() @@ -265,7 +262,7 @@ def test_str(self): self.assertEqual("StartsWith('bar')", str(matcher)) def test_str_with_bytes(self): - b = _b("\xa7") + b = b"\xa7" matcher = StartsWith(b) self.assertEqual(f"StartsWith({b!r})", str(matcher)) @@ -310,8 +307,8 @@ def test_describe_non_ascii_unicode(self): ) def test_describe_non_ascii_bytes(self): - string = _b("A\xa7") - suffix = _b("B\xa7") + string = b"A\xa7" + suffix = b"B\xa7" mismatch = DoesNotEndWith(string, suffix) self.assertEqual( f"{string!r} does not end with {suffix!r}.", mismatch.describe() @@ -326,7 +323,7 @@ def test_str(self): self.assertEqual("EndsWith('bar')", str(matcher)) def test_str_with_bytes(self): - b = _b("\xa7") + b = b"\xa7" matcher = EndsWith(b) self.assertEqual(f"EndsWith({b!r})", str(matcher)) @@ -416,7 +413,7 @@ class TestMatchesRegex(TestCase, TestMatchersInterface): ("MatchesRegex('a|b')", MatchesRegex("a|b")), ("MatchesRegex('a|b', re.M)", MatchesRegex("a|b", re.M)), ("MatchesRegex('a|b', re.I|re.M)", MatchesRegex("a|b", re.I | re.M)), - ("MatchesRegex({!r})".format(_b("\xa7")), MatchesRegex(_b("\xa7"))), + ("MatchesRegex({!r})".format(b"\xa7"), MatchesRegex(b"\xa7")), ("MatchesRegex({!r})".format("\xa7"), MatchesRegex("\xa7")), ] @@ -424,9 +421,9 @@ class TestMatchesRegex(TestCase, TestMatchersInterface): ("'c' does not match /a|b/", "c", MatchesRegex("a|b")), ("'c' does not match /a\\d/", "c", MatchesRegex(r"a\d")), ( - "{!r} does not match /\\s+\\xa7/".format(_b("c")), - _b("c"), - MatchesRegex(_b("\\s+\xa7")), + "{!r} does not match /\\s+\\xa7/".format(b"c"), + b"c", + MatchesRegex(b"\\s+\xa7"), ), ("{!r} does not match /\\s+\\xa7/".format("c"), "c", MatchesRegex("\\s+\xa7")), ] diff --git a/tests/matchers/test_doctest.py b/tests/matchers/test_doctest.py index f413d88c..6e16a41a 100644 --- a/tests/matchers/test_doctest.py +++ b/tests/matchers/test_doctest.py @@ -4,9 +4,6 @@ from typing import ClassVar from testtools import TestCase -from testtools.compat import ( - _b, -) from testtools.matchers._doctest import DocTestMatches from ..helpers import FullStackRunTest @@ -75,7 +72,7 @@ def test_describe_non_ascii_bytes(self): permits arbitrary binary inputs. This is a slightly bogus thing to do, and under Python 3 using bytes objects will reasonably raise an error. """ - header = _b("\x89PNG\r\n\x1a\n...") + header = b"\x89PNG\r\n\x1a\n..." self.assertRaises(TypeError, DocTestMatches, header, doctest.ELLIPSIS) diff --git a/tests/test_compat.py b/tests/test_compat.py index 5397935a..957fa095 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -5,12 +5,9 @@ import ast import io import sys -import traceback import testtools from testtools.compat import ( - _b, - reraise, text_repr, unicode_output_stream, ) @@ -45,28 +42,28 @@ def test_no_encoding_becomes_ascii(self): """A stream with no encoding attribute gets ascii/replace strings""" sout = _FakeOutputStream() unicode_output_stream(sout).write(self.uni) - self.assertEqual([_b("pa???n")], sout.writelog) + self.assertEqual([b"pa???n"], sout.writelog) def test_encoding_as_none_becomes_ascii(self): """A stream with encoding value of None gets ascii/replace strings""" sout = _FakeOutputStream() sout.encoding = None unicode_output_stream(sout).write(self.uni) - self.assertEqual([_b("pa???n")], sout.writelog) + self.assertEqual([b"pa???n"], sout.writelog) def test_bogus_encoding_becomes_ascii(self): """A stream with a bogus encoding gets ascii/replace strings""" sout = _FakeOutputStream() sout.encoding = "bogus" unicode_output_stream(sout).write(self.uni) - self.assertEqual([_b("pa???n")], sout.writelog) + self.assertEqual([b"pa???n"], sout.writelog) def test_partial_encoding_replace(self): """A string which can be partly encoded correctly should be""" sout = _FakeOutputStream() sout.encoding = "iso-8859-7" unicode_output_stream(sout).write(self.uni) - self.assertEqual([_b("pa?\xe8?n")], sout.writelog) + self.assertEqual([b"pa?\xe8?n"], sout.writelog) def test_stringio(self): """A StringIO object should maybe get an ascii native str type""" @@ -126,11 +123,11 @@ class TestTextRepr(testtools.TestCase): # Bytes with the high bit set should always be escaped bytes_examples = ( - (_b("\x80"), "'\\x80'", "'''\\\n\\x80'''"), - (_b("\xa0"), "'\\xa0'", "'''\\\n\\xa0'''"), - (_b("\xc0"), "'\\xc0'", "'''\\\n\\xc0'''"), - (_b("\xff"), "'\\xff'", "'''\\\n\\xff'''"), - (_b("\xc2\xa7"), "'\\xc2\\xa7'", "'''\\\n\\xc2\\xa7'''"), + (b"\x80", "'\\x80'", "'''\\\n\\x80'''"), + (b"\xa0", "'\\xa0'", "'''\\\n\\xa0'''"), + (b"\xc0", "'\\xc0'", "'''\\\n\\xc0'''"), + (b"\xff", "'\\xff'", "'''\\\n\\xff'''"), + (b"\xc2\xa7", "'\\xc2\\xa7'", "'''\\\n\\xc2\\xa7'''"), ) # Unicode doesn't escape printable characters as per the Python 3 model @@ -153,12 +150,12 @@ class TestTextRepr(testtools.TestCase): # Unprintable general categories not fully tested: Cc, Cf, Co, Cn, Zs ) - b_prefix = repr(_b(""))[:-2] + b_prefix = repr(b"")[:-2] u_prefix = repr("")[:-2] def test_ascii_examples_oneline_bytes(self): for s, expected, _ in self.ascii_examples: - b = _b(s) + b = s.encode("utf-8") actual = text_repr(b, multiline=False) # Add self.assertIsInstance check? self.assertEqual(actual, self.b_prefix + expected) @@ -173,7 +170,7 @@ def test_ascii_examples_oneline_unicode(self): def test_ascii_examples_multiline_bytes(self): for s, _, expected in self.ascii_examples: - b = _b(s) + b = s.encode("utf-8") actual = text_repr(b, multiline=True) self.assertEqual(actual, self.b_prefix + expected) self.assertEqual(ast.literal_eval(actual), b) @@ -187,7 +184,7 @@ def test_ascii_examples_multiline_unicode(self): def test_ascii_examples_defaultline_bytes(self): for s, one, multi in self.ascii_examples: expected = ("\n" in s and multi) or one - self.assertEqual(text_repr(_b(s)), self.b_prefix + expected) + self.assertEqual(text_repr(s.encode("utf-8")), self.b_prefix + expected) def test_ascii_examples_defaultline_unicode(self): for s, one, multi in self.ascii_examples: @@ -219,43 +216,6 @@ def test_unicode_examples_multiline(self): self.assertEqual(ast.literal_eval(actual), u) -class TestReraise(testtools.TestCase): - """Tests for trivial reraise wrapper needed for Python 2/3 changes""" - - def test_exc_info(self): - """After reraise exc_info matches plus some extra traceback""" - try: - raise ValueError("Bad value") - except ValueError: - _exc_info = sys.exc_info() - try: - reraise(*_exc_info) - except ValueError: - _new_exc_info = sys.exc_info() - self.assertIs(_exc_info[0], _new_exc_info[0]) - self.assertIs(_exc_info[1], _new_exc_info[1]) - expected_tb = traceback.extract_tb(_exc_info[2]) - self.assertEqual( - expected_tb, traceback.extract_tb(_new_exc_info[2])[-len(expected_tb) :] - ) - - def test_custom_exception_no_args(self): - """Reraising does not require args attribute to contain params""" - - class CustomException(Exception): - """Exception that expects and sets attrs but not args""" - - def __init__(self, value): - Exception.__init__(self) - self.value = value - - try: - raise CustomException("Some value") - except CustomException: - _exc_info = sys.exc_info() - self.assertRaises(CustomException, reraise, *_exc_info) - - def test_suite(): from unittest import TestLoader diff --git a/tests/test_content.py b/tests/test_content.py index d377346e..83e87698 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -7,9 +7,6 @@ from typing import Any from testtools import TestCase -from testtools.compat import ( - _b, -) from testtools.content import ( JSON, Content, @@ -58,15 +55,15 @@ def test___eq__(self): content_type = ContentType("foo", "bar") def one_chunk(): - return [_b("bytes")] + return [b"bytes"] def two_chunk(): - return [_b("by"), _b("tes")] + return [b"by", b"tes"] content1 = Content(content_type, one_chunk) content2 = Content(content_type, one_chunk) content3 = Content(content_type, two_chunk) - content4 = Content(content_type, lambda: [_b("by"), _b("te")]) + content4 = Content(content_type, lambda: [b"by", b"te"]) content5 = Content(ContentType("f", "b"), two_chunk) self.assertEqual(content1, content2) self.assertEqual(content1, content3) @@ -76,7 +73,7 @@ def two_chunk(): def test___repr__(self): content = Content( ContentType("application", "octet-stream"), - lambda: [_b("\x00bin"), _b("ary\xff")], + lambda: [b"\x00bin", b"ary\xff"], ) self.assertIn("\\x00binary\\xff", repr(content)) @@ -105,12 +102,12 @@ def test_as_text(self): def test_from_file(self): fd, path = tempfile.mkstemp() self.addCleanup(os.remove, path) - os.write(fd, _b("some data")) + os.write(fd, b"some data") os.close(fd) content = content_from_file(path, UTF8_TEXT, chunk_size=2) self.assertThat( list(content.iter_bytes()), - Equals([_b("so"), _b("me"), _b(" d"), _b("at"), _b("a")]), + Equals([b"so", b"me", b" d", b"at", b"a"]), ) def test_from_nonexistent_file(self): @@ -125,7 +122,7 @@ def test_from_file_default_type(self): def test_from_file_eager_loading(self): fd, path = tempfile.mkstemp() - os.write(fd, _b("some data")) + os.write(fd, b"some data") os.close(fd) content = content_from_file(path, UTF8_TEXT, buffer_now=True) os.remove(path) @@ -133,21 +130,21 @@ def test_from_file_eager_loading(self): def test_from_file_with_simple_seek(self): f = tempfile.NamedTemporaryFile() - f.write(_b("some data")) + f.write(b"some data") f.flush() self.addCleanup(f.close) content = content_from_file(f.name, UTF8_TEXT, chunk_size=50, seek_offset=5) - self.assertThat(list(content.iter_bytes()), Equals([_b("data")])) + self.assertThat(list(content.iter_bytes()), Equals([b"data"])) def test_from_file_with_whence_seek(self): f = tempfile.NamedTemporaryFile() - f.write(_b("some data")) + f.write(b"some data") f.flush() self.addCleanup(f.close) content = content_from_file( f.name, UTF8_TEXT, chunk_size=50, seek_offset=-4, seek_whence=2 ) - self.assertThat(list(content.iter_bytes()), Equals([_b("data")])) + self.assertThat(list(content.iter_bytes()), Equals([b"data"])) def test_from_stream(self): data = io.StringIO("some data") @@ -165,24 +162,24 @@ def test_from_stream_eager_loading(self): fd, path = tempfile.mkstemp() self.addCleanup(os.remove, path) self.addCleanup(os.close, fd) - os.write(fd, _b("some data")) + os.write(fd, b"some data") stream = open(path, "rb") self.addCleanup(stream.close) content = content_from_stream(stream, UTF8_TEXT, buffer_now=True) - os.write(fd, _b("more data")) + os.write(fd, b"more data") self.assertThat("".join(content.iter_text()), Equals("some data")) def test_from_stream_with_simple_seek(self): - data = io.BytesIO(_b("some data")) + data = io.BytesIO(b"some data") content = content_from_stream(data, UTF8_TEXT, chunk_size=50, seek_offset=5) - self.assertThat(list(content.iter_bytes()), Equals([_b("data")])) + self.assertThat(list(content.iter_bytes()), Equals([b"data"])) def test_from_stream_with_whence_seek(self): - data = io.BytesIO(_b("some data")) + data = io.BytesIO(b"some data") content = content_from_stream( data, UTF8_TEXT, chunk_size=50, seek_offset=-4, seek_whence=2 ) - self.assertThat(list(content.iter_bytes()), Equals([_b("data")])) + self.assertThat(list(content.iter_bytes()), Equals([b"data"])) def test_from_text(self): data = "some data" @@ -190,7 +187,7 @@ def test_from_text(self): self.assertEqual(expected, text_content(data)) def test_text_content_raises_TypeError_when_passed_bytes(self): - data = _b("Some Bytes") + data = b"Some Bytes" self.assertRaises(TypeError, text_content, data) def test_text_content_raises_TypeError_when_passed_non_text(self): @@ -208,7 +205,7 @@ def test_text_content_raises_TypeError_when_passed_non_text(self): def test_json_content(self): data = {"foo": "bar"} - expected = Content(JSON, lambda: [_b('{"foo": "bar"}')]) + expected = Content(JSON, lambda: [b'{"foo": "bar"}']) self.assertEqual(expected, json_content(data)) @@ -302,7 +299,7 @@ def make_file(self, data): # always close the fd. There must be a better way. fd, path = tempfile.mkstemp() self.addCleanup(os.remove, path) - os.write(fd, _b(data)) + os.write(fd, data.encode("utf-8")) os.close(fd) return path diff --git a/tests/test_fixturesupport.py b/tests/test_fixturesupport.py index da0f9411..178dcb5b 100644 --- a/tests/test_fixturesupport.py +++ b/tests/test_fixturesupport.py @@ -7,7 +7,6 @@ content, content_type, ) -from testtools.compat import _b from testtools.matchers import ( Contains, Equals, @@ -70,7 +69,7 @@ class DetailsFixture(fixtures.Fixture): def setUp(self): fixtures.Fixture.setUp(self) self.addCleanup(delattr, self, "content") - self.content = [_b("content available until cleanUp")] + self.content = [b"content available until cleanUp"] self.addDetail( "content", content.Content(content_type.UTF8_TEXT, self.get_content) ) @@ -86,7 +85,7 @@ def test_foo(self): # Add a colliding detail (both should show up) self.addDetail( "content", - content.Content(content_type.UTF8_TEXT, lambda: [_b("foo")]), + content.Content(content_type.UTF8_TEXT, lambda: [b"foo"]), ) result = ExtendedTestResult() diff --git a/tests/test_run.py b/tests/test_run.py index ec5758bc..514d346d 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -11,7 +11,6 @@ import testtools from testtools import TestCase, run, skipUnless -from testtools.compat import _b from testtools.matchers import ( Contains, DocTestMatches, @@ -40,8 +39,7 @@ def __init__(self, broken=False): :param broken: If True, the sample file will not be importable. """ if not broken: - init_contents = _b( - """\ + init_contents = b"""\ from testtools import TestCase class TestFoo(TestCase): @@ -53,7 +51,6 @@ def test_suite(): from unittest import TestLoader return TestLoader().loadTestsFromName(__name__) """ - ) else: init_contents = b"class not in\n" self.package = fixtures.PythonPackage( @@ -80,8 +77,7 @@ def __init__(self): [ ( "__init__.py", - _b( - """ + b""" from fixtures import Fixture from testresources import ( FixtureResource, @@ -112,8 +108,7 @@ def test_quux(self): def test_suite(): from unittest import TestLoader return OptimisingTestSuite(TestLoader().loadTestsFromName(__name__)) -""" - ), +""", ) ], ) @@ -137,8 +132,7 @@ def __init__(self): [ ( "__init__.py", - _b( - """ + b""" from testtools import TestCase, clone_test_with_new_id class TestExample(TestCase): @@ -148,8 +142,7 @@ def test_foo(self): def load_tests(loader, tests, pattern): tests.addTest(clone_test_with_new_id(tests._tests[1]._tests[0], "fred")) return tests -""" - ), +""", ) ], ) @@ -288,12 +281,10 @@ def test_run_orders_tests(self): f = open(tempname, "wb") try: f.write( - _b( - """ + b""" testtools.runexample.TestFoo.test_bar testtools.runexample.missingtest """ - ) ) finally: f.close() @@ -328,12 +319,10 @@ def test_run_load_list(self): f = open(tempname, "wb") try: f.write( - _b( - """ + b""" testtools.runexample.TestFoo.test_bar testtools.runexample.missingtest """ - ) ) finally: f.close() @@ -368,12 +357,10 @@ def test_load_list_preserves_custom_suites(self): f = open(tempname, "wb") try: f.write( - _b( - """ + b""" testtools.resourceexample.TestFoo.test_bar testtools.resourceexample.TestFoo.test_foo """ - ) ) finally: f.close() diff --git a/tests/test_testcase.py b/tests/test_testcase.py index 9340f73d..a10ea6b3 100644 --- a/tests/test_testcase.py +++ b/tests/test_testcase.py @@ -22,9 +22,6 @@ skipUnless, testcase, ) -from testtools.compat import ( - _b, -) from testtools.content import ( TracebackContent, text_content, @@ -1240,7 +1237,7 @@ def assertDetailsProvided(self, case, expected_outcome, expected_keys): self.assertEqual(expected[-1], result._events[-1]) def get_content(self): - return content.Content(content.ContentType("text", "foo"), lambda: [_b("foo")]) + return content.Content(content.ContentType("text", "foo"), lambda: [b"foo"]) class TestExpectedFailure(TestWithDetails): @@ -1435,12 +1432,12 @@ def test_cloned_testcase_does_not_share_details(self): class Test(TestCase): def test_foo(self): - self.addDetail("foo", content.Content("text/plain", lambda: "foo")) + self.addDetail("foo", content.Content("text/plain", lambda: [b"foo"])) orig_test = Test("test_foo") cloned_test = clone_test_with_new_id(orig_test, self.getUniqueString()) orig_test.run(unittest.TestResult()) - self.assertEqual("foo", orig_test.getDetails()["foo"].iter_bytes()) + self.assertEqual(b"foo", b"".join(orig_test.getDetails()["foo"].iter_bytes())) self.assertEqual(None, cloned_test.getDetails().get("foo")) diff --git a/tests/test_testresult.py b/tests/test_testresult.py index b068b0ed..65b74919 100644 --- a/tests/test_testresult.py +++ b/tests/test_testresult.py @@ -44,10 +44,6 @@ TimestampingStreamResult, testresult, ) -from testtools.compat import ( - _b, - _get_exception_encoding, -) from testtools.content import ( Content, TracebackContent, @@ -605,8 +601,8 @@ def test_files(self): ) param_dicts = self._power_set(inputs) for kwargs in param_dicts: - result.status(file_name="foo", file_bytes=_b(""), **kwargs) - result.status(file_name="foo", file_bytes=_b("bar"), **kwargs) + result.status(file_name="foo", file_bytes=b"", **kwargs) + result.status(file_name="foo", file_bytes=b"bar", **kwargs) def test_test_status(self): # Tests non-file attachment parameter combinations. @@ -1129,14 +1125,14 @@ def test_files_reported(self): result.startTestRun() result.status( file_name="some log.txt", - file_bytes=_b("1234 log message"), + file_bytes=b"1234 log message", eof=True, mime_type="text/plain; charset=utf8", test_id="foo.bar", ) result.status( file_name="another file", - file_bytes=_b("""Traceback..."""), + file_bytes=b"""Traceback...""", test_id="foo.bar", ) result.stopTestRun() @@ -1147,7 +1143,7 @@ def test_files_reported(self): details = test["details"] self.assertEqual("1234 log message", details["some log.txt"].as_text()) self.assertEqual( - _b("Traceback..."), _b("").join(details["another file"].iter_bytes()) + b"Traceback...", b"".join(details["another file"].iter_bytes()) ) self.assertEqual( "application/octet-stream", repr(details["another file"].content_type) @@ -1290,7 +1286,7 @@ def test_empty_detail_status_correct(self): None, True, "foo", - _b(""), + b"", True, 'text/plain; charset="utf8"', None, @@ -1507,7 +1503,7 @@ def test_status_skip(self): result.startTestRun() result.status( file_name="reason", - file_bytes=_b("Missing dependency"), + file_bytes=b"Missing dependency", eof=True, mime_type="text/plain; charset=utf8", test_id="foo.bar", @@ -1520,23 +1516,21 @@ def test_status_skip(self): def _report_files(self, result): result.status( file_name="some log.txt", - file_bytes=_b("1234 log message"), + file_bytes=b"1234 log message", eof=True, mime_type="text/plain; charset=utf8", test_id="foo.bar", ) result.status( file_name="traceback", - file_bytes=_b( - """Traceback (most recent call last): + file_bytes=b"""Traceback (most recent call last): File "tests/test_testresult.py", line 607, in test_stopTestRun AllMatch(Equals([('startTestRun',), ('stopTestRun',)]))) testtools.matchers._impl.MismatchError: Differences: [ [('startTestRun',), ('stopTestRun',)] != [] [('startTestRun',), ('stopTestRun',)] != [] ] -""" - ), +""", eof=True, mime_type="text/plain; charset=utf8", test_id="foo.bar", @@ -2675,7 +2669,7 @@ def check_event(event_dict, route=None, time=None): self.assertEqual({"quux"}, event_dict["test_tags"]) self.assertEqual(False, event_dict["runnable"]) self.assertEqual("file", event_dict["file_name"]) - self.assertEqual(_b("content"), event_dict["file_bytes"]) + self.assertEqual(b"content", event_dict["file_bytes"]) self.assertEqual(True, event_dict["eof"]) self.assertEqual("quux", event_dict["mime_type"]) self.assertEqual("test", event_dict["test_id"]) @@ -2689,7 +2683,7 @@ def check_event(event_dict, route=None, time=None): test_tags={"quux"}, runnable=False, file_name="file", - file_bytes=_b("content"), + file_bytes=b"content", eof=True, mime_type="quux", route_code=None, @@ -2703,7 +2697,7 @@ def check_event(event_dict, route=None, time=None): test_tags={"quux"}, runnable=False, file_name="file", - file_bytes=_b("content"), + file_bytes=b"content", eof=True, mime_type="quux", route_code="bar", @@ -2750,13 +2744,13 @@ def get_details_and_string(self): """Get a details dict and expected string.""" def text1(): - return [_b("1\n2\n")] + return [b"1\n2\n"] def text2(): - return [_b("3\n4\n")] + return [b"3\n4\n"] def bin1(): - return [_b("5\n")] + return [b"5\n"] details = { "text 1": Content(ContentType("text", "plain"), text1), @@ -2999,7 +2993,7 @@ def test_outcome_Extended_py3_no_reason(self): def test_outcome_Extended_py3_reason(self): self.make_result() self.check_outcome_details_to_arg( - self.outcome, "foo", {"reason": Content(UTF8_TEXT, lambda: [_b("foo")])} + self.outcome, "foo", {"reason": Content(UTF8_TEXT, lambda: [b"foo"])} ) def test_outcome_Extended_pyextended(self): @@ -3167,13 +3161,13 @@ def _as_output(self, text): def test_non_ascii_failure_string(self): """Assertion contents can be non-ascii and should get decoded""" - text, raw = self._get_sample_text(_get_exception_encoding()) + text, raw = self._get_sample_text("utf-8") textoutput = self._test_external_case(f"self.fail({raw!a})") self.assertIn(self._as_output(text), textoutput) def test_non_ascii_failure_string_via_exec(self): """Assertion via exec can be non-ascii and still gets decoded""" - text, raw = self._get_sample_text(_get_exception_encoding()) + text, raw = self._get_sample_text("utf-8") textoutput = self._test_external_case(testline=f'exec ("self.fail({raw!a})")') self.assertIn(self._as_output(text), textoutput) diff --git a/testtools/compat.py b/testtools/compat.py index b6ece06b..5c85a71d 100644 --- a/testtools/compat.py +++ b/testtools/compat.py @@ -1,47 +1,19 @@ # Copyright (c) 2008-2015 testtools developers. See LICENSE for details. -"""Compatibility support for python 2 and 3.""" +"""Compatibility support - kept for backwards compatibility.""" __all__ = [ "BytesIO", "StringIO", - "_b", - "advance_iterator", - "reraise", + "text_repr", "unicode_output_stream", ] import codecs import io -import locale -import os import sys -import types import unicodedata from io import BytesIO, StringIO # for backwards-compat -from typing import Any, NoReturn - - -def reraise( - exc_class: type[BaseException], - exc_obj: BaseException, - exc_tb: types.TracebackType, - _marker: Any = object(), -) -> NoReturn: - """Re-raise an exception received from sys.exc_info() or similar.""" - raise exc_obj.with_traceback(exc_tb) - - -def _u(s): - return s - - -def _b(s): - """A byte literal.""" - return s.encode("latin-1") - - -advance_iterator = next def _slow_escape(text): @@ -149,15 +121,3 @@ def unicode_output_stream(stream): except AttributeError: pass return writer(stream, "replace") - - -def _get_exception_encoding(): - """Return the encoding we expect messages from the OS to be encoded in""" - if os.name == "nt": - # GZ 2010-05-24: Really want the codepage number instead, the error - # handling of standard codecs is more deterministic - return "mbcs" - # GZ 2010-05-23: We need this call to be after initialisation, but there's - # no benefit in asking more than once as it's a global - # setting that can change after the message is formatted. - return locale.getlocale(locale.LC_MESSAGES)[1] or "ascii" diff --git a/testtools/content.py b/testtools/content.py index ef606068..002fbfd6 100644 --- a/testtools/content.py +++ b/testtools/content.py @@ -17,11 +17,29 @@ import json import os import traceback +import types +from collections.abc import Callable, Iterable, Iterator +from typing import IO, Any, Protocol, TypeAlias, runtime_checkable -from testtools.compat import _b from testtools.content_type import JSON, UTF8_TEXT, ContentType -_join_b = _b("").join +# Type for JSON-serializable data +JSONType: TypeAlias = ( + dict[str, "JSONType"] | list["JSONType"] | str | int | float | bool | None +) + + +class _Detailed(Protocol): + """Protocol for objects that have an addDetail method.""" + + def addDetail(self, name: str, content_object: "Content") -> None: ... + + +@runtime_checkable +class _TestCase(Protocol): + """Protocol for test objects used in TracebackContent.""" + + failureException: Any # Can be type[BaseException], tuple, or None DEFAULT_CHUNK_SIZE = 4096 @@ -30,7 +48,12 @@ STDERR_LINE = "\nStderr:\n%s" -def _iter_chunks(stream, chunk_size, seek_offset=None, seek_whence=0): +def _iter_chunks( + stream: IO[bytes], + chunk_size: int, + seek_offset: int | None = None, + seek_whence: int = 0, +) -> Iterator[bytes]: """Read 'stream' in chunks of 'chunk_size'. :param stream: A file-like object to read from. @@ -58,19 +81,24 @@ class Content: :ivar content_type: The content type of this Content. """ - def __init__(self, content_type, get_bytes): + def __init__( + self, content_type: ContentType, get_bytes: Callable[[], Iterable[bytes]] + ) -> None: """Create a ContentType.""" if None in (content_type, get_bytes): raise ValueError(f"None not permitted in {content_type!r}, {get_bytes!r}") self.content_type = content_type self._get_bytes = get_bytes - def __eq__(self, other): - return self.content_type == other.content_type and _join_b( - self.iter_bytes() - ) == _join_b(other.iter_bytes()) + def __eq__(self, other: object) -> bool: + if not isinstance(other, Content): + return NotImplemented + return bool( + self.content_type == other.content_type + and b"".join(self.iter_bytes()) == b"".join(other.iter_bytes()) + ) - def as_text(self): + def as_text(self) -> str: """Return all of the content as text. This is only valid where ``iter_text`` is. It will load all of the @@ -79,11 +107,11 @@ def as_text(self): """ return "".join(self.iter_text()) - def iter_bytes(self): + def iter_bytes(self) -> Iterator[bytes]: """Iterate over bytestrings of the serialised content.""" - return self._get_bytes() + return iter(self._get_bytes()) - def iter_text(self): + def iter_text(self) -> Iterator[str]: """Iterate over the text of the serialised content. This is only valid for text MIME types, and will use ISO-8859-1 if @@ -96,20 +124,20 @@ def iter_text(self): raise ValueError(f"Not a text type {self.content_type!r}") return self._iter_text() - def _iter_text(self): + def _iter_text(self) -> Iterator[str]: """Worker for iter_text - does the decoding.""" encoding = self.content_type.parameters.get("charset", "ISO-8859-1") decoder = codecs.getincrementaldecoder(encoding)() for bytes in self.iter_bytes(): yield decoder.decode(bytes) - final = decoder.decode(_b(""), True) + final = decoder.decode(b"", True) if final: yield final - def __repr__(self): + def __repr__(self) -> str: return ( f"" + f"value={b''.join(self.iter_bytes())!r}>" ) @@ -129,7 +157,12 @@ class StackLinesContent(Content): # system-under-test is rarely unittest or testtools. HIDE_INTERNAL_STACK = True - def __init__(self, stack_lines, prefix_content="", postfix_content=""): + def __init__( + self, + stack_lines: traceback.StackSummary, + prefix_content: str = "", + postfix_content: str = "", + ) -> None: """Create a StackLinesContent for ``stack_lines``. :param stack_lines: A list of preprocessed stack lines, probably @@ -148,7 +181,7 @@ def __init__(self, stack_lines, prefix_content="", postfix_content=""): ) super().__init__(content_type, lambda: [value.encode("utf8")]) - def _stack_lines_to_unicode(self, stack_lines): + def _stack_lines_to_unicode(self, stack_lines: traceback.StackSummary) -> str: """Converts a list of pre-processed stack lines into a unicode string.""" msg_lines = traceback.format_list(stack_lines) return "".join(msg_lines) @@ -162,7 +195,12 @@ class TracebackContent(Content): provide room for other languages to format their tracebacks differently. """ - def __init__(self, err, test, capture_locals=False): + def __init__( + self, + err: tuple[type[BaseException], BaseException, types.TracebackType | None], + test: _TestCase | None, + capture_locals: bool = False, + ) -> None: """Create a TracebackContent for ``err``. :param err: An exc_info error tuple. @@ -183,6 +221,7 @@ def __init__(self, err, test, capture_locals=False): if ( False and StackLinesContent.HIDE_INTERNAL_STACK + and test is not None and test.failureException and isinstance(value, test.failureException) ): @@ -203,7 +242,7 @@ def __init__(self, err, test, capture_locals=False): super().__init__(content_type, lambda: [x.encode("utf8") for x in stack_lines]) -def StacktraceContent(prefix_content="", postfix_content=""): +def StacktraceContent(prefix_content: str = "", postfix_content: str = "") -> Content: """Content object for stack traces. This function will create and return a 'Content' object that contains a @@ -217,7 +256,9 @@ def StacktraceContent(prefix_content="", postfix_content=""): """ stack = traceback.walk_stack(None) - def filter_stack(stack): + def filter_stack( + stack: Iterator[tuple[types.FrameType, int]], + ) -> Iterator[tuple[types.FrameType, int]]: # Discard the filter_stack frame. next(stack) # Discard the StacktraceContent frame. @@ -233,7 +274,7 @@ def filter_stack(stack): return StackLinesContent(extract, prefix_content, postfix_content) -def json_content(json_data): +def json_content(json_data: JSONType) -> Content: """Create a JSON Content object from JSON-encodeable data.""" json_str = json.dumps(json_data) # The json module perversely returns native str not bytes @@ -241,7 +282,7 @@ def json_content(json_data): return Content(JSON, lambda: [data]) -def text_content(text): +def text_content(text: str) -> Content: """Create a Content object from some text. This is useful for adding details which are short strings. @@ -253,7 +294,9 @@ def text_content(text): return Content(UTF8_TEXT, lambda: [text.encode("utf8")]) -def maybe_wrap(wrapper, func): +def maybe_wrap( + wrapper: Callable[..., Any], func: Callable[..., Any] +) -> Callable[..., Any]: """Merge metadata for func into wrapper if functools is present.""" if functools is not None: wrapper = functools.update_wrapper(wrapper, func) @@ -261,13 +304,13 @@ def maybe_wrap(wrapper, func): def content_from_file( - path, - content_type=None, - chunk_size=DEFAULT_CHUNK_SIZE, - buffer_now=False, - seek_offset=None, - seek_whence=0, -): + path: str, + content_type: ContentType | None = None, + chunk_size: int = DEFAULT_CHUNK_SIZE, + buffer_now: bool = False, + seek_offset: int | None = None, + seek_whence: int = 0, +) -> Content: """Create a Content object from a file on disk. Note that unless ``buffer_now`` is explicitly passed in as True, the file @@ -286,7 +329,7 @@ def content_from_file( if content_type is None: content_type = UTF8_TEXT - def reader(): + def reader() -> Iterable[bytes]: with open(path, "rb") as stream: yield from _iter_chunks(stream, chunk_size, seek_offset, seek_whence) @@ -294,13 +337,13 @@ def reader(): def content_from_stream( - stream, - content_type=None, - chunk_size=DEFAULT_CHUNK_SIZE, - buffer_now=False, - seek_offset=None, - seek_whence=0, -): + stream: IO[bytes], + content_type: ContentType | None = None, + chunk_size: int = DEFAULT_CHUNK_SIZE, + buffer_now: bool = False, + seek_offset: int | None = None, + seek_whence: int = 0, +) -> Content: """Create a Content object from a file-like stream. Note that unless ``buffer_now`` is explicitly passed in as True, the stream @@ -320,13 +363,17 @@ def content_from_stream( if content_type is None: content_type = UTF8_TEXT - def reader(): + def reader() -> Iterator[bytes]: return _iter_chunks(stream, chunk_size, seek_offset, seek_whence) return content_from_reader(reader, content_type, buffer_now) -def content_from_reader(reader, content_type, buffer_now): +def content_from_reader( + reader: Callable[[], Iterable[bytes]], + content_type: ContentType | None, + buffer_now: bool, +) -> Content: """Create a Content object that will obtain the content from reader. :param reader: A callback to read the content. Should return an iterable of @@ -340,20 +387,22 @@ def content_from_reader(reader, content_type, buffer_now): if buffer_now: contents = list(reader()) - def reader(): + def buffered_reader() -> Iterable[bytes]: return contents + return Content(content_type, buffered_reader) + return Content(content_type, reader) def attach_file( - detailed, - path, - name=None, - content_type=None, - chunk_size=DEFAULT_CHUNK_SIZE, - buffer_now=True, -): + detailed: _Detailed, + path: str, + name: str | None = None, + content_type: ContentType | None = None, + chunk_size: int = DEFAULT_CHUNK_SIZE, + buffer_now: bool = True, +) -> None: """Attach a file to this test as a detail. This is a convenience method wrapping around ``addDetail``. diff --git a/testtools/content_type.py b/testtools/content_type.py index 1736ceaf..5ed96908 100644 --- a/testtools/content_type.py +++ b/testtools/content_type.py @@ -12,7 +12,9 @@ class ContentType: content type. """ - def __init__(self, primary_type, sub_type, parameters=None): + def __init__( + self, primary_type: str, sub_type: str, parameters: dict[str, str] | None = None + ) -> None: """Create a ContentType.""" if None in (primary_type, sub_type): raise ValueError(f"None not permitted in {primary_type!r}, {sub_type!r}") @@ -20,12 +22,12 @@ def __init__(self, primary_type, sub_type, parameters=None): self.subtype = sub_type self.parameters = parameters or {} - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if type(other) is not ContentType: return False return self.__dict__ == other.__dict__ - def __repr__(self): + def __repr__(self) -> str: if self.parameters: params = "; " params += "; ".join( diff --git a/testtools/matchers/_impl.py b/testtools/matchers/_impl.py index 05aee234..6253eaab 100644 --- a/testtools/matchers/_impl.py +++ b/testtools/matchers/_impl.py @@ -17,7 +17,78 @@ "MismatchError", ] -from testtools.compat import text_repr +import unicodedata + + +def _slow_escape(text): + """Escape unicode ``text`` leaving printable characters unmodified + + The behaviour emulates the Python 3 implementation of repr, see + unicode_repr in unicodeobject.c and isprintable definition. + + Because this iterates over the input a codepoint at a time, it's slow, and + does not handle astral characters correctly on Python builds with 16 bit + rather than 32 bit unicode type. + """ + output = [] + for c in text: + o = ord(c) + if o < 256: + if o < 32 or 126 < o < 161: + output.append(c.encode("unicode-escape")) + elif o == 92: + # Separate due to bug in unicode-escape codec in Python 2.4 + output.append("\\\\") + else: + output.append(c) + else: + # To get correct behaviour would need to pair up surrogates here + if unicodedata.category(c)[0] in "CZ": + output.append(c.encode("unicode-escape")) + else: + output.append(c) + return "".join(output) + + +def text_repr(text, multiline=None): + """Rich repr for ``text`` returning unicode, triple quoted if ``multiline``.""" + nl = (isinstance(text, bytes) and bytes((0xA,))) or "\n" + if multiline is None: + multiline = nl in text + if not multiline: + # Use normal repr for single line of unicode + return repr(text) + prefix = repr(text[:0])[:-2] + if multiline: + # To escape multiline strings, split and process each line in turn, + # making sure that quotes are not escaped. + offset = len(prefix) + 1 + lines = [] + for line in text.split(nl): + r = repr(line) + q = r[-1] + lines.append(r[offset:-1].replace("\\" + q, q)) + # Combine the escaped lines and append two of the closing quotes, + # then iterate over the result to escape triple quotes correctly. + _semi_done = "\n".join(lines) + "''" + p = 0 + while True: + p = _semi_done.find("'''", p) + if p == -1: + break + _semi_done = "\\".join([_semi_done[:p], _semi_done[p:]]) + p += 2 + return "".join([prefix, "'''\\\n", _semi_done, "'"]) + escaped_text = _slow_escape(text) + # Determine which quote character to use and if one gets prefixed with a + # backslash following the same logic Python uses for repr() on strings + quote = "'" + if "'" in text: + if '"' in text: + escaped_text = escaped_text.replace("'", "\\'") + else: + quote = '"' + return "".join([prefix, quote, escaped_text, quote]) class Matcher: diff --git a/testtools/testcase.py b/testtools/testcase.py index be335a71..a5b2729e 100644 --- a/testtools/testcase.py +++ b/testtools/testcase.py @@ -24,7 +24,6 @@ from unittest.case import SkipTest from testtools import content -from testtools.compat import reraise from testtools.matchers import ( Annotate, Contains, @@ -478,7 +477,7 @@ def assertRaises(self, expected_exception, callable=None, *args, **kwargs): class ReRaiseOtherTypes: def match(self, matchee): if not issubclass(matchee[0], expected_exception): - reraise(*matchee) + raise matchee[1].with_traceback(matchee[2]) class CaptureMatchee: def match(self, matchee): @@ -813,7 +812,9 @@ def useFixture(self, fixture: UseFixtureT) -> UseFixtureT: else: # Gather_details worked, so raise the exception setUp # encountered. - reraise(*exc_info) + if exc_info[1] is not None: + raise exc_info[1].with_traceback(exc_info[2]) + raise else: self.addCleanup(fixture.cleanUp) self.addCleanup(gather_details, fixture.getDetails(), self.getDetails()) @@ -907,24 +908,26 @@ def debug(self): def id(self): return self._test_id - def _result(self, result): + def _result( + self, result: unittest.TestResult | None + ) -> TestResult | ExtendedToOriginalDecorator: if result is None: return TestResult() else: return ExtendedToOriginalDecorator(result) - def run(self, result=None): - result = self._result(result) + def run(self, result: unittest.TestResult | None = None) -> None: + result_obj: TestResult | ExtendedToOriginalDecorator = self._result(result) if self._timestamps[0] is not None: - result.time(self._timestamps[0]) - result.tags(self._tags, set()) - result.startTest(self) + result_obj.time(self._timestamps[0]) + result_obj.tags(self._tags, set()) + result_obj.startTest(self) if self._timestamps[1] is not None: - result.time(self._timestamps[1]) - outcome = getattr(result, self._outcome) + result_obj.time(self._timestamps[1]) + outcome = getattr(result_obj, self._outcome) outcome(self, details=self._details) - result.stopTest(self) - result.tags(set(), self._tags) + result_obj.stopTest(self) + result_obj.tags(set(), self._tags) def shortDescription(self): if self._short_description is None: @@ -933,7 +936,12 @@ def shortDescription(self): return self._short_description -def ErrorHolder(test_id, error, short_description=None, details=None): +def ErrorHolder( + test_id: str, + error: tuple, + short_description: str | None = None, + details: dict | None = None, +) -> PlaceHolder: """Construct an `ErrorHolder`. :param test_id: The id of the test. diff --git a/testtools/testresult/real.py b/testtools/testresult/real.py index 3f1cb269..c80c0270 100644 --- a/testtools/testresult/real.py +++ b/testtools/testresult/real.py @@ -26,11 +26,12 @@ import email.message import math import sys +import threading import unittest from operator import methodcaller -from typing import ClassVar +from queue import Queue +from typing import ClassVar, TypeAlias -from testtools.compat import _b from testtools.content import ( Content, TracebackContent, @@ -39,6 +40,9 @@ from testtools.content_type import ContentType from testtools.tags import TagContext +# Type for event dicts that go into the queue +EventDict: TypeAlias = "dict[str, str | bytes | bool | None | StreamResult]" + # circular import # from testtools.testcase import PlaceHolder PlaceHolder = None @@ -375,17 +379,17 @@ def stopTestRun(self): def status( self, - test_id=None, - test_status=None, - test_tags=None, - runnable=True, - file_name=None, - file_bytes=None, - eof=False, - mime_type=None, - route_code=None, - timestamp=None, - ): + test_id: str | None = None, + test_status: str | None = None, + test_tags: set[str] | None = None, + runnable: bool = True, + file_name: str | None = None, + file_bytes: bytes | None = None, + eof: bool = False, + mime_type: str | None = None, + route_code: str | None = None, + timestamp: datetime.datetime | None = None, + ) -> None: """Inform the result about a test status. :param test_id: The test whose status is being reported. None to @@ -762,7 +766,7 @@ def got_file(self, file_name, file_bytes, mime_type=None): ["details", file_name], Content(content_type, lambda: content_bytes) ) - case.details[file_name].iter_bytes().append(file_bytes) + case.details[file_name]._get_bytes().append(file_bytes) return case def to_test_case(self): @@ -1112,7 +1116,7 @@ def __init__(self): super().__init__() self.shouldStop = False - def stop(self): + def stop(self) -> None: """Indicate that tests should stop running.""" self.shouldStop = True @@ -1389,7 +1393,9 @@ class ThreadsafeForwardingResult(TestResult): opportunity for bugs around global state in the target. """ - def __init__(self, target, semaphore): + def __init__( + self, target: unittest.TestResult, semaphore: threading.Semaphore + ) -> None: """Create a ThreadsafeForwardingResult forwarding to target. :param target: A ``TestResult``. @@ -1749,7 +1755,7 @@ class ExtendedToStreamDecorator(CopyStreamResult, StreamSummary, TestControl): any StreamResult. """ - def __init__(self, decorated): + def __init__(self, decorated: StreamResult) -> None: super().__init__([decorated]) # Deal with mismatched base class constructors. TestControl.__init__(self) @@ -1810,7 +1816,7 @@ def _convert(self, test, err, details, status, reason=None): ) file_bytes = next_bytes if file_bytes is None: - file_bytes = _b("") + file_bytes = b"" self.status( file_name=name, file_bytes=file_bytes, @@ -2078,7 +2084,7 @@ class StreamToQueue(StreamResult): is the result that invoked ``startTestRun``. """ - def __init__(self, queue, routing_code): + def __init__(self, queue: Queue[EventDict], routing_code: str | None) -> None: """Create a StreamToQueue forwarding to target. :param queue: A ``queue.Queue`` to receive events. @@ -2123,10 +2129,12 @@ def status( def stopTestRun(self): self.queue.put(dict(event="stopTestRun", result=self)) - def route_code(self, route_code): + def route_code(self, route_code: str | None) -> str | None: """Adjust route_code on the way through.""" if route_code is None: return self.routing_code + if self.routing_code is None: + return route_code return self.routing_code + "/" + route_code @@ -2304,7 +2312,7 @@ class TimestampingStreamResult(CopyStreamResult): This is convenient for ensuring events are timestamped. """ - def __init__(self, target): + def __init__(self, target: StreamResult) -> None: super().__init__([target]) def status(self, *args, **kwargs): diff --git a/testtools/testsuite.py b/testtools/testsuite.py index 4a35811e..8db7549d 100644 --- a/testtools/testsuite.py +++ b/testtools/testsuite.py @@ -14,28 +14,68 @@ import threading import unittest from collections import Counter +from collections.abc import Callable, Container, Generator, Iterable from pprint import pformat from queue import Queue -from typing import Any +from typing import Any, Protocol, TypeAlias, TypeVar, runtime_checkable import testtools +_T = TypeVar("_T") -def iterate_tests(test_suite_or_case): +# Type alias for objects that can be test suites or test cases +TestSuiteOrCase: TypeAlias = unittest.TestSuite | unittest.TestCase + + +class _Fixture(Protocol): + """Protocol for fixture objects.""" + + def setUp(self) -> None: ... + def cleanUp(self) -> None: ... + + +class _Stoppable(Protocol): + """Protocol for result objects that can be stopped.""" + + def stop(self) -> None: ... + + +class _StreamResultLike(_Stoppable, Protocol): + """Protocol for stream result objects with test run lifecycle.""" + + def startTestRun(self) -> None: ... + def stopTestRun(self) -> None: ... + + +@runtime_checkable +class _Sortable(Protocol): + """Protocol for test suites that can be sorted.""" + + def sort_tests(self) -> None: ... + + +def iterate_tests( + test_suite_or_case: TestSuiteOrCase, +) -> Generator[unittest.TestCase, None, None]: """Iterate through all of the test cases in 'test_suite_or_case'.""" - try: - suite = iter(test_suite_or_case) - except TypeError: - yield test_suite_or_case - else: - for test in suite: + if isinstance(test_suite_or_case, unittest.TestSuite): + # It's a suite, iterate through it + for test in test_suite_or_case: yield from iterate_tests(test) + else: + # It's a test case (could be unittest.TestCase or duck-typed) + yield test_suite_or_case class ConcurrentTestSuite(unittest.TestSuite): """A TestSuite whose run() calls out to a concurrency strategy.""" - def __init__(self, suite, make_tests, wrap_result=None): + def __init__( + self, + suite: unittest.TestSuite, + make_tests: Callable[[unittest.TestSuite], Iterable[unittest.TestCase]], + wrap_result: Callable[[Any, int], unittest.TestResult] | None = None, + ) -> None: """Create a ConcurrentTestSuite to execute suite. :param suite: A suite to run concurrently. @@ -54,7 +94,9 @@ def __init__(self, suite, make_tests, wrap_result=None): self.make_tests = make_tests self._custom_wrap_result = wrap_result - def _wrap_result(self, thread_safe_result, thread_number): + def _wrap_result( + self, thread_safe_result: unittest.TestResult, thread_number: int + ) -> unittest.TestResult: """Wrap a thread-safe result before sending it test results. You can either override this in a subclass or pass your own @@ -64,7 +106,9 @@ def _wrap_result(self, thread_safe_result, thread_number): return self._custom_wrap_result(thread_safe_result, thread_number) return thread_safe_result - def run(self, result, debug=False): + def run( + self, result: unittest.TestResult, debug: bool = False + ) -> unittest.TestResult: """Run the tests concurrently. This calls out to the provided make_tests helper, and then serialises @@ -84,7 +128,8 @@ def run(self, result, debug=False): semaphore = threading.Semaphore(1) for i, test in enumerate(tests): process_result = self._wrap_result( - testtools.ThreadsafeForwardingResult(result, semaphore), i + testtools.ThreadsafeForwardingResult(result, semaphore), + i, # type: ignore[no-untyped-call] ) reader_thread = threading.Thread( target=self._run_test, args=(test, process_result, queue) @@ -99,8 +144,14 @@ def run(self, result, debug=False): for thread, process_result in threads.values(): process_result.stop() raise + return result - def _run_test(self, test, process_result, queue): + def _run_test( + self, + test: unittest.TestCase, + process_result: unittest.TestResult, + queue: Queue[unittest.TestCase], + ) -> None: try: try: test.run(process_result) @@ -115,7 +166,9 @@ def _run_test(self, test, process_result, queue): class ConcurrentStreamTestSuite: """A TestSuite whose run() parallelises.""" - def __init__(self, make_tests): + def __init__( + self, make_tests: Callable[[], Iterable[tuple[TestSuiteOrCase, str | None]]] + ) -> None: """Create a ConcurrentTestSuite to execute tests returned by make_tests. :param make_tests: A helper function that should return some number @@ -128,7 +181,7 @@ def __init__(self, make_tests): super().__init__() self.make_tests = make_tests - def run(self, result, debug=False): + def run(self, result: testtools.StreamResult, debug: bool = False) -> None: """Run the tests concurrently. This calls out to the provided make_tests helper to determine the @@ -182,7 +235,12 @@ def run(self, result, debug=False): process_result.stop() raise - def _run_test(self, test, process_result, route_code): + def _run_test( + self, + test: TestSuiteOrCase, + process_result: unittest.TestResult, + route_code: str | None, + ) -> None: process_result.startTestRun() try: try: @@ -198,48 +256,55 @@ def _run_test(self, test, process_result, route_code): class FixtureSuite(unittest.TestSuite): - def __init__(self, fixture, tests): + def __init__(self, fixture: _Fixture, tests: Iterable[unittest.TestCase]) -> None: super().__init__(tests) self._fixture = fixture - def run(self, result, debug=False): + def run( + self, result: unittest.TestResult, debug: bool = False + ) -> unittest.TestResult: self._fixture.setUp() try: - super().run(result, debug) + return super().run(result, debug) finally: self._fixture.cleanUp() - def sort_tests(self): - self._tests = sorted_tests(self, True) + def sort_tests(self) -> None: + sorted_suite = sorted_tests(self, True) + self._tests[:] = sorted_suite._tests -def _flatten_tests(suite_or_case, unpack_outer=False): - try: - tests = iter(suite_or_case) - except TypeError: - # Not iterable, assume it's a test case. +def _flatten_tests( + suite_or_case: TestSuiteOrCase, unpack_outer: bool = False +) -> list[tuple[str | None, TestSuiteOrCase]]: + if isinstance(suite_or_case, unittest.TestCase): + # Not iterable, it's a test case. return [(suite_or_case.id(), suite_or_case)] + + # It's a suite, try to iterate if type(suite_or_case) in (unittest.TestSuite,) or unpack_outer: # Plain old test suite (or any others we may add). - result = [] - for test in tests: - # Recurse to flatten. + result: list[tuple[str | None, TestSuiteOrCase]] = [] + for test in suite_or_case: + # Recurse to flatten - test is TestSuiteOrCase from the suite result.extend(_flatten_tests(test)) return result else: # Find any old actual test and grab its id. - suite_id = None - tests = iterate_tests(suite_or_case) - for test in tests: + suite_id: str | None = None + tests_iter = iterate_tests(suite_or_case) + for test in tests_iter: suite_id = test.id() break # If it has a sort_tests method, call that. - if hasattr(suite_or_case, "sort_tests"): + if isinstance(suite_or_case, _Sortable): suite_or_case.sort_tests() return [(suite_id, suite_or_case)] -def filter_by_ids(suite_or_case, test_ids): +def filter_by_ids( + suite_or_case: _T, test_ids: Container[str] +) -> _T | unittest.TestSuite: """Remove tests from suite_or_case where their id is not in test_ids. :param suite_or_case: A test suite or test case. @@ -282,7 +347,7 @@ def filter_by_ids(suite_or_case, test_ids): """ # Compatible objects if hasattr(suite_or_case, "filter_by_ids"): - return suite_or_case.filter_by_ids(test_ids) + return suite_or_case.filter_by_ids(test_ids) # type: ignore[no-any-return] # TestCase objects. if hasattr(suite_or_case, "id"): if suite_or_case.id() in test_ids: @@ -291,15 +356,17 @@ def filter_by_ids(suite_or_case, test_ids): return unittest.TestSuite() # Standard TestSuites or derived classes [assumed to be mutable]. if isinstance(suite_or_case, unittest.TestSuite): - filtered = [] + filtered: list[TestSuiteOrCase] = [] for item in suite_or_case: filtered.append(filter_by_ids(item, test_ids)) - suite_or_case._tests[:] = filtered + suite_or_case._tests[:] = filtered # type: ignore[assignment] # Everything else: return suite_or_case -def sorted_tests(suite_or_case, unpack_outer=False): +def sorted_tests( + suite_or_case: TestSuiteOrCase, unpack_outer: bool = False +) -> unittest.TestSuite: """Sort suite_or_case while preserving non-vanilla TestSuites.""" # Duplicate test id can induce TypeError in Python 3.3. # Detect the duplicate test ids, raise exception when found.