Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Upcoming (TBD)
Features
--------
* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL.
* Rework reconnect logic to actually create a new connection instead of simply changing the database (#746).
* Rework reconnect logic to actually reconnect or create a new connection instead of simply changing the database (#746).

Bug Fixes
--------
Expand Down
69 changes: 52 additions & 17 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,15 @@ def register_special_commands(self) -> None:

def manual_reconnect(self, arg: str = "", **_) -> Generator[tuple, None, None]:
"""
wrapper function to use for the \r command so that the real function
may be cleanly used elsewhere
Interactive method to use for the \r command, so that the utility method
may be cleanly used elsewhere.
"""
self.reconnect(arg)
yield (None, None, None, None)
if not self.reconnect(database=arg):
yield (None, None, None, "Not connected")
elif not arg or arg == '``':
yield (None, None, None, None)
else:
yield self.change_db(arg).send(None)

def enable_show_warnings(self, **_) -> Generator[tuple, None, None]:
self.show_warnings = True
Expand Down Expand Up @@ -308,13 +312,18 @@ def change_db(self, arg: str, **_) -> Generator[tuple, None, None]:
return

assert isinstance(self.sqlexecute, SQLExecute)
self.sqlexecute.change_db(arg)

if self.sqlexecute.dbname == arg:
msg = f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"'
else:
self.sqlexecute.change_db(arg)
msg = f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"'

yield (
None,
None,
None,
f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"',
msg,
)

def execute_from_file(self, arg: str, **_) -> Iterable[tuple]:
Expand Down Expand Up @@ -1036,26 +1045,52 @@ def one_iteration(text: str | None = None) -> None:

def reconnect(self, database: str = "") -> bool:
"""
Attempt to reconnect to the database. Return True if successful,
Attempt to reconnect to the server. Return True if successful,
False if unsuccessful.

The "database" argument is used only to improve messages.
"""
assert self.sqlexecute is not None
self.logger.debug("Attempting to reconnect.")
self.echo("Reconnecting...", fg="yellow")
assert self.sqlexecute.conn is not None

# First pass with ping(reconnect=False) is just to have minimal feedback levels,
# especially when "connect" was used as a synonym for "use".
try:
self.sqlexecute.conn.ping(reconnect=False)
if not database:
self.echo("Already connected.", fg="yellow")
return True
except err.Error:
pass

# Second pass with ping(reconnect=True) might reconnect while preserving session state.
try:
old_connection_id = self.sqlexecute.connection_id
self.logger.debug("Attempting to reconnect.")
self.echo("Reconnecting...", fg="yellow")
self.sqlexecute.conn.ping(reconnect=True)
self.logger.debug("Reconnected successfully.")
self.echo("Reconnected successfully.", fg="yellow")
self.sqlexecute.reset_connection_id()
if old_connection_id != self.sqlexecute.connection_id:
self.echo("Any session state was reset.", fg="red")
return True
except err.Error:
pass

# Third pass with sqlexecute.connect() should always work, but always resets session state.
try:
self.logger.debug("Creating new connection")
self.echo("Creating new connection...", fg="yellow")
self.sqlexecute.connect()
self.logger.debug("New connection created successfully.")
self.echo("New connection created successfully.", fg="yellow")
self.echo("Any session state was reset.", fg="red")
return True
except OperationalError as e:
self.logger.debug("Reconnect failed. e: %r", e)
self.echo(str(e), err=True, fg="red")
return False
self.logger.debug("Reconnected successfully.")
self.echo("Reconnected successfully.\n", fg="yellow")
if database and self.sqlexecute.dbname != database:
for result in self.change_db(database):
self.echo(result[3])
elif database:
self.echo(f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"')
return True

def log_output(self, output: str) -> None:
"""Log the output in the audit log, if it's enabled."""
Expand Down
2 changes: 1 addition & 1 deletion test/features/steps/crud_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,6 @@ def step_see_db_dropped_no_default(context):
@then("we see database connected")
def step_see_db_connected(context):
"""Wait to see drop database output."""
wrappers.expect_exact(context, 'You are now connected to database "', timeout=2)
wrappers.expect_exact(context, 'connected to database "', timeout=2)
wrappers.expect_exact(context, '"', timeout=2)
wrappers.expect_exact(context, f' as user "{context.conf["user"]}"', timeout=2)
86 changes: 72 additions & 14 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from click.testing import CliRunner

from mycli.main import MyCli, cli, thanks_picker
import mycli.packages.special
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
from mycli.sqlexecute import ServerInfo, SQLExecute
from test.utils import HOST, PASSWORD, PORT, USER, dbtest, run
Expand Down Expand Up @@ -38,32 +39,89 @@


@dbtest
def test_reconnect_no_database(executor):
runner = CliRunner()
def test_reconnect_no_database(executor, capsys):
m = MyCli()
m.register_special_commands()
m.sqlexecute = SQLExecute(
None,
USER,
PASSWORD,
HOST,
PORT,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
sql = "\\r"
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
expected = "Reconnecting...\nReconnected successfully.\n\n"
assert expected in result.output
result = next(mycli.packages.special.execute(executor, sql))
stdout, _stderr = capsys.readouterr()
assert result[-1] is None
assert "Already connected" in stdout


@dbtest
def test_reconnect_with_different_database(executor):
runner = CliRunner()
m = MyCli()
m.register_special_commands()
m.sqlexecute = SQLExecute(
None,
USER,
PASSWORD,
HOST,
PORT,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
database = "mysql"
sql = f"\\r {database}"
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
expected = f'Reconnecting...\nReconnected successfully.\n\nYou are now connected to database "{database}" as user "{USER}"\n'
assert expected in result.output
result = next(mycli.packages.special.execute(executor, sql))
expected = f'You are now connected to database "{database}" as user "{USER}"'
assert expected in result[-1]


@dbtest
def test_reconnect_with_same_database(executor):
runner = CliRunner()
m = MyCli()
m.register_special_commands()
m.sqlexecute = SQLExecute(
None,
USER,
PASSWORD,
HOST,
PORT,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
database = "mysql"
sql = f"\\u {database}; \\r {database}"
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
expected = f'Reconnecting...\nReconnected successfully.\n\nYou are already connected to database "{database}" as user "{USER}"\n'
assert expected in result.output
sql = f"\\u {database}"
result = next(mycli.packages.special.execute(executor, sql))
sql = f"\\r {database}"
result = next(mycli.packages.special.execute(executor, sql))
expected = f'You are already connected to database "{database}" as user "{USER}"'
assert expected in result[-1]


@dbtest
Expand Down