Skip to content
Open
59 changes: 55 additions & 4 deletions Lib/contextlib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Utilities for with-statement contexts. See PEP 343."""
from inspect import isasyncgenfunction, iscoroutinefunction, \
isgeneratorfunction

import abc
import os
import sys
Expand Down Expand Up @@ -79,11 +82,33 @@ def _recreate_cm(self):
return self

def __call__(self, func):
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
return func(*args, **kwds)
return inner

def gen_inner(*args, **kwds):
with self._recreate_cm(), closing(func(*args, **kwds)) as gen:
yield from gen

async def async_inner(*args, **kwds):
with self._recreate_cm():
return await func(*args, **kwds)

async def asyncgen_inner(*args, **kwds):
with self._recreate_cm():
async with aclosing(func(*args, **kwds)) as gen:
async for value in gen:
yield value

wrapper = wraps(func)
if isasyncgenfunction(func):
return wrapper(asyncgen_inner)
elif iscoroutinefunction(func):
return wrapper(async_inner)
elif isgeneratorfunction(func):
return wrapper(gen_inner)
else:
return wrapper(inner)


class AsyncContextDecorator(object):
Expand All @@ -95,11 +120,37 @@ def _recreate_cm(self):
return self

def __call__(self, func):
@wraps(func)
async def inner(*args, **kwds):
async with self._recreate_cm():
return func(*args, **kwds)

async def gen_inner(*args, **kwds):
async with self._recreate_cm():
with closing(func(*args, **kwds)) as gen:
for value in gen:
yield value

async def async_inner(*args, **kwds):
async with self._recreate_cm():
return await func(*args, **kwds)
return inner

async def asyncgen_inner(*args, **kwds):
async with (
self._recreate_cm(),
aclosing(func(*args, **kwds)) as gen
):
async for value in gen:
yield value

wrapper = wraps(func)
if isasyncgenfunction(func):
return wrapper(asyncgen_inner)
elif iscoroutinefunction(func):
return wrapper(async_inner)
elif isgeneratorfunction(func):
return wrapper(gen_inner)
else:
return wrapper(inner)


class _GeneratorContextManagerBase:
Expand Down
69 changes: 68 additions & 1 deletion Lib/test/test_contextlib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Unit tests for contextlib.py, and other context managers."""

import io
import os
import sys
Expand Down Expand Up @@ -680,6 +679,74 @@ def test(x):
self.assertEqual(state, [1, 'something else', 999])


def test_contextmanager_decorate_generator_function(self):
@contextmanager
def woohoo(y):
state.append(y)
yield
state.append(999)

state = []
@woohoo(1)
def test(x):
self.assertEqual(state, [1])
state.append(x)
yield
state.append("second item")

for _ in test("something"):
self.assertEqual(state, [1, "something"])
self.assertEqual(state, [1, "something", "second item", 999])


def test_contextmanager_decorate_coroutine_function(self):
@contextmanager
def woohoo(y):
state.append(y)
yield
state.append(999)

state = []
@woohoo(1)
async def test(x):
self.assertEqual(state, [1])
state.append(x)

coro = test('something')
with self.assertRaises(StopIteration):
coro.send(None)

self.assertEqual(state, [1, 'something', 999])


def test_contextmanager_decorate_asyncgen_function(self):
@contextmanager
def woohoo(y):
state.append(y)
yield
state.append(999)

state = []
@woohoo(1)
async def test(x):
self.assertEqual(state, [1])
state.append(x)
yield
state.append("second item")

async def run_test():
async for _ in test("something"):
self.assertEqual(state, [1, "something"])

agen = test('something')
with self.assertRaises(StopIteration):
agen.asend(None).send(None)
with self.assertRaises(StopAsyncIteration):
agen.asend(None).send(None)

self.assertEqual(state, [1, 'something', "second item", 999])


class TestBaseExitStack:
exit_stack = None

Expand Down
57 changes: 57 additions & 0 deletions Lib/test/test_contextlib_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,63 @@ async def test():
await test()
self.assertFalse(entered)

@_async_test
async def test_decorator_decorate_sync_function(self):
@asynccontextmanager
async def context():
state.append(1)
yield
state.append(999)

state = []
@context()
def test(x):
self.assertEqual(state, [1])
state.append(x)

await test("something")
self.assertEqual(state, [1, "something", 999])

@_async_test
async def test_decorator_decorate_generator_function(self):
@asynccontextmanager
async def context():
state.append(1)
yield
state.append(999)

state = []
@context()
def test(x):
self.assertEqual(state, [1])
state.append(x)
yield
state.append("second item")

async for _ in test("something"):
self.assertEqual(state, [1, "something"])
self.assertEqual(state, [1, "something", "second item", 999])

@_async_test
async def test_decorator_decorate_asyncgen_function(self):
@asynccontextmanager
async def context():
state.append(1)
yield
state.append(999)

state = []
@context()
async def test(x):
self.assertEqual(state, [1])
state.append(x)
yield
state.append("second item")

async for _ in test("something"):
self.assertEqual(state, [1, "something"])
self.assertEqual(state, [1, "something", "second item", 999])

@_async_test
async def test_decorator_with_exception(self):
entered = False
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved ``@contextmanager`` and ``@asynccontextmanager`` to work correctly with generators, coroutine functions and async generators when the wrapped callables are used as decorators
Loading