diff --git a/changelog.md b/changelog.md index 04387011..d08ba03c 100644 --- a/changelog.md +++ b/changelog.md @@ -3,7 +3,8 @@ Upcoming (TBD) Features -------- -* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL +* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL. +* Rework reconnect logic to actually create a new connection instead of simply changing the database (#746). Bug Fixes -------- diff --git a/mycli/main.py b/mycli/main.py index 86dcc5c4..30d80bfb 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -213,8 +213,8 @@ def close(self) -> None: def register_special_commands(self) -> None: special.register_special_command(self.change_db, "use", "\\u", "Change to a new database.", aliases=["\\u"]) special.register_special_command( - self.change_db, - "connect", + self.manual_reconnect, + "reconnect", "\\r", "Reconnect to the database. Optional database argument.", aliases=["\\r"], @@ -260,6 +260,14 @@ def register_special_commands(self) -> None: self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=["\\R"], case_sensitive=True ) + def manual_reconnect(self, arg: str = "", **_) -> Generator[tuple, None, None]: + """ + wrapper function to use for the \r command so that the real function + may be cleanly used elsewhere + """ + self.reconnect(arg) + yield (None, None, None, None) + def enable_show_warnings(self, **_) -> Generator[tuple, None, None]: self.show_warnings = True msg = "Show warnings enabled." @@ -912,18 +920,11 @@ def one_iteration(text: str | None = None) -> None: special.unset_once_if_written(self.post_redirect_command) special.flush_pipe_once_if_written(self.post_redirect_command) except err.InterfaceError: - logger.debug("Attempting to reconnect.") - self.echo("Reconnecting...", fg="yellow") - try: - sqlexecute.connect() - logger.debug("Reconnected successfully.") - one_iteration(text) - return # OK to just return, cuz the recursion call runs to the end. - except OperationalError as e2: - logger.debug("Reconnect failed. e: %r", e2) - self.echo(str(e2), err=True, fg="red") - # If reconnection failed, don't proceed further. + # attempt to reconnect + if not self.reconnect(): return + one_iteration(text) + return # OK to just return, cuz the recursion call runs to the end. except EOFError as e: raise e except KeyboardInterrupt: @@ -957,18 +958,11 @@ def one_iteration(text: str | None = None) -> None: except OperationalError as e1: logger.debug("Exception: %r", e1) if e1.args[0] in (2003, 2006, 2013): - logger.debug("Attempting to reconnect.") - self.echo("Reconnecting...", fg="yellow") - try: - sqlexecute.connect() - logger.debug("Reconnected successfully.") - one_iteration(text) - return # OK to just return, cuz the recursion call runs to the end. - except OperationalError as e2: - logger.debug("Reconnect failed. e: %r", e2) - self.echo(str(e2), err=True, fg="red") - # If reconnection failed, don't proceed further. + # attempt to reconnect + if not self.reconnect(): return + one_iteration(text) + return # OK to just return, cuz the recursion call runs to the end. else: logger.error("sql: %r, error: %r", text, e1) logger.error("traceback: %r", traceback.format_exc()) @@ -1040,6 +1034,29 @@ def one_iteration(text: str | None = None) -> None: if not self.less_chatty: self.echo("Goodbye!") + def reconnect(self, database: str = "") -> bool: + """ + Attempt to reconnect to the database. Return True if successful, + False if unsuccessful. + """ + assert self.sqlexecute is not None + self.logger.debug("Attempting to reconnect.") + self.echo("Reconnecting...", fg="yellow") + try: + self.sqlexecute.connect() + except OperationalError as e: + self.logger.debug("Reconnect failed. e: %r", e) + self.echo(str(e), err=True, fg="red") + return False + self.logger.debug("Reconnected successfully.") + self.echo("Reconnected successfully.\n", fg="yellow") + if database and self.sqlexecute.dbname != database: + for result in self.change_db(database): + self.echo(result[3]) + elif database: + self.echo(f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"') + return True + def log_output(self, output: str) -> None: """Log the output in the audit log, if it's enabled.""" if isinstance(self.logfile, TextIOWrapper): diff --git a/test/features/fixture_data/help_commands.txt b/test/features/fixture_data/help_commands.txt index 9cb21324..c38aeb4e 100644 --- a/test/features/fixture_data/help_commands.txt +++ b/test/features/fixture_data/help_commands.txt @@ -13,15 +13,16 @@ | \once | \o [-o] filename | Append next result to an output file (overwrite using -o). | | \pipe_once | \| command | Send next result to a subprocess. | | \timing | \t | Toggle timing of commands. | -| connect | \r | Reconnect to the database. Optional database argument. | | delimiter | | Change SQL delimiter. | | exit | \q | Exit. | | help | \? | Show this help. | | nopager | \n | Disable pager, print to stdout. | | notee | notee | Stop writing results to an output file. | +| nowarnings | \w | Disable automatic warnings display. | | pager | \P [command] | Set PAGER. Print the query results via PAGER. | | prompt | \R | Change prompt format. | | quit | \q | Quit. | +| reconnect | \r | Reconnect to the database. Optional database argument. | | redirectformat | \Tr | Change the table format used to output redirected results. | | rehash | \# | Refresh auto-completions. | | source | \. filename | Execute commands from file. | @@ -30,5 +31,6 @@ | tableformat | \T | Change the table format used to output results. | | tee | tee [-o] filename | Append all results to an output file (overwrite using -o). | | use | \u | Change to a new database. | +| warnings | \W | Enable automatic warnings display. | | watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). | +----------------+----------------------------+------------------------------------------------------------+ diff --git a/test/test_main.py b/test/test_main.py index 3d6baaec..565b61fa 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -37,6 +37,35 @@ ] +@dbtest +def test_reconnect_no_database(executor): + runner = CliRunner() + sql = "\\r" + result = runner.invoke(cli, args=CLI_ARGS, input=sql) + expected = "Reconnecting...\nReconnected successfully.\n\n" + assert expected in result.output + + +@dbtest +def test_reconnect_with_different_database(executor): + runner = CliRunner() + database = "mysql" + sql = f"\\r {database}" + result = runner.invoke(cli, args=CLI_ARGS, input=sql) + expected = f'Reconnecting...\nReconnected successfully.\n\nYou are now connected to database "{database}" as user "{USER}"\n' + assert expected in result.output + + +@dbtest +def test_reconnect_with_same_database(executor): + runner = CliRunner() + database = "mysql" + sql = f"\\u {database}; \\r {database}" + result = runner.invoke(cli, args=CLI_ARGS, input=sql) + expected = f'Reconnecting...\nReconnected successfully.\n\nYou are already connected to database "{database}" as user "{USER}"\n' + assert expected in result.output + + @dbtest def test_prompt_no_host_only_socket(executor): mycli = MyCli()