Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ repos:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter. TODO: uncomment when the rest of the code is ruff-formatted
# - id: ruff-format
# Run the formatter.
- id: ruff-format
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ get this running in a development setup.
https://github.com/dbcli/pgcli/blob/main/CONTRIBUTING.rst

Please feel free to reach out to us if you need help.

* Amjith, pgcli author: amjith.r@gmail.com, Twitter: `@amjithr <http://twitter.com/amjithr>`_
* Irina, pgcli maintainer: i.chernyavska@gmail.com, Twitter: `@irinatruong <http://twitter.com/irinatruong>`_

Expand Down
8 changes: 2 additions & 6 deletions pgcli/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def keyring_initialize(keyring_enabled, *, logger):

try:
keyring = importlib.import_module("keyring")
except (
ModuleNotFoundError
) as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
except ModuleNotFoundError as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
logger.warning("import keyring failed: %r.", e)


Expand All @@ -40,9 +38,7 @@ def keyring_get_password(key):
passwd = keyring.get_password("pgcli", key) or ""
except Exception as e:
click.secho(
keyring_error_message.format(
"Load your password from keyring returned:", str(e)
),
keyring_error_message.format("Load your password from keyring returned:", str(e)),
err=True,
fg="red",
)
Expand Down
8 changes: 2 additions & 6 deletions pgcli/completion_refresher.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,14 @@ def refresh(self, executor, special, callbacks, history=None, settings=None):
)
self._completer_thread.daemon = True
self._completer_thread.start()
return [
(None, None, None, "Auto-completion refresh started in the background.")
]
return [(None, None, None, "Auto-completion refresh started in the background.")]

def is_refreshing(self):
return self._completer_thread and self._completer_thread.is_alive()

def _bg_refresh(self, pgexecute, special, callbacks, history=None, settings=None):
settings = settings or {}
completer = PGCompleter(
smart_completion=True, pgspecial=special, settings=settings
)
completer = PGCompleter(smart_completion=True, pgspecial=special, settings=settings)

if settings.get("single_connection"):
executor = pgexecute
Expand Down
3 changes: 1 addition & 2 deletions pgcli/key_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def _(event):
# history search, and one of several conditions are True
@kb.add(
"enter",
filter=~(completion_is_selected | is_searching)
& buffer_should_be_handled(pgcli),
filter=~(completion_is_selected | is_searching) & buffer_should_be_handled(pgcli),
)
def _(event):
_logger.debug("Detected enter key.")
Expand Down
13 changes: 3 additions & 10 deletions pgcli/packages/formatter/sqlformatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,16 @@ def adapter(data, headers, table_format=None, **kwargs):
yield 'UPDATE "{}" SET'.format(table_name)
prefix = " "
for i, v in enumerate(d[keys:], keys):
yield '{}"{}" = {}'.format(
prefix, headers[i], escape_for_sql_statement(v)
)
yield '{}"{}" = {}'.format(prefix, headers[i], escape_for_sql_statement(v))
if prefix == " ":
prefix = ", "
f = '"{}" = {}'
where = (
f.format(headers[i], escape_for_sql_statement(d[i]))
for i in range(keys)
)
where = (f.format(headers[i], escape_for_sql_statement(d[i])) for i in range(keys))
yield "WHERE {};".format(" AND ".join(where))


def register_new_formatter(TabularOutputFormatter):
global formatter
formatter = TabularOutputFormatter
for sql_format in supported_formats:
TabularOutputFormatter.register_new_formatter(
sql_format, adapter, preprocessors, {"table_format": sql_format}
)
TabularOutputFormatter.register_new_formatter(sql_format, adapter, preprocessors, {"table_format": sql_format})
4 changes: 1 addition & 3 deletions pgcli/packages/parseutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def is_destructive(queries, keywords):
for query in sqlparse.split(queries):
if query:
formatted_sql = sqlparse.format(query.lower(), strip_comments=True).strip()
if "unconditional_update" in keywords and query_is_unconditional_update(
formatted_sql
):
if "unconditional_update" in keywords and query_is_unconditional_update(formatted_sql):
return True
if query_starts_with(formatted_sql, keywords):
return True
Expand Down
10 changes: 2 additions & 8 deletions pgcli/packages/parseutils/meta.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from collections import namedtuple

