From adc5c67816061a60687dc31e1aeddb0d469413d1 Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Mon, 22 Dec 2025 18:50:05 -0800 Subject: [PATCH 1/4] Moved reconnect logic to a separate function. Made a wrapper function for use by the command \r to call the new reconnect function. Updated help output in tests to match the change. --- mycli/main.py | 64 ++++++++++++-------- test/features/fixture_data/help_commands.txt | 4 +- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/mycli/main.py b/mycli/main.py index 6c227fcf..0ca3ab28 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -213,8 +213,8 @@ def close(self) -> None: def register_special_commands(self) -> None: special.register_special_command(self.change_db, "use", "\\u", "Change to a new database.", aliases=["\\u"]) special.register_special_command( - self.change_db, - "connect", + self.manual_reconnect, + "reconnect", "\\r", "Reconnect to the database. Optional database argument.", aliases=["\\r"], @@ -260,6 +260,14 @@ def register_special_commands(self) -> None: self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=["\\R"], case_sensitive=True ) + def manual_reconnect(self, arg: str = None, **_) -> Generator[tuple, None, None]: + """ + wrapper function to use for the \r command so that the real function + may be cleanly used elsewhere + """ + _ = self.reconnect(arg) + yield (None, None, None, None) + def enable_show_warnings(self, **_) -> Generator[tuple, None, None]: self.show_warnings = True msg = "Show warnings enabled." @@ -912,18 +920,11 @@ def one_iteration(text: str | None = None) -> None: special.unset_once_if_written(self.post_redirect_command) special.flush_pipe_once_if_written(self.post_redirect_command) except err.InterfaceError: - logger.debug("Attempting to reconnect.") - self.echo("Reconnecting...", fg="yellow") - try: - sqlexecute.connect() - logger.debug("Reconnected successfully.") - one_iteration(text) - return # OK to just return, cuz the recursion call runs to the end. - except OperationalError as e2: - logger.debug("Reconnect failed. e: %r", e2) - self.echo(str(e2), err=True, fg="red") - # If reconnection failed, don't proceed further. + # attempt to reconnect + if not self.reconnect(): return + one_iteration(text) + return # OK to just return, cuz the recursion call runs to the end. except EOFError as e: raise e except KeyboardInterrupt: @@ -957,18 +958,11 @@ def one_iteration(text: str | None = None) -> None: except OperationalError as e1: logger.debug("Exception: %r", e1) if e1.args[0] in (2003, 2006, 2013): - logger.debug("Attempting to reconnect.") - self.echo("Reconnecting...", fg="yellow") - try: - sqlexecute.connect() - logger.debug("Reconnected successfully.") - one_iteration(text) - return # OK to just return, cuz the recursion call runs to the end. - except OperationalError as e2: - logger.debug("Reconnect failed. e: %r", e2) - self.echo(str(e2), err=True, fg="red") - # If reconnection failed, don't proceed further. + # attempt to reconnect + if not self.reconnect(): return + one_iteration(text) + return # OK to just return, cuz the recursion call runs to the end. else: logger.error("sql: %r, error: %r", text, e1) logger.error("traceback: %r", traceback.format_exc()) @@ -1040,6 +1034,28 @@ def one_iteration(text: str | None = None) -> None: if not self.less_chatty: self.echo("Goodbye!") + def reconnect(self, database: str = None) -> bool: + """ + Attempt to reconnect to the database. Return True if successful, + False if unsuccessful. + """ + self.logger.debug("Attempting to reconnect.") + self.echo("Reconnecting...", fg="yellow") + try: + self.sqlexecute.connect() + except OperationalError as e: + self.logger.debug("Reconnect failed. e: %r", e) + self.echo(str(e), err=True, fg="red") + return False + self.logger.debug("Reconnected successfully.") + self.echo("Reconnected successfully.\n", fg="yellow") + if database and self.sqlexecute.dbname != database: + for result in self.change_db(database): + self.echo(result[3]) + elif database: + self.echo(f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"') + return True + def log_output(self, output: str) -> None: """Log the output in the audit log, if it's enabled.""" if isinstance(self.logfile, TextIOWrapper): diff --git a/test/features/fixture_data/help_commands.txt b/test/features/fixture_data/help_commands.txt index 9cb21324..c38aeb4e 100644 --- a/test/features/fixture_data/help_commands.txt +++ b/test/features/fixture_data/help_commands.txt @@ -13,15 +13,16 @@ | \once | \o [-o] filename | Append next result to an output file (overwrite using -o). | | \pipe_once | \| command | Send next result to a subprocess. | | \timing | \t | Toggle timing of commands. | -| connect | \r | Reconnect to the database. Optional database argument. | | delimiter | | Change SQL delimiter. | | exit | \q | Exit. | | help | \? | Show this help. | | nopager | \n | Disable pager, print to stdout. | | notee | notee | Stop writing results to an output file. | +| nowarnings | \w | Disable automatic warnings display. | | pager | \P [command] | Set PAGER. Print the query results via PAGER. | | prompt | \R | Change prompt format. | | quit | \q | Quit. | +| reconnect | \r | Reconnect to the database. Optional database argument. | | redirectformat | \Tr | Change the table format used to output redirected results. | | rehash | \# | Refresh auto-completions. | | source | \. filename | Execute commands from file. | @@ -30,5 +31,6 @@ | tableformat | \T | Change the table format used to output results. | | tee | tee [-o] filename | Append all results to an output file (overwrite using -o). | | use | \u | Change to a new database. | +| warnings | \W | Enable automatic warnings display. | | watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). | +----------------+----------------------------+------------------------------------------------------------+ From 438fef4ea9f1d1ce25835407d97a7edef8ffe7d8 Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Mon, 22 Dec 2025 18:50:05 -0800 Subject: [PATCH 2/4] Moved reconnect logic to a separate function. Made a wrapper function for use by the command \r to call the new reconnect function. Updated help output in tests to match the change. --- mycli/main.py | 64 ++++++++++++-------- test/features/fixture_data/help_commands.txt | 4 +- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/mycli/main.py b/mycli/main.py index 6f9965b5..2ff87b99 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -213,8 +213,8 @@ def close(self) -> None: def register_special_commands(self) -> None: special.register_special_command(self.change_db, "use", "\\u", "Change to a new database.", aliases=["\\u"]) special.register_special_command( - self.change_db, - "connect", + self.manual_reconnect, + "reconnect", "\\r", "Reconnect to the database. Optional database argument.", aliases=["\\r"], @@ -260,6 +260,14 @@ def register_special_commands(self) -> None: self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=["\\R"], case_sensitive=True ) + def manual_reconnect(self, arg: str = None, **_) -> Generator[tuple, None, None]: + """ + wrapper function to use for the \r command so that the real function + may be cleanly used elsewhere + """ + _ = self.reconnect(arg) + yield (None, None, None, None) + def enable_show_warnings(self, **_) -> Generator[tuple, None, None]: self.show_warnings = True msg = "Show warnings enabled." @@ -912,18 +920,11 @@ def one_iteration(text: str | None = None) -> None: special.unset_once_if_written(self.post_redirect_command) special.flush_pipe_once_if_written(self.post_redirect_command) except err.InterfaceError: - logger.debug("Attempting to reconnect.") - self.echo("Reconnecting...", fg="yellow") - try: - sqlexecute.connect() - logger.debug("Reconnected successfully.") - one_iteration(text) - return # OK to just return, cuz the recursion call runs to the end. - except OperationalError as e2: - logger.debug("Reconnect failed. e: %r", e2) - self.echo(str(e2), err=True, fg="red") - # If reconnection failed, don't proceed further. + # attempt to reconnect + if not self.reconnect(): return + one_iteration(text) + return # OK to just return, cuz the recursion call runs to the end. except EOFError as e: raise e except KeyboardInterrupt: @@ -957,18 +958,11 @@ def one_iteration(text: str | None = None) -> None: except OperationalError as e1: logger.debug("Exception: %r", e1) if e1.args[0] in (2003, 2006, 2013): - logger.debug("Attempting to reconnect.") - self.echo("Reconnecting...", fg="yellow") - try: - sqlexecute.connect() - logger.debug("Reconnected successfully.") - one_iteration(text) - return # OK to just return, cuz the recursion call runs to the end. - except OperationalError as e2: - logger.debug("Reconnect failed. e: %r", e2) - self.echo(str(e2), err=True, fg="red") - # If reconnection failed, don't proceed further. + # attempt to reconnect + if not self.reconnect(): return + one_iteration(text) + return # OK to just return, cuz the recursion call runs to the end. else: logger.error("sql: %r, error: %r", text, e1) logger.error("traceback: %r", traceback.format_exc()) @@ -1040,6 +1034,28 @@ def one_iteration(text: str | None = None) -> None: if not self.less_chatty: self.echo("Goodbye!") + def reconnect(self, database: str = None) -> bool: + """ + Attempt to reconnect to the database. Return True if successful, + False if unsuccessful. + """ + self.logger.debug("Attempting to reconnect.") + self.echo("Reconnecting...", fg="yellow") + try: + self.sqlexecute.connect() + except OperationalError as e: + self.logger.debug("Reconnect failed. e: %r", e) + self.echo(str(e), err=True, fg="red") + return False + self.logger.debug("Reconnected successfully.") + self.echo("Reconnected successfully.\n", fg="yellow") + if database and self.sqlexecute.dbname != database: + for result in self.change_db(database): + self.echo(result[3]) + elif database: + self.echo(f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"') + return True + def log_output(self, output: str) -> None: """Log the output in the audit log, if it's enabled.""" if isinstance(self.logfile, TextIOWrapper): diff --git a/test/features/fixture_data/help_commands.txt b/test/features/fixture_data/help_commands.txt index 9cb21324..c38aeb4e 100644 --- a/test/features/fixture_data/help_commands.txt +++ b/test/features/fixture_data/help_commands.txt @@ -13,15 +13,16 @@ | \once | \o [-o] filename | Append next result to an output file (overwrite using -o). | | \pipe_once | \| command | Send next result to a subprocess. | | \timing | \t | Toggle timing of commands. | -| connect | \r | Reconnect to the database. Optional database argument. | | delimiter | | Change SQL delimiter. | | exit | \q | Exit. | | help | \? | Show this help. | | nopager | \n | Disable pager, print to stdout. | | notee | notee | Stop writing results to an output file. | +| nowarnings | \w | Disable automatic warnings display. | | pager | \P [command] | Set PAGER. Print the query results via PAGER. | | prompt | \R | Change prompt format. | | quit | \q | Quit. | +| reconnect | \r | Reconnect to the database. Optional database argument. | | redirectformat | \Tr | Change the table format used to output redirected results. | | rehash | \# | Refresh auto-completions. | | source | \. filename | Execute commands from file. | @@ -30,5 +31,6 @@ | tableformat | \T | Change the table format used to output results. | | tee | tee [-o] filename | Append all results to an output file (overwrite using -o). | | use | \u | Change to a new database. | +| warnings | \W | Enable automatic warnings display. | | watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). | +----------------+----------------------------+------------------------------------------------------------+ From 7b6b4312e547c6dedd80ae26cd78ec9fc19f37f1 Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Mon, 22 Dec 2025 20:23:58 -0800 Subject: [PATCH 3/4] Added reconnect related tests. Fixed some items to make mypy happy. --- mycli/main.py | 7 ++++--- test/test_main.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/mycli/main.py b/mycli/main.py index 2ff87b99..0710150e 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -260,12 +260,12 @@ def register_special_commands(self) -> None: self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=["\\R"], case_sensitive=True ) - def manual_reconnect(self, arg: str = None, **_) -> Generator[tuple, None, None]: + def manual_reconnect(self, arg: str = "", **_) -> Generator[tuple, None, None]: """ wrapper function to use for the \r command so that the real function may be cleanly used elsewhere """ - _ = self.reconnect(arg) + self.reconnect(arg) yield (None, None, None, None) def enable_show_warnings(self, **_) -> Generator[tuple, None, None]: @@ -1034,11 +1034,12 @@ def one_iteration(text: str | None = None) -> None: if not self.less_chatty: self.echo("Goodbye!") - def reconnect(self, database: str = None) -> bool: + def reconnect(self, database: str = "") -> bool: """ Attempt to reconnect to the database. Return True if successful, False if unsuccessful. """ + assert self.sqlexecute is not None self.logger.debug("Attempting to reconnect.") self.echo("Reconnecting...", fg="yellow") try: diff --git a/test/test_main.py b/test/test_main.py index 34cbde66..aea56481 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -37,6 +37,35 @@ ] +@dbtest +def test_reconnect_no_database(executor): + runner = CliRunner() + sql = "\\r" + result = runner.invoke(cli, args=CLI_ARGS, input=sql) + expected = "Reconnecting...\nReconnected successfully.\n\n" + assert expected in result.output + + +@dbtest +def test_reconnect_with_different_database(executor): + runner = CliRunner() + database = "mysql" + sql = f"\\r {database}" + result = runner.invoke(cli, args=CLI_ARGS, input=sql) + expected = f'Reconnecting...\nReconnected successfully.\n\nYou are now connected to database "{database}" as user "{USER}"\n' + assert expected in result.output + + +@dbtest +def test_reconnect_with_same_database(executor): + runner = CliRunner() + database = "mysql" + sql = f"\\u {database}; \\r {database}" + result = runner.invoke(cli, args=CLI_ARGS, input=sql) + expected = f'Reconnecting...\nReconnected successfully.\n\nYou are already connected to database "{database}" as user "{USER}"\n' + assert expected in result.output + + @dbtest def test_prompt_no_host_only_socket(executor): mycli = MyCli() From 9f3893b9dfa31e58386e92ae3fd8abc7c8777eb2 Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Mon, 22 Dec 2025 20:31:30 -0800 Subject: [PATCH 4/4] Updated changelog --- changelog.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/changelog.md b/changelog.md index f2132346..3bec270a 100644 --- a/changelog.md +++ b/changelog.md @@ -1,6 +1,10 @@ -1.42.0 (2025/12/20) +Upcoming (TBD) ============== +Features +-------- +* Rework reconnect logic to actually create a new connection instead of simply changing the database (#746) + Bug Fixes -------- * Update the prompt display logic to handle an edge case where a socket is used without