@@ -213,7 +213,7 @@ def close(self) -> None:
213213 def register_special_commands (self ) -> None :
214214 special .register_special_command (self .change_db , "use" , "\\ u" , "Change to a new database." , aliases = ["\\ u" ])
215215 special .register_special_command (
216- self .change_db ,
216+ self .manual_reconnect ,
217217 "connect" ,
218218 "\\ r" ,
219219 "Reconnect to the database. Optional database argument." ,
@@ -260,6 +260,14 @@ def register_special_commands(self) -> None:
260260 self .change_prompt_format , "prompt" , "\\ R" , "Change prompt format." , aliases = ["\\ R" ], case_sensitive = True
261261 )
262262
263+ def manual_reconnect (self , arg : str = "" , ** _ ) -> Generator [tuple , None , None ]:
264+ """
265+ wrapper function to use for the \r command so that the real function
266+ may be cleanly used elsewhere
267+ """
268+ self .reconnect (arg )
269+ yield (None , None , None , None )
270+
263271 def enable_show_warnings (self , ** _ ) -> Generator [tuple , None , None ]:
264272 self .show_warnings = True
265273 msg = "Show warnings enabled."
@@ -912,18 +920,11 @@ def one_iteration(text: str | None = None) -> None:
912920 special .unset_once_if_written (self .post_redirect_command )
913921 special .flush_pipe_once_if_written (self .post_redirect_command )
914922 except err .InterfaceError :
915- logger .debug ("Attempting to reconnect." )
916- self .echo ("Reconnecting..." , fg = "yellow" )
917- try :
918- sqlexecute .connect ()
919- logger .debug ("Reconnected successfully." )
920- one_iteration (text )
921- return # OK to just return, cuz the recursion call runs to the end.
922- except OperationalError as e2 :
923- logger .debug ("Reconnect failed. e: %r" , e2 )
924- self .echo (str (e2 ), err = True , fg = "red" )
925- # If reconnection failed, don't proceed further.
923+ # attempt to reconnect
924+ if not self .reconnect ():
926925 return
926+ one_iteration (text )
927+ return # OK to just return, cuz the recursion call runs to the end.
927928 except EOFError as e :
928929 raise e
929930 except KeyboardInterrupt :
@@ -957,18 +958,11 @@ def one_iteration(text: str | None = None) -> None:
957958 except OperationalError as e1 :
958959 logger .debug ("Exception: %r" , e1 )
959960 if e1 .args [0 ] in (2003 , 2006 , 2013 ):
960- logger .debug ("Attempting to reconnect." )
961- self .echo ("Reconnecting..." , fg = "yellow" )
962- try :
963- sqlexecute .connect ()
964- logger .debug ("Reconnected successfully." )
965- one_iteration (text )
966- return # OK to just return, cuz the recursion call runs to the end.
967- except OperationalError as e2 :
968- logger .debug ("Reconnect failed. e: %r" , e2 )
969- self .echo (str (e2 ), err = True , fg = "red" )
970- # If reconnection failed, don't proceed further.
961+ # attempt to reconnect
962+ if not self .reconnect ():
971963 return
964+ one_iteration (text )
965+ return # OK to just return, cuz the recursion call runs to the end.
972966 else :
973967 logger .error ("sql: %r, error: %r" , text , e1 )
974968 logger .error ("traceback: %r" , traceback .format_exc ())
@@ -1040,6 +1034,29 @@ def one_iteration(text: str | None = None) -> None:
10401034 if not self .less_chatty :
10411035 self .echo ("Goodbye!" )
10421036
1037+ def reconnect (self , database : str = "" ) -> bool :
1038+ """
1039+ Attempt to reconnect to the database. Return True if successful,
1040+ False if unsuccessful.
1041+ """
1042+ assert self .sqlexecute is not None
1043+ self .logger .debug ("Attempting to reconnect." )
1044+ self .echo ("Reconnecting..." , fg = "yellow" )
1045+ try :
1046+ self .sqlexecute .connect ()
1047+ except OperationalError as e :
1048+ self .logger .debug ("Reconnect failed. e: %r" , e )
1049+ self .echo (str (e ), err = True , fg = "red" )
1050+ return False
1051+ self .logger .debug ("Reconnected successfully." )
1052+ self .echo ("Reconnected successfully.\n " , fg = "yellow" )
1053+ if database and self .sqlexecute .dbname != database :
1054+ for result in self .change_db (database ):
1055+ self .echo (result [3 ])
1056+ elif database :
1057+ self .echo (f'You are already connected to database "{ self .sqlexecute .dbname } " as user "{ self .sqlexecute .user } "' )
1058+ return True
1059+
10431060 def log_output (self , output : str ) -> None :
10441061 """Log the output in the audit log, if it's enabled."""
10451062 if isinstance (self .logfile , TextIOWrapper ):
0 commit comments