Skip to content

Commit 8e87ae9

Browse files
committed
let reconnect preserve session state
Followups to reconnect() refactor: * Before attempting sqlexecute.connect(), try ping(reconnect=True) to do a true reconnect, preserving the connection_id() and other state such as session variables. This is the important part, which the commentary calls the "second pass". * Also, before attempting ping(reconnect=True), try ping(reconnect=False) with fewer feedback messages, which the commentary calls the "first pass". This pass is helpful to keep chatter down when the user habitually chooses the "connect" verb over "use". * Add new explicit feedback around creating a new connection when doing so, including a red tip to the user that session state was lost. * Tweak docstring in manual_reconnect() eg: "real function" -> "utility method" * Move db-change logic out of utility method, into manual_reconnect() and change_db(), keeping the "database" optional argument, as it is still useful for finessing feedback messages. * Silently skip changing the database if it equals "``". * In the usual case, let manual_reconnect() yield the result of change_db(), leaving us directly hooked in to the 4-tuple return- value system (instead of iterating on change_db() internally and manually handling the echo()). * Add an assert on self.sqlexecute.conn before pinging it. * Clarify "database" vs "server" in reconnect() docstring. (Pedantically it could be "cluster" or "endpoint"). * Update changelog, but just piggyback two words onto the previous entry. * Update tests to use mycli.packages.special.execute() rather than CliRunner(). CliRunner() is only capable of testing the first line of output, which is taken up by an initialization statement.
1 parent be626f8 commit 8e87ae9

File tree

4 files changed

+122
-33
lines changed

4 files changed

+122
-33
lines changed

