diff --git a/newsfragments/47.feature.rst b/newsfragments/47.feature.rst new file mode 100644 index 0000000..af49370 --- /dev/null +++ b/newsfragments/47.feature.rst @@ -0,0 +1,5 @@ +Add ``unwrap_and_destroy`` method to remove references to +the wrapped exception or value to prevent issues where +values not being garbage collected when they are no longer +needed, or worse problems with exceptions leaving a +reference cycle. diff --git a/src/outcome/_impl.py b/src/outcome/_impl.py index 004b72d..cd6faaf 100644 --- a/src/outcome/_impl.py +++ b/src/outcome/_impl.py @@ -138,6 +138,23 @@ def unwrap(self) -> ValueT: x = fn(*args) x = outcome.capture(fn, *args).unwrap() + Note: this leaves a reference to the contained value or exception + alive which may result in values not being garbage collected or + exceptions leaving a reference cycle. If this is an issue it's + recommended to call the ``unwrap_and_destroy()`` method + + """ + + @abc.abstractmethod + def unwrap_and_destroy(self) -> ValueT: + """Return or raise the contained value or exception, remove the + reference to the contained value or exception. + + These two lines of code are equivalent:: + + x = fn(*args) + x = outcome.capture(fn, *args).unwrap_and_destroy() + """ @abc.abstractmethod @@ -174,12 +191,21 @@ class Value(Outcome[ValueT], Generic[ValueT]): """The contained value.""" def __repr__(self) -> str: - return f'Value({self.value!r})' + try: + return f'Value({self.value!r})' + except AttributeError: + return 'Value()' def unwrap(self) -> ValueT: self._set_unwrapped() return self.value + def unwrap_and_destroy(self) -> ValueT: + self._set_unwrapped() + v = self.value + object.__delattr__(self, "value") + return v + def send(self, gen: Generator[ResultT, ValueT, object]) -> ResultT: self._set_unwrapped() return gen.send(self.value) @@ -202,7 +228,10 @@ class Error(Outcome[NoReturn]): """The contained exception object.""" def __repr__(self) -> str: - return f'Error({self.error!r})' + try: + return f'Error({self.error!r})' + except AttributeError: + return 'Error()' def unwrap(self) -> NoReturn: self._set_unwrapped() @@ -226,6 +255,29 @@ def unwrap(self) -> NoReturn: # __traceback__ from indirectly referencing 'captured_error'. del captured_error, self + def unwrap_and_destroy(self) -> NoReturn: + self._set_unwrapped() + # Tracebacks show the 'raise' line below out of context, so let's give + # this variable a name that makes sense out of context. + captured_error = self.error + object.__delattr__(self, "error") + try: + raise captured_error + finally: + # We want to avoid creating a reference cycle here. Python does + # collect cycles just fine, so it wouldn't be the end of the world + # if we did create a cycle, but the cyclic garbage collector adds + # latency to Python programs, and the more cycles you create, the + # more often it runs, so it's nicer to avoid creating them in the + # first place. For more details see: + # + # https://github.com/python-trio/trio/issues/1770 + # + # In particuar, by deleting this local variables from the 'unwrap' + # methods frame, we avoid the 'captured_error' object's + # __traceback__ from indirectly referencing 'captured_error'. + del captured_error, self + def send(self, gen: Generator[ResultT, NoReturn, object]) -> ResultT: self._set_unwrapped() return gen.throw(self.error) diff --git a/tests/test_async.py b/tests/test_async.py index 5ff95fd..6e86408 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -1,5 +1,11 @@ import asyncio +import contextlib +import gc +import platform +import sys import traceback +import types +import weakref import pytest @@ -58,3 +64,56 @@ async def raise_ValueError(x): frames = traceback.extract_tb(exc_info.value.__traceback__) functions = [function for _, _, function, _ in frames] assert functions[-2:] == ['unwrap', 'raise_ValueError'] + + +@types.coroutine +def async_yield(v): + return (yield v) + + +async def test_unwrap_leaves_a_refcycle(): + class MyException(Exception): + pass + + async def network_operation(): + return (await async_yield("network operation")).unwrap() + + async def coro_fn(): + try: + await network_operation() + except MyException as e: + wr_e = weakref.ref(e) + del e + + if platform.python_implementation() == "PyPy": + gc.collect() + assert isinstance(wr_e(), MyException) + + with contextlib.closing(coro_fn()) as coro: + assert coro.send(None) == "network operation" + with pytest.raises(StopIteration): + coro.send(outcome.Error(MyException())) + + +async def test_unwrap_and_destroy_does_not_leave_a_refcycle(): + class MyException(Exception): + pass + + async def network_operation(): + return (await async_yield("network operation")).unwrap_and_destroy() + + async def coro_fn(): + try: + await network_operation() + except MyException as e: + wr_e = weakref.ref(e) + del e + + if platform.python_implementation() == "PyPy": + gc.collect() + assert wr_e() is None + + with contextlib.closing(coro_fn()) as coro: + assert coro.send(None) == "network operation" + with pytest.raises(StopIteration): + coro.send(outcome.Error(MyException())) diff --git a/tests/test_sync.py b/tests/test_sync.py index 855d776..809063c 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -16,7 +16,11 @@ def test_Outcome(): with pytest.raises(AlreadyUsedError): v.unwrap() - v = Value(1) + v = Value(2) + assert v.unwrap_and_destroy() == 2 + assert repr(v) == "Value()" + with pytest.raises(AlreadyUsedError): + v.unwrap_and_destroy() exc = RuntimeError("oops") e = Error(exc) @@ -33,12 +37,20 @@ def test_Outcome(): with pytest.raises(TypeError): Error(RuntimeError) + e2 = Error(exc) + with pytest.raises(RuntimeError): + e2.unwrap_and_destroy() + with pytest.raises(AlreadyUsedError): + e2.unwrap_and_destroy() + assert repr(e2) == "Error()" + def expect_1(): assert (yield) == 1 yield "ok" it = iter(expect_1()) next(it) + v = Value(1) assert v.send(it) == "ok" with pytest.raises(AlreadyUsedError): v.send(it)