diff --git a/src/prompt_toolkit/shortcuts/prompt.py b/src/prompt_toolkit/shortcuts/prompt.py index 68cfeb9aa..b3e71a9fe 100644 --- a/src/prompt_toolkit/shortcuts/prompt.py +++ b/src/prompt_toolkit/shortcuts/prompt.py @@ -342,7 +342,6 @@ class PromptSession(Generic[_T]): "is_password", "editing_mode", "key_bindings", - "is_password", "bottom_toolbar", "style", "style_transformation", @@ -951,93 +950,125 @@ class itself. For these, passing in ``None`` will keep the current pressed (for abort) and ``EOFError`` when control-d has been pressed (for exit). """ - # NOTE: We used to create a backup of the PromptSession attributes and - # restore them after exiting the prompt. This code has been - # removed, because it was confusing and didn't really serve a use - # case. (People were changing `Application.editing_mode` - # dynamically and surprised that it was reset after every call.) - - # NOTE 2: YES, this is a lot of repeation below... - # However, it is a very convenient for a user to accept all - # these parameters in this `prompt` method as well. We could - # use `locals()` and `setattr` to avoid the repetition, but - # then we loose the advantage of mypy and pyflakes to be able - # to verify the code. + # Save the original values of any attributes that are being + # overridden, so we can restore them after the prompt finishes. + # This ensures that per-call overrides (like is_password=True) don't + # permanently change the session state (#967). + overrides: dict[str, object] = {} + if message is not None: + overrides["message"] = self.message self.message = message if editing_mode is not None: + overrides["editing_mode"] = self.editing_mode self.editing_mode = editing_mode if refresh_interval is not None: + overrides["refresh_interval"] = self.refresh_interval self.refresh_interval = refresh_interval if vi_mode: + overrides.setdefault("editing_mode", self.editing_mode) self.editing_mode = EditingMode.VI if lexer is not None: + overrides["lexer"] = self.lexer self.lexer = lexer if completer is not None: + overrides["completer"] = self.completer self.completer = completer if complete_in_thread is not None: + overrides["complete_in_thread"] = self.complete_in_thread self.complete_in_thread = complete_in_thread if is_password is not None: + overrides["is_password"] = self.is_password self.is_password = is_password if key_bindings is not None: + overrides["key_bindings"] = self.key_bindings self.key_bindings = key_bindings if bottom_toolbar is not None: + overrides["bottom_toolbar"] = self.bottom_toolbar self.bottom_toolbar = bottom_toolbar if style is not None: + overrides["style"] = self.style self.style = style if color_depth is not None: + overrides["color_depth"] = self.color_depth self.color_depth = color_depth if cursor is not None: + overrides["cursor"] = self.cursor self.cursor = cursor if include_default_pygments_style is not None: + overrides["include_default_pygments_style"] = self.include_default_pygments_style self.include_default_pygments_style = include_default_pygments_style if style_transformation is not None: + overrides["style_transformation"] = self.style_transformation self.style_transformation = style_transformation if swap_light_and_dark_colors is not None: + overrides["swap_light_and_dark_colors"] = self.swap_light_and_dark_colors self.swap_light_and_dark_colors = swap_light_and_dark_colors if rprompt is not None: + overrides["rprompt"] = self.rprompt self.rprompt = rprompt if multiline is not None: + overrides["multiline"] = self.multiline self.multiline = multiline if prompt_continuation is not None: + overrides["prompt_continuation"] = self.prompt_continuation self.prompt_continuation = prompt_continuation if wrap_lines is not None: + overrides["wrap_lines"] = self.wrap_lines self.wrap_lines = wrap_lines if enable_history_search is not None: + overrides["enable_history_search"] = self.enable_history_search self.enable_history_search = enable_history_search if search_ignore_case is not None: + overrides["search_ignore_case"] = self.search_ignore_case self.search_ignore_case = search_ignore_case if complete_while_typing is not None: + overrides["complete_while_typing"] = self.complete_while_typing self.complete_while_typing = complete_while_typing if validate_while_typing is not None: + overrides["validate_while_typing"] = self.validate_while_typing self.validate_while_typing = validate_while_typing if complete_style is not None: + overrides["complete_style"] = self.complete_style self.complete_style = complete_style if auto_suggest is not None: + overrides["auto_suggest"] = self.auto_suggest self.auto_suggest = auto_suggest if validator is not None: + overrides["validator"] = self.validator self.validator = validator if clipboard is not None: + overrides["clipboard"] = self.clipboard self.clipboard = clipboard if mouse_support is not None: + overrides["mouse_support"] = self.mouse_support self.mouse_support = mouse_support if input_processors is not None: + overrides["input_processors"] = self.input_processors self.input_processors = input_processors if placeholder is not None: + overrides["placeholder"] = self.placeholder self.placeholder = placeholder if reserve_space_for_menu is not None: + overrides["reserve_space_for_menu"] = self.reserve_space_for_menu self.reserve_space_for_menu = reserve_space_for_menu if enable_system_prompt is not None: + overrides["enable_system_prompt"] = self.enable_system_prompt self.enable_system_prompt = enable_system_prompt if enable_suspend is not None: + overrides["enable_suspend"] = self.enable_suspend self.enable_suspend = enable_suspend if enable_open_in_editor is not None: + overrides["enable_open_in_editor"] = self.enable_open_in_editor self.enable_open_in_editor = enable_open_in_editor if tempfile_suffix is not None: + overrides["tempfile_suffix"] = self.tempfile_suffix self.tempfile_suffix = tempfile_suffix if tempfile is not None: + overrides["tempfile"] = self.tempfile self.tempfile = tempfile if show_frame is not None: + overrides["show_frame"] = self.show_frame self.show_frame = show_frame self._add_pre_run_callables(pre_run, accept_default) @@ -1046,18 +1077,24 @@ class itself. For these, passing in ``None`` will keep the current ) self.app.refresh_interval = self.refresh_interval # This is not reactive. - # If we are using the default output, and have a dumb terminal. Use the - # dumb prompt. - if self._output is None and is_dumb_terminal(): - with self._dumb_prompt(self.message) as dump_app: - return dump_app.run(in_thread=in_thread, handle_sigint=handle_sigint) - - return self.app.run( - set_exception_handler=set_exception_handler, - in_thread=in_thread, - handle_sigint=handle_sigint, - inputhook=inputhook, - ) + try: + # If we are using the default output, and have a dumb terminal. Use the + # dumb prompt. + if self._output is None and is_dumb_terminal(): + with self._dumb_prompt(self.message) as dump_app: + return dump_app.run(in_thread=in_thread, handle_sigint=handle_sigint) + + return self.app.run( + set_exception_handler=set_exception_handler, + in_thread=in_thread, + handle_sigint=handle_sigint, + inputhook=inputhook, + ) + finally: + # Restore overridden attributes so per-call kwargs don't + # permanently change the session state. + for attr, value in overrides.items(): + setattr(self, attr, value) @contextmanager def _dumb_prompt(self, message: AnyFormattedText = "") -> Iterator[Application[_T]]: @@ -1160,81 +1197,122 @@ async def prompt_async( set_exception_handler: bool = True, handle_sigint: bool = True, ) -> _T: + # Save and restore overridden attributes (same as prompt(), see #967). + overrides: dict[str, object] = {} + if message is not None: + overrides["message"] = self.message self.message = message if editing_mode is not None: + overrides["editing_mode"] = self.editing_mode self.editing_mode = editing_mode if refresh_interval is not None: + overrides["refresh_interval"] = self.refresh_interval self.refresh_interval = refresh_interval if vi_mode: + overrides.setdefault("editing_mode", self.editing_mode) self.editing_mode = EditingMode.VI if lexer is not None: + overrides["lexer"] = self.lexer self.lexer = lexer if completer is not None: + overrides["completer"] = self.completer self.completer = completer if complete_in_thread is not None: + overrides["complete_in_thread"] = self.complete_in_thread self.complete_in_thread = complete_in_thread if is_password is not None: + overrides["is_password"] = self.is_password self.is_password = is_password if key_bindings is not None: + overrides["key_bindings"] = self.key_bindings self.key_bindings = key_bindings if bottom_toolbar is not None: + overrides["bottom_toolbar"] = self.bottom_toolbar self.bottom_toolbar = bottom_toolbar if style is not None: + overrides["style"] = self.style self.style = style if color_depth is not None: + overrides["color_depth"] = self.color_depth self.color_depth = color_depth if cursor is not None: + overrides["cursor"] = self.cursor self.cursor = cursor if include_default_pygments_style is not None: + overrides["include_default_pygments_style"] = self.include_default_pygments_style self.include_default_pygments_style = include_default_pygments_style if style_transformation is not None: + overrides["style_transformation"] = self.style_transformation self.style_transformation = style_transformation if swap_light_and_dark_colors is not None: + overrides["swap_light_and_dark_colors"] = self.swap_light_and_dark_colors self.swap_light_and_dark_colors = swap_light_and_dark_colors if rprompt is not None: + overrides["rprompt"] = self.rprompt self.rprompt = rprompt if multiline is not None: + overrides["multiline"] = self.multiline self.multiline = multiline if prompt_continuation is not None: + overrides["prompt_continuation"] = self.prompt_continuation self.prompt_continuation = prompt_continuation if wrap_lines is not None: + overrides["wrap_lines"] = self.wrap_lines self.wrap_lines = wrap_lines if enable_history_search is not None: + overrides["enable_history_search"] = self.enable_history_search self.enable_history_search = enable_history_search if search_ignore_case is not None: + overrides["search_ignore_case"] = self.search_ignore_case self.search_ignore_case = search_ignore_case if complete_while_typing is not None: + overrides["complete_while_typing"] = self.complete_while_typing self.complete_while_typing = complete_while_typing if validate_while_typing is not None: + overrides["validate_while_typing"] = self.validate_while_typing self.validate_while_typing = validate_while_typing if complete_style is not None: + overrides["complete_style"] = self.complete_style self.complete_style = complete_style if auto_suggest is not None: + overrides["auto_suggest"] = self.auto_suggest self.auto_suggest = auto_suggest if validator is not None: + overrides["validator"] = self.validator self.validator = validator if clipboard is not None: + overrides["clipboard"] = self.clipboard self.clipboard = clipboard if mouse_support is not None: + overrides["mouse_support"] = self.mouse_support self.mouse_support = mouse_support if input_processors is not None: + overrides["input_processors"] = self.input_processors self.input_processors = input_processors if placeholder is not None: + overrides["placeholder"] = self.placeholder self.placeholder = placeholder if reserve_space_for_menu is not None: + overrides["reserve_space_for_menu"] = self.reserve_space_for_menu self.reserve_space_for_menu = reserve_space_for_menu if enable_system_prompt is not None: + overrides["enable_system_prompt"] = self.enable_system_prompt self.enable_system_prompt = enable_system_prompt if enable_suspend is not None: + overrides["enable_suspend"] = self.enable_suspend self.enable_suspend = enable_suspend if enable_open_in_editor is not None: + overrides["enable_open_in_editor"] = self.enable_open_in_editor self.enable_open_in_editor = enable_open_in_editor if tempfile_suffix is not None: + overrides["tempfile_suffix"] = self.tempfile_suffix self.tempfile_suffix = tempfile_suffix if tempfile is not None: + overrides["tempfile"] = self.tempfile self.tempfile = tempfile if show_frame is not None: + overrides["show_frame"] = self.show_frame self.show_frame = show_frame self._add_pre_run_callables(pre_run, accept_default) @@ -1243,15 +1321,19 @@ async def prompt_async( ) self.app.refresh_interval = self.refresh_interval # This is not reactive. - # If we are using the default output, and have a dumb terminal. Use the - # dumb prompt. - if self._output is None and is_dumb_terminal(): - with self._dumb_prompt(self.message) as dump_app: - return await dump_app.run_async(handle_sigint=handle_sigint) - - return await self.app.run_async( - set_exception_handler=set_exception_handler, handle_sigint=handle_sigint - ) + try: + # If we are using the default output, and have a dumb terminal. Use the + # dumb prompt. + if self._output is None and is_dumb_terminal(): + with self._dumb_prompt(self.message) as dump_app: + return await dump_app.run_async(handle_sigint=handle_sigint) + + return await self.app.run_async( + set_exception_handler=set_exception_handler, handle_sigint=handle_sigint + ) + finally: + for attr, value in overrides.items(): + setattr(self, attr, value) def _add_pre_run_callables( self, pre_run: Callable[[], None] | None, accept_default: bool diff --git a/tests/test_shortcuts.py b/tests/test_shortcuts.py index 287c6d33a..61ab5388c 100644 --- a/tests/test_shortcuts.py +++ b/tests/test_shortcuts.py @@ -1,7 +1,9 @@ from __future__ import annotations +from unittest.mock import patch + from prompt_toolkit.shortcuts import print_container -from prompt_toolkit.shortcuts.prompt import _split_multiline_prompt +from prompt_toolkit.shortcuts.prompt import PromptSession, _split_multiline_prompt from prompt_toolkit.widgets import Frame, TextArea @@ -55,6 +57,36 @@ def test_split_multiline_prompt(): assert first_input_line() == [("class:testclass", "a"), ("class:testclass", "b")] +def test_prompt_per_call_override_restore(): + """Per-call kwargs to prompt() should not permanently change session state (#967).""" + session = PromptSession() + + # Verify defaults. + assert session.is_password is False + assert session.multiline is False + assert session.wrap_lines is True + assert session.message == "" + + # Call prompt() with overrides. app.run will raise (no terminal), but the + # finally block should still restore the original values. + with patch.object(session.app, "run", side_effect=EOFError): + try: + session.prompt( + "test> ", + is_password=True, + multiline=True, + wrap_lines=False, + ) + except EOFError: + pass + + # All overridden attributes should be restored. + assert session.is_password is False + assert session.multiline is False + assert session.wrap_lines is True + assert session.message == "" + + def test_print_container(tmpdir): # Call `print_container`, render to a dummy file. f = tmpdir.join("output")