changelog.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Upcoming (TBD)
44
Features
55
--------
66
* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL.
7-
* Rework reconnect logic to actually create a new connection instead of simply changing the database (#746).
7+
* Rework reconnect logic to actually reconnect or create a new connection instead of simply changing the database (#746).
88

99
Bug Fixes
1010
--------

mycli/main.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -262,11 +262,15 @@ def register_special_commands(self) -> None:
262262

263263
def manual_reconnect(self, arg: str = "", **_) -> Generator[tuple, None, None]:
264264
"""
265-
wrapper function to use for the \r command so that the real function
266-
may be cleanly used elsewhere
265+
Interactive method to use for the \r command, so that the utility method
266+
may be cleanly used elsewhere.
267267
"""
268-
self.reconnect(arg)
269-
yield (None, None, None, None)
268+
if not self.reconnect(database=arg):
269+
yield (None, None, None, "Not connected")
270+
elif not arg or arg == '``':
271+
yield (None, None, None, None)
272+
else:
273+
yield self.change_db(arg).send(None)
270274

271275
def enable_show_warnings(self, **_) -> Generator[tuple, None, None]:
272276
self.show_warnings = True
@@ -308,13 +312,18 @@ def change_db(self, arg: str, **_) -> Generator[tuple, None, None]:
308312
return
309313

310314
assert isinstance(self.sqlexecute, SQLExecute)
311-
self.sqlexecute.change_db(arg)
315+
316+
if self.sqlexecute.dbname == arg:
317+
msg = f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"'
318+
else:
319+
self.sqlexecute.change_db(arg)
320+
msg = f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"'
312321

313322
yield (
314323
None,
315324
None,
316325
None,
317-
f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"',
326+
msg,
318327
)
319328

320329
def execute_from_file(self, arg: str, **_) -> Iterable[tuple]:
@@ -1036,26 +1045,48 @@ def one_iteration(text: str | None = None) -> None:
10361045

10371046
def reconnect(self, database: str = "") -> bool:
10381047
"""
1039-
Attempt to reconnect to the database. Return True if successful,
1048+
Attempt to reconnect to the server. Return True if successful,
10401049
False if unsuccessful.
1050+
1051+
The "database" argument is used only to improve messages.
10411052
"""
10421053
assert self.sqlexecute is not None
1043-
self.logger.debug("Attempting to reconnect.")
1044-
self.echo("Reconnecting...", fg="yellow")
1054+
assert self.sqlexecute.conn is not None
1055+
1056+
# First pass with ping(reconnect=False) is just to have minimal feedback levels,
1057+
# especially when "connect" was used as a synonym for "use".
1058+
try:
1059+
self.sqlexecute.conn.ping(reconnect=False)
1060+
if not database:
1061+
self.echo("Already connected.", fg="yellow")
1062+
return True
1063+
except err.Error:
1064+
pass
1065+
1066+
# Second pass with ping(reconnect=True) is to reconnect while preserving session state.
1067+
try:
1068+
self.logger.debug("Attempting to reconnect.")
1069+
self.echo("Reconnecting...", fg="yellow")
1070+
self.sqlexecute.conn.ping(reconnect=True)
1071+
self.logger.debug("Reconnected successfully.")
1072+
self.echo("Reconnected successfully.", fg="yellow")
1073+
return True
1074+
except err.Error:
1075+
pass
1076+
1077+
# Third pass with sqlexecute.connect() should always work, but also resets session state.
10451078
try:
1079+
self.logger.debug("Creating new connection")
1080+
self.echo("Creating new connection...", fg="yellow")
10461081
self.sqlexecute.connect()
1082+
self.logger.debug("New connection created successfully.")
1083+
self.echo("New connection created successfully.", fg="yellow")
1084+
self.echo("Any session state was reset.", fg="red")
1085+
return True
10471086
except OperationalError as e:
10481087
self.logger.debug("Reconnect failed. e: %r", e)
10491088
self.echo(str(e), err=True, fg="red")
10501089
return False
1051-
self.logger.debug("Reconnected successfully.")
1052-
self.echo("Reconnected successfully.\n", fg="yellow")
1053-
if database and self.sqlexecute.dbname != database:
1054-
for result in self.change_db(database):
1055-
self.echo(result[3])
1056-
elif database:
1057-
self.echo(f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"')
1058-
return True
10591090

10601091
def log_output(self, output: str) -> None:
10611092
"""Log the output in the audit log, if it's enabled."""

test/features/steps/crud_database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,6 @@ def step_see_db_dropped_no_default(context):
108108
@then("we see database connected")
109109
def step_see_db_connected(context):
110110
"""Wait to see drop database output."""
111-
wrappers.expect_exact(context, 'You are now connected to database "', timeout=2)
111+
wrappers.expect_exact(context, 'connected to database "', timeout=2)
112112
wrappers.expect_exact(context, '"', timeout=2)
113113
wrappers.expect_exact(context, f' as user "{context.conf["user"]}"', timeout=2)

test/test_main.py

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from click.testing import CliRunner
1111

1212
from mycli.main import MyCli, cli, thanks_picker
13+
import mycli.packages.special
1314
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
1415
from mycli.sqlexecute import ServerInfo, SQLExecute
1516
from test.utils import HOST, PASSWORD, PORT, USER, dbtest, run
@@ -38,32 +39,89 @@
3839

3940

4041
@dbtest
41-
def test_reconnect_no_database(executor):
42-
runner = CliRunner()
42+
def test_reconnect_no_database(executor, capsys):
43+
m = MyCli()
44+
m.register_special_commands()
45+
m.sqlexecute = SQLExecute(
46+
None,
47+
USER,
48+
PASSWORD,
49+
HOST,
50+
PORT,
51+
None,
52+
None,
53+
None,
54+
None,
55+
None,
56+
None,
57+
None,
58+
None,
59+
None,
60+
None,
61+
)
4362
sql = "\\r"
44-
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
45-
expected = "Reconnecting...\nReconnected successfully.\n\n"
46-
assert expected in result.output
63+
result = next(mycli.packages.special.execute(executor, sql))
64+
stdout, _stderr = capsys.readouterr()
65+
assert result[-1] is None
66+
assert "Already connected" in stdout
4767

4868

4969
@dbtest
5070
def test_reconnect_with_different_database(executor):
51-
runner = CliRunner()
71+
m = MyCli()
72+
m.register_special_commands()
73+
m.sqlexecute = SQLExecute(
74+
None,
75+
USER,
76+
PASSWORD,
77+
HOST,
78+
PORT,
79+
None,
80+
None,
81+
None,
82+
None,
83+
None,
84+
None,
85+
None,
86+
None,
87+
None,
88+
None,
89+
)
5290
database = "mysql"
5391
sql = f"\\r {database}"
54-
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
55-
expected = f'Reconnecting...\nReconnected successfully.\n\nYou are now connected to database "{database}" as user "{USER}"\n'
56-
assert expected in result.output
92+
result = next(mycli.packages.special.execute(executor, sql))
93+
expected = f'You are now connected to database "{database}" as user "{USER}"'
94+
assert expected in result[-1]
5795

5896

5997
@dbtest
6098
def test_reconnect_with_same_database(executor):
61-
runner = CliRunner()
99+
m = MyCli()
100+
m.register_special_commands()
101+
m.sqlexecute = SQLExecute(
102+
None,
103+
USER,
104+
PASSWORD,
105+
HOST,
106+
PORT,
107+
None,
108+
None,
109+
None,
110+
None,
111+
None,
112+
None,
113+
None,
114+
None,
115+
None,
116+
None,
117+
)
62118
database = "mysql"
63-
sql = f"\\u {database}; \\r {database}"
64-
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
65-
expected = f'Reconnecting...\nReconnected successfully.\n\nYou are already connected to database "{database}" as user "{USER}"\n'
66-
assert expected in result.output
119+
sql = f"\\u {database}"
120+
result = next(mycli.packages.special.execute(executor, sql))
121+
sql = f"\\r {database}"
122+
result = next(mycli.packages.special.execute(executor, sql))
123+
expected = f'You are already connected to database "{database}" as user "{USER}"'
124+
assert expected in result[-1]
67125

68126

69127
@dbtest

0 commit comments

Comments
 (0)