diff --git a/.coveragerc b/.coveragerc
index b2713c796..e2ccc2cae 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -1,3 +1,2 @@
[run]
-parallel=True
source=pgcli
diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
index 35e8486bf..52c903d80 100644
--- a/.github/PULL_REQUEST_TEMPLATE.md
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -8,5 +8,5 @@
- [ ] I've added this contribution to the `changelog.rst`.
- [ ] I've added my name to the `AUTHORS` file (or it's already there).
-- [ ] I installed pre-commit hooks (`pip install pre-commit && pre-commit install`), and ran `black` on my code.
+- [ ] I installed pre-commit hooks (`pip install pre-commit && pre-commit install`).
- [x] Please squash merge this pull request (uncheck if you'd like us to merge as multiple commits)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 6ea35faa1..ac5b3dae1 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -18,7 +18,7 @@ jobs:
services:
postgres:
- image: postgres:9.6
+ image: postgres:10
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
@@ -31,10 +31,14 @@ jobs:
--health-retries 5
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
+ with:
+ version: "latest"
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
+ uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
with:
python-version: ${{ matrix.python-version }}
@@ -68,14 +72,10 @@ jobs:
psql -h localhost -U postgres -p 6432 pgbouncer -c 'show help'
- name: Install requirements
- run: |
- pip install -U pip setuptools
- pip install --no-cache-dir ".[sshtunnel]"
- pip install -r requirements-dev.txt
- pip install keyrings.alt>=3.1
+ run: uv sync --all-extras -p ${{ matrix.python-version }}
- name: Run unit tests
- run: coverage run --source pgcli -m pytest
+ run: uv run tox -e py${{ matrix.python-version }}
- name: Run integration tests
env:
@@ -83,17 +83,10 @@ jobs:
PGPASSWORD: postgres
TERM: xterm
- run: behave tests/features --no-capture
+ run: uv run tox -e integration
- name: Check changelog for ReST compliance
- run: docutils --halt=warning changelog.rst >/dev/null
+ run: uv run tox -e rest
- - name: Run Black
- run: black --check .
- if: matrix.python-version == '3.8'
-
- - name: Coverage
- run: |
- coverage combine
- coverage report
- codecov
+ - name: Run style checks
+ run: uv run tox -e style
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
new file mode 100644
index 000000000..8b9d5728e
--- /dev/null
+++ b/.github/workflows/publish.yml
@@ -0,0 +1,97 @@
+name: Publish Python Package
+
+on:
+ release:
+ types: [created]
+
+permissions:
+ contents: read
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+
+ strategy:
+ matrix:
+ python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
+
+ services:
+ postgres:
+ image: postgres:10
+ env:
+ POSTGRES_USER: postgres
+ POSTGRES_PASSWORD: postgres
+ ports:
+ - 5432:5432
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ --health-retries 5
+
+ steps:
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
+ with:
+ version: "latest"
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install dependencies
+ run: uv sync --all-extras -p ${{ matrix.python-version }}
+
+ - name: Run unit tests
+ env:
+ LANG: en_US.UTF-8
+ run: uv run tox -e py${{ matrix.python-version }}
+
+ - name: Run Style Checks
+ run: uv run tox -e style
+
+ build:
+ runs-on: ubuntu-latest
+ needs: [test]
+
+ steps:
+ - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+ - uses: astral-sh/setup-uv@c7f87aa956e4c323abf06d5dec078e358f6b4d04 # v6.0.0
+ with:
+ version: "latest"
+
+ - name: Set up Python
+ uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
+ with:
+ python-version: '3.13'
+
+ - name: Install dependencies
+ run: uv sync --all-extras -p 3.13
+
+ - name: Build
+ run: uv build
+
+ - name: Store the distribution packages
+ uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
+ with:
+ name: python-packages
+ path: dist/
+
+ publish:
+ name: Publish to PyPI
+ runs-on: ubuntu-latest
+ if: startsWith(github.ref, 'refs/tags/')
+ needs: [build]
+ environment: release
+ permissions:
+ id-token: write
+ steps:
+ - name: Download distribution packages
+ uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
+ with:
+ name: python-packages
+ path: dist/
+ - name: Publish to PyPI
+ uses: pypa/gh-action-pypi-publish@76f52bc884231f62b9a034ebfe128415bbaabdfc # v1.12.4
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index 7a3386796..1437096ab 100644
--- a/.gitignore
+++ b/.gitignore
@@ -72,4 +72,5 @@ target/
venv/
.ropeproject/
+uv.lock
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 8462cc2ca..f44dd5c09 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,5 +1,10 @@
repos:
-- repo: https://github.com/psf/black
- rev: 23.3.0
- hooks:
- - id: black
+- repo: https://github.com/astral-sh/ruff-pre-commit
+ # Ruff version.
+ rev: v0.11.7
+ hooks:
+ # Run the linter.
+ - id: ruff
+ args: [ --fix ]
+ # Run the formatter. TODO: uncomment when the rest of the code is ruff-formatted
+ # - id: ruff-format
diff --git a/DEVELOP.rst b/CONTRIBUTING.rst
similarity index 82%
rename from DEVELOP.rst
rename to CONTRIBUTING.rst
index aed2cf8a5..ad7eb5bdc 100644
--- a/DEVELOP.rst
+++ b/CONTRIBUTING.rst
@@ -23,8 +23,8 @@ repo.
$ git remote add upstream git@github.com:dbcli/pgcli.git
Once the 'upstream' end point is added you can then periodically do a ``git
-pull upstream master`` to update your local copy and then do a ``git push
-origin master`` to keep your own fork up to date.
+pull upstream main`` to update your local copy and then do a ``git push
+origin main`` to keep your own fork up to date.
Check Github's `Understanding the GitHub flow guide
`_ for a more detailed
@@ -38,30 +38,23 @@ pgcli. If you're developing pgcli, you'll need to install it in a slightly
different way so you can see the effects of your changes right away without
having to go through the install cycle every time you change the code.
-It is highly recommended to use virtualenv for development. If you don't know
-what a virtualenv is, `this guide `_
-will help you get started.
-
-Create a virtualenv (let's call it pgcli-dev). Activate it:
+Set up [uv](https://docs.astral.sh/uv/getting-started/installation/) for development:
::
+ cd pgcli
+ uv venv
source ./pgcli-dev/bin/activate
- or
-
- .\pgcli-dev\scripts\activate (for Windows)
-
-Once the virtualenv is activated, `cd` into the local clone of pgcli folder
-and install pgcli using pip as follows:
+Once the virtualenv is activated, install pgcli using pip as follows:
::
- $ pip install --editable .
+ $ uv pip install --editable .
or
- $ pip install -e .
+ $ uv pip install -e .
This will install the necessary dependencies as well as install pgcli from the
working folder into the virtualenv. By installing it using `pip install -e`
@@ -165,9 +158,7 @@ in the ``tests`` directory. An example::
First, install the requirements for testing:
::
- $ pip install -U pip setuptools
- $ pip install --no-cache-dir ".[sshtunnel]"
- $ pip install -r requirements-dev.txt
+ $ uv pip install ".[dev]"
Ensure that the database user has permissions to create and drop test databases
by checking your ``pg_hba.conf`` file. The default user should be ``postgres``
@@ -180,20 +171,14 @@ service for the changes to take effect.
# ONLY IF YOU MADE CHANGES TO YOUR pg_hba.conf FILE
$ sudo service postgresql restart
-After that, tests in the ``/pgcli/tests`` directory can be run with:
-(Note that these ``behave`` tests do not currently work when developing on Windows due to pexpect incompatibility.)
+After that:
::
- # on directory /pgcli/tests
+ $ cd pgcli/tests
$ behave
-And on the ``/pgcli`` directory:
-
-::
-
- # on directory /pgcli
- $ py.test
+Note that these ``behave`` tests do not currently work when developing on Windows due to pexpect incompatibility.
To see stdout/stderr, use the following command:
@@ -209,10 +194,21 @@ Troubleshooting the integration tests
- Check `this issue `_ for relevant information.
- `File an issue `_.
+Running the unit tests
+----------------------
+
+The unit tests can be run with pytest:
+
+::
+
+ $ cd pgcli
+ $ pytest
+
+
Coding Style
------------
-``pgcli`` uses `black `_ to format the source code. Make sure to install black.
+``pgcli`` uses `ruff `_ to format the source code.
Releases
--------
diff --git a/README.rst b/README.rst
index b7c222fea..8b7ea0854 100644
--- a/README.rst
+++ b/README.rst
@@ -155,7 +155,7 @@ If you're interested in contributing to this project, first of all I would like
to extend my heartfelt gratitude. I've written a small doc to describe how to
get this running in a development setup.
-https://github.com/dbcli/pgcli/blob/master/DEVELOP.rst
+https://github.com/dbcli/pgcli/blob/main/CONTRIBUTING.rst
Please feel free to reach out to us if you need help.
* Amjith, pgcli author: amjith.r@gmail.com, Twitter: `@amjithr `_
@@ -362,12 +362,12 @@ Thanks to all the beta testers and contributors for your time and patience. :)
.. |Build Status| image:: https://github.com/dbcli/pgcli/actions/workflows/ci.yml/badge.svg?branch=main
:target: https://github.com/dbcli/pgcli/actions/workflows/ci.yml
-.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/master/graph/badge.svg
+.. |CodeCov| image:: https://codecov.io/gh/dbcli/pgcli/branch/main/graph/badge.svg
:target: https://codecov.io/gh/dbcli/pgcli
:alt: Code coverage report
-.. |Landscape| image:: https://landscape.io/github/dbcli/pgcli/master/landscape.svg?style=flat
- :target: https://landscape.io/github/dbcli/pgcli/master
+.. |Landscape| image:: https://landscape.io/github/dbcli/pgcli/main/landscape.svg?style=flat
+ :target: https://landscape.io/github/dbcli/pgcli/main
:alt: Code Health
.. |PyPI| image:: https://img.shields.io/pypi/v/pgcli.svg
diff --git a/RELEASES.md b/RELEASES.md
index 526c260e8..d5bc64035 100644
--- a/RELEASES.md
+++ b/RELEASES.md
@@ -1,24 +1,6 @@
Releasing pgcli
---------------
-You have been made the maintainer of `pgcli`? Congratulations! We have a release script to help you:
+You have been made the maintainer of `pgcli`? Congratulations!
-```sh
-> python release.py --help
-Usage: release.py [options]
-
-Options:
- -h, --help show this help message and exit
- -c, --confirm-steps Confirm every step. If the step is not confirmed, it
- will be skipped.
- -d, --dry-run Print out, but not actually run any steps.
-```
-
-The script can be run with `-c` to confirm or skip steps. There's also a `--dry-run` option that only prints out the steps.
-
-To release a new version of the package:
-
-* Create and merge a PR to bump the version in the changelog ([example PR](https://github.com/dbcli/pgcli/pull/1325)).
-* Pull `main` and bump the version number inside `pgcli/__init__.py`. Do not check in - the release script will do that.
-* Make sure you have the dev requirements installed: `pip install -r requirements-dev.txt -U --upgrade-strategy only-if-needed`.
-* Finally, run the release script: `python release.py`.
+To release a new version of the package, [create a new release](https://github.com/dbcli/pgcli/releases) in Github. This will trigger a Github action which will run all the tests, build the wheel and upload it to PyPI.
\ No newline at end of file
diff --git a/changelog.rst b/changelog.rst
index 123451153..8cab8c158 100644
--- a/changelog.rst
+++ b/changelog.rst
@@ -8,6 +8,15 @@ Features:
* Provide `init-command` in the config file
* Support dsn specific init-command in the config file
+Internal:
+---------
+
+* Moderize the repository
+ * Use uv instead of pip
+ * Use github trusted publisher for pypi release
+ * Update dev requirements and replace requirements-dev.txt with pyproject.toml
+ * Use ruff instead of black
+
4.3.0 (2025-03-22)
==================
diff --git a/pgcli/config.py b/pgcli/config.py
index 22f08dc07..2b44a7bb7 100644
--- a/pgcli/config.py
+++ b/pgcli/config.py
@@ -1,4 +1,3 @@
-import errno
import shutil
import os
import platform
diff --git a/pgcli/main.py b/pgcli/main.py
index 4e8f4a768..61bc277bc 100644
--- a/pgcli/main.py
+++ b/pgcli/main.py
@@ -139,9 +139,7 @@ class PgCliQuitError(Exception):
def notify_callback(notify: Notify):
click.secho(
- 'Notification received on channel "{}" (PID {}):\n{}'.format(
- notify.channel, notify.pid, notify.payload
- ),
+ 'Notification received on channel "{}" (PID {}):\n{}'.format(notify.channel, notify.pid, notify.payload),
fg="green",
)
@@ -155,9 +153,7 @@ def set_default_pager(self, config):
os_environ_pager = os.environ.get("PAGER")
if configured_pager:
- self.logger.info(
- 'Default pager found in config file: "%s"', configured_pager
- )
+ self.logger.info('Default pager found in config file: "%s"', configured_pager)
os.environ["PAGER"] = configured_pager
elif os_environ_pager:
self.logger.info(
@@ -166,9 +162,7 @@ def set_default_pager(self, config):
)
os.environ["PAGER"] = os_environ_pager
else:
- self.logger.info(
- "No default pager found in environment. Using os default pager"
- )
+ self.logger.info("No default pager found in environment. Using os default pager")
# Set default set of less recommended options, if they are not already set.
# They are ignored if pager is different than less.
@@ -219,9 +213,7 @@ def __init__(
self.multiline_mode = c["main"].get("multi_line_mode", "psql")
self.vi_mode = c["main"].as_bool("vi")
self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand")
- self.auto_retry_closed_connection = c["main"].as_bool(
- "auto_retry_closed_connection"
- )
+ self.auto_retry_closed_connection = c["main"].as_bool("auto_retry_closed_connection")
self.expanded_output = c["main"].as_bool("expand")
self.pgspecial.timing_enabled = c["main"].as_bool("timing")
if row_limit is not None:
@@ -247,26 +239,14 @@ def __init__(
self.syntax_style = c["main"]["syntax_style"]
self.cli_style = c["colors"]
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
- self.destructive_warning = parse_destructive_warning(
- warn or c["main"].as_list("destructive_warning")
- )
- self.destructive_warning_restarts_connection = c["main"].as_bool(
- "destructive_warning_restarts_connection"
- )
- self.destructive_statements_require_transaction = c["main"].as_bool(
- "destructive_statements_require_transaction"
- )
+ self.destructive_warning = parse_destructive_warning(warn or c["main"].as_list("destructive_warning"))
+ self.destructive_warning_restarts_connection = c["main"].as_bool("destructive_warning_restarts_connection")
+ self.destructive_statements_require_transaction = c["main"].as_bool("destructive_statements_require_transaction")
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
- self.verbose_errors = "verbose_errors" in c["main"] and c["main"].as_bool(
- "verbose_errors"
- )
+ self.verbose_errors = "verbose_errors" in c["main"] and c["main"].as_bool("verbose_errors")
self.null_string = c["main"].get("null_string", "")
- self.prompt_format = (
- prompt
- if prompt is not None
- else c["main"].get("prompt", self.default_prompt)
- )
+ self.prompt_format = prompt if prompt is not None else c["main"].get("prompt", self.default_prompt)
self.prompt_dsn_format = prompt_dsn
self.on_error = c["main"]["on_error"].upper()
self.decimal_format = c["data_formats"]["decimal"]
@@ -275,9 +255,7 @@ def __init__(
auth.keyring_initialize(c["main"].as_bool("keyring"), logger=self.logger)
self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
- self.pgspecial.pset_pager(
- self.config["main"].as_bool("enable_pager") and "on" or "off"
- )
+ self.pgspecial.pset_pager(self.config["main"].as_bool("enable_pager") and "on" or "off")
self.style_output = style_factory_output(self.syntax_style, c["colors"])
@@ -290,9 +268,7 @@ def __init__(
# Initialize completer
smart_completion = c["main"].as_bool("smart_completion")
keyword_casing = c["main"]["keyword_casing"]
- single_connection = single_connection or c["main"].as_bool(
- "always_use_single_connection"
- )
+ single_connection = single_connection or c["main"].as_bool("always_use_single_connection")
self.settings = {
"casing_file": get_casing_file(c),
"generate_casing_file": c["main"].as_bool("generate_casing_file"),
@@ -307,9 +283,7 @@ def __init__(
"alias_map_file": c["main"]["alias_map_file"] or None,
}
- completer = PGCompleter(
- smart_completion, pgspecial=self.pgspecial, settings=self.settings
- )
+ completer = PGCompleter(smart_completion, pgspecial=self.pgspecial, settings=self.settings)
self.completer = completer
self._completer_lock = threading.Lock()
self.register_special_commands()
@@ -341,7 +315,8 @@ def register_special_commands(self):
aliases=("use", "\\connect", "USE"),
)
- refresh_callback = lambda: self.refresh_completions(persist_priorities="all")
+ def refresh_callback():
+ return self.refresh_completions(persist_priorities="all")
self.pgspecial.register(
self.quit,
@@ -375,9 +350,7 @@ def register_special_commands(self):
"Refresh auto-completions.",
arg_type=NO_QUERY,
)
- self.pgspecial.register(
- self.execute_from_file, "\\i", "\\i filename", "Execute commands from file."
- )
+ self.pgspecial.register(self.execute_from_file, "\\i", "\\i filename", "Execute commands from file.")
self.pgspecial.register(
self.write_to_file,
"\\o",
@@ -390,9 +363,7 @@ def register_special_commands(self):
"\\log-file [filename]",
"Log all query results to a logfile, in addition to the normal output destination.",
)
- self.pgspecial.register(
- self.info_connection, "\\conninfo", "\\conninfo", "Get connection details"
- )
+ self.pgspecial.register(self.info_connection, "\\conninfo", "\\conninfo", "Get connection details")
self.pgspecial.register(
self.change_table_format,
"\\T",
@@ -461,8 +432,7 @@ def info_connection(self, **_):
None,
None,
'You are connected to database "%s" as user '
- '"%s" on %s at port "%s".'
- % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port),
+ '"%s" on %s at port "%s".' % (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port),
)
def change_db(self, pattern, **_):
@@ -470,7 +440,7 @@ def change_db(self, pattern, **_):
# Get all the parameters in pattern, handling double quotes if any.
infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern)
# Now removing quotes.
- list(map(lambda s: s.strip('"'), infos))
+ [s.strip('"') for s in infos]
infos.extend([None] * (4 - len(infos)))
db, user, host, port = infos
@@ -492,8 +462,7 @@ def change_db(self, pattern, **_):
None,
None,
None,
- 'You are now connected to database "%s" as '
- 'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user),
+ 'You are now connected to database "%s" as user "%s"' % (self.pgexecute.dbname, self.pgexecute.user),
)
def execute_from_file(self, pattern, **_):
@@ -514,9 +483,7 @@ def execute_from_file(self, pattern, **_):
):
message = "Destructive statements must be run within a transaction. Command execution stopped."
return [(None, None, None, message)]
- destroy = confirm_destructive_query(
- query, self.destructive_warning, self.dsn_alias
- )
+ destroy = confirm_destructive_query(query, self.destructive_warning, self.dsn_alias)
if destroy is False:
message = "Wise choice. Command execution stopped."
return [(None, None, None, message)]
@@ -591,10 +558,7 @@ def initialize_logging(self):
log_level = level_map[log_level.upper()]
- formatter = logging.Formatter(
- "%(asctime)s (%(process)d/%(threadName)s) "
- "%(name)s %(levelname)s - %(message)s"
- )
+ formatter = logging.Formatter("%(asctime)s (%(process)d/%(threadName)s) %(name)s %(levelname)s - %(message)s")
handler.setFormatter(formatter)
@@ -615,9 +579,7 @@ def connect_dsn(self, dsn, **kwargs):
def connect_service(self, service, user):
service_config, file = parse_service_info(service)
if service_config is None:
- click.secho(
- f"service '{service}' was not found in {file}", err=True, fg="red"
- )
+ click.secho(f"service '{service}' was not found in {file}", err=True, fg="red")
sys.exit(1)
self.connect(
database=service_config.get("dbname"),
@@ -633,9 +595,7 @@ def connect_uri(self, uri):
kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
self.connect(**kwargs)
- def connect(
- self, database="", host="", user="", port="", passwd="", dsn="", **kwargs
- ):
+ def connect(self, database="", host="", user="", port="", passwd="", dsn="", **kwargs):
# Connect to the database.
if not user:
@@ -657,9 +617,7 @@ def connect(
# If we successfully parsed a password from a URI, there's no need to
# prompt for it, even with the -W flag
if self.force_passwd_prompt and not passwd:
- passwd = click.prompt(
- "Password for %s" % user, hide_input=True, show_default=False, type=str
- )
+ passwd = click.prompt("Password for %s" % user, hide_input=True, show_default=False, type=str)
key = f"{user}@{host}"
@@ -825,13 +783,9 @@ def execute_command(self, text, handle_closed_connection=True):
and not self.pgexecute.valid_transaction()
and is_destructive(text, self.destructive_warning)
):
- click.secho(
- "Destructive statements must be run within a transaction."
- )
+ click.secho("Destructive statements must be run within a transaction.")
raise KeyboardInterrupt
- destroy = confirm_destructive_query(
- text, self.destructive_warning, self.dsn_alias
- )
+ destroy = confirm_destructive_query(text, self.destructive_warning, self.dsn_alias)
if destroy is False:
click.secho("Wise choice!")
raise KeyboardInterrupt
@@ -844,9 +798,7 @@ def execute_command(self, text, handle_closed_connection=True):
# Restart connection to the database
self.pgexecute.connect()
logger.debug("cancelled query and restarted connection, sql: %r", text)
- click.secho(
- "cancelled query and restarted connection", err=True, fg="red"
- )
+ click.secho("cancelled query and restarted connection", err=True, fg="red")
else:
logger.debug("cancelled query, sql: %r", text)
click.secho("cancelled query", err=True, fg="red")
@@ -866,9 +818,7 @@ def execute_command(self, text, handle_closed_connection=True):
click.secho(str(e), err=True, fg="red")
else:
try:
- if self.output_file and not text.startswith(
- ("\\o ", "\\log-file", "\\? ", "\\echo ")
- ):
+ if self.output_file and not text.startswith(("\\o ", "\\log-file", "\\? ", "\\echo ")):
try:
with open(self.output_file, "a", encoding="utf-8") as f:
click.echo(text, file=f)
@@ -881,16 +831,10 @@ def execute_command(self, text, handle_closed_connection=True):
self.echo_via_pager("\n".join(output))
# Log to file in addition to normal output
- if (
- self.log_file
- and not text.startswith(("\\o ", "\\log-file", "\\? ", "\\echo "))
- and not text.strip() == ""
- ):
+ if self.log_file and not text.startswith(("\\o ", "\\log-file", "\\? ", "\\echo ")) and not text.strip() == "":
try:
with open(self.log_file, "a", encoding="utf-8") as f:
- click.echo(
- dt.datetime.now().isoformat(), file=f
- ) # timestamp log
+ click.echo(dt.datetime.now().isoformat(), file=f) # timestamp log
click.echo(text, file=f)
click.echo("\n".join(output), file=f)
click.echo("", file=f) # extra newline
@@ -1018,9 +962,7 @@ def handle_watch_command(self, text):
try:
self.watch_command = self.query_history[-1].query
except IndexError:
- click.secho(
- "\\watch cannot be used with an empty query", err=True, fg="red"
- )
+ click.secho("\\watch cannot be used with an empty query", err=True, fg="red")
self.watch_command = None
# If there's a command to \watch, run it in a loop.
@@ -1050,10 +992,7 @@ def get_message():
prompt = self.get_prompt(prompt_format)
- if (
- prompt_format == self.default_prompt
- and len(prompt) > self.max_len_prompt
- ):
+ if prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt:
prompt = self.get_prompt("\\d> ")
prompt = prompt.replace("\\x1b", "\x1b")
@@ -1116,12 +1055,7 @@ def _should_limit_output(self, sql, cur):
if not is_select(sql):
return False
- return (
- not self._has_limit(sql)
- and self.row_limit != 0
- and cur
- and cur.rowcount > self.row_limit
- )
+ return not self._has_limit(sql) and self.row_limit != 0 and cur and cur.rowcount > self.row_limit
def _has_limit(self, sql):
if not sql:
@@ -1191,18 +1125,12 @@ def _evaluate_command(self, text):
missingval=self.null_string,
expanded=expanded,
max_width=max_width,
- case_function=(
- self.completer.case
- if self.settings["case_column_headers"]
- else lambda x: x
- ),
+ case_function=(self.completer.case if self.settings["case_column_headers"] else lambda x: x),
style_output=self.style_output,
max_field_width=self.max_field_width,
)
execution = time() - start
- formatted = format_output(
- title, cur, headers, status, settings, self.explain_mode
- )
+ formatted = format_output(title, cur, headers, status, settings, self.explain_mode)
output.extend(formatted)
total = time() - start
@@ -1241,9 +1169,7 @@ def _handle_server_closed_connection(self, text):
click.secho("Reconnect Failed", fg="red")
click.secho(str(e), err=True, fg="red")
else:
- retry = self.auto_retry_closed_connection or confirm(
- "Run the query from before reconnecting?"
- )
+ retry = self.auto_retry_closed_connection or confirm("Run the query from before reconnecting?")
if retry:
click.secho("Running query...", fg="green")
# Don't get stuck in a retry loop
@@ -1258,9 +1184,7 @@ def refresh_completions(self, history=None, persist_priorities="all"):
:param persist_priorities: 'all' or 'keywords'
"""
- callback = functools.partial(
- self._on_completions_refreshed, persist_priorities=persist_priorities
- )
+ callback = functools.partial(self._on_completions_refreshed, persist_priorities=persist_priorities)
return self.completion_refresher.refresh(
self.pgexecute,
self.pgspecial,
@@ -1311,9 +1235,7 @@ def _swap_completer_objects(self, new_completer, persist_priorities):
def get_completions(self, text, cursor_positition):
with self._completer_lock:
- return self.completer.get_completions(
- Document(text=text, cursor_position=cursor_positition), None
- )
+ return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None)
def get_prompt(self, string):
# should be before replacing \\d
@@ -1340,10 +1262,7 @@ def is_too_wide(self, line):
"""Will this line be too wide to fit into terminal?"""
if not self.prompt_app:
return False
- return (
- len(COLOR_CODE_REGEX.sub("", line))
- > self.prompt_app.output.get_size().columns
- )
+ return len(COLOR_CODE_REGEX.sub("", line)) > self.prompt_app.output.get_size().columns
def is_too_tall(self, lines):
"""Are there too many lines to fit into terminal?"""
@@ -1354,10 +1273,7 @@ def is_too_tall(self, lines):
def echo_via_pager(self, text, color=None):
if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
click.echo(text, color=color)
- elif (
- self.pgspecial.pager_config == PAGER_LONG_OUTPUT
- and self.table_format != "csv"
- ):
+ elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT and self.table_format != "csv":
lines = text.split("\n")
# The last 4 lines are reserved for the pgcli menu and padding
@@ -1382,7 +1298,7 @@ def echo_via_pager(self, text, color=None):
"-p",
"--port",
default=5432,
- help="Port number at which the " "postgres instance is listening.",
+ help="Port number at which the postgres instance is listening.",
envvar="PGPORT",
type=click.INT,
)
@@ -1392,9 +1308,7 @@ def echo_via_pager(self, text, color=None):
"username_opt",
help="Username to connect to the postgres database.",
)
-@click.option(
- "-u", "--user", "username_opt", help="Username to connect to the postgres database."
-)
+@click.option("-u", "--user", "username_opt", help="Username to connect to the postgres database.")
@click.option(
"-W",
"--password",
@@ -1560,10 +1474,9 @@ def cli(
for alias in cfg["alias_dsn"]:
click.secho(alias + " : " + cfg["alias_dsn"][alias])
sys.exit(0)
- except Exception as err:
+ except Exception:
click.secho(
- "Invalid DSNs found in the config file. "
- 'Please check the "[alias_dsn]" section in pgclirc.',
+ "Invalid DSNs found in the config file. Please check the \"[alias_dsn]\" section in pgclirc.",
err=True,
fg="red",
)
@@ -1615,16 +1528,14 @@ def cli(
dsn_config = cfg["alias_dsn"][dsn]
except KeyError:
click.secho(
- f"Could not find a DSN with alias {dsn}. "
- 'Please check the "[alias_dsn]" section in pgclirc.',
+ f"Could not find a DSN with alias {dsn}. Please check the \"[alias_dsn]\" section in pgclirc.",
err=True,
fg="red",
)
sys.exit(1)
except Exception:
click.secho(
- "Invalid DSNs found in the config file. "
- 'Please check the "[alias_dsn]" section in pgclirc.',
+ "Invalid DSNs found in the config file. Please check the \"[alias_dsn]\" section in pgclirc.",
err=True,
fg="red",
)
@@ -1640,9 +1551,7 @@ def cli(
else:
pgcli.connect(database, host, user, port)
- if "use_local_timezone" not in cfg["main"] or cfg["main"].as_bool(
- "use_local_timezone"
- ):
+ if "use_local_timezone" not in cfg["main"] or cfg["main"].as_bool("use_local_timezone"):
server_tz = pgcli.pgexecute.get_timezone()
def echo_error(msg: str):
@@ -1741,7 +1650,7 @@ def echo_error(msg: str):
sys.exit(0)
pgcli.logger.debug(
- "Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r",
+ "Launch Params: \n\tdatabase: %r\tuser: %r\thost: %r\tport: %r",
database,
user,
host,
@@ -1759,9 +1668,7 @@ def obfuscate_process_password():
if "://" in process_title:
process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title)
elif "=" in process_title:
- process_title = re.sub(
- r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title
- )
+ process_title = re.sub(r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title)
setproctitle.setproctitle(process_title)
@@ -1901,9 +1808,7 @@ def format_array(val):
def format_arrays(data, headers, **_):
data = list(data)
for row in data:
- row[:] = [
- format_array(val) if isinstance(val, list) else val for val in row
- ]
+ row[:] = [format_array(val) if isinstance(val, list) else val for val in row]
return data, headers
@@ -1968,13 +1873,7 @@ def format_status(cur, status):
formatted = iter(formatted.splitlines())
first_line = next(formatted)
formatted = itertools.chain([first_line], formatted)
- if (
- not explain_mode
- and not expanded
- and max_width
- and len(strip_ansi(first_line)) > max_width
- and headers
- ):
+ if not explain_mode and not expanded and max_width and len(strip_ansi(first_line)) > max_width and headers:
formatted = formatter.format_output(
cur,
headers,
diff --git a/pgcli/packages/parseutils/ctes.py b/pgcli/packages/parseutils/ctes.py
index e1f908850..a6a364a02 100644
--- a/pgcli/packages/parseutils/ctes.py
+++ b/pgcli/packages/parseutils/ctes.py
@@ -17,7 +17,7 @@ def isolate_query_ctes(full_text, text_before_cursor):
"""Simplify a query by converting CTEs into table metadata objects"""
if not full_text or not full_text.strip():
- return full_text, text_before_cursor, tuple()
+ return full_text, text_before_cursor, ()
ctes, remainder = extract_ctes(full_text)
if not ctes:
diff --git a/pgcli/packages/sqlcompletion.py b/pgcli/packages/sqlcompletion.py
index b78edd6d9..9eb3ca858 100644
--- a/pgcli/packages/sqlcompletion.py
+++ b/pgcli/packages/sqlcompletion.py
@@ -1,4 +1,3 @@
-import sys
import re
import sqlparse
from collections import namedtuple
@@ -27,16 +26,16 @@
Function = namedtuple("Function", ["schema", "table_refs", "usage"])
# For convenience, don't require the `usage` argument in Function constructor
-Function.__new__.__defaults__ = (None, tuple(), None)
-Table.__new__.__defaults__ = (None, tuple(), tuple())
-View.__new__.__defaults__ = (None, tuple())
-FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
+Function.__new__.__defaults__ = (None, (), None)
+Table.__new__.__defaults__ = (None, (), ())
+View.__new__.__defaults__ = (None, ())
+FromClauseItem.__new__.__defaults__ = (None, (), ())
Column = namedtuple(
"Column",
["table_refs", "require_last_table", "local_tables", "qualifiable", "context"],
)
-Column.__new__.__defaults__ = (None, None, tuple(), False, None)
+Column.__new__.__defaults__ = (None, None, (), False, None)
Keyword = namedtuple("Keyword", ["last_token"])
Keyword.__new__.__defaults__ = (None,)
@@ -50,15 +49,11 @@
class SqlStatement:
def __init__(self, full_text, text_before_cursor):
self.identifier = None
- self.word_before_cursor = word_before_cursor = last_word(
- text_before_cursor, include="many_punctuations"
- )
+ self.word_before_cursor = word_before_cursor = last_word(text_before_cursor, include="many_punctuations")
full_text = _strip_named_query(full_text)
text_before_cursor = _strip_named_query(text_before_cursor)
- full_text, text_before_cursor, self.local_tables = isolate_query_ctes(
- full_text, text_before_cursor
- )
+ full_text, text_before_cursor, self.local_tables = isolate_query_ctes(full_text, text_before_cursor)
self.text_before_cursor_including_last_word = text_before_cursor
@@ -78,9 +73,7 @@ def __init__(self, full_text, text_before_cursor):
else:
parsed = sqlparse.parse(text_before_cursor)
- full_text, text_before_cursor, parsed = _split_multiple_statements(
- full_text, text_before_cursor, parsed
- )
+ full_text, text_before_cursor, parsed = _split_multiple_statements(full_text, text_before_cursor, parsed)
self.full_text = full_text
self.text_before_cursor = text_before_cursor
@@ -98,9 +91,7 @@ def get_tables(self, scope="full"):
If 'before', only tables before the cursor are returned.
If not 'insert' and the stmt is an insert, the first table is skipped.
"""
- tables = extract_tables(
- self.full_text if scope == "full" else self.text_before_cursor
- )
+ tables = extract_tables(self.full_text if scope == "full" else self.text_before_cursor)
if scope == "insert":
tables = tables[:1]
elif self.is_insert():
@@ -119,9 +110,7 @@ def get_identifier_schema(self):
return schema
def reduce_to_prev_keyword(self, n_skip=0):
- prev_keyword, self.text_before_cursor = find_prev_keyword(
- self.text_before_cursor, n_skip=n_skip
- )
+ prev_keyword, self.text_before_cursor = find_prev_keyword(self.text_before_cursor, n_skip=n_skip)
return prev_keyword
@@ -222,9 +211,7 @@ def _split_multiple_statements(full_text, text_before_cursor, parsed):
token1_idx = statement.token_index(token1)
token2 = statement.token_next(token1_idx)[1]
if token2 and token2.value.upper() == "FUNCTION":
- full_text, text_before_cursor, statement = _statement_from_function(
- full_text, text_before_cursor, statement
- )
+ full_text, text_before_cursor, statement = _statement_from_function(full_text, text_before_cursor, statement)
return full_text, text_before_cursor, statement
@@ -361,11 +348,7 @@ def suggest_based_on_last_token(token, stmt):
# Get the token before the parens
prev_tok = p.token_prev(len(p.tokens) - 1)[1]
- if (
- prev_tok
- and prev_tok.value
- and prev_tok.value.lower().split(" ")[-1] == "using"
- ):
+ if prev_tok and prev_tok.value and prev_tok.value.lower().split(" ")[-1] == "using":
# tbl1 INNER JOIN tbl2 USING (col1, col2)
tables = stmt.get_tables("before")
@@ -395,9 +378,7 @@ def suggest_based_on_last_token(token, stmt):
elif token_v == "as":
# Don't suggest anything for aliases
return ()
- elif (token_v.endswith("join") and token.is_keyword) or (
- token_v in ("copy", "from", "update", "into", "describe", "truncate")
- ):
+ elif (token_v.endswith("join") and token.is_keyword) or (token_v in ("copy", "from", "update", "into", "describe", "truncate")):
schema = stmt.get_identifier_schema()
tables = extract_tables(stmt.text_before_cursor)
is_join = token_v.endswith("join") and token.is_keyword
@@ -411,11 +392,7 @@ def suggest_based_on_last_token(token, stmt):
suggest.insert(0, Schema())
if token_v == "from" or is_join:
- suggest.append(
- FromClauseItem(
- schema=schema, table_refs=tables, local_tables=stmt.local_tables
- )
- )
+ suggest.append(FromClauseItem(schema=schema, table_refs=tables, local_tables=stmt.local_tables))
elif token_v == "truncate":
suggest.append(Table(schema))
else:
@@ -447,7 +424,7 @@ def suggest_based_on_last_token(token, stmt):
except ValueError:
pass
- return tuple()
+ return ()
elif token_v in ("table", "view"):
# E.g. 'ALTER TABLE '
@@ -553,14 +530,10 @@ def _suggest_expression(token_v, stmt):
)
-def identifies(id, ref):
+def identifies(table_id, ref):
"""Returns true if string `id` matches TableReference `ref`"""
- return (
- id == ref.alias
- or id == ref.name
- or (ref.schema and (id == ref.schema + "." + ref.name))
- )
+ return table_id == ref.alias or table_id == ref.name or (ref.schema and (table_id == ref.schema + "." + ref.name))
def _allow_join_condition(statement):
diff --git a/pgcli/pgbuffer.py b/pgcli/pgbuffer.py
index c236c133a..aba180c8f 100644
--- a/pgcli/pgbuffer.py
+++ b/pgcli/pgbuffer.py
@@ -25,9 +25,7 @@ def _is_complete(sql):
def safe_multi_line_mode(pgcli):
@Condition
def cond():
- _logger.debug(
- 'Multi-line mode state: "%s" / "%s"', pgcli.multi_line, pgcli.multiline_mode
- )
+ _logger.debug('Multi-line mode state: "%s" / "%s"', pgcli.multi_line, pgcli.multiline_mode)
return pgcli.multi_line and (pgcli.multiline_mode == "safe")
return cond
@@ -48,14 +46,13 @@ def cond():
text = doc.text.strip()
return (
- text.startswith("\\") # Special Command
- or text.endswith(r"\e") # Special Command
- or text.endswith(r"\G") # Ended with \e which should launch the editor
- or _is_complete(text) # A complete SQL command
- or (text == "exit") # Exit doesn't need semi-colon
- or (text == "quit") # Quit doesn't need semi-colon
- or (text == ":q") # To all the vim fans out there
- or (text == "") # Just a plain enter without any text
+ text.startswith("\\")
+ or text.endswith((r"\e", r"\G"))
+ or _is_complete(text)
+ or text == "exit"
+ or text == "quit"
+ or text == ":q"
+ or text == "" # Just a plain enter without any text
)
return cond
diff --git a/pgcli/pgcompleter.py b/pgcli/pgcompleter.py
index 8df2958e0..ced0f1687 100644
--- a/pgcli/pgcompleter.py
+++ b/pgcli/pgcompleter.py
@@ -1,7 +1,7 @@
import json
import logging
import re
-from itertools import count, repeat, chain
+from itertools import count, chain
import operator
from collections import namedtuple, defaultdict, OrderedDict
from cli_helpers.tabular_output import TabularOutputFormatter
@@ -32,7 +32,6 @@
from .packages.parseutils.tables import TableReference
from .packages.pgliterals.main import get_literals
from .packages.prioritization import PrevalenceCounter
-from .config import load_config, config_location
_logger = logging.getLogger(__name__)
@@ -48,18 +47,16 @@ def SchemaObject(name, schema=None, meta=None):
_Candidate = namedtuple("Candidate", "completion prio meta synonyms prio2 display")
-def Candidate(
- completion, prio=None, meta=None, synonyms=None, prio2=None, display=None
-):
- return _Candidate(
- completion, prio, meta, synonyms or [completion], prio2, display or completion
- )
+def Candidate(completion, prio=None, meta=None, synonyms=None, prio2=None, display=None):
+ return _Candidate(completion, prio, meta, synonyms or [completion], prio2, display or completion)
# Used to strip trailing '::some_type' from default-value expressions
arg_default_type_strip_regex = re.compile(r"::[\w\.]+(\[\])?$")
-normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
+
+def normalize_ref(ref):
+ return ref if ref[0] == '"' else '"' + ref.lower() + '"'
def generate_alias(tbl, alias_map=None):
@@ -77,10 +74,7 @@ def generate_alias(tbl, alias_map=None):
"""
if alias_map and tbl in alias_map:
return alias_map[tbl]
- return "".join(
- [l for l in tbl if l.isupper()]
- or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"]
- )
+ return "".join([l for l in tbl if l.isupper()] or [l for l, prev in zip(tbl, "_" + tbl) if prev == "_" and l != "_"])
class InvalidMapFile(ValueError):
@@ -92,9 +86,7 @@ def load_alias_map_file(path):
with open(path) as fo:
alias_map = json.load(fo)
except FileNotFoundError as err:
- raise InvalidMapFile(
- f"Cannot read alias_map_file - {err.filename} does not exist"
- )
+ raise InvalidMapFile(f"Cannot read alias_map_file - {err.filename} does not exist")
except json.JSONDecodeError:
raise InvalidMapFile(f"Cannot read alias_map_file - {path} is not valid json")
else:
@@ -116,15 +108,9 @@ def __init__(self, smart_completion=True, pgspecial=None, settings=None):
self.pgspecial = pgspecial
self.prioritizer = PrevalenceCounter()
settings = settings or {}
- self.signature_arg_style = settings.get(
- "signature_arg_style", "{arg_name} {arg_type}"
- )
- self.call_arg_style = settings.get(
- "call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}"
- )
- self.call_arg_display_style = settings.get(
- "call_arg_display_style", "{arg_name}"
- )
+ self.signature_arg_style = settings.get("signature_arg_style", "{arg_name} {arg_type}")
+ self.call_arg_style = settings.get("call_arg_style", "{arg_name: <{max_arg_len}} := {arg_default}")
+ self.call_arg_display_style = settings.get("call_arg_display_style", "{arg_name}")
self.call_arg_oneliner_max = settings.get("call_arg_oneliner_max", 2)
self.search_path_filter = settings.get("search_path_filter")
self.generate_aliases = settings.get("generate_aliases")
@@ -135,16 +121,11 @@ def __init__(self, smart_completion=True, pgspecial=None, settings=None):
self.alias_map = None
self.casing_file = settings.get("casing_file")
self.insert_col_skip_patterns = [
- re.compile(pattern)
- for pattern in settings.get(
- "insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("]
- )
+ re.compile(pattern) for pattern in settings.get("insert_col_skip_patterns", [r"^now\(\)$", r"^nextval\("])
]
self.generate_casing_file = settings.get("generate_casing_file")
self.qualify_columns = settings.get("qualify_columns", "if_more_than_one_table")
- self.asterisk_column_order = settings.get(
- "asterisk_column_order", "table_order"
- )
+ self.asterisk_column_order = settings.get("asterisk_column_order", "table_order")
keyword_casing = settings.get("keyword_casing", "upper").lower()
if keyword_casing not in ("upper", "lower", "auto"):
@@ -160,11 +141,7 @@ def __init__(self, smart_completion=True, pgspecial=None, settings=None):
self.all_completions = set(self.keywords + self.functions)
def escape_name(self, name):
- if name and (
- (not self.name_pattern.match(name))
- or (name.upper() in self.reserved_words)
- or (name.upper() in self.functions)
- ):
+ if name and ((not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions)):
name = '"%s"' % name
return name
@@ -230,9 +207,7 @@ def extend_relations(self, data, kind):
try:
metadata[schema][relname] = OrderedDict()
except KeyError:
- _logger.error(
- "%r %r listed in unrecognized schema %r", kind, relname, schema
- )
+ _logger.error("%r %r listed in unrecognized schema %r", kind, relname, schema)
self.all_completions.add(relname)
def extend_columns(self, column_data, kind):
@@ -306,9 +281,7 @@ def extend_foreignkeys(self, fk_data):
childcol, parcol = e([fk.childcolumn, fk.parentcolumn])
childcolmeta = meta[childschema][childtable][childcol]
parcolmeta = meta[parentschema][parenttable][parcol]
- fk = ForeignKey(
- parentschema, parenttable, parcol, childschema, childtable, childcol
- )
+ fk = ForeignKey(parentschema, parenttable, parcol, childschema, childtable, childcol)
childcolmeta.foreignkeys.append(fk)
parcolmeta.foreignkeys.append(fk)
@@ -452,12 +425,7 @@ def _match(item):
# We also use the unescape_name to make sure quoted names have
# the same priority as unquoted names.
lexical_priority = (
- tuple(
- 0 if c in " _" else -ord(c)
- for c in self.unescape_name(item.lower())
- )
- + (1,)
- + tuple(c for c in item)
+ tuple(0 if c in " _" else -ord(c) for c in self.unescape_name(item.lower())) + (1,) + tuple(c for c in item)
)
item = self.case(item)
@@ -495,9 +463,7 @@ def get_completions(self, document, complete_event, smart_completion=None):
# If smart_completion is off then match any word that starts with
# 'word_before_cursor'.
if not smart_completion:
- matches = self.find_matches(
- word_before_cursor, self.all_completions, mode="strict"
- )
+ matches = self.find_matches(word_before_cursor, self.all_completions, mode="strict")
completions = [m.completion for m in matches]
return sorted(completions, key=operator.attrgetter("text"))
@@ -528,9 +494,10 @@ def get_column_matches(self, suggestion, word_before_cursor):
"if_more_than_one_table": len(tables) > 1,
}[self.qualify_columns]
)
- qualify = lambda col, tbl: (
- (tbl + "." + self.case(col)) if do_qualify else self.case(col)
- )
+
+ def qualify(col, tbl):
+ return (tbl + "." + self.case(col)) if do_qualify else self.case(col)
+
_logger.debug("Completion column scope: %r", tables)
scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables)
@@ -539,61 +506,38 @@ def make_cand(name, ref):
return Candidate(qualify(name, ref), 0, "column", synonyms)
def flat_cols():
- return [
- make_cand(c.name, t.ref)
- for t, cols in scoped_cols.items()
- for c in cols
- ]
+ return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items() for c in cols]
if suggestion.require_last_table:
# require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
# suggest only columns that appear in the last table and one more
ltbl = tables[-1].ref
- other_tbl_cols = {
- c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
- }
- scoped_cols = {
- t: [col for col in cols if col.name in other_tbl_cols]
- for t, cols in scoped_cols.items()
- if t.ref == ltbl
- }
+ other_tbl_cols = {c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs}
+ scoped_cols = {t: [col for col in cols if col.name in other_tbl_cols] for t, cols in scoped_cols.items() if t.ref == ltbl}
lastword = last_word(word_before_cursor, include="most_punctuations")
if lastword == "*":
if suggestion.context == "insert":
- def filter(col):
+ def _filter(col):
if not col.has_default:
return True
- return not any(
- p.match(col.default) for p in self.insert_col_skip_patterns
- )
+ return not any(p.match(col.default) for p in self.insert_col_skip_patterns)
- scoped_cols = {
- t: [col for col in cols if filter(col)]
- for t, cols in scoped_cols.items()
- }
+ scoped_cols = {t: [col for col in cols if _filter(col)] for t, cols in scoped_cols.items()}
if self.asterisk_column_order == "alphabetic":
for cols in scoped_cols.values():
cols.sort(key=operator.attrgetter("name"))
- if (
- lastword != word_before_cursor
- and len(tables) == 1
- and word_before_cursor[-len(lastword) - 1] == "."
- ):
+ if lastword != word_before_cursor and len(tables) == 1 and word_before_cursor[-len(lastword) - 1] == ".":
# User typed x.*; replicate "x." for all columns except the
# first, which gets the original (as we only replace the "*"")
sep = ", " + word_before_cursor[:-1]
collist = sep.join(self.case(c.completion) for c in flat_cols())
else:
- collist = ", ".join(
- qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs
- )
+ collist = ", ".join(qualify(c.name, t.ref) for t, cs in scoped_cols.items() for c in cs)
return [
Match(
- completion=Completion(
- collist, -1, display_meta="columns", display="*"
- ),
+ completion=Completion(collist, -1, display_meta="columns", display="*"),
priority=(1, 1, 1),
)
]
@@ -627,12 +571,7 @@ def get_join_matches(self, suggestion, word_before_cursor):
other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]}
joins = []
# Iterate over FKs in existing tables to find potential joins
- fks = (
- (fk, rtbl, rcol)
- for rtbl, rcols in cols.items()
- for rcol in rcols
- for fk in rcol.foreignkeys
- )
+ fks = ((fk, rtbl, rcol) for rtbl, rcols in cols.items() for rcol in rcols for fk in rcol.foreignkeys)
col = namedtuple("col", "schema tbl col")
for fk, rtbl, rcol in fks:
right = col(rtbl.schema, rtbl.name, rcol.name)
@@ -644,31 +583,21 @@ def get_join_matches(self, suggestion, word_before_cursor):
c = self.case
if self.generate_aliases or normalize_ref(left.tbl) in refs:
lref = self.alias(left.tbl, suggestion.table_refs)
- join = "{0} {4} ON {4}.{1} = {2}.{3}".format(
- c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref
- )
+ join = "{0} {4} ON {4}.{1} = {2}.{3}".format(c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref)
else:
- join = "{0} ON {0}.{1} = {2}.{3}".format(
- c(left.tbl), c(left.col), rtbl.ref, c(right.col)
- )
+ join = "{0} ON {0}.{1} = {2}.{3}".format(c(left.tbl), c(left.col), rtbl.ref, c(right.col))
alias = generate_alias(self.case(left.tbl), alias_map=self.alias_map)
synonyms = [
join,
- "{0} ON {0}.{1} = {2}.{3}".format(
- alias, c(left.col), rtbl.ref, c(right.col)
- ),
+ "{0} ON {0}.{1} = {2}.{3}".format(alias, c(left.col), rtbl.ref, c(right.col)),
]
# Schema-qualify if (1) new table in same schema as old, and old
# is schema-qualified, or (2) new in other schema, except public
if not suggestion.schema and (
- qualified[normalize_ref(rtbl.ref)]
- and left.schema == right.schema
- or left.schema not in (right.schema, "public")
+ qualified[normalize_ref(rtbl.ref)] and left.schema == right.schema or left.schema not in (right.schema, "public")
):
join = left.schema + "." + join
- prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
- 0 if (left.schema, left.tbl) in other_tbls else 1
- )
+ prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (0 if (left.schema, left.tbl) in other_tbls else 1)
joins.append(Candidate(join, prio, "join", synonyms=synonyms))
return self.find_matches(word_before_cursor, joins, meta="join")
@@ -701,9 +630,7 @@ def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
# Tables that are closer to the cursor get higher prio
ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)}
# Map (schema, table, col) to tables
- coldict = list_dict(
- ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
- )
+ coldict = list_dict(((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref)
# For each fk from the left table, generate a join condition if
# the other table is also in the scope
fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys)
@@ -734,24 +661,16 @@ def filt(f):
not f.is_aggregate
and not f.is_window
and not f.is_extension
- and (
- f.is_public
- or f.schema_name in self.search_path
- or f.schema_name == suggestion.schema
- )
+ and (f.is_public or f.schema_name in self.search_path or f.schema_name == suggestion.schema)
)
else:
alias = False
def filt(f):
- return not f.is_extension and (
- f.is_public or f.schema_name == suggestion.schema
- )
+ return not f.is_extension and (f.is_public or f.schema_name == suggestion.schema)
- arg_mode = {"signature": "signature", "special": None}.get(
- suggestion.usage, "call"
- )
+ arg_mode = {"signature": "signature", "special": None}.get(suggestion.usage, "call")
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
@@ -762,9 +681,7 @@ def filt(f):
if not suggestion.schema and not suggestion.usage:
# also suggest hardcoded functions using startswith matching
- predefined_funcs = self.find_matches(
- word_before_cursor, self.functions, mode="strict", meta="function"
- )
+ predefined_funcs = self.find_matches(word_before_cursor, self.functions, mode="strict", meta="function")
matches.extend(predefined_funcs)
return matches
@@ -815,10 +732,7 @@ def _arg_list(self, func, usage):
return "()"
multiline = usage == "call" and len(args) > self.call_arg_oneliner_max
max_arg_len = max(len(a.name) for a in args) if multiline else 0
- args = (
- self._format_arg(template, arg, arg_num + 1, max_arg_len)
- for arg_num, arg in enumerate(args)
- )
+ args = (self._format_arg(template, arg, arg_num + 1, max_arg_len) for arg_num, arg in enumerate(args))
if multiline:
return "(" + ",".join("\n " + a for a in args if a) + "\n)"
else:
@@ -917,15 +831,11 @@ def get_keyword_matches(self, suggestion, word_before_cursor):
else:
keywords = [k.lower() for k in keywords]
- return self.find_matches(
- word_before_cursor, keywords, mode="strict", meta="keyword"
- )
+ return self.find_matches(word_before_cursor, keywords, mode="strict", meta="keyword")
def get_path_matches(self, _, word_before_cursor):
completer = PathCompleter(expanduser=True)
- document = Document(
- text=word_before_cursor, cursor_position=len(word_before_cursor)
- )
+ document = Document(text=word_before_cursor, cursor_position=len(word_before_cursor))
for c in completer.get_completions(document, None):
yield Match(completion=c, priority=(0,))
@@ -946,18 +856,12 @@ def get_datatype_matches(self, suggestion, word_before_cursor):
if not suggestion.schema:
# Also suggest hardcoded types
- matches.extend(
- self.find_matches(
- word_before_cursor, self.datatypes, mode="strict", meta="datatype"
- )
- )
+ matches.extend(self.find_matches(word_before_cursor, self.datatypes, mode="strict", meta="datatype"))
return matches
def get_namedquery_matches(self, _, word_before_cursor):
- return self.find_matches(
- word_before_cursor, NamedQueries.instance.list(), meta="named query"
- )
+ return self.find_matches(word_before_cursor, NamedQueries.instance.list(), meta="named query")
suggestion_matchers = {
FromClauseItem: get_from_clause_item_matches,
@@ -1047,9 +951,7 @@ def populate_schema_objects(self, schema, obj_type):
"""
return [
- SchemaObject(
- name=obj, schema=(self._maybe_schema(schema=sch, parent=schema))
- )
+ SchemaObject(name=obj, schema=(self._maybe_schema(schema=sch, parent=schema)))
for sch in self._get_schemas(obj_type, schema)
for obj in self.dbmetadata[obj_type][sch].keys()
]
diff --git a/pyproject.toml b/pyproject.toml
index 04087114d..f8458291a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -51,11 +51,25 @@ pgcli = "pgcli.main:cli"
[project.optional-dependencies]
keyring = ["keyring >= 12.2.0"]
sshtunnel = ["sshtunnel >= 0.4.0"]
+dev = [
+ "behave>=1.2.4",
+ "coverage>=7.2.7",
+ "docutils>=0.13.1",
+ "keyrings.alt>=3.1",
+ "pexpect>=4.9.0; platform_system != 'Windows'",
+ "pytest>=7.4.4",
+ "pytest-cov>=4.1.0",
+ "ruff>=0.11.7",
+ "sshtunnel>=0.4.0",
+ "tox>=1.9.2",
+]
[build-system]
-requires = ["setuptools>=61.2"]
+requires = ["setuptools>=64.0", "setuptools-scm>=8"]
build-backend = "setuptools.build_meta"
+[tool.setuptools_scm]
+
[tool.setuptools]
include-package-data = false
@@ -68,24 +82,55 @@ find = { namespaces = false }
[tool.setuptools.package-data]
pgcli = ["pgclirc", "packages/pgliterals/pgliterals.json"]
-[tool.black]
-line-length = 88
-target-version = ['py38']
-include = '\.pyi?$'
-exclude = '''
-/(
- \.eggs
- | \.git
- | \.hg
- | \.mypy_cache
- | \.tox
- | \.venv
- | \.cache
- | \.pytest_cache
- | _build
- | buck-out
- | build
- | dist
- | tests/data
-)/
-'''
+[tool.ruff]
+target-version = 'py39'
+line-length = 140
+
+[tool.ruff.lint]
+select = [
+ 'A',
+# 'I', # todo enableme imports
+ 'E',
+ 'W',
+ 'F',
+ 'C4',
+ 'PIE',
+ 'TID',
+]
+ignore = [
+ 'E401', # Multiple imports on one line
+ 'E402', # Module level import not at top of file
+ 'PIE808', # range() starting with 0
+ # https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules
+ 'E111', # indentation-with-invalid-multiple
+ 'E114', # indentation-with-invalid-multiple-comment
+ 'E117', # over-indented
+ 'W191', # tab-indentation
+ 'E741', # ambiguous-variable-name
+ # TODO
+ 'PIE796', # todo enableme Enum contains duplicate value
+]
+exclude = [
+ 'pgcli/magic.py',
+ 'pgcli/pyev.py',
+]
+
+[tool.ruff.lint.isort]
+force-sort-within-sections = true
+known-first-party = [
+ 'pgcli',
+ 'tests',
+ 'steps',
+]
+
+[tool.ruff.format]
+preview = true
+quote-style = 'preserve'
+exclude = [
+ 'build',
+]
+
+[tool.pytest.ini_options]
+minversion = "6.0"
+addopts = "--capture=sys --showlocals -rxs"
+testpaths = ["tests"]
\ No newline at end of file
diff --git a/requirements-dev.txt b/requirements-dev.txt
deleted file mode 100644
index 8a0141a75..000000000
--- a/requirements-dev.txt
+++ /dev/null
@@ -1,14 +0,0 @@
-pytest>=2.7.0
-tox>=1.9.2
-behave>=1.2.4
-black>=23.3.0
-pexpect==3.3; platform_system != "Windows"
-pre-commit>=1.16.0
-coverage>=5.0.4
-codecov>=1.5.1
-docutils>=0.13.1
-autopep8>=1.3.3
-twine>=1.11.0
-wheel>=0.33.6
-sshtunnel>=0.4.0
-build<0.10.0
\ No newline at end of file
diff --git a/tests/features/db_utils.py b/tests/features/db_utils.py
index 595c6c2c3..db7f017f1 100644
--- a/tests/features/db_utils.py
+++ b/tests/features/db_utils.py
@@ -1,9 +1,7 @@
from psycopg import connect
-def create_db(
- hostname="localhost", username=None, password=None, dbname=None, port=None
-):
+def create_db(hostname="localhost", username=None, password=None, dbname=None, port=None):
"""Create test database.
:param hostname: string
@@ -36,9 +34,7 @@ def create_cn(hostname, password, username, dbname, port):
:param dbname: string
:return: psycopg2.connection
"""
- cn = connect(
- host=hostname, user=username, dbname=dbname, password=password, port=port
- )
+ cn = connect(host=hostname, user=username, dbname=dbname, password=password, port=port)
print(f"Created connection: {cn.info.get_parameters()}.")
return cn
@@ -49,7 +45,7 @@ def pgbouncer_available(hostname="localhost", password=None, username="postgres"
try:
cn = create_cn(hostname, password, username, "pgbouncer", 6432)
return True
- except:
+ except Exception:
print("Pgbouncer is not available.")
finally:
if cn:
diff --git a/tests/features/environment.py b/tests/features/environment.py
index 50ac5faf0..a6cde7021 100644
--- a/tests/features/environment.py
+++ b/tests/features/environment.py
@@ -1,13 +1,13 @@
import copy
import os
+import shutil
+import signal
import sys
+import tempfile
+
import db_utils as dbutils
import fixture_utils as fixutils
import pexpect
-import tempfile
-import shutil
-import signal
-
from steps import wrappers
@@ -22,17 +22,13 @@ def before_all(context):
os.environ["VISUAL"] = "ex"
os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1"
- context.package_root = os.path.abspath(
- os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
- )
+ context.package_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data")
print("package root:", context.package_root)
print("fixture dir:", fixture_dir)
- os.environ["COVERAGE_PROCESS_START"] = os.path.join(
- context.package_root, ".coveragerc"
- )
+ os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, ".coveragerc")
context.exit_sent = False
@@ -42,30 +38,20 @@ def before_all(context):
# Store get params from config.
context.conf = {
- "host": context.config.userdata.get(
- "pg_test_host", os.getenv("PGHOST", "localhost")
- ),
- "user": context.config.userdata.get(
- "pg_test_user", os.getenv("PGUSER", "postgres")
- ),
- "pass": context.config.userdata.get(
- "pg_test_pass", os.getenv("PGPASSWORD", None)
- ),
- "port": context.config.userdata.get(
- "pg_test_port", os.getenv("PGPORT", "5432")
- ),
+ "host": context.config.userdata.get("pg_test_host", os.getenv("PGHOST", "localhost")),
+ "user": context.config.userdata.get("pg_test_user", os.getenv("PGUSER", "postgres")),
+ "pass": context.config.userdata.get("pg_test_pass", os.getenv("PGPASSWORD", None)),
+ "port": context.config.userdata.get("pg_test_port", os.getenv("PGPORT", "5432")),
"cli_command": (
context.config.userdata.get("pg_cli_command", None)
or '{python} -c "{startup}"'.format(
python=sys.executable,
- startup="; ".join(
- [
- "import coverage",
- "coverage.process_startup()",
- "import pgcli.main",
- "pgcli.main.cli(auto_envvar_prefix='BEHAVE')",
- ]
- ),
+ startup="; ".join([
+ "import coverage",
+ "coverage.process_startup()",
+ "import pgcli.main",
+ "pgcli.main.cli(auto_envvar_prefix='BEHAVE')",
+ ]),
)
),
"dbname": db_name_full,
@@ -165,15 +151,16 @@ def before_step(context, _):
def is_known_problem(scenario):
- """TODO: why is this not working in 3.12?"""
- if sys.version_info >= (3, 12):
- return scenario.name in (
- 'interrupt current query via "ctrl + c"',
- "run the cli with --username",
- "run the cli with --user",
- "run the cli with --port",
- )
- return False
+ """TODO: can we fix this?"""
+ return scenario.name in (
+ 'interrupt current query via "ctrl + c"',
+ "run the cli with --username",
+ "run the cli with --user",
+ "run the cli with --port",
+ "confirm exit when a transaction is ongoing",
+ "cancel exit when a transaction is ongoing",
+ "run the cli and exit",
+ )
def before_scenario(context, scenario):
diff --git a/tests/features/steps/wrappers.py b/tests/features/steps/wrappers.py
index 3ebcc92c4..2162b8b96 100644
--- a/tests/features/steps/wrappers.py
+++ b/tests/features/steps/wrappers.py
@@ -1,6 +1,5 @@
import re
import pexpect
-from pgcli.main import COLOR_CODE_REGEX
import textwrap
from io import StringIO
@@ -37,10 +36,7 @@ def expect_exact(context, expected, timeout):
def expect_pager(context, expected, timeout):
formatted = expected if isinstance(expected, list) else [expected]
- formatted = [
- f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n"
- for t in formatted
- ]
+ formatted = [f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n" for t in formatted]
expect_exact(
context,
diff --git a/tests/formatter/test_sqlformatter.py b/tests/formatter/test_sqlformatter.py
index 016ed956b..78ff5dc95 100644
--- a/tests/formatter/test_sqlformatter.py
+++ b/tests/formatter/test_sqlformatter.py
@@ -55,7 +55,7 @@ def test_output_sql_insert():
}
formatter.query = 'SELECT * FROM "user";'
output = adapter(data, header, table_format=table_format, **kwargs)
- output_list = [l for l in output]
+ output_list = list(output)
expected = [
'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES',
" ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', NULL, "
@@ -96,7 +96,7 @@ def test_output_sql_update():
}
formatter.query = 'SELECT * FROM "user";'
output = adapter(data, header, table_format=table_format, **kwargs)
- output_list = [l for l in output]
+ output_list = list(output)
print(output_list)
expected = [
'UPDATE "user" SET',
diff --git a/tests/parseutils/test_parseutils.py b/tests/parseutils/test_parseutils.py
index 349cbd021..8a62cdbc5 100644
--- a/tests/parseutils/test_parseutils.py
+++ b/tests/parseutils/test_parseutils.py
@@ -172,7 +172,7 @@ def test_subselect_tables():
@pytest.mark.parametrize("text", ["SELECT * FROM foo.", "SELECT 123 AS foo"])
def test_extract_no_tables(text):
tables = extract_tables(text)
- assert tables == tuple()
+ assert tables == ()
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
diff --git a/tests/test_main.py b/tests/test_main.py
index 102ebcd3e..b893d2c9d 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -61,9 +61,7 @@ def test_obfuscate_process_password():
def test_format_output():
settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g")
- results = format_output(
- "Title", [("abc", "def")], ["head1", "head2"], "test status", settings
- )
+ results = format_output("Title", [("abc", "def")], ["head1", "head2"], "test status", settings)
expected = [
"Title",
"+-------+-------+",
@@ -128,9 +126,7 @@ def test_no_column_date_formats():
def test_format_output_truncate_on():
- settings = OutputSettings(
- table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=10
- )
+ settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=10)
results = format_output(
None,
[("first field value", "second field value")],
@@ -149,9 +145,7 @@ def test_format_output_truncate_on():
def test_format_output_truncate_off():
- settings = OutputSettings(
- table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=None
- )
+ settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g", max_field_width=None)
long_field_value = ("first field " * 100).strip()
results = format_output(None, [(long_field_value,)], ["head1"], None, settings)
lines = list(results)
@@ -207,12 +201,8 @@ def test_format_array_output_expanded(executor):
def test_format_output_auto_expand():
- settings = OutputSettings(
- table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100
- )
- table_results = format_output(
- "Title", [("abc", "def")], ["head1", "head2"], "test status", settings
- )
+ settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g", max_width=100)
+ table_results = format_output("Title", [("abc", "def")], ["head1", "head2"], "test status", settings)
table = [
"Title",
"+-------+-------+",
@@ -269,18 +259,18 @@ def test_format_output_auto_expand():
def pset_pager_mocks():
cli = PGCli()
cli.watch_command = None
- with mock.patch("pgcli.main.click.echo") as mock_echo, mock.patch(
- "pgcli.main.click.echo_via_pager"
- ) as mock_echo_via_pager, mock.patch.object(cli, "prompt_app") as mock_app:
+ with (
+ mock.patch("pgcli.main.click.echo") as mock_echo,
+ mock.patch("pgcli.main.click.echo_via_pager") as mock_echo_via_pager,
+ mock.patch.object(cli, "prompt_app") as mock_app,
+ ):
yield cli, mock_echo, mock_echo_via_pager, mock_app
@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
- mock_cli.output.get_size.return_value = termsize(
- rows=term_height, columns=term_width
- )
+ mock_cli.output.get_size.return_value = termsize(rows=term_height, columns=term_width)
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_OFF):
cli.echo_via_pager(text)
@@ -292,9 +282,7 @@ def test_pset_pager_off(term_height, term_width, text, pset_pager_mocks):
@pytest.mark.parametrize("term_height,term_width,text", test_data, ids=test_ids)
def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
- mock_cli.output.get_size.return_value = termsize(
- rows=term_height, columns=term_width
- )
+ mock_cli.output.get_size.return_value = termsize(rows=term_height, columns=term_width)
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_ALWAYS):
cli.echo_via_pager(text)
@@ -306,14 +294,10 @@ def test_pset_pager_always(term_height, term_width, text, pset_pager_mocks):
pager_on_test_data = [l + (r,) for l, r in zip(test_data, use_pager_when_on)]
-@pytest.mark.parametrize(
- "term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids
-)
+@pytest.mark.parametrize("term_height,term_width,text,use_pager", pager_on_test_data, ids=test_ids)
def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mocks):
cli, mock_echo, mock_echo_via_pager, mock_cli = pset_pager_mocks
- mock_cli.output.get_size.return_value = termsize(
- rows=term_height, columns=term_width
- )
+ mock_cli.output.get_size.return_value = termsize(rows=term_height, columns=term_width)
with mock.patch.object(cli.pgspecial, "pager_config", PAGER_LONG_OUTPUT):
cli.echo_via_pager(text)
@@ -330,15 +314,14 @@ def test_pset_pager_on(term_height, term_width, text, use_pager, pset_pager_mock
"text,expected_length",
[
(
- "22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s",
+ "22200K .......\u001b[0m\u001b[91m... .......... ...\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... .........\u001b[0m\u001b[91m.\u001b[0m\u001b[91m \u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m.\u001b[0m\u001b[91m...... 50% 28.6K 12m55s", # noqa: E501
78,
),
("=\u001b[m=", 2),
("-\u001b]23\u0007-", 2),
],
)
-def test_color_pattern(text, expected_length, pset_pager_mocks):
- cli = pset_pager_mocks[0]
+def test_color_pattern(text, expected_length):
assert len(COLOR_CODE_REGEX.sub("", text)) == expected_length
@@ -405,34 +388,24 @@ def test_logfile_unwriteable_file(executor):
cli = PGCli(pgexecute=executor)
statement = r"\log-file forbidden.log"
with mock.patch("builtins.open") as mock_open:
- mock_open.side_effect = PermissionError(
- "[Errno 13] Permission denied: 'forbidden.log'"
- )
+ mock_open.side_effect = PermissionError("[Errno 13] Permission denied: 'forbidden.log'")
result = run(executor, statement, pgspecial=cli.pgspecial)
- assert result == [
- "[Errno 13] Permission denied: 'forbidden.log'\nLogfile capture disabled"
- ]
+ assert result == ["[Errno 13] Permission denied: 'forbidden.log'\nLogfile capture disabled"]
@dbtest
def test_watch_works(executor):
cli = PGCli(pgexecute=executor)
- def run_with_watch(
- query, target_call_count=1, expected_output="", expected_timing=None
- ):
+ def run_with_watch(query, target_call_count=1, expected_output="", expected_timing=None):
"""
:param query: Input to the CLI
:param target_call_count: Number of times the user lets the command run before Ctrl-C
:param expected_output: Substring expected to be found for each executed query
:param expected_timing: value `time.sleep` expected to be called with on every invocation
"""
- with mock.patch.object(cli, "echo_via_pager") as mock_echo, mock.patch(
- "pgcli.main.sleep"
- ) as mock_sleep:
- mock_sleep.side_effect = [None] * (target_call_count - 1) + [
- KeyboardInterrupt
- ]
+ with mock.patch.object(cli, "echo_via_pager") as mock_echo, mock.patch("pgcli.main.sleep") as mock_sleep:
+ mock_sleep.side_effect = [None] * (target_call_count - 1) + [KeyboardInterrupt]
cli.handle_watch_command(query)
# Validate that sleep was called with the right timing
for i in range(target_call_count - 1):
@@ -446,16 +419,11 @@ def run_with_watch(
with mock.patch("pgcli.main.click.secho") as mock_secho:
cli.handle_watch_command(r"\watch 2")
mock_secho.assert_called()
- assert (
- r"\watch cannot be used with an empty query"
- in mock_secho.call_args_list[0][0][0]
- )
+ assert r"\watch cannot be used with an empty query" in mock_secho.call_args_list[0][0][0]
# Usage 1: Run a query and then re-run it with \watch across two prompts.
run_with_watch("SELECT 111", expected_output="111")
- run_with_watch(
- "\\watch 10", target_call_count=2, expected_output="111", expected_timing=10
- )
+ run_with_watch("\\watch 10", target_call_count=2, expected_output="111", expected_timing=10)
# Usage 2: Run a query and \watch via the same prompt.
run_with_watch(
@@ -466,9 +434,7 @@ def run_with_watch(
)
# Usage 3: Re-run the last watched command with a new timing
- run_with_watch(
- "\\watch 5", target_call_count=4, expected_output="222", expected_timing=5
- )
+ run_with_watch("\\watch 5", target_call_count=4, expected_output="222", expected_timing=5)
def test_missing_rc_dir(tmpdir):
@@ -482,9 +448,7 @@ def test_quoted_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B")
- mock_connect.assert_called_with(
- database="testdb[", host="baz.com", user="bar^", passwd="]foo"
- )
+ mock_connect.assert_called_with(database="testdb[", host="baz.com", user="bar^", passwd="]foo")
def test_pg_service_file(tmpdir):
@@ -544,8 +508,7 @@ def test_ssl_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri(
- "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?"
- "sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem"
+ "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem"
)
mock_connect.assert_called_with(
database="testdb[",
@@ -563,17 +526,13 @@ def test_port_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb")
- mock_connect.assert_called_with(
- database="testdb", host="baz.com", user="bar", passwd="foo", port="2543"
- )
+ mock_connect.assert_called_with(database="testdb", host="baz.com", user="bar", passwd="foo", port="2543")
def test_multihost_db_uri(tmpdir):
with mock.patch.object(PGCli, "connect") as mock_connect:
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
- cli.connect_uri(
- "postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb"
- )
+ cli.connect_uri("postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb")
mock_connect.assert_called_with(
database="testdb",
host="baz1.com,baz2.com,baz3.com",
@@ -588,9 +547,7 @@ def test_application_name_db_uri(tmpdir):
mock_pgexecute.return_value = None
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
cli.connect_uri("postgres://bar@baz.com/?application_name=cow")
- mock_pgexecute.assert_called_with(
- "bar", "bar", "", "baz.com", "", "", notify_callback, application_name="cow"
- )
+ mock_pgexecute.assert_called_with("bar", "bar", "", "baz.com", "", "", notify_callback, application_name="cow")
@pytest.mark.parametrize(
diff --git a/tests/test_pgexecute.py b/tests/test_pgexecute.py
index f1cadfd68..2b8e87cc0 100644
--- a/tests/test_pgexecute.py
+++ b/tests/test_pgexecute.py
@@ -90,8 +90,8 @@ def test_expanded_slash_G(executor, pgspecial):
# Tests whether we reset the expanded output after a \G.
run(executor, """create table test(a boolean)""")
run(executor, """insert into test values(True)""")
- results = run(executor, r"""select * from test \G""", pgspecial=pgspecial)
- assert pgspecial.expanded_output == False
+ run(executor, r"""select * from test \G""", pgspecial=pgspecial)
+ assert pgspecial.expanded_output is False
@dbtest
@@ -132,9 +132,7 @@ def test_schemata_table_views_and_columns_query(executor):
# views
assert set(executor.views()) >= {("public", "d")}
- assert set(executor.view_columns()) >= {
- ("public", "d", "e", "integer", False, None)
- }
+ assert set(executor.view_columns()) >= {("public", "d", "e", "integer", False, None)}
@dbtest
@@ -147,9 +145,7 @@ def test_foreign_key_query(executor):
"create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)",
)
- assert set(executor.foreignkeys()) >= {
- ("schema1", "parent", "parentid", "schema2", "child", "motherid")
- }
+ assert set(executor.foreignkeys()) >= {("schema1", "parent", "parentid", "schema2", "child", "motherid")}
@dbtest
@@ -198,9 +194,7 @@ def test_functions_query(executor):
return_type="integer",
is_set_returning=True,
),
- function_meta_data(
- schema_name="schema1", func_name="func2", return_type="integer"
- ),
+ function_meta_data(schema_name="schema1", func_name="func2", return_type="integer"),
}
@@ -251,9 +245,7 @@ def test_invalid_syntax_verbose(executor):
@dbtest
def test_invalid_column_name(executor, exception_formatter):
- result = run(
- executor, "select invalid command", exception_formatter=exception_formatter
- )
+ result = run(executor, "select invalid command", exception_formatter=exception_formatter)
assert 'column "invalid" does not exist' in result[0]
@@ -268,9 +260,7 @@ def test_unicode_support_in_output(executor, expanded):
run(executor, "insert into unicodechars (t) values ('é')")
# See issue #24, this raises an exception without proper handling
- assert "é" in run(
- executor, "select * from unicodechars", join=True, expanded=expanded
- )
+ assert "é" in run(executor, "select * from unicodechars", join=True, expanded=expanded)
@dbtest
@@ -279,8 +269,8 @@ def test_not_is_special(executor, pgspecial):
query = "select 1"
result = list(executor.run(query, pgspecial=pgspecial))
success, is_special = result[0][5:]
- assert success == True
- assert is_special == False
+ assert success is True
+ assert is_special is False
@dbtest
@@ -289,8 +279,8 @@ def test_execute_from_file_no_arg(executor, pgspecial):
result = list(executor.run(r"\i", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert "missing required argument" in status
- assert success == False
- assert is_special == True
+ assert success is False
+ assert is_special is True
@dbtest
@@ -304,14 +294,12 @@ def test_execute_from_file_io_error(os, executor, pgspecial):
result = list(executor.run(r"\i test", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert status == "test"
- assert success == False
- assert is_special == True
+ assert success is False
+ assert is_special is True
@dbtest
-def test_execute_from_commented_file_that_executes_another_file(
- executor, pgspecial, tmpdir
-):
+def test_execute_from_commented_file_that_executes_another_file(executor, pgspecial, tmpdir):
# https://github.com/dbcli/pgcli/issues/1336
sqlfile1 = tmpdir.join("test01.sql")
sqlfile1.write("-- asdf \n\\h")
@@ -321,10 +309,10 @@ def test_execute_from_commented_file_that_executes_another_file(
rcfile = str(tmpdir.join("rcfile"))
print(rcfile)
cli = PGCli(pgexecute=executor, pgclirc_file=rcfile)
- assert cli != None
+ assert cli is not None
statement = "--comment\n\\h"
result = run(executor, statement, pgspecial=cli.pgspecial)
- assert result != None
+ assert result is not None
assert result[0].find("ALTER TABLE")
@@ -333,38 +321,38 @@ def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir):
# just some base cases that should work also
statement = "--comment\nselect now();"
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("now") >= 0
statement = "/*comment*/\nselect now();"
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("now") >= 0
# https://github.com/dbcli/pgcli/issues/1362
statement = "--comment\n\\h"
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0
statement = "--comment1\n--comment2\n\\h"
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0
- statement = "/*comment*/\n\h;"
+ statement = "/*comment*/\n\\h;"
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0
- statement = """/*comment1
+ statement = r"""/*comment1
comment2*/
\h"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0
@@ -374,43 +362,43 @@ def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir):
comment4*/
\\h"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0
- statement = " /*comment*/\n\h;"
+ statement = " /*comment*/\n\\h;"
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0
- statement = "/*comment\ncomment line2*/\n\h;"
+ statement = "/*comment\ncomment line2*/\n\\h;"
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0
- statement = " /*comment\ncomment line2*/\n\h;"
+ statement = " /*comment\ncomment line2*/\n\\h;"
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("ALTER") >= 0
assert result[1].find("ABORT") >= 0
statement = """\\h /*comment4 */"""
result = run(executor, statement, pgspecial=pgspecial)
print(result)
- assert result != None
+ assert result is not None
assert result[0].find("No help") >= 0
# TODO: we probably don't want to do this but sqlparse is not parsing things well
# we relly want it to find help but right now, sqlparse isn't dropping the /*comment*/
# style comments after command
- statement = """/*comment1*/
+ statement = r"""/*comment1*/
\h
/*comment4 */"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[0].find("No help") >= 0
# TODO: same for this one
@@ -422,7 +410,7 @@ def test_execute_commented_first_line_and_special(executor, pgspecial, tmpdir):
comment5
comment6*/"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[0].find("No help") >= 0
@@ -433,12 +421,12 @@ def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir):
# just some base cases that should work also
statement = "--comment\nselect now();"
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("now") >= 0
statement = "/*comment*/\nselect now();"
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[1].find("now") >= 0
# this simulates the original error (1403) without having to add/drop tables
@@ -448,26 +436,26 @@ def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir):
# test that the statement works
statement = """VALUES (1, 'one'), (2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[5].find("three") >= 0
# test the statement with a \n in the middle
statement = """VALUES (1, 'one'),\n (2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[5].find("three") >= 0
# test the statement with a newline in the middle
statement = """VALUES (1, 'one'),
(2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[5].find("three") >= 0
# now add a single comment line
statement = """--comment\nVALUES (1, 'one'), (2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[5].find("three") >= 0
# doing without special char \n
@@ -475,13 +463,13 @@ def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir):
VALUES (1,'one'),
(2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[5].find("three") >= 0
# two comment lines
statement = """--comment\n--comment2\nVALUES (1,'one'), (2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[5].find("three") >= 0
# doing without special char \n
@@ -490,7 +478,7 @@ def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir):
VALUES (1,'one'), (2, 'two'), (3, 'three');
"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[5].find("three") >= 0
# multiline comment + newline in middle of the statement
@@ -500,7 +488,7 @@ def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir):
VALUES (1,'one'),
(2, 'two'), (3, 'three');"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[5].find("three") >= 0
# multiline comment + newline in middle of the statement
@@ -513,7 +501,7 @@ def test_execute_commented_first_line_and_normal(executor, pgspecial, tmpdir):
--comment4
--comment5"""
result = run(executor, statement, pgspecial=pgspecial)
- assert result != None
+ assert result is not None
assert result[5].find("three") >= 0
@@ -582,9 +570,7 @@ def test_unicode_support_in_enum_type(executor):
def test_json_renders_without_u_prefix(executor, expanded):
run(executor, "create table jsontest(d json)")
run(executor, """insert into jsontest (d) values ('{"name": "Éowyn"}')""")
- result = run(
- executor, "SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded
- )
+ result = run(executor, "SELECT d FROM jsontest LIMIT 1", join=True, expanded=expanded)
assert '{"name": "Éowyn"}' in result
@@ -593,9 +579,7 @@ def test_json_renders_without_u_prefix(executor, expanded):
def test_jsonb_renders_without_u_prefix(executor, expanded):
run(executor, "create table jsonbtest(d jsonb)")
run(executor, """insert into jsonbtest (d) values ('{"name": "Éowyn"}')""")
- result = run(
- executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded
- )
+ result = run(executor, "SELECT d FROM jsonbtest LIMIT 1", join=True, expanded=expanded)
assert '{"name": "Éowyn"}' in result
@@ -603,28 +587,10 @@ def test_jsonb_renders_without_u_prefix(executor, expanded):
@dbtest
def test_date_time_types(executor):
run(executor, "SET TIME ZONE UTC")
- assert (
- run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3]
- == "| 00:00:00 |"
- )
- assert (
- run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split(
- "\n"
- )[3]
- == "| 00:00:00+14:59 |"
- )
- assert (
- run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[
- 3
- ]
- == "| 4713-01-01 BC |"
- )
- assert (
- run(
- executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True
- ).split("\n")[3]
- == "| 4713-01-01 00:00:00 BC |"
- )
+ assert run(executor, "SELECT (CAST('00:00:00' AS time))", join=True).split("\n")[3] == "| 00:00:00 |"
+ assert run(executor, "SELECT (CAST('00:00:00+14:59' AS timetz))", join=True).split("\n")[3] == "| 00:00:00+14:59 |"
+ assert run(executor, "SELECT (CAST('4713-01-01 BC' AS date))", join=True).split("\n")[3] == "| 4713-01-01 BC |"
+ assert run(executor, "SELECT (CAST('4713-01-01 00:00:00 BC' AS timestamp))", join=True).split("\n")[3] == "| 4713-01-01 00:00:00 BC |"
assert (
run(
executor,
@@ -634,10 +600,7 @@ def test_date_time_types(executor):
== "| 4713-01-01 00:00:00+00 BC |"
)
assert (
- run(
- executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True
- ).split("\n")[3]
- == "| -123456789 days, 12:23:56 |"
+ run(executor, "SELECT (CAST('-123456789 days 12:23:56' AS interval))", join=True).split("\n")[3] == "| -123456789 days, 12:23:56 |"
)
@@ -670,20 +633,14 @@ def test_raises_with_no_formatter(executor, sql):
@dbtest
def test_on_error_resume(executor, exception_formatter):
sql = "select 1; error; select 1;"
- result = list(
- executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter)
- )
+ result = list(executor.run(sql, on_error_resume=True, exception_formatter=exception_formatter))
assert len(result) == 3
@dbtest
def test_on_error_stop(executor, exception_formatter):
sql = "select 1; error; select 1;"
- result = list(
- executor.run(
- sql, on_error_resume=False, exception_formatter=exception_formatter
- )
- )
+ result = list(executor.run(sql, on_error_resume=False, exception_formatter=exception_formatter))
assert len(result) == 2
@@ -697,7 +654,7 @@ def test_on_error_stop(executor, exception_formatter):
@dbtest
def test_nonexistent_function_definition(executor):
with pytest.raises(RuntimeError):
- result = executor.view_definition("there_is_no_such_function")
+ executor.view_definition("there_is_no_such_function")
@dbtest
@@ -713,7 +670,7 @@ def test_function_definition(executor):
$function$
""",
)
- result = executor.function_definition("the_number_three")
+ executor.function_definition("the_number_three")
@dbtest
@@ -764,9 +721,9 @@ def test_view_definition(executor):
@dbtest
def test_nonexistent_view_definition(executor):
with pytest.raises(RuntimeError):
- result = executor.view_definition("there_is_no_such_view")
+ executor.view_definition("there_is_no_such_view")
with pytest.raises(RuntimeError):
- result = executor.view_definition("mvw1")
+ executor.view_definition("mvw1")
@dbtest
@@ -775,9 +732,7 @@ def test_short_host(executor):
assert executor.short_host == "localhost"
with patch.object(executor, "host", "localhost.example.org"):
assert executor.short_host == "localhost"
- with patch.object(
- executor, "host", "localhost1.example.org,localhost2.example.org"
- ):
+ with patch.object(executor, "host", "localhost1.example.org,localhost2.example.org"):
assert executor.short_host == "localhost1"
with patch.object(executor, "host", "ec2-11-222-333-444.compute-1.amazonaws.com"):
assert executor.short_host == "ec2-11-222-333-444"
@@ -814,9 +769,7 @@ def test_exit_without_active_connection(executor):
aliases=(":q",),
)
- with patch.object(
- executor.conn, "cursor", side_effect=psycopg.InterfaceError("I'm broken!")
- ):
+ with patch.object(executor.conn, "cursor", side_effect=psycopg.InterfaceError("I'm broken!")):
# we should be able to quit the app, even without active connection
run(executor, "\\q", pgspecial=pgspecial)
quit_handler.assert_called_once()
diff --git a/tests/test_smart_completion_multiple_schemata.py b/tests/test_smart_completion_multiple_schemata.py
index 5c9c9af48..98feb02db 100644
--- a/tests/test_smart_completion_multiple_schemata.py
+++ b/tests/test_smart_completion_multiple_schemata.py
@@ -11,7 +11,6 @@
wildcard_expansion,
column,
get_result,
- result_set,
qual,
no_qual,
parametrize,
@@ -125,9 +124,7 @@
@parametrize("table", ["users", '"users"'])
def test_suggested_column_names_from_shadowed_visible_table(completer, table):
result = get_result(completer, "SELECT FROM " + table, len("SELECT "))
- assert completions_to_set(result) == completions_to_set(
- testdata.columns_functions_and_keywords("users")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("users"))
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
@@ -140,18 +137,14 @@ def test_suggested_column_names_from_shadowed_visible_table(completer, table):
)
def test_suggested_column_names_from_qualified_shadowed_table(completer, text):
result = get_result(completer, text, position=text.find(" ") + 1)
- assert completions_to_set(result) == completions_to_set(
- testdata.columns_functions_and_keywords("users", "custom")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("users", "custom"))
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
@parametrize("text", ["WITH users as (SELECT 1 AS foo) SELECT from users"])
def test_suggested_column_names_from_cte(completer, text):
result = completions_to_set(get_result(completer, text, text.find(" ") + 1))
- assert result == completions_to_set(
- [column("foo")] + testdata.functions_and_keywords()
- )
+ assert result == completions_to_set([column("foo")] + testdata.functions_and_keywords())
@parametrize("completer", completers(casing=False))
@@ -166,14 +159,12 @@ def test_suggested_column_names_from_cte(completer, text):
)
def test_suggested_join_conditions(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [
- alias("users"),
- alias("shipments"),
- name_join("shipments.id = users.id"),
- fk_join("shipments.user_id = users.id"),
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ alias("users"),
+ alias("shipments"),
+ name_join("shipments.id = users.id"),
+ fk_join("shipments.user_id = users.id"),
+ ])
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
@@ -192,17 +183,14 @@ def test_suggested_join_conditions(completer, text):
def test_suggested_joins(completer, query, tbl):
result = get_result(completer, query.format(tbl))
assert completions_to_set(result) == completions_to_set(
- testdata.schemas_and_from_clause_items()
- + [join(f"custom.shipments ON shipments.user_id = {tbl}.id")]
+ testdata.schemas_and_from_clause_items() + [join(f"custom.shipments ON shipments.user_id = {tbl}.id")]
)
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
def test_suggested_column_names_from_schema_qualifed_table(completer):
result = get_result(completer, "SELECT from custom.products", len("SELECT "))
- assert completions_to_set(result) == completions_to_set(
- testdata.columns_functions_and_keywords("products", "custom")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("products", "custom"))
@parametrize(
@@ -216,19 +204,13 @@ def test_suggested_column_names_from_schema_qualifed_table(completer):
)
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggested_columns_with_insert(completer, text):
- assert completions_to_set(get_result(completer, text)) == completions_to_set(
- testdata.columns("orders")
- )
+ assert completions_to_set(get_result(completer, text)) == completions_to_set(testdata.columns("orders"))
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
def test_suggested_column_names_in_function(completer):
- result = get_result(
- completer, "SELECT MAX( from custom.products", len("SELECT MAX(")
- )
- assert completions_to_set(result) == completions_to_set(
- testdata.columns_functions_and_keywords("products", "custom")
- )
+ result = get_result(completer, "SELECT MAX( from custom.products", len("SELECT MAX("))
+ assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("products", "custom"))
@parametrize("completer", completers(casing=False, aliasing=False))
@@ -237,9 +219,7 @@ def test_suggested_column_names_in_function(completer):
["SELECT * FROM Custom.", "SELECT * FROM custom.", 'SELECT * FROM "custom".'],
)
@parametrize("use_leading_double_quote", [False, True])
-def test_suggested_table_names_with_schema_dot(
- completer, text, use_leading_double_quote
-):
+def test_suggested_table_names_with_schema_dot(completer, text, use_leading_double_quote):
if use_leading_double_quote:
text += '"'
start_position = -1
@@ -247,17 +227,13 @@ def test_suggested_table_names_with_schema_dot(
start_position = 0
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- testdata.from_clause_items("custom", start_position)
- )
+ assert completions_to_set(result) == completions_to_set(testdata.from_clause_items("custom", start_position))
@parametrize("completer", completers(casing=False, aliasing=False))
@parametrize("text", ['SELECT * FROM "Custom".'])
@parametrize("use_leading_double_quote", [False, True])
-def test_suggested_table_names_with_schema_dot2(
- completer, text, use_leading_double_quote
-):
+def test_suggested_table_names_with_schema_dot2(completer, text, use_leading_double_quote):
if use_leading_double_quote:
text += '"'
start_position = -1
@@ -265,37 +241,25 @@ def test_suggested_table_names_with_schema_dot2(
start_position = 0
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- testdata.from_clause_items("Custom", start_position)
- )
+ assert completions_to_set(result) == completions_to_set(testdata.from_clause_items("Custom", start_position))
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggested_column_names_with_qualified_alias(completer):
result = get_result(completer, "SELECT p. from custom.products p", len("SELECT p."))
- assert completions_to_set(result) == completions_to_set(
- testdata.columns("products", "custom")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns("products", "custom"))
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
def test_suggested_multiple_column_names(completer):
- result = get_result(
- completer, "SELECT id, from custom.products", len("SELECT id, ")
- )
- assert completions_to_set(result) == completions_to_set(
- testdata.columns_functions_and_keywords("products", "custom")
- )
+ result = get_result(completer, "SELECT id, from custom.products", len("SELECT id, "))
+ assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("products", "custom"))
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggested_multiple_column_names_with_alias(completer):
- result = get_result(
- completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u.")
- )
- assert completions_to_set(result) == completions_to_set(
- testdata.columns("products", "custom")
- )
+ result = get_result(completer, "SELECT p.id, p. from custom.products p", len("SELECT u.id, u."))
+ assert completions_to_set(result) == completions_to_set(testdata.columns("products", "custom"))
@parametrize("completer", completers(filtr=True, casing=False))
@@ -307,19 +271,15 @@ def test_suggested_multiple_column_names_with_alias(completer):
],
)
def test_suggestions_after_on(completer, text):
- position = len(
- "SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON "
- )
+ position = len("SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON ")
result = get_result(completer, text, position)
- assert completions_to_set(result) == completions_to_set(
- [
- alias("x"),
- alias("y"),
- name_join("y.price = x.price"),
- name_join("y.product_name = x.product_name"),
- name_join("y.id = x.id"),
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ alias("x"),
+ alias("y"),
+ name_join("y.price = x.price"),
+ name_join("y.product_name = x.product_name"),
+ name_join("y.id = x.id"),
+ ])
@parametrize("completer", completers())
@@ -333,32 +293,26 @@ def test_suggested_aliases_after_on_right_side(completer):
def test_table_names_after_from(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- testdata.schemas_and_from_clause_items()
- )
+ assert completions_to_set(result) == completions_to_set(testdata.schemas_and_from_clause_items())
@parametrize("completer", completers(filtr=True, casing=False))
def test_schema_qualified_function_name(completer):
text = "SELECT custom.func"
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [
- function("func3()", -len("func")),
- function("set_returning_func()", -len("func")),
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ function("func3()", -len("func")),
+ function("set_returning_func()", -len("func")),
+ ])
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
def test_schema_qualified_function_name_after_from(completer):
text = "SELECT * FROM custom.set_r"
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [
- function("set_returning_func()", -len("func")),
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ function("set_returning_func()", -len("func")),
+ ])
@parametrize("completer", completers(filtr=True, casing=False, aliasing=False))
@@ -373,11 +327,9 @@ def test_unqualified_function_name_in_search_path(completer):
completer.search_path = ["public", "custom"]
text = "SELECT * FROM set_r"
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [
- function("set_returning_func()", -len("func")),
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ function("set_returning_func()", -len("func")),
+ ])
@parametrize("completer", completers(filtr=True, casing=False))
@@ -397,12 +349,8 @@ def test_schema_qualified_type_name(completer, text):
@parametrize("completer", completers(filtr=True, casing=False))
def test_suggest_columns_from_aliased_set_returning_function(completer):
- result = get_result(
- completer, "select f. from custom.set_returning_func() f", len("select f.")
- )
- assert completions_to_set(result) == completions_to_set(
- testdata.columns("set_returning_func", "custom", "functions")
- )
+ result = get_result(completer, "select f. from custom.set_returning_func() f", len("select f."))
+ assert completions_to_set(result) == completions_to_set(testdata.columns("set_returning_func", "custom", "functions"))
@parametrize("completer", completers(filtr=True, casing=False, qualify=no_qual))
@@ -499,10 +447,7 @@ def test_wildcard_column_expansion_with_two_tables(completer):
completions = get_result(completer, text, position)
- cols = (
- '"select".id, "select"."localtime", "select"."ABC", '
- "users.id, users.phone_number"
- )
+ cols = '"select".id, "select"."localtime", "select"."ABC", users.id, users.phone_number'
expected = [wildcard_expansion(cols)]
assert completions == expected
@@ -535,21 +480,15 @@ def test_wildcard_column_expansion_with_two_tables_and_parent(completer):
def test_suggest_columns_from_unquoted_table(completer, text):
position = len("SELECT U.")
result = get_result(completer, text, position)
- assert completions_to_set(result) == completions_to_set(
- testdata.columns("users", "custom")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns("users", "custom"))
@parametrize("completer", completers(filtr=True, casing=False))
-@parametrize(
- "text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U']
-)
+@parametrize("text", ['SELECT U. FROM custom."Users" U', 'SELECT U. FROM "custom"."Users" U'])
def test_suggest_columns_from_quoted_table(completer, text):
position = len("SELECT U.")
result = get_result(completer, text, position)
- assert completions_to_set(result) == completions_to_set(
- testdata.columns("Users", "custom")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns("Users", "custom"))
texts = ["SELECT * FROM ", "SELECT * FROM public.Orders O CROSS JOIN "]
@@ -559,9 +498,7 @@ def test_suggest_columns_from_quoted_table(completer, text):
@parametrize("text", texts)
def test_schema_or_visible_table_completion(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- testdata.schemas_and_from_clause_items()
- )
+ assert completions_to_set(result) == completions_to_set(testdata.schemas_and_from_clause_items())
@parametrize("completer", completers(aliasing=True, casing=False, filtr=True))
@@ -703,9 +640,7 @@ def test_column_alias_search(completer):
@parametrize("completer", completers(casing=True))
def test_column_alias_search_qualified(completer):
- result = get_result(
- completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei")
- )
+ result = get_result(completer, "SELECT E.ei FROM blog.Entries E", len("SELECT E.ei"))
cols = ("EntryID", "EntryTitle")
assert result[:3] == [column(c, -2) for c in cols]
@@ -713,9 +648,7 @@ def test_column_alias_search_qualified(completer):
@parametrize("completer", completers(casing=False, filtr=False, aliasing=False))
def test_schema_object_order(completer):
result = get_result(completer, "SELECT * FROM u")
- assert result[:3] == [
- table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users")
- ]
+ assert result[:3] == [table(t, pos=-1) for t in ("users", 'custom."Users"', "custom.users")]
@parametrize("completer", completers(casing=False, filtr=False, aliasing=False))
@@ -723,8 +656,7 @@ def test_all_schema_objects(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
assert completions_to_set(result) >= completions_to_set(
- [table(x) for x in ("orders", '"select"', "custom.shipments")]
- + [function(x + "()") for x in ("func2",)]
+ [table(x) for x in ("orders", '"select"', "custom.shipments")] + [function(x + "()") for x in ("func2",)]
)
@@ -733,8 +665,7 @@ def test_all_schema_objects_with_casing(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
assert completions_to_set(result) >= completions_to_set(
- [table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")]
- + [function(x + "()") for x in ("func2",)]
+ [table(x) for x in ("Orders", '"select"', "CUSTOM.shipments")] + [function(x + "()") for x in ("func2",)]
)
@@ -743,8 +674,7 @@ def test_all_schema_objects_with_aliases(completer):
text = "SELECT * FROM "
result = get_result(completer, text)
assert completions_to_set(result) >= completions_to_set(
- [table(x) for x in ("orders o", '"select" s', "custom.shipments s")]
- + [function(x) for x in ("func2() f",)]
+ [table(x) for x in ("orders o", '"select" s', "custom.shipments s")] + [function(x) for x in ("func2() f",)]
)
@@ -752,6 +682,4 @@ def test_all_schema_objects_with_aliases(completer):
def test_set_schema(completer):
text = "SET SCHEMA "
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")]
- )
+ assert completions_to_set(result) == completions_to_set([schema("'blog'"), schema("'Custom'"), schema("'custom'"), schema("'public'")])
diff --git a/tests/test_smart_completion_public_schema_only.py b/tests/test_smart_completion_public_schema_only.py
index db1fe0a39..92bfff765 100644
--- a/tests/test_smart_completion_public_schema_only.py
+++ b/tests/test_smart_completion_public_schema_only.py
@@ -12,7 +12,6 @@
column,
wildcard_expansion,
get_result,
- result_set,
qual,
no_qual,
parametrize,
@@ -68,19 +67,11 @@
]
cased_tbls = ["Users", "Orders"]
cased_views = ["User_Emails", "Functions"]
-casing = (
- ["SELECT", "PUBLIC"]
- + cased_func_names
- + cased_tbls
- + cased_views
- + cased_users_col_names
- + cased_users2_col_names
-)
+casing = ["SELECT", "PUBLIC"] + cased_func_names + cased_tbls + cased_views + cased_users_col_names + cased_users2_col_names
# Lists for use in assertions
-cased_funcs = [
- function(f)
- for f in ("Custom_Fun()", "_custom_fun()", "Custom_Func1()", "custom_func2()")
-] + [function("set_returning_func(x := , y := )", display="set_returning_func(x, y)")]
+cased_funcs = [function(f) for f in ("Custom_Fun()", "_custom_fun()", "Custom_Func1()", "custom_func2()")] + [
+ function("set_returning_func(x := , y := )", display="set_returning_func(x, y)")
+]
cased_tbls = [table(t) for t in (cased_tbls + ['"Users"', '"select"'])]
cased_rels = [view(t) for t in cased_views] + cased_funcs + cased_tbls
cased_users_cols = [column(c) for c in cased_users_col_names]
@@ -132,25 +123,19 @@ def test_function_column_name(completer):
len("SELECT * FROM Functions WHERE function:"),
len("SELECT * FROM Functions WHERE function:text") + 1,
):
- assert [] == get_result(
- completer, "SELECT * FROM Functions WHERE function:text"[:l]
- )
+ assert [] == get_result(completer, "SELECT * FROM Functions WHERE function:text"[:l])
@parametrize("action", ["ALTER", "DROP", "CREATE", "CREATE OR REPLACE"])
@parametrize("completer", completers())
def test_drop_alter_function(completer, action):
- assert get_result(completer, action + " FUNCTION set_ret") == [
- function("set_returning_func(x integer, y integer)", -len("set_ret"))
- ]
+ assert get_result(completer, action + " FUNCTION set_ret") == [function("set_returning_func(x integer, y integer)", -len("set_ret"))]
@parametrize("completer", completers())
def test_empty_string_completion(completer):
result = get_result(completer, "")
- assert completions_to_set(
- testdata.keywords() + testdata.specials()
- ) == completions_to_set(result)
+ assert completions_to_set(testdata.keywords() + testdata.specials()) == completions_to_set(result)
@parametrize("completer", completers())
@@ -162,19 +147,17 @@ def test_select_keyword_completion(completer):
@parametrize("completer", completers())
def test_builtin_function_name_completion(completer):
result = get_result(completer, "SELECT MA")
- assert completions_to_set(result) == completions_to_set(
- [
- function("MAKE_DATE", -2),
- function("MAKE_INTERVAL", -2),
- function("MAKE_TIME", -2),
- function("MAKE_TIMESTAMP", -2),
- function("MAKE_TIMESTAMPTZ", -2),
- function("MASKLEN", -2),
- function("MAX", -2),
- keyword("MAXEXTENTS", -2),
- keyword("MATERIALIZED VIEW", -2),
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ function("MAKE_DATE", -2),
+ function("MAKE_INTERVAL", -2),
+ function("MAKE_TIME", -2),
+ function("MAKE_TIMESTAMP", -2),
+ function("MAKE_TIMESTAMPTZ", -2),
+ function("MASKLEN", -2),
+ function("MAX", -2),
+ keyword("MAXEXTENTS", -2),
+ keyword("MATERIALIZED VIEW", -2),
+ ])
@parametrize("completer", completers())
@@ -189,58 +172,47 @@ def test_builtin_function_matches_only_at_start(completer):
@parametrize("completer", completers(casing=False, aliasing=False))
def test_user_function_name_completion(completer):
result = get_result(completer, "SELECT cu")
- assert completions_to_set(result) == completions_to_set(
- [
- function("custom_fun()", -2),
- function("_custom_fun()", -2),
- function("custom_func1()", -2),
- function("custom_func2()", -2),
- function("CURRENT_DATE", -2),
- function("CURRENT_TIMESTAMP", -2),
- function("CUME_DIST", -2),
- function("CURRENT_TIME", -2),
- keyword("CURRENT", -2),
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ function("custom_fun()", -2),
+ function("_custom_fun()", -2),
+ function("custom_func1()", -2),
+ function("custom_func2()", -2),
+ function("CURRENT_DATE", -2),
+ function("CURRENT_TIMESTAMP", -2),
+ function("CUME_DIST", -2),
+ function("CURRENT_TIME", -2),
+ keyword("CURRENT", -2),
+ ])
@parametrize("completer", completers(casing=False, aliasing=False))
def test_user_function_name_completion_matches_anywhere(completer):
result = get_result(completer, "SELECT om")
- assert completions_to_set(result) == completions_to_set(
- [
- function("custom_fun()", -2),
- function("_custom_fun()", -2),
- function("custom_func1()", -2),
- function("custom_func2()", -2),
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ function("custom_fun()", -2),
+ function("_custom_fun()", -2),
+ function("custom_func1()", -2),
+ function("custom_func2()", -2),
+ ])
@parametrize("completer", completers(casing=True))
def test_list_functions_for_special(completer):
result = get_result(completer, r"\df ")
- assert completions_to_set(result) == completions_to_set(
- [schema("PUBLIC")] + [function(f) for f in cased_func_names]
- )
+ assert completions_to_set(result) == completions_to_set([schema("PUBLIC")] + [function(f) for f in cased_func_names])
@parametrize("completer", completers(casing=False, qualify=no_qual))
def test_suggested_column_names_from_visible_table(completer):
result = get_result(completer, "SELECT from users", len("SELECT "))
- assert completions_to_set(result) == completions_to_set(
- testdata.columns_functions_and_keywords("users")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("users"))
@parametrize("completer", completers(casing=True, qualify=no_qual))
def test_suggested_cased_column_names(completer):
result = get_result(completer, "SELECT from users", len("SELECT "))
assert completions_to_set(result) == completions_to_set(
- cased_funcs
- + cased_users_cols
- + testdata.builtin_functions()
- + testdata.keywords()
+ cased_funcs + cased_users_cols + testdata.builtin_functions() + testdata.keywords()
)
@@ -250,9 +222,7 @@ def test_suggested_auto_qualified_column_names(text, completer):
position = text.index(" ") + 1
cols = [column(c.lower()) for c in cased_users_col_names]
result = get_result(completer, text, position)
- assert completions_to_set(result) == completions_to_set(
- cols + testdata.functions_and_keywords()
- )
+ assert completions_to_set(result) == completions_to_set(cols + testdata.functions_and_keywords())
@parametrize("completer", completers(casing=False, qualify=qual))
@@ -268,9 +238,7 @@ def test_suggested_auto_qualified_column_names_two_tables(text, completer):
cols = [column("U." + c.lower()) for c in cased_users_col_names]
cols += [column('"Users".' + c.lower()) for c in cased_users2_col_names]
result = get_result(completer, text, position)
- assert completions_to_set(result) == completions_to_set(
- cols + testdata.functions_and_keywords()
- )
+ assert completions_to_set(result) == completions_to_set(cols + testdata.functions_and_keywords())
@parametrize("completer", completers(casing=True, qualify=["always"]))
@@ -287,17 +255,13 @@ def test_suggested_cased_always_qualified_column_names(completer):
position = len("SELECT ")
cols = [column("users." + c) for c in cased_users_col_names]
result = get_result(completer, text, position)
- assert completions_to_set(result) == completions_to_set(
- cased_funcs + cols + testdata.builtin_functions() + testdata.keywords()
- )
+ assert completions_to_set(result) == completions_to_set(cased_funcs + cols + testdata.builtin_functions() + testdata.keywords())
@parametrize("completer", completers(casing=False, qualify=no_qual))
def test_suggested_column_names_in_function(completer):
result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX("))
- assert completions_to_set(result) == completions_to_set(
- testdata.columns_functions_and_keywords("users")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("users"))
@parametrize("completer", completers(casing=False))
@@ -315,24 +279,18 @@ def test_suggested_column_names_with_alias(completer):
@parametrize("completer", completers(casing=False, qualify=no_qual))
def test_suggested_multiple_column_names(completer):
result = get_result(completer, "SELECT id, from users u", len("SELECT id, "))
- assert completions_to_set(result) == completions_to_set(
- testdata.columns_functions_and_keywords("users")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("users"))
@parametrize("completer", completers(casing=False))
def test_suggested_multiple_column_names_with_alias(completer):
- result = get_result(
- completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.")
- )
+ result = get_result(completer, "SELECT u.id, u. from users u", len("SELECT u.id, u."))
assert completions_to_set(result) == completions_to_set(testdata.columns("users"))
@parametrize("completer", completers(casing=True))
def test_suggested_cased_column_names_with_alias(completer):
- result = get_result(
- completer, "SELECT u.id, u. from users u", len("SELECT u.id, u.")
- )
+ result = get_result(completer, "SELECT u.id, u. from users u", len("SELECT u.id, u."))
assert completions_to_set(result) == completions_to_set(cased_users_cols)
@@ -378,18 +336,14 @@ def test_suggest_columns_after_three_way_join(completer):
@parametrize("text", join_condition_texts)
def test_suggested_join_conditions(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [alias("U"), alias("U2"), fk_join("U2.userid = U.id")]
- )
+ assert completions_to_set(result) == completions_to_set([alias("U"), alias("U2"), fk_join("U2.userid = U.id")])
@parametrize("completer", completers(casing=True))
@parametrize("text", join_condition_texts)
def test_cased_join_conditions(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [alias("U"), alias("U2"), fk_join("U2.UserID = U.ID")]
- )
+ assert completions_to_set(result) == completions_to_set([alias("U"), alias("U2"), fk_join("U2.UserID = U.ID")])
@parametrize("completer", completers(casing=False))
@@ -435,9 +389,7 @@ def test_suggested_join_conditions_with_invalid_qualifier(completer, text):
)
def test_suggested_join_conditions_with_invalid_table(completer, text, ref):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [alias("users"), alias(ref)]
- )
+ assert completions_to_set(result) == completions_to_set([alias("users"), alias(ref)])
@parametrize("completer", completers(casing=False, aliasing=False))
@@ -531,8 +483,7 @@ def test_aliased_joins(completer, text):
def test_suggested_joins_quoted_schema_qualified_table(completer, text):
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
- testdata.schemas_and_from_clause_items()
- + [join('public.users ON users.id = "Users".userid')]
+ testdata.schemas_and_from_clause_items() + [join('public.users ON users.id = "Users".userid')]
)
@@ -547,14 +498,12 @@ def test_suggested_joins_quoted_schema_qualified_table(completer, text):
def test_suggested_aliases_after_on(completer, text):
position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ")
result = get_result(completer, text, position)
- assert completions_to_set(result) == completions_to_set(
- [
- alias("u"),
- name_join("o.id = u.id"),
- name_join("o.email = u.email"),
- alias("o"),
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ alias("u"),
+ name_join("o.id = u.id"),
+ name_join("o.email = u.email"),
+ alias("o"),
+ ])
@parametrize("completer", completers())
@@ -582,14 +531,12 @@ def test_suggested_aliases_after_on_right_side(completer, text):
def test_suggested_tables_after_on(completer, text):
position = len("SELECT users.name, orders.id FROM users JOIN orders ON ")
result = get_result(completer, text, position)
- assert completions_to_set(result) == completions_to_set(
- [
- name_join("orders.id = users.id"),
- name_join("orders.email = users.email"),
- alias("users"),
- alias("orders"),
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ name_join("orders.id = users.id"),
+ name_join("orders.email = users.email"),
+ alias("users"),
+ alias("orders"),
+ ])
@parametrize("completer", completers(casing=False))
@@ -601,13 +548,9 @@ def test_suggested_tables_after_on(completer, text):
],
)
def test_suggested_tables_after_on_right_side(completer, text):
- position = len(
- "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = "
- )
+ position = len("SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ")
result = get_result(completer, text, position)
- assert completions_to_set(result) == completions_to_set(
- [alias("users"), alias("orders")]
- )
+ assert completions_to_set(result) == completions_to_set([alias("users"), alias("orders")])
@parametrize("completer", completers(casing=False))
@@ -620,9 +563,7 @@ def test_suggested_tables_after_on_right_side(completer, text):
)
def test_join_using_suggests_common_columns(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [column("id"), column("email")]
- )
+ assert completions_to_set(result) == completions_to_set([column("id"), column("email")])
@parametrize("completer", completers(casing=False))
@@ -638,9 +579,7 @@ def test_join_using_suggests_common_columns(completer, text):
def test_join_using_suggests_from_last_table(completer, text):
position = text.index("()") + 1
result = get_result(completer, text, position)
- assert completions_to_set(result) == completions_to_set(
- [column("id"), column("email")]
- )
+ assert completions_to_set(result) == completions_to_set([column("id"), column("email")])
@parametrize("completer", completers(casing=False))
@@ -653,9 +592,7 @@ def test_join_using_suggests_from_last_table(completer, text):
)
def test_join_using_suggests_columns_after_first_column(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [column("id"), column("email")]
- )
+ assert completions_to_set(result) == completions_to_set([column("id"), column("email")])
@parametrize("completer", completers(casing=False, aliasing=False))
@@ -669,9 +606,7 @@ def test_join_using_suggests_columns_after_first_column(completer, text):
)
def test_table_names_after_from(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- testdata.schemas_and_from_clause_items()
- )
+ assert completions_to_set(result) == completions_to_set(testdata.schemas_and_from_clause_items())
assert [c.text for c in result] == [
"public",
"orders",
@@ -691,9 +626,7 @@ def test_table_names_after_from(completer, text):
@parametrize("completer", completers(casing=False, qualify=no_qual))
def test_auto_escaped_col_names(completer):
result = get_result(completer, 'SELECT from "select"', len("SELECT "))
- assert completions_to_set(result) == completions_to_set(
- testdata.columns_functions_and_keywords("select")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("select"))
@parametrize("completer", completers(aliasing=False))
@@ -717,9 +650,7 @@ def test_allow_leading_double_quote_in_last_word(completer):
)
def test_suggest_datatype(text, completer):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- testdata.schemas() + testdata.types() + testdata.builtin_datatypes()
- )
+ assert completions_to_set(result) == completions_to_set(testdata.schemas() + testdata.types() + testdata.builtin_datatypes())
@parametrize("completer", completers(casing=False))
@@ -731,19 +662,13 @@ def test_suggest_columns_from_escaped_table_alias(completer):
@parametrize("completer", completers(casing=False, qualify=no_qual))
def test_suggest_columns_from_set_returning_function(completer):
result = get_result(completer, "select from set_returning_func()", len("select "))
- assert completions_to_set(result) == completions_to_set(
- testdata.columns_functions_and_keywords("set_returning_func", typ="functions")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns_functions_and_keywords("set_returning_func", typ="functions"))
@parametrize("completer", completers(casing=False))
def test_suggest_columns_from_aliased_set_returning_function(completer):
- result = get_result(
- completer, "select f. from set_returning_func() f", len("select f.")
- )
- assert completions_to_set(result) == completions_to_set(
- testdata.columns("set_returning_func", typ="functions")
- )
+ result = get_result(completer, "select f. from set_returning_func() f", len("select f."))
+ assert completions_to_set(result) == completions_to_set(testdata.columns("set_returning_func", typ="functions"))
@parametrize("completer", completers(casing=False))
@@ -751,9 +676,7 @@ def test_join_functions_using_suggests_common_columns(completer):
text = """SELECT * FROM set_returning_func() f1
INNER JOIN set_returning_func() f2 USING ("""
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- testdata.columns("set_returning_func", typ="functions")
- )
+ assert completions_to_set(result) == completions_to_set(testdata.columns("set_returning_func", typ="functions"))
@parametrize("completer", completers(casing=False))
@@ -762,8 +685,7 @@ def test_join_functions_on_suggests_columns_and_join_conditions(completer):
INNER JOIN set_returning_func() f2 ON f1."""
result = get_result(completer, text)
assert completions_to_set(result) == completions_to_set(
- [name_join("y = f2.y"), name_join("x = f2.x")]
- + testdata.columns("set_returning_func", typ="functions")
+ [name_join("y = f2.y"), name_join("x = f2.x")] + testdata.columns("set_returning_func", typ="functions")
)
@@ -880,10 +802,7 @@ def test_wildcard_column_expansion_with_two_tables(completer):
completions = get_result(completer, text, position)
- cols = (
- '"select".id, "select".insert, "select"."ABC", '
- "u.id, u.parentid, u.email, u.first_name, u.last_name"
- )
+ cols = '"select".id, "select".insert, "select"."ABC", u.id, u.parentid, u.email, u.first_name, u.last_name'
expected = [wildcard_expansion(cols)]
assert completions == expected
@@ -922,18 +841,14 @@ def test_suggest_columns_from_quoted_table(completer):
@parametrize("text", ["SELECT * FROM ", "SELECT * FROM Orders o CROSS JOIN "])
def test_schema_or_visible_table_completion(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- testdata.schemas_and_from_clause_items()
- )
+ assert completions_to_set(result) == completions_to_set(testdata.schemas_and_from_clause_items())
@parametrize("completer", completers(casing=False, aliasing=True))
@parametrize("text", ["SELECT * FROM "])
def test_table_aliases(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- testdata.schemas() + aliased_rels
- )
+ assert completions_to_set(result) == completions_to_set(testdata.schemas() + aliased_rels)
@parametrize("completer", completers(casing=False, aliasing=True))
@@ -965,43 +880,37 @@ def test_duplicate_table_aliases(completer, text):
@parametrize("text", ["SELECT * FROM Orders o CROSS JOIN "])
def test_duplicate_aliases_with_casing(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [
- schema("PUBLIC"),
- table("Orders O2"),
- table("Users U"),
- table('"Users" U'),
- table('"select" s'),
- view("User_Emails UE"),
- view("Functions F"),
- function("_custom_fun() cf"),
- function("Custom_Fun() CF"),
- function("Custom_Func1() CF"),
- function("custom_func2() cf"),
- function(
- "set_returning_func(x := , y := ) srf",
- display="set_returning_func(x, y) srf",
- ),
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ schema("PUBLIC"),
+ table("Orders O2"),
+ table("Users U"),
+ table('"Users" U'),
+ table('"select" s'),
+ view("User_Emails UE"),
+ view("Functions F"),
+ function("_custom_fun() cf"),
+ function("Custom_Fun() CF"),
+ function("Custom_Func1() CF"),
+ function("custom_func2() cf"),
+ function(
+ "set_returning_func(x := , y := ) srf",
+ display="set_returning_func(x, y) srf",
+ ),
+ ])
@parametrize("completer", completers(casing=True, aliasing=True))
@parametrize("text", ["SELECT * FROM "])
def test_aliases_with_casing(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [schema("PUBLIC")] + cased_aliased_rels
- )
+ assert completions_to_set(result) == completions_to_set([schema("PUBLIC")] + cased_aliased_rels)
@parametrize("completer", completers(casing=True, aliasing=False))
@parametrize("text", ["SELECT * FROM "])
def test_table_casing(completer, text):
result = get_result(completer, text)
- assert completions_to_set(result) == completions_to_set(
- [schema("PUBLIC")] + cased_rels
- )
+ assert completions_to_set(result) == completions_to_set([schema("PUBLIC")] + cased_rels)
@parametrize("completer", completers(casing=False))
@@ -1028,12 +937,10 @@ def test_suggest_cte_names(completer):
SELECT * FROM
"""
result = get_result(completer, text)
- expected = completions_to_set(
- [
- Completion("cte1", 0, display_meta="table"),
- Completion("cte2", 0, display_meta="table"),
- ]
- )
+ expected = completions_to_set([
+ Completion("cte1", 0, display_meta="table"),
+ Completion("cte2", 0, display_meta="table"),
+ ])
assert expected <= completions_to_set(result)
@@ -1101,12 +1008,10 @@ def test_set_schema(completer):
@parametrize("completer", completers())
def test_special_name_completion(completer):
result = get_result(completer, "\\t")
- assert completions_to_set(result) == completions_to_set(
- [
- Completion(
- text="\\timing",
- start_position=-2,
- display_meta="Toggle timing of commands.",
- )
- ]
- )
+ assert completions_to_set(result) == completions_to_set([
+ Completion(
+ text="\\timing",
+ start_position=-2,
+ display_meta="Toggle timing of commands.",
+ )
+ ])
diff --git a/tests/test_sqlcompletion.py b/tests/test_sqlcompletion.py
index 1034bbe2a..f53e4bea6 100644
--- a/tests/test_sqlcompletion.py
+++ b/tests/test_sqlcompletion.py
@@ -46,7 +46,7 @@ def test_select_suggests_cols_with_qualified_table_scope():
def test_cte_does_not_crash():
sql = "WITH CTE AS (SELECT F.* FROM Foo F WHERE F.Bar > 23) SELECT C.* FROM CTE C WHERE C.FooID BETWEEN 123 AND 234;"
for i in range(len(sql)):
- suggestions = suggest_type(sql[: i + 1], sql[: i + 1])
+ suggest_type(sql[: i + 1], sql[: i + 1])
@pytest.mark.parametrize("expression", ['SELECT * FROM "tabl" WHERE '])
@@ -140,7 +140,7 @@ def test_suggest_tables_views_schemas_and_functions(expression):
)
def test_suggest_after_join_with_two_tables(expression):
suggestions = suggest_type(expression, expression)
- tables = tuple([(None, "foo", None, False), (None, "bar", None, False)])
+ tables = ((None, "foo", None, False), (None, "bar", None, False))
assert set(suggestions) == {
FromClauseItem(schema=None, table_refs=tables),
Join(tables, None),
@@ -193,7 +193,7 @@ def test_suggest_qualified_tables_views_and_functions(expression):
@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."])
def test_suggest_qualified_tables_views_functions_and_joins(expression):
suggestions = suggest_type(expression, expression)
- tbls = tuple([(None, "foo", None, False)])
+ tbls = ((None, "foo", None, False),)
assert set(suggestions) == {
FromClauseItem(schema="sch", table_refs=tbls),
Join(tbls, "sch"),
@@ -452,7 +452,7 @@ def test_sub_select_table_name_completion(expression):
)
def test_sub_select_table_name_completion_with_outer_table(expression):
suggestion = suggest_type(expression, expression)
- tbls = tuple([(None, "foo", None, False)])
+ tbls = ((None, "foo", None, False),)
assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
@@ -492,7 +492,7 @@ def test_sub_select_dot_col_name_completion():
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN "
suggestion = suggest_type(text, text)
- tbls = tuple([(None, "abc", tbl_alias or None, False)])
+ tbls = ((None, "abc", tbl_alias or None, False),)
assert set(suggestion) == {
FromClauseItem(schema=None, table_refs=tbls),
Schema(),
@@ -505,7 +505,7 @@ def test_left_join_with_comma():
suggestions = suggest_type(text, text)
# tbls should also include (None, 'bar', 'b', False)
# but there's a bug with commas
- tbls = tuple([(None, "foo", "f", False)])
+ tbls = ((None, "foo", "f", False),)
assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
diff --git a/tests/utils.py b/tests/utils.py
index 67d769fd4..e6dad62a7 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -27,7 +27,7 @@ def db_connection(dbname=None):
SERVER_VERSION = conn.info.parameter_status("server_version")
JSON_AVAILABLE = True
JSONB_AVAILABLE = True
-except Exception as x:
+except Exception:
CAN_CONNECT_TO_DB = JSON_AVAILABLE = JSONB_AVAILABLE = False
SERVER_VERSION = 0
@@ -38,21 +38,17 @@ def db_connection(dbname=None):
)
-requires_json = pytest.mark.skipif(
- not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined"
-)
+requires_json = pytest.mark.skipif(not JSON_AVAILABLE, reason="Postgres server unavailable or json type not defined")
-requires_jsonb = pytest.mark.skipif(
- not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined"
-)
+requires_jsonb = pytest.mark.skipif(not JSONB_AVAILABLE, reason="Postgres server unavailable or jsonb type not defined")
def create_db(dbname):
with db_connection().cursor() as cur:
try:
cur.execute("""CREATE DATABASE _test_db""")
- except:
+ except Exception:
pass
@@ -67,16 +63,12 @@ def drop_tables(conn):
)
-def run(
- executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None
-):
+def run(executor, sql, join=False, expanded=False, pgspecial=None, exception_formatter=None):
"Return string output for the sql to be run"
results = executor.run(sql, pgspecial, exception_formatter)
formatted = []
- settings = OutputSettings(
- table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded
- )
+ settings = OutputSettings(table_format="psql", dcmlfmt="d", floatfmt="g", expanded=expanded)
for title, rows, headers, status, sql, success, is_special in results:
formatted.extend(format_output(title, rows, headers, status, settings))
if join:
@@ -86,7 +78,4 @@ def run(
def completions_to_set(completions):
- return {
- (completion.display_text, completion.display_meta_text)
- for completion in completions
- }
+ return {(completion.display_text, completion.display_meta_text) for completion in completions}
diff --git a/tox.ini b/tox.ini
index 554d66d8f..786738460 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,14 +1,32 @@
[tox]
-envlist = py39, py310, py311, py312, py313
+envlist = py
+
[testenv]
-deps = pytest>=2.7.0,<=3.0.7
- mock>=1.0.1
- behave>=1.2.4
- pexpect==3.3
- sshtunnel>=0.4.0
-commands = py.test
- behave tests/features
+skip_install = true
+deps = uv
+commands = uv pip install -e .[dev]
+ coverage run -m pytest -v tests
+ coverage report -m
passenv = PGHOST
PGPORT
PGUSER
PGPASSWORD
+
+[testenv:style]
+skip_install = true
+deps = ruff
+commands = ruff check
+# TODO: Uncomment the following line to enable ruff formatting
+# ruff format --diff
+
+[testenv:integration]
+skip_install = true
+deps = uv
+commands = uv pip install -e .[dev]
+ behave tests/features --no-capture
+
+[testenv:rest]
+skip_install = true
+deps = uv
+commands = uv pip install -e .[dev]
+ docutils --halt=warning changelog.rst