From ae896ada861c4ab986a3fd43dce44c3f90970ffa Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Fri, 3 Apr 2026 14:14:10 -0400 Subject: [PATCH] move --list-ssh-config out of main.py The added tests are not 100% equivalent to the removed tests, but the execution path is also deprecated. --- changelog.md | 1 + mycli/main.py | 39 +----- mycli/main_modes/list_ssh_config.py | 26 ++++ mycli/packages/ssh_utils.py | 27 ++++ .../test_main_modes_list_ssh_config.py | 87 +++++++++++++ test/pytests/test_main_regression.py | 121 +----------------- test/pytests/test_ssh_utils.py | 68 ++++++++++ 7 files changed, 217 insertions(+), 152 deletions(-) create mode 100644 mycli/main_modes/list_ssh_config.py create mode 100644 mycli/packages/ssh_utils.py create mode 100644 test/pytests/test_main_modes_list_ssh_config.py create mode 100644 test/pytests/test_ssh_utils.py diff --git a/changelog.md b/changelog.md index 96602a63..5aac047d 100644 --- a/changelog.md +++ b/changelog.md @@ -38,6 +38,7 @@ Internal * Move `--checkup` logic to the new `main_modes` with `--batch`. * Move `--execute` logic to the new `main_modes` with `--batch`. * Move `--list-dsn` logic to the new `main_modes` with `--batch`. +* Move `--list-ssh-config` logic to the new `main_modes` with `--batch`. * Sort coverage report in tox suite. * Skip more tests when a database connection is not present. diff --git a/mycli/main.py b/mycli/main.py index 7d377a62..4092ddc1 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -88,6 +88,7 @@ from mycli.main_modes.checkup import main_checkup from mycli.main_modes.execute import main_execute_from_cli from mycli.main_modes.list_dsn import main_list_dsn +from mycli.main_modes.list_ssh_config import main_list_ssh_config from mycli.packages import special from mycli.packages.filepaths import dir_path_exists, guess_socket_location from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command @@ -98,16 +99,12 @@ from mycli.packages.special.main import ArgType from mycli.packages.special.utils import format_uptime, get_ssl_version, get_uptime, get_warning_count from mycli.packages.sqlresult import SQLResult +from mycli.packages.ssh_utils import read_ssh_config from mycli.packages.string_utils import sanitize_terminal_title from mycli.packages.tabular_output import sql_format from mycli.sqlcompleter import SQLCompleter from mycli.sqlexecute import FIELD_TYPES, SQLExecute -try: - import paramiko -except ImportError: - from mycli.packages.paramiko_stub import paramiko # type: ignore[no-redef] - sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment] sqlparse.engine.grouping.MAX_GROUPING_TOKENS = None # type: ignore[assignment] @@ -2316,19 +2313,7 @@ def get_password_from_file(password_file: str | None) -> str | None: sys.exit(main_list_dsn(mycli, cli_args)) if cli_args.list_ssh_config: - ssh_config = read_ssh_config(cli_args.ssh_config_path) - try: - host_entries = ssh_config.get_hostnames() - except KeyError: - click.secho('Error reading ssh config', err=True, fg="red") - sys.exit(1) - for host_entry in host_entries: - if cli_args.verbose: - host_config = ssh_config.lookup(host_entry) - click.secho(f"{host_entry} : {host_config.get('hostname')}") - else: - click.secho(host_entry) - sys.exit(0) + sys.exit(main_list_ssh_config(mycli, cli_args)) if 'MYSQL_UNIX_PORT' in os.environ: # deprecated 2026-03 @@ -2761,24 +2746,6 @@ def edit_and_execute(event: KeyPressEvent) -> None: buff.open_in_editor(validate_and_handle=False) -def read_ssh_config(ssh_config_path: str): - ssh_config = paramiko.config.SSHConfig() - try: - with open(ssh_config_path) as f: - ssh_config.parse(f) - except FileNotFoundError as e: - click.secho(str(e), err=True, fg="red") - sys.exit(1) - # Paramiko prior to version 2.7 raises Exception on parse errors. - # In 2.7 it has become paramiko.ssh_exception.SSHException, - # but let's catch everything for compatibility - except Exception as err: - click.secho(f"Could not parse SSH configuration file {ssh_config_path}:\n{err} ", err=True, fg="red") - sys.exit(1) - else: - return ssh_config - - def filtered_sys_argv() -> list[str]: args = sys.argv[1:] if args == ['-h']: diff --git a/mycli/main_modes/list_ssh_config.py b/mycli/main_modes/list_ssh_config.py new file mode 100644 index 00000000..8c27a011 --- /dev/null +++ b/mycli/main_modes/list_ssh_config.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import click + +from mycli.packages.ssh_utils import read_ssh_config + +if TYPE_CHECKING: + from mycli.main import CliArgs, MyCli + + +def main_list_ssh_config(mycli: 'MyCli', cli_args: 'CliArgs') -> int: + ssh_config = read_ssh_config(cli_args.ssh_config_path) + try: + host_entries = ssh_config.get_hostnames() + except KeyError: + click.secho('Error reading ssh config', err=True, fg="red") + return 1 + for host_entry in host_entries: + if cli_args.verbose: + host_config = ssh_config.lookup(host_entry) + click.secho(f"{host_entry} : {host_config.get('hostname')}") + else: + click.secho(host_entry) + return 0 diff --git a/mycli/packages/ssh_utils.py b/mycli/packages/ssh_utils.py new file mode 100644 index 00000000..1b81384a --- /dev/null +++ b/mycli/packages/ssh_utils.py @@ -0,0 +1,27 @@ +import sys + +import click + +try: + import paramiko +except ImportError: + from mycli.packages.paramiko_stub import paramiko # type: ignore[no-redef] + + +# it isn't cool that this utility function can exit(), but it is slated to be removed anyway +def read_ssh_config(ssh_config_path: str): + ssh_config = paramiko.config.SSHConfig() + try: + with open(ssh_config_path) as f: + ssh_config.parse(f) + except FileNotFoundError as e: + click.secho(str(e), err=True, fg="red") + sys.exit(1) + # Paramiko prior to version 2.7 raises Exception on parse errors. + # In 2.7 it has become paramiko.ssh_exception.SSHException, + # but let's catch everything for compatibility + except Exception as err: + click.secho(f"Could not parse SSH configuration file {ssh_config_path}:\n{err} ", err=True, fg="red") + sys.exit(1) + else: + return ssh_config diff --git a/test/pytests/test_main_modes_list_ssh_config.py b/test/pytests/test_main_modes_list_ssh_config.py new file mode 100644 index 00000000..287ed1f2 --- /dev/null +++ b/test/pytests/test_main_modes_list_ssh_config.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, cast + +import mycli.main_modes.list_ssh_config as list_ssh_config_mode + + +@dataclass +class DummyCliArgs: + ssh_config_path: str = 'ssh_config' + verbose: bool = False + + +class DummySSHConfig: + def __init__(self, hostnames: list[str] | Exception, lookups: dict[str, dict[str, str]] | None = None) -> None: + self.hostnames = hostnames + self.lookups = lookups or {} + + def get_hostnames(self) -> list[str]: + if isinstance(self.hostnames, Exception): + raise self.hostnames + return self.hostnames + + def lookup(self, hostname: str) -> dict[str, str]: + return self.lookups[hostname] + + +def main_list_ssh_config(cli_args: DummyCliArgs) -> int: + return list_ssh_config_mode.main_list_ssh_config(cast(Any, object()), cast(Any, cli_args)) + + +def test_main_list_ssh_config_lists_hostnames(monkeypatch) -> None: + secho_calls: list[tuple[str, bool | None, str | None]] = [] + ssh_config = DummySSHConfig(['prod', 'staging']) + + monkeypatch.setattr(list_ssh_config_mode, 'read_ssh_config', lambda _path: ssh_config) + monkeypatch.setattr( + list_ssh_config_mode.click, + 'secho', + lambda message, err=None, fg=None: secho_calls.append((message, err, fg)), + ) + + result = main_list_ssh_config(DummyCliArgs(verbose=False)) + + assert result == 0 + assert secho_calls == [ + ('prod', None, None), + ('staging', None, None), + ] + + +def test_main_list_ssh_config_lists_verbose_host_details(monkeypatch) -> None: + secho_calls: list[tuple[str, bool | None, str | None]] = [] + ssh_config = DummySSHConfig( + ['prod'], + lookups={'prod': {'hostname': 'db.example.com'}}, + ) + + monkeypatch.setattr(list_ssh_config_mode, 'read_ssh_config', lambda _path: ssh_config) + monkeypatch.setattr( + list_ssh_config_mode.click, + 'secho', + lambda message, err=None, fg=None: secho_calls.append((message, err, fg)), + ) + + result = main_list_ssh_config(DummyCliArgs(verbose=True)) + + assert result == 0 + assert secho_calls == [('prod : db.example.com', None, None)] + + +def test_main_list_ssh_config_reports_host_lookup_errors(monkeypatch) -> None: + secho_calls: list[tuple[str, bool | None, str | None]] = [] + ssh_config = DummySSHConfig(KeyError('bad ssh config')) + + monkeypatch.setattr(list_ssh_config_mode, 'read_ssh_config', lambda _path: ssh_config) + monkeypatch.setattr( + list_ssh_config_mode.click, + 'secho', + lambda message, err=None, fg=None: secho_calls.append((message, err, fg)), + ) + + result = main_list_ssh_config(DummyCliArgs()) + + assert result == 1 + assert secho_calls == [('Error reading ssh config', True, 'red')] diff --git a/test/pytests/test_main_regression.py b/test/pytests/test_main_regression.py index 33f9a6c2..f2251f3b 100644 --- a/test/pytests/test_main_regression.py +++ b/test/pytests/test_main_regression.py @@ -252,7 +252,7 @@ def make_bare_mycli() -> Any: return cli -def load_main_variant(monkeypatch: pytest.MonkeyPatch, *, fail_pwd: bool = False, fail_paramiko: bool = False) -> ModuleType: +def load_main_variant(monkeypatch: pytest.MonkeyPatch, *, fail_pwd: bool = False) -> ModuleType: import builtins original_import = builtins.__import__ @@ -260,12 +260,10 @@ def load_main_variant(monkeypatch: pytest.MonkeyPatch, *, fail_pwd: bool = False def fake_import(name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0) -> Any: # noqa: A002 if fail_pwd and name == 'pwd': raise ImportError('no pwd') - if fail_paramiko and name == 'paramiko': - raise ImportError('no paramiko') return original_import(name, globals, locals, fromlist, level) monkeypatch.setattr(builtins, '__import__', fake_import) - module_name = f'mycli_main_variant_{int(fail_pwd)}_{int(fail_paramiko)}' + module_name = f'mycli_main_variant_{int(fail_pwd)}' spec = importlib.util.spec_from_file_location(module_name, Path(main.__file__)) assert spec is not None assert spec.loader is not None @@ -322,10 +320,9 @@ def call_click_entrypoint_direct(cli_args: main.CliArgs) -> None: cast(Any, main.click_entrypoint.callback).__wrapped__(cli_args) -def test_import_fallbacks_for_pwd_and_paramiko(monkeypatch: pytest.MonkeyPatch) -> None: - module = load_main_variant(monkeypatch, fail_pwd=True, fail_paramiko=True) +def test_import_fallbacks_for_pwd(monkeypatch: pytest.MonkeyPatch) -> None: + module = load_main_variant(monkeypatch, fail_pwd=True) - assert hasattr(module, 'paramiko') assert module.Query('sql', True, False).query == 'sql' @@ -1487,7 +1484,7 @@ def test_filtered_sys_argv_covers_help_and_passthrough(monkeypatch: pytest.Monke assert main.need_completion_refresh('') is False -def test_completion_helpers_title_helpers_thanks_tips_and_read_ssh_config(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: +def test_completion_helpers_title_helpers_thanks_tips(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: cli = make_bare_mycli() cli.completer = cast(Any, SimpleNamespace(keyword_casing='auto', get_completions=lambda document, event: ['done'])) entered_lock = {'count': 0} @@ -1575,32 +1572,6 @@ def joinpath(self, name: str) -> 'FakeResource': monkeypatch.setattr(main.resources, 'files', lambda package: SponsorResource(None)) assert main.thanks_picker() == 'Sponsor Person' - class FakeSSHConfig: - def __init__(self) -> None: - self.parsed = False - - def parse(self, file_obj: Any) -> None: - self.parsed = True - - monkeypatch.setattr(main.paramiko.config, 'SSHConfig', FakeSSHConfig) - ssh_file = tmp_path / 'ssh.conf' - ssh_file.write_text('Host prod\n', encoding='utf-8') - ssh_config = main.read_ssh_config(str(ssh_file)) - assert ssh_config.parsed is True - - missing_errs: list[str] = [] - monkeypatch.setattr(click, 'secho', lambda message, **kwargs: missing_errs.append(str(message))) - with pytest.raises(SystemExit): - main.read_ssh_config(str(tmp_path / 'missing.conf')) - - class BadSSHConfig(FakeSSHConfig): - def parse(self, file_obj: Any) -> None: - raise Exception('bad parse') - - monkeypatch.setattr(main.paramiko.config, 'SSHConfig', BadSSHConfig) - with pytest.raises(SystemExit): - main.read_ssh_config(str(ssh_file)) - def test_main_wrapper_and_edit_and_execute(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(main, 'filtered_sys_argv', lambda: ['--help']) @@ -1694,28 +1665,6 @@ def test_click_entrypoint_branches_with_dummy_mycli(monkeypatch: pytest.MonkeyPa assert result.exit_code == 1 assert 'Invalid DSNs found' in result.output - class FakeSSHLookup: - def get_hostnames(self) -> list[str]: - return ['prod'] - - def lookup(self, host: str) -> dict[str, str]: - return {'hostname': 'db.example'} - - monkeypatch.setattr(main, 'read_ssh_config', lambda path: FakeSSHLookup()) - monkeypatch.setattr(main, 'MyCli', make_dummy_mycli_class()) - result = runner.invoke(main.click_entrypoint, ['--list-ssh-config', '--verbose']) - assert result.exit_code == 0 - assert 'prod : db.example' in result.output - - class BadSSHLookup: - def get_hostnames(self) -> list[str]: - raise KeyError() - - monkeypatch.setattr(main, 'read_ssh_config', lambda path: BadSSHLookup()) - result = runner.invoke(main.click_entrypoint, ['--list-ssh-config']) - assert result.exit_code == 1 - assert 'Error reading ssh config' in result.output - monkeypatch.setenv('MYSQL_UNIX_PORT', '/tmp/mysql.sock') monkeypatch.setenv('DSN', 'mysql://user:pw@host/db') monkeypatch.setattr(main, 'MyCli', make_dummy_mycli_class()) @@ -1924,15 +1873,8 @@ def test_click_entrypoint_callback_covers_dsn_params_init_commands_and_keyring(m monkeypatch.setattr(click, 'secho', lambda message='', **kwargs: click_lines.append(str(message))) monkeypatch.setattr(click, 'echo', lambda message='', **kwargs: click_lines.append(str(message))) - class SSHConfig: - def lookup(self, host: str) -> dict[str, Any]: - return {'hostname': 'ssh.example', 'user': 'sshuser', 'port': '2200', 'identityfile': ['/tmp/id_rsa']} - - monkeypatch.setattr(main, 'read_ssh_config', lambda path: SSHConfig()) cli_args = main.CliArgs() cli_args.database = 'prod' - cli_args.ssh_config_host = 'edge' - cli_args.ssh_port = 2201 cli_args.init_command = 'set e=5' cli_args.use_keyring = 'reset' call_click_entrypoint_direct(cli_args) @@ -1943,10 +1885,6 @@ def lookup(self, host: str) -> dict[str, Any]: assert connect_kwargs['database'] == 'prod_db' assert connect_kwargs['user'] == 'user' assert connect_kwargs['passwd'] == 'pw' - assert connect_kwargs['ssh_host'] == 'ssh.example' - assert connect_kwargs['ssh_user'] == 'sshuser' - assert connect_kwargs['ssh_port'] == 2201 - assert connect_kwargs['ssh_key_filename'] == '/tmp/id_rsa' assert connect_kwargs['ssl'] is None assert connect_kwargs['character_set'] == 'utf8mb4' assert connect_kwargs['keepalive_ticks'] == 9 @@ -1980,21 +1918,6 @@ def test_click_entrypoint_callback_covers_database_dsn_and_verbose_lists(monkeyp click_lines.clear() - class SSHConfig: - def get_hostnames(self) -> list[str]: - return ['prod'] - - def lookup(self, host: str) -> dict[str, str]: - return {'hostname': 'db.example'} - - monkeypatch.setattr(main, 'read_ssh_config', lambda path: SSHConfig()) - cli_args = main.CliArgs() - cli_args.list_ssh_config = True - cli_args.ssh_warning_off = True - with pytest.raises(SystemExit): - call_click_entrypoint_direct(cli_args) - assert click_lines == ['prod'] - dummy_class = make_dummy_mycli_class( config={ 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'}, @@ -2111,40 +2034,6 @@ def failing_run_query(self: Any, query: str, checkpoint: Any = None, new_line: b assert any('execute failed' in line for line in click_lines) -def test_click_entrypoint_callback_covers_ssh_default_port_alias_list_and_transition_underscore(monkeypatch: pytest.MonkeyPatch) -> None: - click_lines: list[str] = [] - monkeypatch.setattr(click, 'secho', lambda message='', **kwargs: click_lines.append(str(message))) - monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) - monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: False) - - dummy_class = make_dummy_mycli_class( - config={ - 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'false'}, - 'connection': {'default_keepalive_ticks': 0}, - 'alias_dsn': {'prod': 'mysql://u:p@h/db'}, - 'alias_dsn.init-commands': {'prod': ['set list=1']}, - }, - my_cnf={'client': {}, 'mysqld': {'loose_local_infile': '1'}}, - config_without_package_defaults={'connection': {}}, - ) - monkeypatch.setattr(main, 'MyCli', dummy_class) - - class SSHConfig: - def lookup(self, host: str) -> dict[str, Any]: - return {'hostname': 'ssh.example', 'user': 'sshuser', 'port': '2200', 'identityfile': ['/tmp/id_rsa']} - - monkeypatch.setattr(main, 'read_ssh_config', lambda path: SSHConfig()) - cli_args = main.CliArgs() - cli_args.database = 'prod' - cli_args.ssh_config_host = 'edge' - call_click_entrypoint_direct(cli_args) - dummy = dummy_class.last_instance - assert dummy is not None - assert dummy.connect_calls[-1]['ssh_port'] == 2200 - assert dummy.connect_calls[-1]['init_command'] == 'set list=1' - assert any('Reading configuration from my.cnf files is deprecated.' in line for line in click_lines) - - def test_configure_pager_and_refresh_completions(monkeypatch: pytest.MonkeyPatch) -> None: cli = make_bare_mycli() cli.my_cnf = {'client': {}, 'mysqld': {}} diff --git a/test/pytests/test_ssh_utils.py b/test/pytests/test_ssh_utils.py new file mode 100644 index 00000000..dadf5412 --- /dev/null +++ b/test/pytests/test_ssh_utils.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TextIO + +import pytest + +from mycli.packages import ssh_utils + + +class FakeSSHConfig: + def __init__(self, parse_error: Exception | None = None) -> None: + self.parse_error = parse_error + self.parsed_text: str | None = None + + def parse(self, handle: TextIO) -> None: + if self.parse_error is not None: + raise self.parse_error + self.parsed_text = handle.read() + + +def test_read_ssh_config_parses_and_returns_config(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + config_path = tmp_path / 'ssh_config' + config_path.write_text('Host demo\n HostName example.com\n', encoding='utf-8') + fake_ssh_config = FakeSSHConfig() + + monkeypatch.setattr(ssh_utils.paramiko.config, 'SSHConfig', lambda: fake_ssh_config) + + result = ssh_utils.read_ssh_config(str(config_path)) + + assert result is fake_ssh_config + assert fake_ssh_config.parsed_text == 'Host demo\n HostName example.com\n' + + +def test_read_ssh_config_reports_missing_file_and_exits(monkeypatch: pytest.MonkeyPatch) -> None: + secho_calls: list[tuple[str, bool, str]] = [] + + monkeypatch.setattr( + ssh_utils.click, + 'secho', + lambda message, err, fg: secho_calls.append((message, err, fg)), + ) + + with pytest.raises(SystemExit) as excinfo: + ssh_utils.read_ssh_config('/definitely/missing/ssh_config') + + assert excinfo.value.code == 1 + assert secho_calls == [("[Errno 2] No such file or directory: '/definitely/missing/ssh_config'", True, 'red')] + + +def test_read_ssh_config_reports_parse_errors_and_exits(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + config_path = tmp_path / 'ssh_config' + config_path.write_text('Host broken\n', encoding='utf-8') + fake_ssh_config = FakeSSHConfig(parse_error=RuntimeError('bad config')) + secho_calls: list[tuple[str, bool, str]] = [] + + monkeypatch.setattr(ssh_utils.paramiko.config, 'SSHConfig', lambda: fake_ssh_config) + monkeypatch.setattr( + ssh_utils.click, + 'secho', + lambda message, err, fg: secho_calls.append((message, err, fg)), + ) + + with pytest.raises(SystemExit) as excinfo: + ssh_utils.read_ssh_config(str(config_path)) + + assert excinfo.value.code == 1 + assert secho_calls == [(f'Could not parse SSH configuration file {config_path}:\nbad config ', True, 'red')]