From 8bae90ab47cfb902b2c15533fc1120a26e687c85 Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Sat, 16 May 2026 11:50:49 -0400 Subject: [PATCH] migrate state, cli_args, output out of main.py Create a mycli/app_state.py, mycli/cli_args.py, and mycli/output.py, factoring logic out of mycli/main.py, with no functional change. --- changelog.md | 4 + mycli/app_state.py | 107 ++++ mycli/cli_args.py | 373 +++++++++++++ mycli/main.py | 750 ++------------------------- mycli/output.py | 291 +++++++++++ mycli/packages/special/main.py | 6 +- test/pytests/test_app_state.py | 146 ++++++ test/pytests/test_cli_args.py | 175 +++++++ test/pytests/test_main.py | 8 +- test/pytests/test_main_regression.py | 26 +- test/pytests/test_output.py | 232 +++++++++ test/utils.py | 7 +- 12 files changed, 1390 insertions(+), 735 deletions(-) create mode 100644 mycli/app_state.py create mode 100644 mycli/cli_args.py create mode 100644 mycli/output.py create mode 100644 test/pytests/test_app_state.py create mode 100644 test/pytests/test_cli_args.py create mode 100644 test/pytests/test_output.py diff --git a/changelog.md b/changelog.md index 0a086618..13f7a0f5 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,10 @@ Bug Fixes * Respect `history_file` setting in the `[main]` section of `~/.myclirc`. * Adapt test suite to pygments v2.20.0. +Internal +--------- +* Factor `app_state.py`, `cli_args.py`, and `output.py` out of `main.py`. + 1.72.1 (2026/05/11) ============== diff --git a/mycli/app_state.py b/mycli/app_state.py new file mode 100644 index 00000000..0aaad28a --- /dev/null +++ b/mycli/app_state.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from collections import defaultdict +import re +from typing import TYPE_CHECKING, Any + +from configobj import ConfigObj + +from mycli.config import str_to_bool, strip_matching_quotes + +if TYPE_CHECKING: + from mycli.main import MyCli + + +def normalize_ssl_mode(config: ConfigObj) -> tuple[str | None, str | None]: + ssl_mode = config['main'].get('ssl_mode', None) or config['connection'].get('default_ssl_mode', None) + if ssl_mode not in ('auto', 'on', 'off', None): + return None, f'Invalid config option provided for ssl_mode ({ssl_mode}); ignoring.' + return ssl_mode, None + + +def ensure_my_cnf_sections(my_cnf: ConfigObj) -> None: + if not my_cnf.get('client'): + my_cnf['client'] = {} + if not my_cnf.get('mysqld'): + my_cnf['mysqld'] = {} + + +def configure_prompt_state( + mycli: MyCli, + config: ConfigObj, + prompt: str | None, + prompt_cnf: str | None, + toolbar_format: str | None, +) -> None: + mycli.prompt_format = prompt or prompt_cnf or config['main']['prompt'] or mycli.default_prompt + mycli.prompt_lines = 0 + mycli.multiline_continuation_char = config['main']['prompt_continuation'] + mycli.toolbar_format = toolbar_format or config['main']['toolbar'] + mycli.terminal_tab_title_format = config['main']['terminal_tab_title'] + mycli.terminal_window_title_format = config['main']['terminal_window_title'] + mycli.multiplex_window_title_format = config['main']['multiplex_window_title'] + mycli.multiplex_pane_title_format = config['main']['multiplex_pane_title'] + + +def destructive_keywords_from_config(config: ConfigObj) -> list[str]: + keywords = config['main'].get('destructive_keywords', 'DROP SHUTDOWN DELETE TRUNCATE ALTER UPDATE') + return [keyword for keyword in keywords.split(' ') if keyword] + + +def llm_prompt_truncation(config: ConfigObj) -> tuple[int, int]: + if 'llm' in config and re.match(r'^\d+$', config['llm'].get('prompt_field_truncate', '')): + field_truncate = int(config['llm'].get('prompt_field_truncate')) + else: + field_truncate = 0 + if 'llm' in config and re.match(r'^\d+$', config['llm'].get('prompt_section_truncate', '')): + section_truncate = int(config['llm'].get('prompt_section_truncate')) + else: + section_truncate = 0 + return field_truncate, section_truncate + + +class AppStateMixin: + defaults_suffix: str | None + login_path: str | None + + def read_my_cnf(self, cnf: ConfigObj, keys: list[str]) -> dict[str, Any]: + sections = ['client', 'mysqld'] + key_transformations = { + 'mysqld': { + 'socket': 'default_socket', + 'port': 'default_port', + 'user': 'default_user', + }, + } + + if self.login_path and self.login_path != 'client': + sections.append(self.login_path) + + if self.defaults_suffix: + sections.extend([sect + self.defaults_suffix for sect in sections]) + + configuration: dict[str, Any] = defaultdict(lambda: None) + for key in keys: + for section in cnf: + if section not in sections or key not in cnf[section]: + continue + new_key = key_transformations.get(section, {}).get(key) or key + configuration[new_key] = strip_matching_quotes(cnf[section][key]) + + return configuration + + def merge_ssl_with_cnf(self, ssl: dict[str, Any], cnf: dict[str, Any]) -> dict[str, Any]: + merged = {} + merged.update(ssl) + prefix = 'ssl-' + for key, value in cnf.items(): + if not key.startswith(prefix): + continue + if value is None: + continue + if key == 'ssl-verify-server-cert': + merged['check_hostname'] = str_to_bool(value) + else: + merged[key[len(prefix) :]] = value + + return merged diff --git a/mycli/cli_args.py b/mycli/cli_args.py new file mode 100644 index 00000000..bf95f59d --- /dev/null +++ b/mycli/cli_args.py @@ -0,0 +1,373 @@ +from __future__ import annotations + +from dataclasses import dataclass +from io import TextIOWrapper +import os +import sys +from typing import Callable + +import click +import clickdc + +EMPTY_PASSWORD_FLAG_SENTINEL = -1 +DEFAULT_PROMPT = "\\t \\u@\\h:\\d> " + + +class IntOrStringClickParamType(click.ParamType): + name = 'text' # display as TEXT in helpdoc + + def convert(self, value, param, ctx): + if isinstance(value, int): + return value + elif isinstance(value, str): + return value + elif value is None: + return value + else: + self.fail('Not a valid password string', param, ctx) + + +INT_OR_STRING_CLICK_TYPE = IntOrStringClickParamType() + + +@dataclass(slots=True) +class CliArgs: + database: str | None = clickdc.argument( + type=str, + default=None, + nargs=1, + ) + host: str | None = clickdc.option( + '-h', + '--hostname', + 'host', + type=str, + envvar='MYSQL_HOST', + help='Host address of the database.', + ) + port: int | None = clickdc.option( + '-P', + type=int, + envvar='MYSQL_TCP_PORT', + help='Port number to use for connection. Honors $MYSQL_TCP_PORT.', + ) + user: str | None = clickdc.option( + '-u', + '--user', + '--username', + 'user', + type=str, + envvar='MYSQL_USER', + help='User name to connect to the database.', + ) + socket: str | None = clickdc.option( + '-S', + type=str, + envvar='MYSQL_UNIX_SOCKET', + help='The socket file to use for connection.', + ) + password: int | str | None = clickdc.option( + '-p', + '--pass', + '--password', + 'password', + type=INT_OR_STRING_CLICK_TYPE, + is_flag=False, + flag_value=EMPTY_PASSWORD_FLAG_SENTINEL, + help='Prompt for (or pass in cleartext) the password to connect to the database.', + ) + password_file: str | None = clickdc.option( + type=click.Path(), + help='File or FIFO path containing the password to connect to the db if not specified otherwise.', + ) + ssh_user: str | None = clickdc.option( + type=str, + help='User name to connect to ssh server.', + ) + ssh_host: str | None = clickdc.option( + type=str, + help='Host name to connect to ssh server.', + ) + ssh_port: int = clickdc.option( + type=int, + default=22, + help='Port to connect to ssh server.', + ) + ssh_password: str | None = clickdc.option( + type=str, + help='Password to connect to ssh server.', + ) + ssh_key_filename: str | None = clickdc.option( + type=str, + help='Private key filename (identify file) for the ssh connection.', + ) + ssh_config_path: str = clickdc.option( + type=str, + help='Path to ssh configuration.', + default=os.path.expanduser('~') + '/.ssh/config', + ) + ssh_config_host: str | None = clickdc.option( + type=str, + help='Host to connect to ssh server reading from ssh configuration.', + ) + list_ssh_config: bool = clickdc.option( + is_flag=True, + help='list ssh configurations in the ssh config (requires paramiko).', + ) + ssh_warning_off: bool = clickdc.option( + is_flag=True, + help='Suppress the SSH deprecation notice.', + ) + ssl_mode: str = clickdc.option( + type=click.Choice(['auto', 'on', 'off']), + help='Set desired SSL behavior. auto=preferred if TCP/IP, on=required, off=off.', + ) + deprecated_ssl: bool | None = clickdc.option( + '--ssl/--no-ssl', + 'deprecated_ssl', + default=None, + clickdc=None, + help='Enable SSL for connection (automatically enabled with other flags).', + ) + ssl_ca: str | None = clickdc.option( + type=click.Path(exists=True), + help='CA file in PEM format.', + ) + ssl_capath: str | None = clickdc.option( + type=click.Path(exists=True, file_okay=False, dir_okay=True), + help='CA directory.', + ) + ssl_cert: str | None = clickdc.option( + type=click.Path(exists=True), + help='X509 cert in PEM format.', + ) + ssl_key: str | None = clickdc.option( + type=click.Path(exists=True), + help='X509 key in PEM format.', + ) + ssl_cipher: str | None = clickdc.option( + type=str, + help='SSL cipher to use.', + ) + tls_version: str | None = clickdc.option( + type=click.Choice(['TLSv1', 'TLSv1.1', 'TLSv1.2', 'TLSv1.3'], case_sensitive=False), + help='TLS protocol version for secure connection.', + ) + ssl_verify_server_cert: bool = clickdc.option( + is_flag=True, + help=("""Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default."""), + ) + verbose: int = clickdc.option( + '-v', + count=True, + help='More verbose output and feedback. Can be given multiple times.', + ) + quiet: bool = clickdc.option( + '-q', + is_flag=True, + help='Less verbose output and feedback.', + ) + dbname: str | None = clickdc.option( + '-D', + '--database', + 'dbname', + type=str, + clickdc=None, + help='Database or DSN to use for the connection.', + ) + dsn: str = clickdc.option( + '-d', + type=str, + default='', + envvar='MYSQL_DSN', + help='DSN alias configured in the ~/.myclirc file, or a full DSN.', + ) + list_dsn: bool = clickdc.option( + is_flag=True, + help='Show list of DSN aliases configured in the [alias_dsn] section of ~/.myclirc.', + ) + prompt: str | None = clickdc.option( + '-R', + type=str, + help=f'Prompt format (Default: "{DEFAULT_PROMPT}").', + ) + toolbar: str | None = clickdc.option( + type=str, + help='Toolbar format.', + ) + logfile: TextIOWrapper | None = clickdc.option( + '-l', + type=click.File(mode='a', encoding='utf-8'), + help='Log every query and its results to a file.', + ) + checkpoint: TextIOWrapper | None = clickdc.option( + type=click.File(mode='a', encoding='utf-8'), + help='In batch or --execute mode, log successful queries to a file, and skipped with --resume.', + ) + resume: bool = clickdc.option( + '--resume', + is_flag=True, + help='In batch mode, resume after replaying statements in the --checkpoint file.', + ) + defaults_group_suffix: str | None = clickdc.option( + type=str, + help='Read MySQL config groups with the specified suffix.', + ) + defaults_file: str | None = clickdc.option( + type=click.Path(), + help='Only read MySQL options from the given file.', + ) + myclirc: str = clickdc.option( + type=click.Path(), + default='~/.myclirc', + help='Location of myclirc file.', + ) + auto_vertical_output: bool = clickdc.option( + is_flag=True, + help='Automatically switch to vertical output mode if the result is wider than the terminal width.', + ) + show_warnings: bool | None = clickdc.option( + '--show-warnings/--no-show-warnings', + is_flag=True, + default=None, + clickdc=None, + help='Automatically show warnings after executing a SQL statement.', + ) + table: bool = clickdc.option( + '-t', + is_flag=True, + help='Shorthand for --format=table.', + ) + csv: bool = clickdc.option( + is_flag=True, + help='Shorthand for --format=csv.', + ) + warn: bool | None = clickdc.option( + '--warn/--no-warn', + default=None, + clickdc=None, + help='Warn before running a destructive query.', + ) + local_infile: bool | None = clickdc.option( + type=bool, + is_flag=False, + default=None, + help='Enable/disable LOAD DATA LOCAL INFILE.', + ) + login_path: str | None = clickdc.option( + '-g', + type=str, + help='Read this path from the login file.', + ) + execute: str | None = clickdc.option( + '-e', + type=str, + help='Execute command and quit.', + ) + init_command: str | None = clickdc.option( + type=str, + help='SQL statement to execute after connecting.', + ) + unbuffered: bool | None = clickdc.option( + is_flag=True, + help='Instead of copying every row of data into a buffer, fetch rows as needed, to save memory.', + ) + character_set: str | None = clickdc.option( + '--charset', + '--character-set', + 'character_set', + type=str, + help='Character set for MySQL session.', + ) + batch: str | None = clickdc.option( + type=str, + help='SQL script to execute in batch mode.', + ) + noninteractive: bool = clickdc.option( + is_flag=True, + help="Don't prompt during batch input. Recommended.", + ) + format: str | None = clickdc.option( + type=click.Choice(['default', 'csv', 'tsv', 'table']), + help='Format for batch or --execute output.', + ) + throttle: float = clickdc.option( + type=float, + default=0.0, + help='Pause in seconds between queries in batch mode.', + ) + progress: bool = clickdc.option( + is_flag=True, + help='Show progress on the standard error with --batch.', + ) + use_keyring: str | None = clickdc.option( + type=click.Choice(['true', 'false', 'reset']), + default=None, + help='Store and retrieve passwords from the system keyring: true/false/reset.', + ) + keepalive_ticks: int | None = clickdc.option( + type=int, + help='Send regular keepalive pings to the connection, roughly every seconds.', + ) + checkup: bool = clickdc.option( + is_flag=True, + help='Run a checkup on your configuration.', + ) + + +def get_password_from_file(password_file: str | None) -> str | None: + if not password_file: + return None + try: + with open(password_file) as fp: + return fp.readline().removesuffix('\n') + except FileNotFoundError: + click.secho(f"Password file '{password_file}' not found", err=True, fg='red') + sys.exit(1) + except PermissionError: + click.secho(f"Permission denied reading password file '{password_file}'", err=True, fg='red') + sys.exit(1) + except IsADirectoryError: + click.secho(f"Path '{password_file}' is a directory, not a file", err=True, fg='red') + sys.exit(1) + except Exception as e: + click.secho(f"Error reading password file '{password_file}': {str(e)}", err=True, fg='red') + sys.exit(1) + + +def preprocess_cli_args( + cli_args: CliArgs, + is_valid_connection_scheme: Callable[[str], tuple[bool, str | None]], +) -> int: + if cli_args.database is None and isinstance(cli_args.password, str) and '://' in cli_args.password: + is_valid_scheme, scheme = is_valid_connection_scheme(cli_args.password) + if not is_valid_scheme: + click.secho(f'Error: Unknown connection scheme provided for DSN URI ({scheme}://)', err=True, fg='red') + sys.exit(1) + cli_args.database = cli_args.password + cli_args.password = EMPTY_PASSWORD_FLAG_SENTINEL + + if cli_args.password is None and cli_args.password_file: + password_from_file = get_password_from_file(cli_args.password_file) + if password_from_file is not None: + cli_args.password = password_from_file + + if cli_args.password is None and os.environ.get('MYSQL_PWD') is not None: + cli_args.password = os.environ.get('MYSQL_PWD') + + if cli_args.resume and not cli_args.checkpoint: + click.secho('Error: --resume requires a --checkpoint file.', err=True, fg='red') + sys.exit(1) + + if cli_args.resume and not cli_args.batch: + click.secho('Error: --resume requires a --batch file.', err=True, fg='red') + sys.exit(1) + + if cli_args.verbose and cli_args.quiet: + click.secho('Error: --verbose and --quiet are incompatible.', err=True, fg='red') + sys.exit(1) + elif cli_args.verbose: + return int(cli_args.verbose) + elif cli_args.quiet: + return -1 + return 0 diff --git a/mycli/main.py b/mycli/main.py index fb2ffd4f..bbe8f5d4 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1,13 +1,9 @@ from __future__ import annotations -from collections import defaultdict -from dataclasses import dataclass -from decimal import Decimal from io import TextIOWrapper import logging import os import re -import shutil import sys import threading import traceback @@ -17,26 +13,16 @@ from pwd import getpwuid except ImportError: pass -from datetime import datetime -import itertools from textwrap import dedent from urllib.parse import parse_qs, unquote, urlparse -from cli_helpers.tabular_output import TabularOutputFormatter, preprocessors -from cli_helpers.tabular_output.output_formatter import MISSING_VALUE as DEFAULT_MISSING_VALUE -from cli_helpers.utils import strip_ansi +from cli_helpers.tabular_output import TabularOutputFormatter +from cli_helpers.tabular_output.output_formatter import MISSING_VALUE as _DEFAULT_MISSING_VALUE import click import clickdc -from configobj import ConfigObj import keyring -from prompt_toolkit import print_formatted_text from prompt_toolkit.formatted_text import ( - ANSI, - HTML, - AnyFormattedText, - FormattedText, to_formatted_text, - to_plain_text, ) from prompt_toolkit.shortcuts import PromptSession import pymysql @@ -46,16 +32,28 @@ import sqlparse import mycli as mycli_package +from mycli.app_state import ( + AppStateMixin, + configure_prompt_state, + destructive_keywords_from_config, + ensure_my_cnf_sections, + llm_prompt_truncation, + normalize_ssl_mode, +) +from mycli.cli_args import ( + DEFAULT_PROMPT, + EMPTY_PASSWORD_FLAG_SENTINEL, + CliArgs, + preprocess_cli_args, +) from mycli.clistyle import style_factory_helpers, style_factory_ptoolkit from mycli.compat import WIN from mycli.completion_refresher import CompletionRefresher -from mycli.config import get_mylogin_cnf_path, open_mylogin_cnf, read_config_files, str_to_bool, strip_matching_quotes, write_default_config +from mycli.config import get_mylogin_cnf_path, open_mylogin_cnf, read_config_files, str_to_bool, write_default_config from mycli.constants import ( DEFAULT_CHARSET, - DEFAULT_HEIGHT, DEFAULT_HOST, DEFAULT_PORT, - DEFAULT_WIDTH, ER_MUST_CHANGE_PASSWORD_LOGIN, ISSUES_URL, REPO_URL, @@ -70,7 +68,8 @@ from mycli.main_modes.execute import main_execute_from_cli from mycli.main_modes.list_dsn import main_list_dsn from mycli.main_modes.list_ssh_config import main_list_ssh_config -from mycli.main_modes.repl import main_repl, render_prompt_string, set_all_external_titles +from mycli.main_modes.repl import main_repl, set_all_external_titles +from mycli.output import OutputMixin from mycli.packages import special from mycli.packages.cli_utils import filtered_sys_argv, is_valid_connection_scheme from mycli.packages.filepaths import dir_path_exists, guess_socket_location @@ -82,37 +81,21 @@ from mycli.packages.tabular_output import sql_format from mycli.schema_prefetcher import SchemaPrefetcher from mycli.sqlcompleter import SQLCompleter -from mycli.sqlexecute import FIELD_TYPES, SQLExecute +from mycli.sqlexecute import SQLExecute from mycli.types import Query sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment] sqlparse.engine.grouping.MAX_GROUPING_TOKENS = None # type: ignore[assignment] -EMPTY_PASSWORD_FLAG_SENTINEL = -1 - - -class IntOrStringClickParamType(click.ParamType): - name = 'text' # display as TEXT in helpdoc - - def convert(self, value, param, ctx): - if isinstance(value, int): - return value - elif isinstance(value, str): - return value - elif value is None: - return value - else: - self.fail('Not a valid password string', param, ctx) - - -INT_OR_STRING_CLICK_TYPE = IntOrStringClickParamType() +DEFAULT_MISSING_VALUE = _DEFAULT_MISSING_VALUE -class MyCli: - default_prompt = "\\t \\u@\\h:\\d> " +class MyCli(AppStateMixin, OutputMixin): + default_prompt = DEFAULT_PROMPT default_prompt_splitln = "\\u@\\h\\n(\\t):\\d>" max_len_prompt = 45 defaults_suffix = None + prompt_lines: int # In order of being loaded. Files lower in list override earlier ones. cnf_files: list[str | IO[str]] = [ @@ -211,22 +194,11 @@ def __init__( self.null_string = c['main'].get('null_string') self.numeric_alignment = c['main'].get('numeric_alignment', 'right') self.binary_display = c['main'].get('binary_display') - if 'llm' in c and re.match(r'^\d+$', c['llm'].get('prompt_field_truncate', '')): - self.llm_prompt_field_truncate = int(c['llm'].get('prompt_field_truncate')) - else: - self.llm_prompt_field_truncate = 0 - if 'llm' in c and re.match(r'^\d+$', c['llm'].get('prompt_section_truncate', '')): - self.llm_prompt_section_truncate = int(c['llm'].get('prompt_section_truncate')) - else: - self.llm_prompt_section_truncate = 0 + self.llm_prompt_field_truncate, self.llm_prompt_section_truncate = llm_prompt_truncation(c) - # set ssl_mode if a valid option is provided in a config file, otherwise None - ssl_mode = c["main"].get("ssl_mode", None) or c["connection"].get("default_ssl_mode", None) - if ssl_mode not in ("auto", "on", "off", None): - self.echo(f"Invalid config option provided for ssl_mode ({ssl_mode}); ignoring.", err=True, fg="red") - self.ssl_mode = None - else: - self.ssl_mode = ssl_mode + self.ssl_mode, ssl_mode_error = normalize_ssl_mode(c) + if ssl_mode_error: + self.echo(ssl_mode_error, err=True, fg="red") # read from cli argument or user config file self.auto_vertical_output = auto_vertical_output or c["main"].as_bool("auto_vertical_output") @@ -286,23 +258,11 @@ def __init__( print("Error: Unable to read login path file.") self.my_cnf = read_config_files(self.cnf_files, list_values=False) - if not self.my_cnf.get('client'): - self.my_cnf['client'] = {} - if not self.my_cnf.get('mysqld'): - self.my_cnf['mysqld'] = {} + ensure_my_cnf_sections(self.my_cnf) prompt_cnf = self.read_my_cnf(self.my_cnf, ["prompt"])["prompt"] - self.prompt_format = prompt or prompt_cnf or c["main"]["prompt"] or self.default_prompt - self.prompt_lines = 0 - self.multiline_continuation_char = c["main"]["prompt_continuation"] - self.toolbar_format = toolbar_format or c['main']['toolbar'] - self.terminal_tab_title_format = c['main']['terminal_tab_title'] - self.terminal_window_title_format = c['main']['terminal_window_title'] - self.multiplex_window_title_format = c['main']['multiplex_window_title'] - self.multiplex_pane_title_format = c['main']['multiplex_pane_title'] + configure_prompt_state(self, c, prompt, prompt_cnf, toolbar_format) self.prompt_session = None - self.destructive_keywords = [ - keyword for keyword in c["main"].get("destructive_keywords", "DROP SHUTDOWN DELETE TRUNCATE ALTER UPDATE").split(' ') if keyword - ] + self.destructive_keywords = destructive_keywords_from_config(c) special.set_destructive_keywords(self.destructive_keywords) def close(self) -> None: @@ -486,62 +446,6 @@ def initialize_logging(self) -> None: root_logger.debug("Initializing mycli logging.") root_logger.debug("Log file %r.", log_file) - def read_my_cnf(self, cnf: ConfigObj, keys: list[str]) -> dict[str, Any]: - """ - Retrieves some keys from a configuration, applies transformations, returns a new configuration. - :param cnf: configuration to read - :param keys: list of keys to retrieve - :returns: tuple, with None for missing keys. - """ - - sections = ["client", "mysqld"] - key_transformations = { - "mysqld": { - "socket": "default_socket", - "port": "default_port", - "user": "default_user", - }, - } - - if self.login_path and self.login_path != "client": - sections.append(self.login_path) - - if self.defaults_suffix: - sections.extend([sect + self.defaults_suffix for sect in sections]) - - configuration: dict[str, Any] = defaultdict(lambda: None) - for key in keys: - for section in cnf: - if section not in sections or key not in cnf[section]: - continue - new_key = key_transformations.get(section, {}).get(key) or key - configuration[new_key] = strip_matching_quotes(cnf[section][key]) - - return configuration - - def merge_ssl_with_cnf(self, ssl: dict[str, Any], cnf: dict[str, Any]) -> dict[str, Any]: - """Merge SSL configuration dict with cnf dict""" - - merged = {} - merged.update(ssl) - prefix = "ssl-" - for k, v in cnf.items(): - # skip unrelated options - if not k.startswith(prefix): - continue - if v is None: - continue - # special case because PyMySQL argument is significantly different - # from commandline - if k == "ssl-verify-server-cert": - merged["check_hostname"] = str_to_bool(v) - else: - # use argument name just strip "ssl-" prefix - arg = k[len(prefix) :] - merged[arg] = v - - return merged - def connect( self, database: str | None = "", @@ -830,13 +734,6 @@ def _connect( self.echo(str(e), err=True, fg="red") sys.exit(1) - def output_timing(self, timing: str, is_warnings_style: bool = False) -> None: - self.log_output(timing) - add_style = 'class:warnings.timing' if is_warnings_style else 'class:output.timing' - formatted_timing = FormattedText([('', timing)]) - styled_timing = to_formatted_text(formatted_timing, style=add_style) - print_formatted_text(styled_timing, style=self.ptoolkit_style) - def run_cli(self) -> None: main_repl(self) @@ -895,146 +792,6 @@ def reconnect(self, database: str = "") -> bool: self.echo(str(e), err=True, fg="red") return False - def log_query(self, query: str) -> None: - if isinstance(self.logfile, TextIOWrapper): - self.logfile.write(f"\n# {datetime.now()}\n") - self.logfile.write(query) - self.logfile.write("\n") - - def log_output(self, output: str | AnyFormattedText) -> None: - """Log the output in the audit log, if it's enabled.""" - if isinstance(output, (ANSI, HTML, FormattedText)): - output = to_plain_text(output) - if isinstance(self.logfile, TextIOWrapper): - click.echo(output, file=self.logfile) - - def echo(self, s: str, **kwargs) -> None: - """Print a message to stdout. - - The message will be logged in the audit log, if enabled. - - All keyword arguments are passed to click.echo(). - - """ - self.log_output(s) - click.secho(s, **kwargs) - - def get_output_margin(self, status: str | None = None) -> int: - """Get the output margin (number of rows for the prompt, footer and - timing message.""" - if not self.prompt_lines: - if self.prompt_session and self.prompt_session.app: - render_counter = self.prompt_session.app.render_counter - else: - render_counter = 0 - # todo: this jump back to render_prompt_string() in repl.py is a sign that separation is incomplete - prompt_string = render_prompt_string(self, self.prompt_format, render_counter) - self.prompt_lines = to_plain_text(prompt_string).count('\n') + 1 - margin = self.get_reserved_space() + self.prompt_lines - if special.is_timing_enabled(): - margin += 1 - if status: - margin += 1 + status.count("\n") - - return margin - - def output( - self, - output: itertools.chain[str], - result: SQLResult, - is_warnings_style: bool = False, - ) -> None: - """Output text to stdout or a pager command. - - The status text is not outputted to pager or files. - - The message will be logged in the audit log, if enabled. The - message will be written to the tee file, if enabled. The - message will be written to the output file, if enabled. - - """ - if output: - if self.prompt_session is not None: - size = self.prompt_session.output.get_size() - size_columns = size.columns - size_rows = size.rows - else: - size_columns = DEFAULT_WIDTH - size_rows = DEFAULT_HEIGHT - - margin = self.get_output_margin(result.status_plain) - - fits = True - buf = [] - output_via_pager = self.explicit_pager and special.is_pager_enabled() - for i, line in enumerate(output, 1): - self.log_output(line) - special.write_tee(line) - special.write_once(line) - special.write_pipe_once(line) - - if special.is_redirected(): - pass - elif fits or output_via_pager: - # buffering - buf.append(line) - if len(line) > size_columns or i > (size_rows - margin): - fits = False - if not self.explicit_pager and special.is_pager_enabled(): - # doesn't fit, use pager - output_via_pager = True - - if not output_via_pager: - # doesn't fit, flush buffer - for buf_line in buf: - click.secho(buf_line) - buf = [] - else: - click.secho(line) - - if buf: - if output_via_pager: - - def newlinewrapper(text: list[str]) -> Generator[str, None, None]: - for line in text: - yield line + "\n" - - click.echo_via_pager(newlinewrapper(buf)) - else: - for line in buf: - click.secho(line) - - if result.status: - self.log_output(result.status_plain) - add_style = 'class:warnings.status' if is_warnings_style else 'class:output.status' - if isinstance(result.status, FormattedText): - status = result.status - else: - status = FormattedText([('', result.status_plain)]) - styled_status = to_formatted_text(status, style=add_style) - print_formatted_text(styled_status, style=self.ptoolkit_style) - - def configure_pager(self) -> None: - # Provide sane defaults for less if they are empty. - if not os.environ.get("LESS"): - os.environ["LESS"] = "-RXF" - - cnf = self.read_my_cnf(self.my_cnf, ["pager", "skip-pager"]) - cnf_pager = cnf["pager"] or self.config["main"]["pager"] - - # help Windows users who haven't edited the default myclirc - if WIN and cnf_pager == 'less' and not shutil.which(cnf_pager): - cnf_pager = 'more' - - if cnf_pager: - special.set_pager(cnf_pager) - self.explicit_pager = True - else: - self.explicit_pager = False - - if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"): - special.disable_pager() - def refresh_completions(self, reset: bool = False) -> list[SQLResult]: # Cancel any in-flight schema prefetch before the completer is # replaced. Loaded-schema bookkeeping is intentionally preserved @@ -1119,395 +876,11 @@ def run_query( checkpoint.write(query.rstrip('\n') + '\n') checkpoint.flush() - def format_sqlresult( - self, - result, - is_expanded: bool = False, - is_redirected: bool = False, - null_string: str | None = None, - numeric_alignment: str = 'right', - binary_display: str | None = None, - max_width: int | None = None, - is_warnings_style: bool = False, - ) -> itertools.chain[str]: - if is_redirected: - use_formatter = self.redirect_formatter - else: - use_formatter = self.main_formatter - - is_expanded = is_expanded or use_formatter.format_name == "vertical" - output: itertools.chain[str] = itertools.chain() - - output_kwargs = { - "dialect": "unix", - "disable_numparse": True, - "preserve_whitespace": True, - "style": self.helpers_warnings_style if is_warnings_style else self.helpers_style, - } - default_kwargs = use_formatter._output_formats[use_formatter.format_name].formatter_args - - if null_string is not None and default_kwargs.get('missing_value') == DEFAULT_MISSING_VALUE: - output_kwargs['missing_value'] = null_string - - if use_formatter.format_name not in sql_format.supported_formats and binary_display != 'utf8': - # will run before preprocessors defined as part of the format in cli_helpers - output_kwargs["preprocessors"] = (preprocessors.convert_to_undecoded_string,) - - if result.preamble: - output = itertools.chain(output, [result.preamble]) - - if result.header or (result.rows and result.preamble): - column_types = None - colalign = None - if isinstance(result.rows, Cursor): - - def get_col_type(col) -> type: - col_type = FIELD_TYPES.get(col[1], str) - return col_type if type(col_type) is type else str - - if result.rows.rowcount > 0: - column_types = [get_col_type(tup) for tup in result.rows.description] - colalign = [numeric_alignment if x in (int, float, Decimal) else 'left' for x in column_types] - else: - column_types, colalign = [], [] - - if max_width is not None and isinstance(result.rows, Cursor): - result_rows = list(result.rows) - else: - result_rows = result.rows - - formatted = use_formatter.format_output( - result_rows, - result.header or [], - format_name="vertical" if is_expanded else None, - column_types=column_types, - colalign=colalign, - **output_kwargs, - ) - - if isinstance(formatted, str): - formatted = formatted.splitlines() - formatted = iter(formatted) - - if not is_expanded and max_width and result.header and result_rows: - first_line = next(formatted) - if len(strip_ansi(first_line)) > max_width: - formatted = use_formatter.format_output( - result_rows, - result.header, - format_name="vertical", - column_types=column_types, - **output_kwargs, - ) - if isinstance(formatted, str): - formatted = iter(formatted.splitlines()) - else: - formatted = itertools.chain([first_line], formatted) - - output = itertools.chain(output, formatted) - - if result.postamble: - output = itertools.chain(output, [result.postamble]) - - return output - - def get_reserved_space(self) -> int: - """Get the number of lines to reserve for the completion menu.""" - reserved_space_ratio = 0.45 - max_reserved_space = 8 - _, height = shutil.get_terminal_size() - return min(int(round(height * reserved_space_ratio)), max_reserved_space) - def get_last_query(self) -> str | None: """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None -@dataclass(slots=True) -class CliArgs: - database: str | None = clickdc.argument( - type=str, - default=None, - nargs=1, - ) - host: str | None = clickdc.option( - '-h', - '--hostname', - 'host', - type=str, - envvar='MYSQL_HOST', - help='Host address of the database.', - ) - port: int | None = clickdc.option( - '-P', - type=int, - envvar='MYSQL_TCP_PORT', - help='Port number to use for connection. Honors $MYSQL_TCP_PORT.', - ) - user: str | None = clickdc.option( - '-u', - '--user', - '--username', - 'user', - type=str, - envvar='MYSQL_USER', - help='User name to connect to the database.', - ) - socket: str | None = clickdc.option( - '-S', - type=str, - envvar='MYSQL_UNIX_SOCKET', - help='The socket file to use for connection.', - ) - password: int | str | None = clickdc.option( - '-p', - '--pass', - '--password', - 'password', - type=INT_OR_STRING_CLICK_TYPE, - is_flag=False, - flag_value=EMPTY_PASSWORD_FLAG_SENTINEL, - help='Prompt for (or pass in cleartext) the password to connect to the database.', - ) - password_file: str | None = clickdc.option( - type=click.Path(), - help='File or FIFO path containing the password to connect to the db if not specified otherwise.', - ) - ssh_user: str | None = clickdc.option( - type=str, - help='User name to connect to ssh server.', - ) - ssh_host: str | None = clickdc.option( - type=str, - help='Host name to connect to ssh server.', - ) - ssh_port: int = clickdc.option( - type=int, - default=22, - help='Port to connect to ssh server.', - ) - ssh_password: str | None = clickdc.option( - type=str, - help='Password to connect to ssh server.', - ) - ssh_key_filename: str | None = clickdc.option( - type=str, - help='Private key filename (identify file) for the ssh connection.', - ) - ssh_config_path: str = clickdc.option( - type=str, - help='Path to ssh configuration.', - default=os.path.expanduser('~') + '/.ssh/config', - ) - ssh_config_host: str | None = clickdc.option( - type=str, - help='Host to connect to ssh server reading from ssh configuration.', - ) - list_ssh_config: bool = clickdc.option( - is_flag=True, - help='list ssh configurations in the ssh config (requires paramiko).', - ) - ssh_warning_off: bool = clickdc.option( - is_flag=True, - help='Suppress the SSH deprecation notice.', - ) - ssl_mode: str = clickdc.option( - type=click.Choice(['auto', 'on', 'off']), - help='Set desired SSL behavior. auto=preferred if TCP/IP, on=required, off=off.', - ) - deprecated_ssl: bool | None = clickdc.option( - '--ssl/--no-ssl', - 'deprecated_ssl', - default=None, - clickdc=None, - help='Enable SSL for connection (automatically enabled with other flags).', - ) - ssl_ca: str | None = clickdc.option( - type=click.Path(exists=True), - help='CA file in PEM format.', - ) - ssl_capath: str | None = clickdc.option( - type=click.Path(exists=True, file_okay=False, dir_okay=True), - help='CA directory.', - ) - ssl_cert: str | None = clickdc.option( - type=click.Path(exists=True), - help='X509 cert in PEM format.', - ) - ssl_key: str | None = clickdc.option( - type=click.Path(exists=True), - help='X509 key in PEM format.', - ) - ssl_cipher: str | None = clickdc.option( - type=str, - help='SSL cipher to use.', - ) - tls_version: str | None = clickdc.option( - type=click.Choice(['TLSv1', 'TLSv1.1', 'TLSv1.2', 'TLSv1.3'], case_sensitive=False), - help='TLS protocol version for secure connection.', - ) - ssl_verify_server_cert: bool = clickdc.option( - is_flag=True, - help=("""Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default."""), - ) - verbose: int = clickdc.option( - '-v', - count=True, - help='More verbose output and feedback. Can be given multiple times.', - ) - quiet: bool = clickdc.option( - '-q', - is_flag=True, - help='Less verbose output and feedback.', - ) - dbname: str | None = clickdc.option( - '-D', - '--database', - 'dbname', - type=str, - clickdc=None, - help='Database or DSN to use for the connection.', - ) - dsn: str = clickdc.option( - '-d', - type=str, - default='', - envvar='MYSQL_DSN', - help='DSN alias configured in the ~/.myclirc file, or a full DSN.', - ) - list_dsn: bool = clickdc.option( - is_flag=True, - help='Show list of DSN aliases configured in the [alias_dsn] section of ~/.myclirc.', - ) - prompt: str | None = clickdc.option( - '-R', - type=str, - help=f'Prompt format (Default: "{MyCli.default_prompt}").', - ) - toolbar: str | None = clickdc.option( - type=str, - help='Toolbar format.', - ) - logfile: TextIOWrapper | None = clickdc.option( - '-l', - type=click.File(mode='a', encoding='utf-8'), - help='Log every query and its results to a file.', - ) - checkpoint: TextIOWrapper | None = clickdc.option( - type=click.File(mode='a', encoding='utf-8'), - help='In batch or --execute mode, log successful queries to a file, and skipped with --resume.', - ) - resume: bool = clickdc.option( - '--resume', - is_flag=True, - help='In batch mode, resume after replaying statements in the --checkpoint file.', - ) - defaults_group_suffix: str | None = clickdc.option( - type=str, - help='Read MySQL config groups with the specified suffix.', - ) - defaults_file: str | None = clickdc.option( - type=click.Path(), - help='Only read MySQL options from the given file.', - ) - myclirc: str = clickdc.option( - type=click.Path(), - default='~/.myclirc', - help='Location of myclirc file.', - ) - auto_vertical_output: bool = clickdc.option( - is_flag=True, - help='Automatically switch to vertical output mode if the result is wider than the terminal width.', - ) - show_warnings: bool | None = clickdc.option( - '--show-warnings/--no-show-warnings', - is_flag=True, - default=None, - clickdc=None, - help='Automatically show warnings after executing a SQL statement.', - ) - table: bool = clickdc.option( - '-t', - is_flag=True, - help='Shorthand for --format=table.', - ) - csv: bool = clickdc.option( - is_flag=True, - help='Shorthand for --format=csv.', - ) - warn: bool | None = clickdc.option( - '--warn/--no-warn', - default=None, - clickdc=None, - help='Warn before running a destructive query.', - ) - local_infile: bool | None = clickdc.option( - type=bool, - is_flag=False, - default=None, - help='Enable/disable LOAD DATA LOCAL INFILE.', - ) - login_path: str | None = clickdc.option( - '-g', - type=str, - help='Read this path from the login file.', - ) - execute: str | None = clickdc.option( - '-e', - type=str, - help='Execute command and quit.', - ) - init_command: str | None = clickdc.option( - type=str, - help='SQL statement to execute after connecting.', - ) - unbuffered: bool | None = clickdc.option( - is_flag=True, - help='Instead of copying every row of data into a buffer, fetch rows as needed, to save memory.', - ) - character_set: str | None = clickdc.option( - '--charset', - '--character-set', - 'character_set', - type=str, - help='Character set for MySQL session.', - ) - batch: str | None = clickdc.option( - type=str, - help='SQL script to execute in batch mode.', - ) - noninteractive: bool = clickdc.option( - is_flag=True, - help="Don't prompt during batch input. Recommended.", - ) - format: str | None = clickdc.option( - type=click.Choice(['default', 'csv', 'tsv', 'table']), - help='Format for batch or --execute output.', - ) - throttle: float = clickdc.option( - type=float, - default=0.0, - help='Pause in seconds between queries in batch mode.', - ) - progress: bool = clickdc.option( - is_flag=True, - help='Show progress on the standard error with --batch.', - ) - use_keyring: str | None = clickdc.option( - type=click.Choice(['true', 'false', 'reset']), - default=None, - help='Store and retrieve passwords from the system keyring: true/false/reset.', - ) - keepalive_ticks: int | None = clickdc.option( - type=int, - help='Send regular keepalive pings to the connection, roughly every seconds.', - ) - checkup: bool = clickdc.option( - is_flag=True, - help='Run a checkup on your configuration.', - ) - - @click.command() @clickdc.adddc('cli_args', CliArgs) @click.version_option(mycli_package.__version__, '--version', '-V', help="Output mycli's version.") @@ -1524,66 +897,7 @@ def click_entrypoint( """ - def get_password_from_file(password_file: str | None) -> str | None: - if not password_file: - return None - try: - with open(password_file) as fp: - password = fp.readline().removesuffix('\n') - return password - except FileNotFoundError: - click.secho(f"Password file '{password_file}' not found", err=True, fg="red") - sys.exit(1) - except PermissionError: - click.secho(f"Permission denied reading password file '{password_file}'", err=True, fg="red") - sys.exit(1) - except IsADirectoryError: - click.secho(f"Path '{password_file}' is a directory, not a file", err=True, fg="red") - sys.exit(1) - except Exception as e: - click.secho(f"Error reading password file '{password_file}': {str(e)}", err=True, fg="red") - sys.exit(1) - - # if the password value looks like a DSN, treat it as such and - # prompt for password - if cli_args.database is None and isinstance(cli_args.password, str) and "://" in cli_args.password: - # check if the scheme is valid. We do not actually have any logic for these, but - # it will most usefully catch the case where we erroneously catch someone's - # password, and give them an easy error message to follow / report - is_valid_scheme, scheme = is_valid_connection_scheme(cli_args.password) - if not is_valid_scheme: - click.secho(f"Error: Unknown connection scheme provided for DSN URI ({scheme}://)", err=True, fg="red") - sys.exit(1) - cli_args.database = cli_args.password - cli_args.password = EMPTY_PASSWORD_FLAG_SENTINEL - - # if the password is not specified try to set it using the password_file option - if cli_args.password is None and cli_args.password_file: - password_from_file = get_password_from_file(cli_args.password_file) - if password_from_file is not None: - cli_args.password = password_from_file - - # getting the envvar ourselves because the envvar from a click - # option cannot be an empty string, but a password can be - if cli_args.password is None and os.environ.get("MYSQL_PWD") is not None: - cli_args.password = os.environ.get("MYSQL_PWD") - - if cli_args.resume and not cli_args.checkpoint: - click.secho('Error: --resume requires a --checkpoint file.', err=True, fg='red') - sys.exit(1) - - if cli_args.resume and not cli_args.batch: - click.secho('Error: --resume requires a --batch file.', err=True, fg='red') - sys.exit(1) - - cli_verbosity = 0 - if cli_args.verbose and cli_args.quiet: - click.secho('Error: --verbose and --quiet are incompatible.', err=True, fg='red') - sys.exit(1) - elif cli_args.verbose: - cli_verbosity = int(cli_args.verbose) - elif cli_args.quiet: - cli_verbosity = -1 + cli_verbosity = preprocess_cli_args(cli_args, is_valid_connection_scheme) mycli = MyCli( prompt=cli_args.prompt, diff --git a/mycli/output.py b/mycli/output.py new file mode 100644 index 00000000..eee1021a --- /dev/null +++ b/mycli/output.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from datetime import datetime +from decimal import Decimal +from io import TextIOWrapper +import itertools +import os +import shutil +from typing import Any, Generator, Literal, Protocol + +from cli_helpers.tabular_output import TabularOutputFormatter, preprocessors +from cli_helpers.tabular_output.output_formatter import MISSING_VALUE as DEFAULT_MISSING_VALUE +from cli_helpers.utils import strip_ansi +import click +from configobj import ConfigObj +import prompt_toolkit +from prompt_toolkit.formatted_text import ( + ANSI, + HTML, + AnyFormattedText, + FormattedText, + to_formatted_text, + to_plain_text, +) +from prompt_toolkit.shortcuts import PromptSession +from prompt_toolkit.styles.style import _MergedStyle +from pygments.style import Style as PygmentsStyle +from pymysql.cursors import Cursor + +from mycli.compat import WIN +from mycli.constants import DEFAULT_HEIGHT, DEFAULT_WIDTH +import mycli.main_modes.repl as repl_mode +from mycli.packages import special +from mycli.packages.sqlresult import SQLResult +from mycli.packages.tabular_output import sql_format +from mycli.sqlexecute import FIELD_TYPES + + +class MyCliState(Protocol): + # Provided by AppStateMixin. + def read_my_cnf(self, cnf: ConfigObj, keys: list[str]) -> dict[str, Any]: ... + + # Provided by OutputMixin itself; declared so cross-method calls type-check. + def log_output(self, output: str | AnyFormattedText) -> None: ... + def get_output_margin(self, status: str | None = None) -> int: ... + def get_reserved_space(self) -> int: ... + + +class OutputMixin(MyCliState): + prompt_lines: int + multiline_continuation_char: str + multiplex_pane_title_format: str + multiplex_window_title_format: str + terminal_tab_title_format: str + terminal_window_title_format: str + toolbar_format: str + redirect_formatter: TabularOutputFormatter + config: ConfigObj + my_cnf: ConfigObj + logfile: TextIOWrapper | Literal[False] | None + prompt_session: PromptSession | None + prompt_format: str + explicit_pager: bool + ptoolkit_style: _MergedStyle + helpers_style: PygmentsStyle + helpers_warnings_style: PygmentsStyle + main_formatter: TabularOutputFormatter + + def output_timing(self, timing: str, is_warnings_style: bool = False) -> None: + self.log_output(timing) + add_style = 'class:warnings.timing' if is_warnings_style else 'class:output.timing' + formatted_timing = FormattedText([('', timing)]) + styled_timing = to_formatted_text(formatted_timing, style=add_style) + prompt_toolkit.print_formatted_text(styled_timing, style=self.ptoolkit_style) + + def log_query(self, query: str) -> None: + if isinstance(self.logfile, TextIOWrapper): + self.logfile.write(f"\n# {datetime.now()}\n") + self.logfile.write(query) + self.logfile.write("\n") + + def log_output(self, output: str | AnyFormattedText) -> None: + """Log the output in the audit log, if it's enabled.""" + if isinstance(output, (ANSI, HTML, FormattedText)): + output = to_plain_text(output) + if isinstance(self.logfile, TextIOWrapper): + click.echo(output, file=self.logfile) + + def echo(self, s: str, **kwargs) -> None: + """Print a message to stdout.""" + self.log_output(s) + click.secho(s, **kwargs) + + def get_output_margin(self, status: str | None = None) -> int: + """Get the output margin for prompt, footer, timing, and status.""" + if not self.prompt_lines: + if self.prompt_session and self.prompt_session.app: + render_counter = self.prompt_session.app.render_counter + else: + render_counter = 0 + prompt_string = repl_mode.render_prompt_string(self, self.prompt_format, render_counter) + self.prompt_lines = to_plain_text(prompt_string).count('\n') + 1 + margin = self.get_reserved_space() + self.prompt_lines + if special.is_timing_enabled(): + margin += 1 + if status: + margin += 1 + status.count("\n") + + return margin + + def output( + self, + output: itertools.chain[str], + result: SQLResult, + is_warnings_style: bool = False, + ) -> None: + """Output text to stdout or a pager command.""" + if output: + if self.prompt_session is not None: + size = self.prompt_session.output.get_size() + size_columns = size.columns + size_rows = size.rows + else: + size_columns = DEFAULT_WIDTH + size_rows = DEFAULT_HEIGHT + + margin = self.get_output_margin(result.status_plain) + + fits = True + buf = [] + output_via_pager = self.explicit_pager and special.is_pager_enabled() + for i, line in enumerate(output, 1): + self.log_output(line) + special.write_tee(line) + special.write_once(line) + special.write_pipe_once(line) + + if special.is_redirected(): + pass + elif fits or output_via_pager: + buf.append(line) + if len(line) > size_columns or i > (size_rows - margin): + fits = False + if not self.explicit_pager and special.is_pager_enabled(): + output_via_pager = True + + if not output_via_pager: + for buf_line in buf: + click.secho(buf_line) + buf = [] + else: + click.secho(line) + + if buf: + if output_via_pager: + + def newlinewrapper(text: list[str]) -> Generator[str, None, None]: + for line in text: + yield line + "\n" + + click.echo_via_pager(newlinewrapper(buf)) + else: + for line in buf: + click.secho(line) + + if result.status: + self.log_output(result.status_plain) + add_style = 'class:warnings.status' if is_warnings_style else 'class:output.status' + if isinstance(result.status, FormattedText): + status = result.status + else: + status = FormattedText([('', result.status_plain)]) + styled_status = to_formatted_text(status, style=add_style) + prompt_toolkit.print_formatted_text(styled_status, style=self.ptoolkit_style) + + def configure_pager(self) -> None: + if not os.environ.get("LESS"): + os.environ["LESS"] = "-RXF" + + cnf = self.read_my_cnf(self.my_cnf, ["pager", "skip-pager"]) + cnf_pager = cnf["pager"] or self.config["main"]["pager"] + + if WIN and cnf_pager == 'less' and not shutil.which(cnf_pager): + cnf_pager = 'more' + + if cnf_pager: + special.set_pager(cnf_pager) + self.explicit_pager = True + else: + self.explicit_pager = False + + if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"): + special.disable_pager() + + def format_sqlresult( + self, + result, + is_expanded: bool = False, + is_redirected: bool = False, + null_string: str | None = None, + numeric_alignment: str = 'right', + binary_display: str | None = None, + max_width: int | None = None, + is_warnings_style: bool = False, + ) -> itertools.chain[str]: + if is_redirected: + use_formatter = self.redirect_formatter + else: + use_formatter = self.main_formatter + + is_expanded = is_expanded or use_formatter.format_name == "vertical" + output: itertools.chain[str] = itertools.chain() + + output_kwargs = { + "dialect": "unix", + "disable_numparse": True, + "preserve_whitespace": True, + "style": self.helpers_warnings_style if is_warnings_style else self.helpers_style, + } + default_kwargs = use_formatter._output_formats[use_formatter.format_name].formatter_args + + if null_string is not None and default_kwargs.get('missing_value') == DEFAULT_MISSING_VALUE: + output_kwargs['missing_value'] = null_string + + if use_formatter.format_name not in sql_format.supported_formats and binary_display != 'utf8': + output_kwargs["preprocessors"] = (preprocessors.convert_to_undecoded_string,) + + if result.preamble: + output = itertools.chain(output, [result.preamble]) + + if result.header or (result.rows and result.preamble): + column_types = None + colalign = None + if isinstance(result.rows, Cursor): + + def get_col_type(col) -> type: + col_type = FIELD_TYPES.get(col[1], str) + return col_type if type(col_type) is type else str + + if result.rows.rowcount > 0: + column_types = [get_col_type(tup) for tup in result.rows.description] + colalign = [numeric_alignment if x in (int, float, Decimal) else 'left' for x in column_types] + else: + column_types, colalign = [], [] + + if max_width is not None and isinstance(result.rows, Cursor): + result_rows = list(result.rows) + else: + result_rows = result.rows + + formatted = use_formatter.format_output( + result_rows, + result.header or [], + format_name="vertical" if is_expanded else None, + column_types=column_types, + colalign=colalign, + **output_kwargs, + ) + + if isinstance(formatted, str): + formatted = formatted.splitlines() + formatted = iter(formatted) + + if not is_expanded and max_width and result.header and result_rows: + first_line = next(formatted) + if len(strip_ansi(first_line)) > max_width: + formatted = use_formatter.format_output( + result_rows, + result.header, + format_name="vertical", + column_types=column_types, + **output_kwargs, + ) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + else: + formatted = itertools.chain([first_line], formatted) + + output = itertools.chain(output, formatted) + + if result.postamble: + output = itertools.chain(output, [result.postamble]) + + return output + + def get_reserved_space(self) -> int: + """Get the number of lines to reserve for the completion menu.""" + reserved_space_ratio = 0.45 + max_reserved_space = 8 + _, height = shutil.get_terminal_size() + return min(int(round(height * reserved_space_ratio)), max_reserved_space) diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py index 3c6e3741..12a6c7de 100644 --- a/mycli/packages/special/main.py +++ b/mycli/packages/special/main.py @@ -21,9 +21,9 @@ logger = logging.getLogger(__name__) -COMMANDS = {} -CASE_SENSITIVE_COMMANDS = set() -CASE_INSENSITIVE_COMMANDS = set() +COMMANDS: dict[str, 'SpecialCommand'] = {} +CASE_SENSITIVE_COMMANDS: set[str] = set() +CASE_INSENSITIVE_COMMANDS: set[str] = set() class ArgType(Enum): diff --git a/test/pytests/test_app_state.py b/test/pytests/test_app_state.py new file mode 100644 index 00000000..c1f61aca --- /dev/null +++ b/test/pytests/test_app_state.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import Any + +from configobj import ConfigObj +import pytest + +from mycli.app_state import ( + AppStateMixin, + destructive_keywords_from_config, + ensure_my_cnf_sections, + llm_prompt_truncation, + normalize_ssl_mode, +) + + +class AppState(AppStateMixin): + def __init__(self, defaults_suffix: str | None = None, login_path: str | None = None) -> None: + self.defaults_suffix = defaults_suffix + self.login_path = login_path + + +@pytest.mark.parametrize('ssl_mode', ['auto', 'on', 'off']) +def test_normalize_ssl_mode_accepts_known_values(ssl_mode: str) -> None: + config = ConfigObj({'main': {'ssl_mode': ssl_mode}, 'connection': {'default_ssl_mode': 'off'}}) + + assert normalize_ssl_mode(config) == (ssl_mode, None) + + +def test_normalize_ssl_mode_falls_back_to_connection_default() -> None: + config = ConfigObj({'main': {'ssl_mode': ''}, 'connection': {'default_ssl_mode': 'on'}}) + + assert normalize_ssl_mode(config) == ('on', None) + + +def test_normalize_ssl_mode_reports_invalid_values() -> None: + config = ConfigObj({'main': {'ssl_mode': 'required'}, 'connection': {'default_ssl_mode': 'off'}}) + + ssl_mode, warning = normalize_ssl_mode(config) + + assert ssl_mode is None + assert warning == 'Invalid config option provided for ssl_mode (required); ignoring.' + + +def test_ensure_my_cnf_sections_adds_missing_sections() -> None: + config = ConfigObj({'client': {'user': 'alice'}, 'extra': {'port': '3307'}}) + + ensure_my_cnf_sections(config) + + assert config['client'] == {'user': 'alice'} + assert config['mysqld'] == {} + assert config['extra'] == {'port': '3307'} + + +def test_destructive_keywords_from_config_splits_non_empty_words() -> None: + config = ConfigObj({'main': {'destructive_keywords': 'DROP DELETE UPDATE'}}) + + assert destructive_keywords_from_config(config) == ['DROP', 'DELETE', 'UPDATE'] + + +def test_destructive_keywords_from_config_uses_default() -> None: + config = ConfigObj({'main': {}}) + + assert destructive_keywords_from_config(config) == ['DROP', 'SHUTDOWN', 'DELETE', 'TRUNCATE', 'ALTER', 'UPDATE'] + + +@pytest.mark.parametrize( + ('llm_config', 'expected'), + [ + ({'prompt_field_truncate': '12', 'prompt_section_truncate': '34'}, (12, 34)), + ({'prompt_field_truncate': 'abc', 'prompt_section_truncate': '-1'}, (0, 0)), + ({}, (0, 0)), + ], +) +def test_llm_prompt_truncation_reads_positive_integer_strings( + llm_config: dict[str, str], + expected: tuple[int, int], +) -> None: + config = ConfigObj({'main': {}, 'llm': llm_config}) + + assert llm_prompt_truncation(config) == expected + + +def test_llm_prompt_truncation_handles_missing_llm_section() -> None: + assert llm_prompt_truncation(ConfigObj({'main': {}})) == (0, 0) + + +def test_read_my_cnf_reads_allowed_sections_and_strips_quotes() -> None: + app_state = AppState() + cnf = ConfigObj({ + 'client': {'host': '"db.example.com"', 'socket': '/tmp/client.sock'}, + 'mysqld': {'socket': "'/tmp/mysql.sock'", 'port': '3307', 'user': 'mysql'}, + 'ignored': {'host': 'ignored.example.com'}, + }) + + configuration = app_state.read_my_cnf(cnf, ['host', 'socket', 'port', 'user', 'password']) + + assert configuration == { + 'host': 'db.example.com', + 'socket': '/tmp/client.sock', + 'default_socket': '/tmp/mysql.sock', + 'default_port': '3307', + 'default_user': 'mysql', + } + assert configuration['password'] is None + + +def test_read_my_cnf_includes_login_path_and_suffix_sections() -> None: + app_state = AppState(defaults_suffix='test', login_path='work') + cnf = ConfigObj({ + 'client': {'user': 'client-user'}, + 'work': {'password': 'work-pass'}, + 'clienttest': {'host': 'client-test-host'}, + 'worktest': {'database': 'work-test-db'}, + }) + + configuration = app_state.read_my_cnf(cnf, ['user', 'password', 'host', 'database']) + + assert configuration == { + 'user': 'client-user', + 'password': 'work-pass', + 'host': 'client-test-host', + 'database': 'work-test-db', + } + + +def test_merge_ssl_with_cnf_keeps_existing_ssl_and_adds_cnf_values() -> None: + app_state = AppState() + ssl: dict[str, Any] = {'ca': 'existing-ca.pem', 'cert': 'existing-cert.pem'} + cnf = { + 'ssl-ca': 'cnf-ca.pem', + 'ssl-key': 'client-key.pem', + 'ssl-verify-server-cert': 'ON', + 'ssl-empty': None, + 'host': 'db.example.com', + } + + merged = app_state.merge_ssl_with_cnf(ssl, cnf) + + assert merged == { + 'ca': 'cnf-ca.pem', + 'cert': 'existing-cert.pem', + 'key': 'client-key.pem', + 'check_hostname': True, + } + assert ssl == {'ca': 'existing-ca.pem', 'cert': 'existing-cert.pem'} diff --git a/test/pytests/test_cli_args.py b/test/pytests/test_cli_args.py new file mode 100644 index 00000000..f9171bdc --- /dev/null +++ b/test/pytests/test_cli_args.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import builtins +from pathlib import Path +from typing import Any + +import click +import pytest + +from mycli import cli_args as cli_args_module +from mycli.cli_args import ( + EMPTY_PASSWORD_FLAG_SENTINEL, + INT_OR_STRING_CLICK_TYPE, + CliArgs, + get_password_from_file, + preprocess_cli_args, +) + + +def valid_connection_scheme(value: str) -> tuple[bool, str | None]: + scheme, _, _ = value.partition('://') + return scheme == 'mysql', scheme or None + + +def test_int_or_string_click_type_accepts_int_string_and_none() -> None: + assert INT_OR_STRING_CLICK_TYPE.convert(7, None, None) == 7 + assert INT_OR_STRING_CLICK_TYPE.convert('secret', None, None) == 'secret' + assert INT_OR_STRING_CLICK_TYPE.convert(None, None, None) is None + + +def test_int_or_string_click_type_rejects_other_values() -> None: + with pytest.raises(click.BadParameter, match='Not a valid password string'): + INT_OR_STRING_CLICK_TYPE.convert(object(), None, None) + + +def test_get_password_from_file_reads_first_line_without_trailing_newline(tmp_path: Path) -> None: + password_file = tmp_path / 'password.txt' + password_file.write_text('secret\nignored\n', encoding='utf8') + + assert get_password_from_file(str(password_file)) == 'secret' + + +def test_get_password_from_file_returns_none_for_missing_path() -> None: + assert get_password_from_file(None) is None + assert get_password_from_file('') is None + + +@pytest.mark.parametrize( + ('exception', 'expected'), + [ + (FileNotFoundError(), "Password file 'secret.txt' not found"), + (PermissionError(), "Permission denied reading password file 'secret.txt'"), + (IsADirectoryError(), "Path 'secret.txt' is a directory, not a file"), + (RuntimeError('boom'), "Error reading password file 'secret.txt': boom"), + ], +) +def test_get_password_from_file_exits_with_error_for_read_failures( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], + exception: Exception, + expected: str, +) -> None: + def raise_error(*_args: Any, **_kwargs: Any) -> None: + raise exception + + monkeypatch.setattr(builtins, 'open', raise_error) + + with pytest.raises(SystemExit) as excinfo: + get_password_from_file('secret.txt') + + assert excinfo.value.code == 1 + assert expected in capsys.readouterr().err + + +def test_preprocess_cli_args_moves_dsn_from_password_to_database() -> None: + cli_args = CliArgs() + cli_args.password = 'mysql://user:pass@host/db' + + verbosity = preprocess_cli_args(cli_args, valid_connection_scheme) + + assert verbosity == 0 + assert cli_args.database == 'mysql://user:pass@host/db' + assert cli_args.password == EMPTY_PASSWORD_FLAG_SENTINEL # type: ignore[comparison-overlap] + + +def test_preprocess_cli_args_rejects_unknown_dsn_scheme(capsys: pytest.CaptureFixture[str]) -> None: + cli_args = CliArgs() + cli_args.password = 'postgres://user:pass@host/db' + + with pytest.raises(SystemExit) as excinfo: + preprocess_cli_args(cli_args, valid_connection_scheme) + + assert excinfo.value.code == 1 + assert 'Unknown connection scheme provided for DSN URI (postgres://)' in capsys.readouterr().err + + +def test_preprocess_cli_args_reads_password_file_when_password_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cli_args = CliArgs() + cli_args.password_file = 'secret.txt' + monkeypatch.setattr(cli_args_module, 'get_password_from_file', lambda password_file: f'from:{password_file}') + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 + assert cli_args.password == 'from:secret.txt' + + +def test_preprocess_cli_args_uses_mysql_pwd_when_password_and_file_missing(monkeypatch: pytest.MonkeyPatch) -> None: + cli_args = CliArgs() + monkeypatch.setenv('MYSQL_PWD', 'env-secret') + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 + assert cli_args.password == 'env-secret' + + +def test_preprocess_cli_args_prefers_existing_password_over_mysql_pwd(monkeypatch: pytest.MonkeyPatch) -> None: + cli_args = CliArgs() + cli_args.password = 'cli-secret' + monkeypatch.setenv('MYSQL_PWD', 'env-secret') + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 + assert cli_args.password == 'cli-secret' + + +@pytest.mark.parametrize( + ('checkpoint', 'batch', 'expected'), + [ + (None, 'batch.sql', 'Error: --resume requires a --checkpoint file.'), + (object(), None, 'Error: --resume requires a --batch file.'), + ], +) +def test_preprocess_cli_args_validates_resume_requirements( + capsys: pytest.CaptureFixture[str], + checkpoint: object | None, + batch: str | None, + expected: str, +) -> None: + cli_args = CliArgs() + cli_args.resume = True + cli_args.checkpoint = checkpoint # type: ignore[assignment] + cli_args.batch = batch + + with pytest.raises(SystemExit) as excinfo: + preprocess_cli_args(cli_args, valid_connection_scheme) + + assert excinfo.value.code == 1 + assert expected in capsys.readouterr().err + + +def test_preprocess_cli_args_rejects_verbose_and_quiet(capsys: pytest.CaptureFixture[str]) -> None: + cli_args = CliArgs() + cli_args.verbose = 1 + cli_args.quiet = True + + with pytest.raises(SystemExit) as excinfo: + preprocess_cli_args(cli_args, valid_connection_scheme) + + assert excinfo.value.code == 1 + assert 'Error: --verbose and --quiet are incompatible.' in capsys.readouterr().err + + +@pytest.mark.parametrize( + ('verbose', 'quiet', 'expected'), + [ + (2, False, 2), + (0, True, -1), + (0, False, 0), + ], +) +def test_preprocess_cli_args_returns_cli_verbosity(verbose: int, quiet: bool, expected: int) -> None: + cli_args = CliArgs() + cli_args.verbose = verbose + cli_args.quiet = quiet + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == expected diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index d7b660c7..8541f808 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -13,6 +13,7 @@ import click from click.testing import CliRunner +import prompt_toolkit from prompt_toolkit.formatted_text import ( FormattedText, to_formatted_text, @@ -32,6 +33,7 @@ ) from mycli.main import EMPTY_PASSWORD_FLAG_SENTINEL, MyCli, click_entrypoint import mycli.main_modes.repl as repl_mode +import mycli.output as output_module import mycli.packages.special from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS from mycli.packages.sqlresult import SQLResult @@ -2245,7 +2247,7 @@ def test_output_timing_logs_and_prints_with_warning_style(monkeypatch: pytest.Mo timings_logged: list[str] = [] cli.log_output = lambda text: timings_logged.append(text) # type: ignore[assignment] printed: list[tuple[Any, Any]] = [] - monkeypatch.setattr(main, 'print_formatted_text', lambda text, style=None: printed.append((text, style))) + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed.append((text, style))) main.MyCli.output_timing(cli, 'Time: 1.000s', is_warnings_style=True) assert timings_logged == ['Time: 1.000s'] assert printed[-1][1] == cli.ptoolkit_style @@ -2273,7 +2275,7 @@ def fake_render_prompt_string(mycli: Any, string: str, render_counter: int) -> F render_counters.append(render_counter) return to_formatted_text('line1\nline2') - monkeypatch.setattr(main, 'render_prompt_string', fake_render_prompt_string) + monkeypatch.setattr(repl_mode, 'render_prompt_string', fake_render_prompt_string) monkeypatch.setattr(main.special, 'is_timing_enabled', lambda: False) assert main.MyCli.get_output_margin(cli, 'ok') == 5 assert render_counters == [7] @@ -2404,7 +2406,7 @@ def test_format_sqlresult_materializes_cursor_rows_when_width_is_limited(monkeyp cli = make_bare_mycli() cli.main_formatter = DummyFormatter() rows = FakeCursorBase(rows=[(1,)], rowcount=1, description=[('id', 3)]) - monkeypatch.setattr(main, 'Cursor', FakeCursorBase) + monkeypatch.setattr(output_module, 'Cursor', FakeCursorBase) result = SQLResult(header=['id'], rows=cast(Any, rows), status='ok') list(main.MyCli.format_sqlresult(cli, result, max_width=100)) diff --git a/test/pytests/test_main_regression.py b/test/pytests/test_main_regression.py index 017fab0d..1712115a 100644 --- a/test/pytests/test_main_regression.py +++ b/test/pytests/test_main_regression.py @@ -21,6 +21,7 @@ import itertools import os from pathlib import Path +import shutil import sys from types import ModuleType, SimpleNamespace from typing import Any, cast @@ -28,11 +29,18 @@ import click from click.testing import CliRunner from configobj import ConfigObj +import prompt_toolkit +from prompt_toolkit.formatted_text import ( + ANSI, + FormattedText, +) import pymysql import pytest from mycli import main +from mycli.cli_args import IntOrStringClickParamType import mycli.key_bindings +import mycli.output as output_module from mycli.packages.sqlresult import SQLResult from test.utils import ( # type: ignore[attr-defined] DummyFormatter, @@ -302,7 +310,7 @@ def __init__(self) -> None: def test_int_or_string_click_param_type_accepts_and_rejects_values() -> None: - param_type = main.IntOrStringClickParamType() + param_type = IntOrStringClickParamType() assert param_type.convert(1, None, None) == 1 assert param_type.convert('pw', None, None) == 'pw' @@ -827,7 +835,7 @@ def failing_connect() -> None: with logfile.open('w+', encoding='utf-8') as handle: cli.logfile = handle main.MyCli.log_query(cli, 'select 1') - main.MyCli.log_output(cli, main.ANSI('\x1b[31mhello\x1b[0m')) + main.MyCli.log_output(cli, ANSI('\x1b[31mhello\x1b[0m')) handle.seek(0) contents = handle.read() assert 'select 1' in contents @@ -842,7 +850,7 @@ def failing_connect() -> None: monkeypatch.setattr(main.special, 'is_pager_enabled', lambda: False) monkeypatch.setattr(main.MyCli, 'get_output_margin', lambda self, status=None: 1) monkeypatch.setattr(click, 'secho', lambda line, **kwargs: echoed_lines.append(str(line))) - monkeypatch.setattr(main, 'print_formatted_text', lambda text, style=None: printed_status.append((text, style))) + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed_status.append((text, style))) main.MyCli.output(cli, itertools.chain(['row 1']), SQLResult(status='status')) assert echoed_lines == [] assert printed_status @@ -930,7 +938,7 @@ def test_output_uses_stdout_and_pager_paths(monkeypatch: pytest.MonkeyPatch) -> paged_lines: list[str] = [] monkeypatch.setattr(click, 'secho', lambda line, **kwargs: printed_lines.append(str(line))) monkeypatch.setattr(click, 'echo_via_pager', lambda gen: paged_lines.extend(list(gen))) - monkeypatch.setattr(main, 'print_formatted_text', lambda text, style=None: None) + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: None) main.MyCli.output(cli, itertools.chain(['a' * 81, 'tail']), SQLResult(status='ok')) assert printed_lines[:2] == ['a' * 81, 'tail'] @@ -947,13 +955,13 @@ def test_format_sqlresult_output_covers_extra_branches(monkeypatch: pytest.Monke cli.main_formatter = DummyFormatter() cli.redirect_formatter = DummyFormatter() cli.get_reserved_space = lambda: 1 # type: ignore[assignment] - monkeypatch.setattr(main, 'Cursor', FakeCursorBase) + monkeypatch.setattr(output_module, 'Cursor', FakeCursorBase) rows = FakeCursorBase(rows=[], rowcount=0, description=[('id', 3, None, None, None, None, None)]) result = SQLResult( header=['id'], rows=cast(Any, rows), preamble='preamble', - status=main.FormattedText([('', 'formatted-status')]), + status=FormattedText([('', 'formatted-status')]), ) formatted = list(main.MyCli.format_sqlresult(cli, result, null_string='NULL')) assert 'preamble' in formatted @@ -973,7 +981,7 @@ def test_format_sqlresult_output_covers_extra_branches(monkeypatch: pytest.Monke monkeypatch.setattr(main.MyCli, 'get_output_margin', lambda self, status=None: 1) monkeypatch.setattr(click, 'echo_via_pager', lambda gen: paged_lines.extend(list(gen))) monkeypatch.setattr(click, 'secho', lambda line, **kwargs: printed_lines.append(str(line))) - monkeypatch.setattr(main, 'print_formatted_text', lambda text, style=None: status_prints.append(text)) + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: status_prints.append(text)) cli.log_output = lambda text: None # type: ignore[assignment] cli.explicit_pager = False main.MyCli.output(cli, itertools.chain(['x' * 81]), result) @@ -1447,8 +1455,8 @@ def test_configure_pager_and_refresh_completions(monkeypatch: pytest.MonkeyPatch monkeypatch.delenv('LESS', raising=False) monkeypatch.setattr(main.special, 'set_pager', lambda pager: set_pager_calls.append(pager)) monkeypatch.setattr(main.special, 'disable_pager', lambda: disable_calls.append(True)) - monkeypatch.setattr(main, 'WIN', True) - monkeypatch.setattr(main.shutil, 'which', lambda name: None) + monkeypatch.setattr(output_module, 'WIN', True) + monkeypatch.setattr(shutil, 'which', lambda name: None) main.MyCli.configure_pager(cli) assert os.environ['LESS'] == '-RXF' assert set_pager_calls == ['more'] diff --git a/test/pytests/test_output.py b/test/pytests/test_output.py new file mode 100644 index 00000000..47f7e0f5 --- /dev/null +++ b/test/pytests/test_output.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import itertools +import shutil +from typing import Any, cast + +import click +from configobj import ConfigObj +import prompt_toolkit +from prompt_toolkit.formatted_text import ANSI, FormattedText, to_plain_text +import pytest + +from mycli import output as output_module +from mycli.output import OutputMixin +from mycli.packages.sqlresult import SQLResult +from test.utils import DummyFormatter, FakeCursorBase, make_bare_mycli # type: ignore[attr-defined] + + +def test_output_timing_logs_and_prints_with_default_style(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + logged: list[Any] = [] + printed: list[tuple[Any, Any]] = [] + cli.log_output = lambda value: logged.append(value) # type: ignore[assignment] + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed.append((text, style))) + + OutputMixin.output_timing(cli, '0.12 sec') + + assert logged == ['0.12 sec'] + assert to_plain_text(printed[0][0]) == '0.12 sec' + assert list(printed[0][0])[0][0].strip() == 'class:output.timing' + assert printed[0][1] == cli.ptoolkit_style + + +def test_output_timing_uses_warning_style(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.log_output = lambda value: None # type: ignore[assignment] + printed: list[Any] = [] + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed.append(text)) + + OutputMixin.output_timing(cli, '0.34 sec', is_warnings_style=True) + + assert list(printed[0])[0][0].strip() == 'class:warnings.timing' + + +def test_log_query_and_log_output_write_plain_text(tmp_path) -> None: + cli = make_bare_mycli() + logfile = tmp_path / 'audit.log' + + with logfile.open('w+', encoding='utf-8') as handle: + cli.logfile = handle + OutputMixin.log_query(cli, 'select 1') + OutputMixin.log_output(cli, ANSI('\x1b[31mhello\x1b[0m')) + handle.seek(0) + contents = handle.read() + + assert 'select 1' in contents + assert 'hello' in contents + assert '\x1b[31m' not in contents + + +def test_log_output_ignores_missing_logfile() -> None: + cli = make_bare_mycli() + cli.logfile = None + + OutputMixin.log_output(cli, 'nothing to write') + + +def test_echo_logs_and_prints(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + logged: list[str] = [] + printed: list[tuple[str, dict[str, Any]]] = [] + cli.log_output = lambda value: logged.append(value) # type: ignore[assignment] + monkeypatch.setattr(click, 'secho', lambda value, **kwargs: printed.append((value, kwargs))) + + OutputMixin.echo(cli, 'message', fg='red') + + assert logged == ['message'] + assert printed == [('message', {'fg': 'red'})] + + +def test_get_output_margin_renders_prompt_once_and_counts_status_lines(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.prompt_lines = 0 + cli.prompt_format = 'ignored' + cli.prompt_session = None + cli.get_reserved_space = lambda: 2 # type: ignore[assignment] + monkeypatch.setattr(output_module.repl_mode, 'render_prompt_string', lambda *_args: FormattedText([('', 'one\ntwo')])) + monkeypatch.setattr(output_module.special, 'is_timing_enabled', lambda: True) + + margin = OutputMixin.get_output_margin(cli, 'ok\nwarning') + + assert margin == 7 + assert cli.prompt_lines == 2 + + +def test_output_writes_lines_sinks_and_status(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.prompt_session = None + cli.explicit_pager = False + cli.get_output_margin = lambda status=None: 1 # type: ignore[assignment] + logged: list[Any] = [] + tee: list[str] = [] + once: list[str] = [] + pipe_once: list[str] = [] + printed_lines: list[str] = [] + printed_status: list[Any] = [] + cli.log_output = lambda value: logged.append(value) # type: ignore[assignment] + monkeypatch.setattr(output_module.special, 'write_tee', lambda value: tee.append(value)) + monkeypatch.setattr(output_module.special, 'write_once', lambda value: once.append(value)) + monkeypatch.setattr(output_module.special, 'write_pipe_once', lambda value: pipe_once.append(value)) + monkeypatch.setattr(output_module.special, 'is_redirected', lambda: False) + monkeypatch.setattr(output_module.special, 'is_pager_enabled', lambda: False) + monkeypatch.setattr(click, 'secho', lambda value, **_kwargs: printed_lines.append(value)) + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed_status.append(text)) + + OutputMixin.output(cli, itertools.chain(['row 1', 'row 2']), SQLResult(status='done')) + + assert logged == ['row 1', 'row 2', 'done'] + assert tee == ['row 1', 'row 2'] + assert once == ['row 1', 'row 2'] + assert pipe_once == ['row 1', 'row 2'] + assert printed_lines == ['row 1', 'row 2'] + assert to_plain_text(printed_status[0]) == 'done' + assert list(printed_status[0])[0][0].strip() == 'class:output.status' + + +def test_output_uses_warning_status_style(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.log_output = lambda value: None # type: ignore[assignment] + cli.get_output_margin = lambda status=None: 1 # type: ignore[assignment] + printed_status: list[Any] = [] + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed_status.append(text)) + + OutputMixin.output(cli, itertools.chain([]), SQLResult(status='warning'), is_warnings_style=True) + + assert list(printed_status[0])[0][0].strip() == 'class:warnings.status' + + +def test_output_sends_buffer_to_pager_when_pager_is_explicit(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.prompt_session = None + cli.explicit_pager = True + cli.log_output = lambda value: None # type: ignore[assignment] + cli.get_output_margin = lambda status=None: 1 # type: ignore[assignment] + paged_lines: list[str] = [] + monkeypatch.setattr(output_module.special, 'write_tee', lambda value: None) + monkeypatch.setattr(output_module.special, 'write_once', lambda value: None) + monkeypatch.setattr(output_module.special, 'write_pipe_once', lambda value: None) + monkeypatch.setattr(output_module.special, 'is_redirected', lambda: False) + monkeypatch.setattr(output_module.special, 'is_pager_enabled', lambda: True) + monkeypatch.setattr(click, 'echo_via_pager', lambda values: paged_lines.extend(list(values))) + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: None) + + OutputMixin.output(cli, itertools.chain(['row 1', 'row 2']), SQLResult()) + + assert paged_lines == ['row 1\n', 'row 2\n'] + + +def test_configure_pager_prefers_my_cnf_pager_and_sets_less(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.my_cnf = ConfigObj({'client': {'pager': 'my-pager'}}) + cli.config = ConfigObj({'main': {'pager': 'config-pager', 'enable_pager': 'True'}}) + cli.read_my_cnf = lambda cnf, keys: {'pager': 'my-pager', 'skip-pager': None} # type: ignore[assignment] + pager_calls: list[str] = [] + disabled: list[bool] = [] + monkeypatch.delenv('LESS', raising=False) + monkeypatch.setattr(output_module.special, 'set_pager', lambda value: pager_calls.append(value)) + monkeypatch.setattr(output_module.special, 'disable_pager', lambda: disabled.append(True)) + + OutputMixin.configure_pager(cli) + + assert pager_calls == ['my-pager'] + assert disabled == [] + assert cli.explicit_pager is True + assert output_module.os.environ['LESS'] == '-RXF' + + +def test_configure_pager_disables_when_skip_pager_is_set(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.my_cnf = ConfigObj({'client': {}}) + cli.config = ConfigObj({'main': {'pager': '', 'enable_pager': 'True'}}) + cli.read_my_cnf = lambda cnf, keys: {'pager': None, 'skip-pager': '1'} # type: ignore[assignment] + disabled: list[bool] = [] + monkeypatch.setattr(output_module.special, 'set_pager', lambda value: None) + monkeypatch.setattr(output_module.special, 'disable_pager', lambda: disabled.append(True)) + + OutputMixin.configure_pager(cli) + + assert cli.explicit_pager is False + assert disabled == [True] + + +def test_format_sqlresult_uses_redirect_formatter_and_appends_preamble_postamble() -> None: + cli = make_bare_mycli() + cli.main_formatter = DummyFormatter() + cli.redirect_formatter = DummyFormatter() + result = SQLResult(preamble='before', header=['id'], rows=[(1,)], postamble='after') + + formatted = list(OutputMixin.format_sqlresult(cli, result, is_redirected=True)) + + assert formatted == ['before', 'plain output', 'after'] + assert cli.main_formatter.calls == [] + assert cli.redirect_formatter.calls + + +def test_format_sqlresult_for_cursor_sets_column_types_and_alignment(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.main_formatter = DummyFormatter() + monkeypatch.setattr(output_module, 'Cursor', FakeCursorBase) + rows = FakeCursorBase(rows=[(1, 'name')], rowcount=1, description=[('id', 3), ('name', 253)]) + result = SQLResult(header=['id', 'name'], rows=cast(Any, rows)) + + assert list(OutputMixin.format_sqlresult(cli, result, numeric_alignment='left')) == ['plain output'] + + _, kwargs = cli.main_formatter.calls[-1] + assert kwargs['column_types'] == [int, str] + assert kwargs['colalign'] == ['left', 'left'] + + +def test_format_sqlresult_switches_to_vertical_when_first_line_is_too_wide() -> None: + cli = make_bare_mycli() + cli.main_formatter = DummyFormatter() + result = SQLResult(header=['id'], rows=[(1,)]) + + assert list(OutputMixin.format_sqlresult(cli, result, max_width=2)) == ['vertical output'] + + +def test_get_reserved_space_caps_ratio(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + monkeypatch.setattr(shutil, 'get_terminal_size', lambda *args, **kwargs: (120, 40)) + + assert OutputMixin.get_reserved_space(cli) == 8 diff --git a/test/utils.py b/test/utils.py index cc0f9702..5bda3d3d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -10,6 +10,9 @@ from typing import Any, Callable, Literal, cast from packaging.version import Version +from prompt_toolkit.formatted_text import ( + ANSI, +) import pygments import pymysql import pytest @@ -145,8 +148,8 @@ def make_bare_mycli() -> Any: cli.query_history = [] cli.toolbar_error_message = None cli.prompt_session = None - cli.last_prompt_message = main.ANSI('') - cli.last_custom_toolbar_message = main.ANSI('') + cli.last_prompt_message = ANSI('') + cli.last_custom_toolbar_message = ANSI('') cli.prompt_lines = 0 cli.prompt_format = main.MyCli.default_prompt cli.multiline_continuation_char = '>'