_ColumnMetadata = namedtuple(
"ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"]
)
_ColumnMetadata = namedtuple("ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"])


def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False):
Expand Down Expand Up @@ -143,11 +141,7 @@ def arg(name, typ, num):
num_args = len(args)
num_defaults = len(self.arg_defaults)
has_default = num + num_defaults >= num_args
default = (
self.arg_defaults[num - num_args + num_defaults]
if has_default
else None
)
default = self.arg_defaults[num - num_args + num_defaults] if has_default else None
return ColumnMetadata(name, typ, [], default, has_default)

return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
Expand Down
25 changes: 5 additions & 20 deletions pgcli/packages/parseutils/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,9 @@
from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation

TableReference = namedtuple(
"TableReference", ["schema", "name", "alias", "is_function"]
)
TableReference = namedtuple("TableReference", ["schema", "name", "alias", "is_function"])
TableReference.ref = property(
lambda self: self.alias
or (
self.name
if self.name.islower() or self.name[0] == '"'
else '"' + self.name + '"'
)
lambda self: self.alias or (self.name if self.name.islower() or self.name[0] == '"' else '"' + self.name + '"')
)


Expand Down Expand Up @@ -53,11 +46,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its variants
# INNER JOIN, FULL OUTER JOIN, etc.
elif (
item.ttype is Keyword
and (not item.value.upper() == "FROM")
and (not item.value.upper().endswith("JOIN"))
):
elif item.ttype is Keyword and (not item.value.upper() == "FROM") and (not item.value.upper().endswith("JOIN")):
tbl_prefix_seen = False
else:
yield item
Expand Down Expand Up @@ -116,15 +105,11 @@ def parse_identifier(item):
try:
schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
is_function = allow_functions and _identifier_is_function(
identifier
)
is_function = allow_functions and _identifier_is_function(identifier)
except AttributeError:
continue
if real_name:
yield TableReference(
schema_name, real_name, identifier.get_alias(), is_function
)
yield TableReference(schema_name, real_name, identifier.get_alias(), is_function)
elif isinstance(item, Identifier):
schema_name, real_name, alias = parse_identifier(item)
is_function = allow_functions and _identifier_is_function(item)
Expand Down
4 changes: 1 addition & 3 deletions pgcli/packages/parseutils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def find_prev_keyword(sql, n_skip=0):
logical_operators = ("AND", "OR", "NOT", "BETWEEN")

for t in reversed(flattened):
if t.value == "(" or (
t.is_keyword and (t.value.upper() not in logical_operators)
):
if t.value == "(" or (t.is_keyword and (t.value.upper() not in logical_operators)):
# Find the location of token t in the original parsed statement
# We can't use parsed.token_index(t) because t may be a child token
# inside a TokenList, in which case token_index throws an error
Expand Down
31 changes: 7 additions & 24 deletions pgcli/pgexecute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@

_logger = logging.getLogger(__name__)

ViewDef = namedtuple(
"ViewDef", "nspname relname relkind viewdef reloptions checkoption"
)
ViewDef = namedtuple("ViewDef", "nspname relname relkind viewdef reloptions checkoption")


# we added this funcion to strip beginning comments
Expand Down Expand Up @@ -51,9 +49,7 @@ def register_typecasters(connection):
"json",
"jsonb",
]:
connection.adapters.register_loader(
forced_text_type, psycopg.types.string.TextLoader
)
connection.adapters.register_loader(forced_text_type, psycopg.types.string.TextLoader)


# pg3: I don't know what is this
Expand Down Expand Up @@ -219,9 +215,7 @@ def connect(
new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}

