Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 116 additions & 34 deletions src/prompt_toolkit/shortcuts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ class PromptSession(Generic[_T]):
"is_password",
"editing_mode",
"key_bindings",
"is_password",
"bottom_toolbar",
"style",
"style_transformation",
Expand Down Expand Up @@ -951,93 +950,125 @@ class itself. For these, passing in ``None`` will keep the current
pressed (for abort) and ``EOFError`` when control-d has been pressed
(for exit).
"""
# NOTE: We used to create a backup of the PromptSession attributes and
# restore them after exiting the prompt. This code has been
# removed, because it was confusing and didn't really serve a use
# case. (People were changing `Application.editing_mode`
# dynamically and surprised that it was reset after every call.)

# NOTE 2: YES, this is a lot of repeation below...
# However, it is a very convenient for a user to accept all
# these parameters in this `prompt` method as well. We could
# use `locals()` and `setattr` to avoid the repetition, but
# then we loose the advantage of mypy and pyflakes to be able
# to verify the code.
# Save the original values of any attributes that are being
# overridden, so we can restore them after the prompt finishes.
# This ensures that per-call overrides (like is_password=True) don't
# permanently change the session state (#967).
overrides: dict[str, object] = {}

if message is not None:
overrides["message"] = self.message
self.message = message
if editing_mode is not None:
overrides["editing_mode"] = self.editing_mode
self.editing_mode = editing_mode
if refresh_interval is not None:
overrides["refresh_interval"] = self.refresh_interval
self.refresh_interval = refresh_interval
if vi_mode:
overrides.setdefault("editing_mode", self.editing_mode)
self.editing_mode = EditingMode.VI
if lexer is not None:
overrides["lexer"] = self.lexer
self.lexer = lexer
if completer is not None:
overrides["completer"] = self.completer
self.completer = completer
if complete_in_thread is not None:
overrides["complete_in_thread"] = self.complete_in_thread
self.complete_in_thread = complete_in_thread
if is_password is not None:
overrides["is_password"] = self.is_password
self.is_password = is_password
if key_bindings is not None:
overrides["key_bindings"] = self.key_bindings
self.key_bindings = key_bindings
if bottom_toolbar is not None:
overrides["bottom_toolbar"] = self.bottom_toolbar
self.bottom_toolbar = bottom_toolbar
if style is not None:
overrides["style"] = self.style
self.style = style
if color_depth is not None:
overrides["color_depth"] = self.color_depth
self.color_depth = color_depth
if cursor is not None:
overrides["cursor"] = self.cursor
self.cursor = cursor
if include_default_pygments_style is not None:
overrides["include_default_pygments_style"] = self.include_default_pygments_style
self.include_default_pygments_style = include_default_pygments_style
if style_transformation is not None:
overrides["style_transformation"] = self.style_transformation
self.style_transformation = style_transformation
if swap_light_and_dark_colors is not None:
overrides["swap_light_and_dark_colors"] = self.swap_light_and_dark_colors
self.swap_light_and_dark_colors = swap_light_and_dark_colors
if rprompt is not None:
overrides["rprompt"] = self.rprompt
self.rprompt = rprompt
if multiline is not None:
overrides["multiline"] = self.multiline
self.multiline = multiline
if prompt_continuation is not None:
overrides["prompt_continuation"] = self.prompt_continuation
self.prompt_continuation = prompt_continuation
if wrap_lines is not None:
overrides["wrap_lines"] = self.wrap_lines
self.wrap_lines = wrap_lines
if enable_history_search is not None:
overrides["enable_history_search"] = self.enable_history_search
self.enable_history_search = enable_history_search
if search_ignore_case is not None:
overrides["search_ignore_case"] = self.search_ignore_case
self.search_ignore_case = search_ignore_case
if complete_while_typing is not None:
overrides["complete_while_typing"] = self.complete_while_typing
self.complete_while_typing = complete_while_typing
if validate_while_typing is not None:
overrides["validate_while_typing"] = self.validate_while_typing
self.validate_while_typing = validate_while_typing
if complete_style is not None:
overrides["complete_style"] = self.complete_style
self.complete_style = complete_style
if auto_suggest is not None:
overrides["auto_suggest"] = self.auto_suggest
self.auto_suggest = auto_suggest
if validator is not None:
overrides["validator"] = self.validator
self.validator = validator
if clipboard is not None:
overrides["clipboard"] = self.clipboard
self.clipboard = clipboard
if mouse_support is not None:
overrides["mouse_support"] = self.mouse_support
self.mouse_support = mouse_support
if input_processors is not None:
overrides["input_processors"] = self.input_processors
self.input_processors = input_processors
if placeholder is not None:
overrides["placeholder"] = self.placeholder
self.placeholder = placeholder
if reserve_space_for_menu is not None:
overrides["reserve_space_for_menu"] = self.reserve_space_for_menu
self.reserve_space_for_menu = reserve_space_for_menu
if enable_system_prompt is not None:
overrides["enable_system_prompt"] = self.enable_system_prompt
self.enable_system_prompt = enable_system_prompt
if enable_suspend is not None:
overrides["enable_suspend"] = self.enable_suspend
self.enable_suspend = enable_suspend
if enable_open_in_editor is not None:
overrides["enable_open_in_editor"] = self.enable_open_in_editor
self.enable_open_in_editor = enable_open_in_editor
if tempfile_suffix is not None:
overrides["tempfile_suffix"] = self.tempfile_suffix
self.tempfile_suffix = tempfile_suffix
if tempfile is not None:
overrides["tempfile"] = self.tempfile
self.tempfile = tempfile
if show_frame is not None:
overrides["show_frame"] = self.show_frame
self.show_frame = show_frame

self._add_pre_run_callables(pre_run, accept_default)
Expand All @@ -1046,18 +1077,24 @@ class itself. For these, passing in ``None`` will keep the current
)
self.app.refresh_interval = self.refresh_interval # This is not reactive.

# If we are using the default output, and have a dumb terminal. Use the
# dumb prompt.
if self._output is None and is_dumb_terminal():
with self._dumb_prompt(self.message) as dump_app:
return dump_app.run(in_thread=in_thread, handle_sigint=handle_sigint)

return self.app.run(
set_exception_handler=set_exception_handler,
in_thread=in_thread,
handle_sigint=handle_sigint,
inputhook=inputhook,
)
try:
# If we are using the default output, and have a dumb terminal. Use the
# dumb prompt.
if self._output is None and is_dumb_terminal():
with self._dumb_prompt(self.message) as dump_app:
return dump_app.run(in_thread=in_thread, handle_sigint=handle_sigint)

return self.app.run(
set_exception_handler=set_exception_handler,
in_thread=in_thread,
handle_sigint=handle_sigint,
inputhook=inputhook,
)
finally:
# Restore overridden attributes so per-call kwargs don't
# permanently change the session state.
for attr, value in overrides.items():
setattr(self, attr, value)

@contextmanager
def _dumb_prompt(self, message: AnyFormattedText = "") -> Iterator[Application[_T]]:
Expand Down Expand Up @@ -1160,81 +1197,122 @@ async def prompt_async(
set_exception_handler: bool = True,
handle_sigint: bool = True,
) -> _T:
# Save and restore overridden attributes (same as prompt(), see #967).
overrides: dict[str, object] = {}

if message is not None:
overrides["message"] = self.message
self.message = message
if editing_mode is not None:
overrides["editing_mode"] = self.editing_mode
self.editing_mode = editing_mode
if refresh_interval is not None:
overrides["refresh_interval"] = self.refresh_interval
self.refresh_interval = refresh_interval
if vi_mode:
overrides.setdefault("editing_mode", self.editing_mode)
self.editing_mode = EditingMode.VI
if lexer is not None:
overrides["lexer"] = self.lexer
self.lexer = lexer
if completer is not None:
overrides["completer"] = self.completer
self.completer = completer
if complete_in_thread is not None:
overrides["complete_in_thread"] = self.complete_in_thread
self.complete_in_thread = complete_in_thread
if is_password is not None:
overrides["is_password"] = self.is_password
self.is_password = is_password
if key_bindings is not None:
overrides["key_bindings"] = self.key_bindings
self.key_bindings = key_bindings
if bottom_toolbar is not None:
overrides["bottom_toolbar"] = self.bottom_toolbar
self.bottom_toolbar = bottom_toolbar
if style is not None:
overrides["style"] = self.style
self.style = style
if color_depth is not None:
overrides["color_depth"] = self.color_depth
self.color_depth = color_depth
if cursor is not None:
overrides["cursor"] = self.cursor
self.cursor = cursor
if include_default_pygments_style is not None:
overrides["include_default_pygments_style"] = self.include_default_pygments_style
self.include_default_pygments_style = include_default_pygments_style
if style_transformation is not None:
overrides["style_transformation"] = self.style_transformation
self.style_transformation = style_transformation
if swap_light_and_dark_colors is not None:
overrides["swap_light_and_dark_colors"] = self.swap_light_and_dark_colors
self.swap_light_and_dark_colors = swap_light_and_dark_colors
if rprompt is not None:
overrides["rprompt"] = self.rprompt
self.rprompt = rprompt
if multiline is not None:
overrides["multiline"] = self.multiline
self.multiline = multiline
if prompt_continuation is not None:
overrides["prompt_continuation"] = self.prompt_continuation
self.prompt_continuation = prompt_continuation
if wrap_lines is not None:
overrides["wrap_lines"] = self.wrap_lines
self.wrap_lines = wrap_lines
if enable_history_search is not None:
overrides["enable_history_search"] = self.enable_history_search
self.enable_history_search = enable_history_search
if search_ignore_case is not None:
overrides["search_ignore_case"] = self.search_ignore_case
self.search_ignore_case = search_ignore_case
if complete_while_typing is not None:
overrides["complete_while_typing"] = self.complete_while_typing
self.complete_while_typing = complete_while_typing
if validate_while_typing is not None:
overrides["validate_while_typing"] = self.validate_while_typing
self.validate_while_typing = validate_while_typing
if complete_style is not None:
overrides["complete_style"] = self.complete_style
self.complete_style = complete_style
if auto_suggest is not None:
overrides["auto_suggest"] = self.auto_suggest
self.auto_suggest = auto_suggest
if validator is not None:
overrides["validator"] = self.validator
self.validator = validator
if clipboard is not None:
overrides["clipboard"] = self.clipboard
self.clipboard = clipboard
if mouse_support is not None:
overrides["mouse_support"] = self.mouse_support
self.mouse_support = mouse_support
if input_processors is not None:
overrides["input_processors"] = self.input_processors
self.input_processors = input_processors
if placeholder is not None:
overrides["placeholder"] = self.placeholder
self.placeholder = placeholder
if reserve_space_for_menu is not None:
overrides["reserve_space_for_menu"] = self.reserve_space_for_menu
self.reserve_space_for_menu = reserve_space_for_menu
if enable_system_prompt is not None:
overrides["enable_system_prompt"] = self.enable_system_prompt
self.enable_system_prompt = enable_system_prompt
if enable_suspend is not None:
overrides["enable_suspend"] = self.enable_suspend
self.enable_suspend = enable_suspend
if enable_open_in_editor is not None:
overrides["enable_open_in_editor"] = self.enable_open_in_editor
self.enable_open_in_editor = enable_open_in_editor
if tempfile_suffix is not None:
overrides["tempfile_suffix"] = self.tempfile_suffix
self.tempfile_suffix = tempfile_suffix
if tempfile is not None:
overrides["tempfile"] = self.tempfile
self.tempfile = tempfile
if show_frame is not None:
overrides["show_frame"] = self.show_frame
self.show_frame = show_frame

self._add_pre_run_callables(pre_run, accept_default)
Expand All @@ -1243,15 +1321,19 @@ async def prompt_async(
)
self.app.refresh_interval = self.refresh_interval # This is not reactive.

# If we are using the default output, and have a dumb terminal. Use the
# dumb prompt.
if self._output is None and is_dumb_terminal():
with self._dumb_prompt(self.message) as dump_app:
return await dump_app.run_async(handle_sigint=handle_sigint)

return await self.app.run_async(
set_exception_handler=set_exception_handler, handle_sigint=handle_sigint
)
try:
# If we are using the default output, and have a dumb terminal. Use the
# dumb prompt.
if self._output is None and is_dumb_terminal():
with self._dumb_prompt(self.message) as dump_app:
return await dump_app.run_async(handle_sigint=handle_sigint)

return await self.app.run_async(
set_exception_handler=set_exception_handler, handle_sigint=handle_sigint
)
finally:
for attr, value in overrides.items():
setattr(self, attr, value)

def _add_pre_run_callables(
self, pre_run: Callable[[], None] | None, accept_default: bool
Expand Down
34 changes: 33 additions & 1 deletion tests/test_shortcuts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

from unittest.mock import patch

from prompt_toolkit.shortcuts import print_container
from prompt_toolkit.shortcuts.prompt import _split_multiline_prompt
from prompt_toolkit.shortcuts.prompt import PromptSession, _split_multiline_prompt
from prompt_toolkit.widgets import Frame, TextArea


Expand Down Expand Up @@ -55,6 +57,36 @@ def test_split_multiline_prompt():
assert first_input_line() == [("class:testclass", "a"), ("class:testclass", "b")]


def test_prompt_per_call_override_restore():
"""Per-call kwargs to prompt() should not permanently change session state (#967)."""
session = PromptSession()

# Verify defaults.
assert session.is_password is False
assert session.multiline is False
assert session.wrap_lines is True
assert session.message == ""

# Call prompt() with overrides. app.run will raise (no terminal), but the
# finally block should still restore the original values.
with patch.object(session.app, "run", side_effect=EOFError):
try:
session.prompt(
"test> ",
is_password=True,
multiline=True,
wrap_lines=False,
)
except EOFError:
pass

# All overridden attributes should be restored.
assert session.is_password is False
assert session.multiline is False
assert session.wrap_lines is True
assert session.message == ""


def test_print_container(tmpdir):
# Call `print_container`, render to a dummy file.
f = tmpdir.join("output")
Expand Down