Skip to content

Commit be626f8

Browse files
authored
[feat] Rework reconnect logic to actually create a new connection instead of only changing the database (#746) (#1416)
* Moved reconnect logic to a separate function. Made a wrapper function for use by the command \r to call the new reconnect function. Updated help output in tests to match the change.
1 parent 3683b9f commit be626f8

File tree

4 files changed

+73
-24
lines changed

4 files changed

+73
-24
lines changed

changelog.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ Upcoming (TBD)
33

44
Features
55
--------
6-
* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL
6+
* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL.
7+
* Rework reconnect logic to actually create a new connection instead of simply changing the database (#746).
78

89
Bug Fixes
910
--------

mycli/main.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def close(self) -> None:
213213
def register_special_commands(self) -> None:
214214
special.register_special_command(self.change_db, "use", "\\u", "Change to a new database.", aliases=["\\u"])
215215
special.register_special_command(
216-
self.change_db,
216+
self.manual_reconnect,
217217
"connect",
218218
"\\r",
219219
"Reconnect to the database. Optional database argument.",
@@ -260,6 +260,14 @@ def register_special_commands(self) -> None:
260260
self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=["\\R"], case_sensitive=True
261261
)
262262

263+
def manual_reconnect(self, arg: str = "", **_) -> Generator[tuple, None, None]:
264+
"""
265+
wrapper function to use for the \r command so that the real function
266+
may be cleanly used elsewhere
267+
"""
268+
self.reconnect(arg)
269+
yield (None, None, None, None)
270+
263271
def enable_show_warnings(self, **_) -> Generator[tuple, None, None]:
264272
self.show_warnings = True
265273
msg = "Show warnings enabled."
@@ -912,18 +920,11 @@ def one_iteration(text: str | None = None) -> None:
912920
special.unset_once_if_written(self.post_redirect_command)
913921
special.flush_pipe_once_if_written(self.post_redirect_command)
914922
except err.InterfaceError:
915-
logger.debug("Attempting to reconnect.")
916-
self.echo("Reconnecting...", fg="yellow")
917-
try:
918-
sqlexecute.connect()
919-
logger.debug("Reconnected successfully.")
920-
one_iteration(text)
921-
return # OK to just return, cuz the recursion call runs to the end.
922-
except OperationalError as e2:
923-
logger.debug("Reconnect failed. e: %r", e2)
924-
self.echo(str(e2), err=True, fg="red")
925-
# If reconnection failed, don't proceed further.
923+
# attempt to reconnect
924+
if not self.reconnect():
926925
return
926+
one_iteration(text)
927+
return # OK to just return, cuz the recursion call runs to the end.
927928
except EOFError as e:
928929
raise e
929930
except KeyboardInterrupt:
@@ -957,18 +958,11 @@ def one_iteration(text: str | None = None) -> None:
957958
except OperationalError as e1:
958959
logger.debug("Exception: %r", e1)
959960
if e1.args[0] in (2003, 2006, 2013):
960-
logger.debug("Attempting to reconnect.")
961-
self.echo("Reconnecting...", fg="yellow")
962-
try:
963-
sqlexecute.connect()
964-
logger.debug("Reconnected successfully.")
965-
one_iteration(text)
966-
return # OK to just return, cuz the recursion call runs to the end.
967-
except OperationalError as e2:
968-
logger.debug("Reconnect failed. e: %r", e2)
969-
self.echo(str(e2), err=True, fg="red")
970-
# If reconnection failed, don't proceed further.
961+
# attempt to reconnect
962+
if not self.reconnect():
971963
return
964+
one_iteration(text)
965+
return # OK to just return, cuz the recursion call runs to the end.
972966
else:
973967
logger.error("sql: %r, error: %r", text, e1)
974968
logger.error("traceback: %r", traceback.format_exc())
@@ -1040,6 +1034,29 @@ def one_iteration(text: str | None = None) -> None:
10401034
if not self.less_chatty:
10411035
self.echo("Goodbye!")
10421036

1037+
def reconnect(self, database: str = "") -> bool:
1038+
"""
1039+
Attempt to reconnect to the database. Return True if successful,
1040+
False if unsuccessful.
1041+
"""
1042+
assert self.sqlexecute is not None
1043+
self.logger.debug("Attempting to reconnect.")
1044+
self.echo("Reconnecting...", fg="yellow")
1045+
try:
1046+
self.sqlexecute.connect()
1047+
except OperationalError as e:
1048+
self.logger.debug("Reconnect failed. e: %r", e)
1049+
self.echo(str(e), err=True, fg="red")
1050+
return False
1051+
self.logger.debug("Reconnected successfully.")
1052+
self.echo("Reconnected successfully.\n", fg="yellow")
1053+
if database and self.sqlexecute.dbname != database:
1054+
for result in self.change_db(database):
1055+
self.echo(result[3])
1056+
elif database:
1057+
self.echo(f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"')
1058+
return True
1059+
10431060
def log_output(self, output: str) -> None:
10441061
"""Log the output in the audit log, if it's enabled."""
10451062
if isinstance(self.logfile, TextIOWrapper):

test/features/fixture_data/help_commands.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
| help | \? | Show this help. |
2020
| nopager | \n | Disable pager, print to stdout. |
2121
| notee | notee | Stop writing results to an output file. |
22+
| nowarnings | \w | Disable automatic warnings display. |
2223
| pager | \P [command] | Set PAGER. Print the query results via PAGER. |
2324
| prompt | \R | Change prompt format. |
2425
| quit | \q | Quit. |
@@ -30,5 +31,6 @@
3031
| tableformat | \T | Change the table format used to output results. |
3132
| tee | tee [-o] filename | Append all results to an output file (overwrite using -o). |
3233
| use | \u | Change to a new database. |
34+
| warnings | \W | Enable automatic warnings display. |
3335
| watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). |
3436
+----------------+----------------------------+------------------------------------------------------------+

test/test_main.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,35 @@
3737
]
3838

3939

40+
@dbtest
41+
def test_reconnect_no_database(executor):
42+
runner = CliRunner()
43+
sql = "\\r"
44+
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
45+
expected = "Reconnecting...\nReconnected successfully.\n\n"
46+
assert expected in result.output
47+
48+
49+
@dbtest
50+
def test_reconnect_with_different_database(executor):
51+
runner = CliRunner()
52+
database = "mysql"
53+
sql = f"\\r {database}"
54+
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
55+
expected = f'Reconnecting...\nReconnected successfully.\n\nYou are now connected to database "{database}" as user "{USER}"\n'
56+
assert expected in result.output
57+
58+
59+
@dbtest
60+
def test_reconnect_with_same_database(executor):
61+
runner = CliRunner()
62+
database = "mysql"
63+
sql = f"\\u {database}; \\r {database}"
64+
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
65+
expected = f'Reconnecting...\nReconnected successfully.\n\nYou are already connected to database "{database}" as user "{USER}"\n'
66+
assert expected in result.output
67+
68+
4069
@dbtest
4170
def test_prompt_no_host_only_socket(executor):
4271
mycli = MyCli()

0 commit comments

Comments
 (0)