diff --git a/changelog.md b/changelog.md index d82416e0..29776b9a 100644 --- a/changelog.md +++ b/changelog.md @@ -4,7 +4,7 @@ Upcoming (TBD) Features -------- * Update query processing functions to allow automatic show_warnings to work for more code paths like DDL. -* Rework reconnect logic to actually create a new connection instead of simply changing the database (#746). +* Rework reconnect logic to actually reconnect or create a new connection instead of simply changing the database (#746). Bug Fixes diff --git a/mycli/main.py b/mycli/main.py index d30f286b..d062e05a 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -262,11 +262,15 @@ def register_special_commands(self) -> None: def manual_reconnect(self, arg: str = "", **_) -> Generator[tuple, None, None]: """ - wrapper function to use for the \r command so that the real function - may be cleanly used elsewhere + Interactive method to use for the \r command, so that the utility method + may be cleanly used elsewhere. """ - self.reconnect(arg) - yield (None, None, None, None) + if not self.reconnect(database=arg): + yield (None, None, None, "Not connected") + elif not arg or arg == '``': + yield (None, None, None, None) + else: + yield self.change_db(arg).send(None) def enable_show_warnings(self, **_) -> Generator[tuple, None, None]: self.show_warnings = True @@ -308,13 +312,18 @@ def change_db(self, arg: str, **_) -> Generator[tuple, None, None]: return assert isinstance(self.sqlexecute, SQLExecute) - self.sqlexecute.change_db(arg) + + if self.sqlexecute.dbname == arg: + msg = f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"' + else: + self.sqlexecute.change_db(arg) + msg = f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"' yield ( None, None, None, - f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"', + msg, ) def execute_from_file(self, arg: str, **_) -> Iterable[tuple]: @@ -1036,26 +1045,55 @@ def one_iteration(text: str | None = None) -> None: def reconnect(self, database: str = "") -> bool: """ - Attempt to reconnect to the database. Return True if successful, + Attempt to reconnect to the server. Return True if successful, False if unsuccessful. + + The "database" argument is used only to improve messages. """ assert self.sqlexecute is not None - self.logger.debug("Attempting to reconnect.") - self.echo("Reconnecting...", fg="yellow") + assert self.sqlexecute.conn is not None + + # First pass with ping(reconnect=False) and minimal feedback levels. This definitely + # works as expected, and is a good idea especially when "connect" was used as a + # synonym for "use". + try: + self.sqlexecute.conn.ping(reconnect=False) + if not database: + self.echo("Already connected.", fg="yellow") + return True + except pymysql.err.Error: + pass + + # Second pass with ping(reconnect=True). It is not demonstrated that this pass ever + # gives the benefit it is looking for, _ie_ preserves session state. We need to test + # this with connection pooling. + try: + old_connection_id = self.sqlexecute.connection_id + self.logger.debug("Attempting to reconnect.") + self.echo("Reconnecting...", fg="yellow") + self.sqlexecute.conn.ping(reconnect=True) + self.logger.debug("Reconnected successfully.") + self.echo("Reconnected successfully.", fg="yellow") + self.sqlexecute.reset_connection_id() + if old_connection_id != self.sqlexecute.connection_id: + self.echo("Any session state was reset.", fg="red") + return True + except pymysql.err.Error: + pass + + # Third pass with sqlexecute.connect() should always work, but always resets session state. try: + self.logger.debug("Creating new connection") + self.echo("Creating new connection...", fg="yellow") self.sqlexecute.connect() + self.logger.debug("New connection created successfully.") + self.echo("New connection created successfully.", fg="yellow") + self.echo("Any session state was reset.", fg="red") + return True except pymysql.OperationalError as e: self.logger.debug("Reconnect failed. e: %r", e) self.echo(str(e), err=True, fg="red") return False - self.logger.debug("Reconnected successfully.") - self.echo("Reconnected successfully.\n", fg="yellow") - if database and self.sqlexecute.dbname != database: - for result in self.change_db(database): - self.echo(result[3]) - elif database: - self.echo(f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"') - return True def log_output(self, output: str) -> None: """Log the output in the audit log, if it's enabled.""" diff --git a/test/features/steps/crud_database.py b/test/features/steps/crud_database.py index 0e1726f5..01f36db1 100644 --- a/test/features/steps/crud_database.py +++ b/test/features/steps/crud_database.py @@ -108,6 +108,6 @@ def step_see_db_dropped_no_default(context): @then("we see database connected") def step_see_db_connected(context): """Wait to see drop database output.""" - wrappers.expect_exact(context, 'You are now connected to database "', timeout=2) + wrappers.expect_exact(context, 'connected to database "', timeout=2) wrappers.expect_exact(context, '"', timeout=2) wrappers.expect_exact(context, f' as user "{context.conf["user"]}"', timeout=2) diff --git a/test/test_main.py b/test/test_main.py index 565b61fa..04ac5c18 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -10,6 +10,7 @@ from click.testing import CliRunner from mycli.main import MyCli, cli, thanks_picker +import mycli.packages.special from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS from mycli.sqlexecute import ServerInfo, SQLExecute from test.utils import HOST, PASSWORD, PORT, USER, dbtest, run @@ -38,32 +39,92 @@ @dbtest -def test_reconnect_no_database(executor): - runner = CliRunner() +def test_reconnect_no_database(executor, capsys): + m = MyCli() + m.register_special_commands() + m.sqlexecute = SQLExecute( + None, + USER, + PASSWORD, + HOST, + PORT, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) sql = "\\r" - result = runner.invoke(cli, args=CLI_ARGS, input=sql) - expected = "Reconnecting...\nReconnected successfully.\n\n" - assert expected in result.output + result = next(mycli.packages.special.execute(executor, sql)) + stdout, _stderr = capsys.readouterr() + assert result[-1] is None + assert "Already connected" in stdout @dbtest def test_reconnect_with_different_database(executor): - runner = CliRunner() - database = "mysql" - sql = f"\\r {database}" - result = runner.invoke(cli, args=CLI_ARGS, input=sql) - expected = f'Reconnecting...\nReconnected successfully.\n\nYou are now connected to database "{database}" as user "{USER}"\n' - assert expected in result.output + m = MyCli() + m.register_special_commands() + m.sqlexecute = SQLExecute( + None, + USER, + PASSWORD, + HOST, + PORT, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + database_1 = "mycli_test_db" + database_2 = "mysql" + sql_1 = f"use {database_1}" + sql_2 = f"\\r {database_2}" + _result_1 = next(mycli.packages.special.execute(executor, sql_1)) + result_2 = next(mycli.packages.special.execute(executor, sql_2)) + expected = f'You are now connected to database "{database_2}" as user "{USER}"' + assert expected in result_2[-1] @dbtest def test_reconnect_with_same_database(executor): - runner = CliRunner() + m = MyCli() + m.register_special_commands() + m.sqlexecute = SQLExecute( + None, + USER, + PASSWORD, + HOST, + PORT, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) database = "mysql" - sql = f"\\u {database}; \\r {database}" - result = runner.invoke(cli, args=CLI_ARGS, input=sql) - expected = f'Reconnecting...\nReconnected successfully.\n\nYou are already connected to database "{database}" as user "{USER}"\n' - assert expected in result.output + sql = f"\\u {database}" + result = next(mycli.packages.special.execute(executor, sql)) + sql = f"\\r {database}" + result = next(mycli.packages.special.execute(executor, sql)) + expected = f'You are already connected to database "{database}" as user "{USER}"' + assert expected in result[-1] @dbtest