if new_params["password"]:
new_params["dsn"] = make_conninfo(
new_params["dsn"], password=new_params.pop("password")
)
new_params["dsn"] = make_conninfo(new_params["dsn"], password=new_params.pop("password"))

conn_params.update({k: v for k, v in new_params.items() if v})

Expand Down Expand Up @@ -262,11 +256,7 @@ def connect(
self.extra_args = kwargs

if not self.host:
self.host = (
"pgbouncer"
if self.is_virtual_database()
else self.get_socket_directory()
)
self.host = "pgbouncer" if self.is_virtual_database() else self.get_socket_directory()

self.pid = conn.info.backend_pid
self.superuser = conn.info.parameter_status("is_superuser") in ("on", "1")
Expand Down Expand Up @@ -306,10 +296,7 @@ def failed_transaction(self):

def valid_transaction(self):
status = self.conn.info.transaction_status
return (
status == psycopg.pq.TransactionStatus.ACTIVE
or status == psycopg.pq.TransactionStatus.INTRANS
)
return status == psycopg.pq.TransactionStatus.ACTIVE or status == psycopg.pq.TransactionStatus.INTRANS

def run(
self,
Expand Down Expand Up @@ -649,9 +636,7 @@ def is_protocol_error(self):

def get_socket_directory(self):
with self.conn.cursor() as cur:
_logger.debug(
"Socket directory Query. sql: %r", self.socket_directory_query
)
_logger.debug("Socket directory Query. sql: %r", self.socket_directory_query)
cur.execute(self.socket_directory_query)
result = cur.fetchone()
return result[0] if result else ""
Expand Down Expand Up @@ -889,8 +874,6 @@ def get_timezone(self) -> str:
return cur.fetchone()[0]

def set_timezone(self, timezone: str):
query = psycopg.sql.SQL("set time zone {}").format(
psycopg.sql.Identifier(timezone)
)
query = psycopg.sql.SQL("set time zone {}").format(psycopg.sql.Identifier(timezone))
with self.conn.cursor() as cur:
cur.execute(query)
4 changes: 1 addition & 3 deletions pgcli/pgstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def style_factory(name, cli_style):
prompt_styles.append((token, cli_style[token]))

override_style = Style([("bottom-toolbar", "noreverse")])
return merge_styles(
[style_from_pygments_cls(style), override_style, Style(prompt_styles)]
)
return merge_styles([style_from_pygments_cls(style), override_style, Style(prompt_styles)])


def style_factory_output(name, cli_style):
Expand Down
16 changes: 4 additions & 12 deletions pgcli/pgtoolbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,10 @@ def get_toolbar_tokens():
if pgcli.multiline_mode == "safe":
result.append(("class:bottom-toolbar", " ([Esc] [Enter] to execute]) "))
else:
result.append(
("class:bottom-toolbar", " (Semi-colon [;] will end the line) ")
)
result.append(("class:bottom-toolbar", " (Semi-colon [;] will end the line) "))

if pgcli.vi_mode:
result.append(
("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ") ")
)
result.append(("class:bottom-toolbar", "[F4] Vi-mode (" + _get_vi_mode() + ") "))
else:
result.append(("class:bottom-toolbar", "[F4] Emacs-mode "))

Expand All @@ -54,14 +50,10 @@ def get_toolbar_tokens():
result.append(("class:bottom-toolbar", "[F5] Explain: OFF "))

if pgcli.pgexecute.failed_transaction():
result.append(
("class:bottom-toolbar.transaction.failed", " Failed transaction")
)
result.append(("class:bottom-toolbar.transaction.failed", " Failed transaction"))

if pgcli.pgexecute.valid_transaction():
result.append(
("class:bottom-toolbar.transaction.valid", " Transaction")
)
result.append(("class:bottom-toolbar.transaction.valid", " Transaction"))

if pgcli.completion_refresher.is_refreshing():
result.append(("class:bottom-toolbar", " Refreshing completions..."))
Expand Down
Loading
Loading