Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ Upcoming (TBD)

Features
--------
* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL
* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL.
* Rework reconnect logic to actually create a new connection instead of simply changing the database (#746).

Bug Fixes
--------
Expand Down
65 changes: 41 additions & 24 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def close(self) -> None:
def register_special_commands(self) -> None:
special.register_special_command(self.change_db, "use", "\\u", "Change to a new database.", aliases=["\\u"])
special.register_special_command(
self.change_db,
"connect",
self.manual_reconnect,
"reconnect",
"\\r",
"Reconnect to the database. Optional database argument.",
aliases=["\\r"],
Expand Down Expand Up @@ -260,6 +260,14 @@ def register_special_commands(self) -> None:
self.change_prompt_format, "prompt", "\\R", "Change prompt format.", aliases=["\\R"], case_sensitive=True
)

def manual_reconnect(self, arg: str = "", **_) -> Generator[tuple, None, None]:
"""
wrapper function to use for the \r command so that the real function
may be cleanly used elsewhere
"""
self.reconnect(arg)
yield (None, None, None, None)

def enable_show_warnings(self, **_) -> Generator[tuple, None, None]:
self.show_warnings = True
msg = "Show warnings enabled."
Expand Down Expand Up @@ -912,18 +920,11 @@ def one_iteration(text: str | None = None) -> None:
special.unset_once_if_written(self.post_redirect_command)
special.flush_pipe_once_if_written(self.post_redirect_command)
except err.InterfaceError:
logger.debug("Attempting to reconnect.")
self.echo("Reconnecting...", fg="yellow")
try:
sqlexecute.connect()
logger.debug("Reconnected successfully.")
one_iteration(text)
return # OK to just return, cuz the recursion call runs to the end.
except OperationalError as e2:
logger.debug("Reconnect failed. e: %r", e2)
self.echo(str(e2), err=True, fg="red")
# If reconnection failed, don't proceed further.
# attempt to reconnect
if not self.reconnect():
return
one_iteration(text)
return # OK to just return, cuz the recursion call runs to the end.
except EOFError as e:
raise e
except KeyboardInterrupt:
Expand Down Expand Up @@ -957,18 +958,11 @@ def one_iteration(text: str | None = None) -> None:
except OperationalError as e1:
logger.debug("Exception: %r", e1)
if e1.args[0] in (2003, 2006, 2013):
logger.debug("Attempting to reconnect.")
self.echo("Reconnecting...", fg="yellow")
try:
sqlexecute.connect()
logger.debug("Reconnected successfully.")
one_iteration(text)
return # OK to just return, cuz the recursion call runs to the end.
except OperationalError as e2:
logger.debug("Reconnect failed. e: %r", e2)
self.echo(str(e2), err=True, fg="red")
# If reconnection failed, don't proceed further.
# attempt to reconnect
if not self.reconnect():
return
one_iteration(text)
return # OK to just return, cuz the recursion call runs to the end.
else:
logger.error("sql: %r, error: %r", text, e1)
logger.error("traceback: %r", traceback.format_exc())
Expand Down Expand Up @@ -1040,6 +1034,29 @@ def one_iteration(text: str | None = None) -> None:
if not self.less_chatty:
self.echo("Goodbye!")

def reconnect(self, database: str = "") -> bool:
"""
Attempt to reconnect to the database. Return True if successful,
False if unsuccessful.
"""
assert self.sqlexecute is not None
self.logger.debug("Attempting to reconnect.")
self.echo("Reconnecting...", fg="yellow")
try:
self.sqlexecute.connect()
except OperationalError as e:
self.logger.debug("Reconnect failed. e: %r", e)
self.echo(str(e), err=True, fg="red")
return False
self.logger.debug("Reconnected successfully.")
self.echo("Reconnected successfully.\n", fg="yellow")
if database and self.sqlexecute.dbname != database:
for result in self.change_db(database):
self.echo(result[3])
elif database:
self.echo(f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"')
return True

def log_output(self, output: str) -> None:
"""Log the output in the audit log, if it's enabled."""
if isinstance(self.logfile, TextIOWrapper):
Expand Down
4 changes: 3 additions & 1 deletion test/features/fixture_data/help_commands.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
| \once | \o [-o] filename | Append next result to an output file (overwrite using -o). |
| \pipe_once | \| command | Send next result to a subprocess. |
| \timing | \t | Toggle timing of commands. |
| connect | \r | Reconnect to the database. Optional database argument. |
| delimiter | <null> | Change SQL delimiter. |
| exit | \q | Exit. |
| help | \? | Show this help. |
| nopager | \n | Disable pager, print to stdout. |
| notee | notee | Stop writing results to an output file. |
| nowarnings | \w | Disable automatic warnings display. |
| pager | \P [command] | Set PAGER. Print the query results via PAGER. |
| prompt | \R | Change prompt format. |
| quit | \q | Quit. |
| reconnect | \r | Reconnect to the database. Optional database argument. |
| redirectformat | \Tr | Change the table format used to output redirected results. |
| rehash | \# | Refresh auto-completions. |
| source | \. filename | Execute commands from file. |
Expand All @@ -30,5 +31,6 @@
| tableformat | \T | Change the table format used to output results. |
| tee | tee [-o] filename | Append all results to an output file (overwrite using -o). |
| use | \u | Change to a new database. |
| warnings | \W | Enable automatic warnings display. |
| watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). |
+----------------+----------------------------+------------------------------------------------------------+
29 changes: 29 additions & 0 deletions test/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,35 @@
]


@dbtest
def test_reconnect_no_database(executor):
runner = CliRunner()
sql = "\\r"
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
expected = "Reconnecting...\nReconnected successfully.\n\n"
assert expected in result.output


@dbtest
def test_reconnect_with_different_database(executor):
runner = CliRunner()
database = "mysql"
sql = f"\\r {database}"
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
expected = f'Reconnecting...\nReconnected successfully.\n\nYou are now connected to database "{database}" as user "{USER}"\n'
assert expected in result.output


@dbtest
def test_reconnect_with_same_database(executor):
runner = CliRunner()
database = "mysql"
sql = f"\\u {database}; \\r {database}"
result = runner.invoke(cli, args=CLI_ARGS, input=sql)
expected = f'Reconnecting...\nReconnected successfully.\n\nYou are already connected to database "{database}" as user "{USER}"\n'
assert expected in result.output


@dbtest
def test_prompt_no_host_only_socket(executor):
mycli = MyCli()
Expand Down