diff --git a/changelog.md b/changelog.md index cb6118f9..c6f87f56 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,13 @@ Upcoming (TBD) ============== +Features +-------- +* Add support for the automatic displaying of warnings after a SQL statement is executed. + May be set with the commands \W and \w, in the config file with show_warnings, or + with --show-warnings/--no-show-warnings on the command line. + + Internal -------- * Improve robustness for flaky tests when publishing. diff --git a/mycli/AUTHORS b/mycli/AUTHORS index d39b3e4f..fc4cc4d3 100644 --- a/mycli/AUTHORS +++ b/mycli/AUTHORS @@ -111,6 +111,7 @@ Contributors: * keltaklo * 924060929 * tmijieux + * Scott Nemes Created by: diff --git a/mycli/main.py b/mycli/main.py index a54b80e1..6c227fcf 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -109,6 +109,7 @@ def __init__( defaults_file: str | None = None, login_path: str | None = None, auto_vertical_output: bool = False, + show_warnings: bool = False, warn: bool | None = None, myclirc: str = "~/.myclirc", ) -> None: @@ -155,6 +156,7 @@ def __init__( # read from cli argument or user config file self.auto_vertical_output = auto_vertical_output or c["main"].as_bool("auto_vertical_output") + self.show_warnings = show_warnings or c["main"].as_bool("show_warnings") # Write user config if system config wasn't the last config loaded. if c.filename not in self.system_config_files and not os.path.exists(myclirc): @@ -237,11 +239,37 @@ def register_special_commands(self) -> None: aliases=["\\Tr"], case_sensitive=True, ) + special.register_special_command( + self.disable_show_warnings, + "nowarnings", + "\\w", + "Disable automatic warnings display.", + aliases=["\\w"], + case_sensitive=True, + ) + special.register_special_command( + self.enable_show_warnings, + "warnings", + "\\W", + "Enable automatic warnings display.", + aliases=["\\W"], + case_sensitive=True, + ) special.register_special_command(self.execute_from_file, "source", "\\. filename", "Execute commands from file.", aliases=["\\."]) special.register_special_command( self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=["\\R"], case_sensitive=True ) + def enable_show_warnings(self, **_) -> Generator[tuple, None, None]: + self.show_warnings = True + msg = "Show warnings enabled." + yield (None, None, None, msg) + + def disable_show_warnings(self, **_) -> Generator[tuple, None, None]: + self.show_warnings = False + msg = "Show warnings disabled." + yield (None, None, None, msg) + def change_table_format(self, arg: str, **_) -> Generator[tuple, None, None]: try: self.main_formatter.format_name = arg @@ -768,6 +796,21 @@ def output_res(res: Generator[tuple], start: float) -> None: result_count += 1 mutating = mutating or is_mutating(status) + # get and display warnings if enabled + if self.show_warnings and isinstance(cur, Cursor) and cur.warning_count > 0: + warnings = sqlexecute.run("SHOW WARNINGS") + for title, cur, headers, status in warnings: + formatted = self.format_output( + title, + cur, + headers, + special.is_expanded_output(), + special.is_redirected(), + max_width, + ) + self.echo("") + self.output(formatted, status) + def one_iteration(text: str | None = None) -> None: if text is None: try: @@ -1186,6 +1229,20 @@ def run_query(self, query: str, new_line: bool = True) -> None: for line in output: click.echo(line, nl=new_line) + # get and display warnings if enabled + if self.show_warnings and isinstance(cur, Cursor) and cur.warning_count > 0: + warnings = self.sqlexecute.run("SHOW WARNINGS") + for title, cur, headers, _ in warnings: + output = self.format_output( + title, + cur, + headers, + special.is_expanded_output(), + special.is_redirected(), + ) + for line in output: + click.echo(line, nl=new_line) + def format_output( self, title: str | None, @@ -1315,6 +1372,7 @@ def get_last_query(self) -> str | None: is_flag=True, help="Automatically switch to vertical output mode if the result is wider than the terminal width.", ) +@click.option("--show-warnings/--no-show-warnings", is_flag=True, help="Automatically show warnings after executing a SQL statement.") @click.option("-t", "--table", is_flag=True, help="Display batch output in table format.") @click.option("--csv", is_flag=True, help="Display batch output in CSV format.") @click.option("--warn/--no-warn", default=None, help="Warn before running a destructive query.") @@ -1342,6 +1400,7 @@ def cli( defaults_file: str | None, login_path: str | None, auto_vertical_output: bool, + show_warnings: bool, local_infile: bool, ssl_enable: bool, ssl_ca: str | None, @@ -1533,6 +1592,10 @@ def cli( combined_init_cmd = "; ".join(cmd.strip() for cmd in init_cmds if cmd) + # --show-warnings / --no-show-warnings + if show_warnings: + mycli.show_warnings = show_warnings + mycli.connect( database=database, user=user, diff --git a/mycli/myclirc b/mycli/myclirc index 26387860..a9e15808 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -1,6 +1,10 @@ # vi: ft=dosini [main] +# Enable or disable the automatic displaying of warnings ("SHOW WARNINGS") +# after executing a SQL statement when applicable. +show_warnings = False + # Enables context sensitive auto-completion. If this is disabled the all # possible completions will be listed. smart_completion = True diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 9794a946..49c41e8a 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -208,10 +208,10 @@ def connect( ) conv = conversions.copy() conv.update({ - FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj), - FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj), - FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj), - FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj), + FIELD_TYPE.TIMESTAMP: lambda obj: convert_datetime(obj) or obj, + FIELD_TYPE.DATETIME: lambda obj: convert_datetime(obj) or obj, + FIELD_TYPE.TIME: lambda obj: convert_timedelta(obj) or obj, + FIELD_TYPE.DATE: lambda obj: convert_date(obj) or obj, }) defer_connect = False @@ -342,15 +342,18 @@ def get_result(self, cursor: Cursor) -> tuple: # cursor.description is not None for queries that return result sets, # e.g. SELECT or SHOW. + plural = '' if cursor.rowcount == 1 else 's' if cursor.description: headers = [x[0] for x in cursor.description] - plural = '' if cursor.rowcount == 1 else 's' status = f'{cursor.rowcount} row{plural} in set' else: _logger.debug("No rows in result.") - plural = '' if cursor.rowcount == 1 else 's' status = f'Query OK, {cursor.rowcount} row{plural} affected' + if cursor.warning_count > 0: + plural = '' if cursor.warning_count == 1 else 's' + status = f'{status}, {cursor.warning_count} warning{plural}' + return (title, cursor if cursor.description else None, headers, status) def tables(self) -> Generator[tuple[str], None, None]: diff --git a/test/myclirc b/test/myclirc index a2bb8dd5..a19a34ba 100644 --- a/test/myclirc +++ b/test/myclirc @@ -1,6 +1,10 @@ # vi: ft=dosini [main] +# Enable or disable the automatic displaying of warnings ("SHOW WARNINGS") +# after executing a SQL statement when applicable. +show_warnings = False + # Enables context sensitive auto-completion. If this is disabled the all # possible completions will be listed. smart_completion = True diff --git a/test/test_main.py b/test/test_main.py index d4ef6862..159c1ba7 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -37,6 +37,75 @@ ] +@dbtest +def test_enable_show_warnings(executor): + mycli = MyCli() + mycli.register_special_commands() + sql = "\\W" + result = run(executor, sql) + assert result[0]["status"] == "Show warnings enabled." + + +@dbtest +def test_disable_show_warnings(executor): + mycli = MyCli() + mycli.register_special_commands() + sql = "\\w" + result = run(executor, sql) + assert result[0]["status"] == "Show warnings disabled." + + +@dbtest +def test_output_with_warning_and_show_warnings_enabled(executor): + runner = CliRunner() + sql = "SELECT 1 + '0 foo'" + result = runner.invoke(cli, args=CLI_ARGS + ["--show-warnings"], input=sql) + expected = "1 + '0 foo'\n1.0\nLevel\tCode\tMessage\nWarning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + assert expected in result.output + + +@dbtest +def test_output_with_warning_and_show_warnings_disabled(executor): + runner = CliRunner() + sql = "SELECT 1 + '0 foo'" + result = runner.invoke(cli, args=CLI_ARGS + ["--no-show-warnings"], input=sql) + expected = "1 + '0 foo'\n1.0\nLevel\tCode\tMessage\nWarning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + assert expected not in result.output + + +@dbtest +def test_output_with_multiple_warnings_in_single_statement(executor): + runner = CliRunner() + sql = "SELECT 1 + '0 foo', 2 + '0 foo'" + result = runner.invoke(cli, args=CLI_ARGS + ["--show-warnings"], input=sql) + expected = ( + "1 + '0 foo'\t2 + '0 foo'\n" + "1.0\t2.0\n" + "Level\tCode\tMessage\n" + "Warning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + "Warning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + ) + assert expected in result.output + + +@dbtest +def test_output_with_multiple_warnings_in_multiple_statements(executor): + runner = CliRunner() + sql = "SELECT 1 + '0 foo'; SELECT 2 + '0 foo'" + result = runner.invoke(cli, args=CLI_ARGS + ["--show-warnings"], input=sql) + expected = ( + "1 + '0 foo'\n" + "1.0\n" + "Level\tCode\tMessage\n" + "Warning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + "2 + '0 foo'\n" + "2.0\n" + "Level\tCode\tMessage\n" + "Warning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + ) + assert expected in result.output + + @dbtest def test_execute_arg(executor): run(executor, "create table test (a text)") diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py index d1d97478..a0e91e48 100644 --- a/test/test_sqlexecute.py +++ b/test/test_sqlexecute.py @@ -25,6 +25,20 @@ def assert_result_equal(result, title=None, rows=None, headers=None, status=None assert result == [fields] +@dbtest +def test_get_result_status_without_warning(executor): + sql = "select 1" + result = run(executor, sql) + assert result[0]["status"] == "1 row in set" + + +@dbtest +def test_get_result_status_with_warning(executor): + sql = "SELECT 1 + '0 foo'" + result = run(executor, sql) + assert result[0]["status"] == "1 row in set, 1 warning" + + @dbtest def test_conn(executor): run(executor, """create table test(a text)""")