From bd00fcbd52e0aa6cffac626edc504f9897054de9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 2 Jul 2025 19:53:38 +0300 Subject: [PATCH 1/5] gh-125862: Improve context decorator support for generators and async functions --- Lib/contextlib.py | 54 +++++++++++++++++++++++++-- Lib/test/test_contextlib.py | 62 ++++++++++++++++++++++++++++++- Lib/test/test_contextlib_async.py | 57 ++++++++++++++++++++++++++++ 3 files changed, 168 insertions(+), 5 deletions(-) diff --git a/Lib/contextlib.py b/Lib/contextlib.py index 5b646fabca0225..9fcbc237ca2c5a 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -1,4 +1,7 @@ """Utilities for with-statement contexts. See PEP 343.""" +from inspect import isasyncgenfunction, iscoroutinefunction, \ + isgeneratorfunction + import abc import os import sys @@ -79,11 +82,32 @@ def _recreate_cm(self): return self def __call__(self, func): - @wraps(func) def inner(*args, **kwds): with self._recreate_cm(): return func(*args, **kwds) - return inner + + def gen_inner(*args, **kwds): + with self._recreate_cm(): + yield from func(*args, **kwds) + + async def async_inner(*args, **kwds): + with self._recreate_cm(): + return await func(*args, **kwds) + + async def asyncgen_inner(*args, **kwds): + with self._recreate_cm(): + async for value in func(*args, **kwds): + yield value + + wrapper = wraps(func) + if isasyncgenfunction(func): + return wrapper(asyncgen_inner) + elif iscoroutinefunction(func): + return wrapper(async_inner) + elif isgeneratorfunction(func): + return wrapper(gen_inner) + else: + return wrapper(inner) class AsyncContextDecorator(object): @@ -95,11 +119,33 @@ def _recreate_cm(self): return self def __call__(self, func): - @wraps(func) async def inner(*args, **kwds): + async with self._recreate_cm(): + return func(*args, **kwds) + + async def gen_inner(*args, **kwds): + async with self._recreate_cm(): + for value in func(*args, **kwds): + yield value + + async def async_inner(*args, **kwds): async with self._recreate_cm(): return await func(*args, **kwds) - return inner + + async def asyncgen_inner(*args, **kwds): + async with self._recreate_cm(): + async for value in func(*args, **kwds): + yield value + + wrapper = wraps(func) + if isasyncgenfunction(func): + return wrapper(asyncgen_inner) + elif iscoroutinefunction(func): + return wrapper(async_inner) + elif isgeneratorfunction(func): + return wrapper(gen_inner) + else: + return wrapper(inner) class _GeneratorContextManagerBase: diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 6a3329fa5aaace..baeb0d254044b5 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -1,5 +1,5 @@ """Unit tests for contextlib.py, and other context managers.""" - +import asyncio import io import os import sys @@ -680,6 +680,66 @@ def test(x): self.assertEqual(state, [1, 'something else', 999]) + def test_contextmanager_decorate_generator_function(self): + @contextmanager + def woohoo(y): + state.append(y) + yield + state.append(999) + + state = [] + @woohoo(1) + def test(x): + self.assertEqual(state, [1]) + state.append(x) + yield + state.append("second item") + + for _ in test("something"): + self.assertEqual(state, [1, "something"]) + self.assertEqual(state, [1, "something", "second item", 999]) + + + def test_contextmanager_decorate_coroutine_function(self): + @contextmanager + def woohoo(y): + state.append(y) + yield + state.append(999) + + state = [] + @woohoo(1) + async def test(x): + self.assertEqual(state, [1]) + state.append(x) + + asyncio.run(test('something')) + self.assertEqual(state, [1, 'something', 999]) + + + def test_contextmanager_decorate_asyncgen_function(self): + @contextmanager + def woohoo(y): + state.append(y) + yield + state.append(999) + + state = [] + @woohoo(1) + async def test(x): + self.assertEqual(state, [1]) + state.append(x) + yield + state.append("second item") + + async def run_test(): + async for _ in test("something"): + self.assertEqual(state, [1, "something"]) + + asyncio.run(run_test()) + self.assertEqual(state, [1, 'something', "second item", 999]) + + class TestBaseExitStack: exit_stack = None diff --git a/Lib/test/test_contextlib_async.py b/Lib/test/test_contextlib_async.py index dcd0072037950e..0d1ce9a3bd6be7 100644 --- a/Lib/test/test_contextlib_async.py +++ b/Lib/test/test_contextlib_async.py @@ -402,6 +402,63 @@ async def test(): await test() self.assertFalse(entered) + @_async_test + async def test_decorator_decorate_sync_function(self): + @asynccontextmanager + async def context(): + state.append(1) + yield + state.append(999) + + state = [] + @context() + def test(x): + self.assertEqual(state, [1]) + state.append(x) + + await test("something") + self.assertEqual(state, [1, "something", 999]) + + @_async_test + async def test_decorator_decorate_generator_function(self): + @asynccontextmanager + async def context(): + state.append(1) + yield + state.append(999) + + state = [] + @context() + def test(x): + self.assertEqual(state, [1]) + state.append(x) + yield + state.append("second item") + + async for _ in test("something"): + self.assertEqual(state, [1, "something"]) + self.assertEqual(state, [1, "something", "second item", 999]) + + @_async_test + async def test_decorator_decorate_asyncgen_function(self): + @asynccontextmanager + async def context(): + state.append(1) + yield + state.append(999) + + state = [] + @context() + async def test(x): + self.assertEqual(state, [1]) + state.append(x) + yield + state.append("second item") + + async for _ in test("something"): + self.assertEqual(state, [1, "something"]) + self.assertEqual(state, [1, "something", "second item", 999]) + @_async_test async def test_decorator_with_exception(self): entered = False From 9b3ba1304aa11e32374fa3fa4de082277b2f7cf1 Mon Sep 17 00:00:00 2001 From: "blurb-it[bot]" <43283697+blurb-it[bot]@users.noreply.github.com> Date: Wed, 2 Jul 2025 17:01:18 +0000 Subject: [PATCH 2/5] =?UTF-8?q?=F0=9F=93=9C=F0=9F=A4=96=20Added=20by=20blu?= =?UTF-8?q?rb=5Fit.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../next/Library/2025-07-02-17-01-17.gh-issue-125862.WgFYj3.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 Misc/NEWS.d/next/Library/2025-07-02-17-01-17.gh-issue-125862.WgFYj3.rst diff --git a/Misc/NEWS.d/next/Library/2025-07-02-17-01-17.gh-issue-125862.WgFYj3.rst b/Misc/NEWS.d/next/Library/2025-07-02-17-01-17.gh-issue-125862.WgFYj3.rst new file mode 100644 index 00000000000000..2e7b654bf07840 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-07-02-17-01-17.gh-issue-125862.WgFYj3.rst @@ -0,0 +1 @@ +Improved ``@contextmanager`` and ``@asynccontextmanager`` to work correctly with generators, coroutine functions and async generators when the wrapped callables are used as decorators From 8535a21d0b44064e205462ff758e2d95ba4e18bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 2 Jul 2025 20:55:14 +0300 Subject: [PATCH 3/5] Manually iterate coroutines to avoid asyncio use --- Lib/test/test_contextlib.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index baeb0d254044b5..d4ee315ff8a97c 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -1,5 +1,4 @@ """Unit tests for contextlib.py, and other context managers.""" -import asyncio import io import os import sys @@ -713,7 +712,10 @@ async def test(x): self.assertEqual(state, [1]) state.append(x) - asyncio.run(test('something')) + coro = test('something') + with self.assertRaises(StopIteration): + coro.send(None) + self.assertEqual(state, [1, 'something', 999]) @@ -736,7 +738,12 @@ async def run_test(): async for _ in test("something"): self.assertEqual(state, [1, "something"]) - asyncio.run(run_test()) + agen = test('something') + with self.assertRaises(StopIteration): + agen.asend(None).send(None) + with self.assertRaises(StopAsyncIteration): + agen.asend(None).send(None) + self.assertEqual(state, [1, 'something', "second item", 999]) From 1fd52d5d2aea07932c2bed2ad95630aa406195e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Tue, 9 Dec 2025 16:18:43 +0200 Subject: [PATCH 4/5] Make sure we at least try to close the generators --- Lib/contextlib.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/Lib/contextlib.py b/Lib/contextlib.py index 9fcbc237ca2c5a..b9e64bb3bc48a7 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -87,8 +87,8 @@ def inner(*args, **kwds): return func(*args, **kwds) def gen_inner(*args, **kwds): - with self._recreate_cm(): - yield from func(*args, **kwds) + with self._recreate_cm(), func(*args, **kwds) as gen: + yield from gen async def async_inner(*args, **kwds): with self._recreate_cm(): @@ -96,8 +96,9 @@ async def async_inner(*args, **kwds): async def asyncgen_inner(*args, **kwds): with self._recreate_cm(): - async for value in func(*args, **kwds): - yield value + async with func(*args, **kwds) as gen: + async for value in gen: + yield value wrapper = wraps(func) if isasyncgenfunction(func): @@ -125,16 +126,17 @@ async def inner(*args, **kwds): async def gen_inner(*args, **kwds): async with self._recreate_cm(): - for value in func(*args, **kwds): - yield value + with func(*args, **kwds) as gen: + for value in func(*args, **kwds): + yield value async def async_inner(*args, **kwds): async with self._recreate_cm(): return await func(*args, **kwds) async def asyncgen_inner(*args, **kwds): - async with self._recreate_cm(): - async for value in func(*args, **kwds): + async with self._recreate_cm(), func(*args, **kwds) as gen: + async for value in gen: yield value wrapper = wraps(func) From 65c8f50c457004399ce5ff9a0c2a6125130be8c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Tue, 9 Dec 2025 16:35:23 +0200 Subject: [PATCH 5/5] Use (a)closing --- Lib/contextlib.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/Lib/contextlib.py b/Lib/contextlib.py index b9e64bb3bc48a7..61d8fb52515035 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -87,7 +87,7 @@ def inner(*args, **kwds): return func(*args, **kwds) def gen_inner(*args, **kwds): - with self._recreate_cm(), func(*args, **kwds) as gen: + with self._recreate_cm(), closing(func(*args, **kwds)) as gen: yield from gen async def async_inner(*args, **kwds): @@ -96,7 +96,7 @@ async def async_inner(*args, **kwds): async def asyncgen_inner(*args, **kwds): with self._recreate_cm(): - async with func(*args, **kwds) as gen: + async with aclosing(func(*args, **kwds)) as gen: async for value in gen: yield value @@ -126,8 +126,8 @@ async def inner(*args, **kwds): async def gen_inner(*args, **kwds): async with self._recreate_cm(): - with func(*args, **kwds) as gen: - for value in func(*args, **kwds): + with closing(func(*args, **kwds)) as gen: + for value in gen: yield value async def async_inner(*args, **kwds): @@ -135,7 +135,10 @@ async def async_inner(*args, **kwds): return await func(*args, **kwds) async def asyncgen_inner(*args, **kwds): - async with self._recreate_cm(), func(*args, **kwds) as gen: + async with ( + self._recreate_cm(), + aclosing(func(*args, **kwds)) as gen + ): async for value in gen: yield value