Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Features
--------
* Options to limit size of LLM prompts; cache LLM prompt data.
* Add startup usage tips.
* Suggest tables/views that contain the given columns first when provided in a SELECT query.
* Move `main.ssl_mode` config option to `connection.default_ssl_mode`.
* Add "unsupported" and "deprecated" `--checkup` sections.

Expand Down
81 changes: 81 additions & 0 deletions mycli/packages/parseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,42 @@ def is_subselect(parsed: TokenList) -> bool:
return False


def get_last_select(parsed: TokenList) -> TokenList:
"""
Takes a parsed sql statement and returns the last select query where applicable.

The intended use case is for when giving table suggestions based on columns, where
we only want to look at the columns from the most recent select. This works for a single
select query, or one or more sub queries (the useful part).

The custom logic is necessary because the typical sqlparse logic for things like finding
sub selects (i.e. is_subselect) only works on complete statements, such as:

* select c1 from t1;

However when suggesting tables based on columns, we only have partial select statements, i.e.:

* select c1
* select c1 from (select c2)

So given the above, we must parse them ourselves as they are not viewed as complete statements.

Returns a TokenList of the last select statement's tokens.
"""
select_indexes: list[int] = []

for token in parsed:
if token.match(DML, "select"): # match is case insensitive
select_indexes.append(parsed.token_index(token))

last_select = TokenList()

if select_indexes:
last_select = TokenList(parsed[select_indexes[-1] :])

return last_select


def extract_from_part(parsed: TokenList, stop_at_punctuation: bool = True) -> Generator[Any, None, None]:
tbl_prefix_seen = False
for item in parsed.tokens:
Expand Down Expand Up @@ -185,6 +221,51 @@ def extract_tables(sql: str) -> list[tuple[str | None, str, str]]:
return list(extract_table_identifiers(stream))


def extract_columns_from_select(sql: str) -> list[str]:
"""
Extract the column names from a select SQL statement.

Returns a list of columns.
"""
parsed = sqlparse.parse(sql)
if not parsed:
return []

statement = get_last_select(parsed[0])

# if there is no select, skip checking for columns
if not statement:
return []

columns = []

# Loops through the tokens (pieces) of the SQL statement.
# Once it finds the SELECT token (generally first), it
# will then start looking for columns from that point on.
# The get_real_name() function returns the real column name
# even if an alias is used.
found_select = False
for token in statement.tokens:
if token.ttype is DML and token.value.upper() == 'SELECT':
found_select = True
elif found_select:
if isinstance(token, IdentifierList):
# multiple columns
for identifier in token.get_identifiers():
column = identifier.get_real_name()
columns.append(column)
elif isinstance(token, Identifier):
# single column
column = token.get_real_name()
columns.append(column)
elif token.ttype is Keyword:
break

if columns:
break
return columns


def extract_tables_from_complete_statements(sql: str) -> list[tuple[str | None, str, str | None]]:
"""Extract the table names from a complete and valid series of SQL
statements.
Expand Down
39 changes: 33 additions & 6 deletions mycli/sqlcompleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from mycli.packages.completion_engine import suggest_type
from mycli.packages.filepaths import complete_path, parse_path, suggest_path
from mycli.packages.parseutils import last_word
from mycli.packages.parseutils import extract_columns_from_select, last_word
from mycli.packages.special import llm
from mycli.packages.special.favoritequeries import FavoriteQueries
from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS
Expand Down Expand Up @@ -1131,7 +1131,15 @@ def get_completions(
completions.extend([(*x, rank) for x in procs_m])

elif suggestion["type"] == "table":
tables = self.populate_schema_objects(suggestion["schema"], "tables")
# If this is a select and columns are given, parse the columns and
# then only return tables that have one or more of the given columns.
# If no columns are given (or able to be parsed), return all tables
# as usual.
columns = extract_columns_from_select(document.text)
if columns:
tables = self.populate_schema_objects(suggestion["schema"], "tables", columns)
else:
tables = self.populate_schema_objects(suggestion["schema"], "tables")
tables_m = self.find_matches(word_before_cursor, tables)
completions.extend([(*x, rank) for x in tables_m])

Expand Down Expand Up @@ -1341,15 +1349,34 @@ def _matches_parent(parent: str, schema: str | None, relname: str, alias: str |
def _quote_sql_string(value: str) -> str:
return "'" + value.replace("'", "''") + "'"

def populate_schema_objects(self, schema: str | None, obj_type: str) -> list[str]:
def populate_schema_objects(self, schema: str | None, obj_type: str, columns: list[str] | None = None) -> list[str]:
"""Returns list of tables or functions for a (optional) schema"""
metadata = self.dbmetadata[obj_type]
schema = schema or self.dbname

try:
objects = metadata[schema].keys()
objects = list(metadata[schema].keys())
except KeyError:
# schema doesn't exist
objects = []

return objects
filtered_objects: list[str] = []
remaining_objects: list[str] = []

# If the requested object type is tables and the user already entered
# columns, return a filtered list of tables (or views) that contain
# one or more of the given columns. If a table does not contain the
# given columns, add it to a separate list to add to the end of the
# filtered suggestions.
if obj_type == "tables" and columns and objects:
for obj in objects:
matched = False
for column in metadata[schema][obj]:
if column in columns:
filtered_objects.append(obj)
matched = True
break
if not matched:
remaining_objects.append(obj)
else:
filtered_objects = objects
return filtered_objects + remaining_objects
38 changes: 38 additions & 0 deletions test/test_smart_completion_public_schema_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,44 @@ def test_table_completion(completer, complete_event):
]


def test_select_filtered_table_completion(completer, complete_event):
text = "SELECT ABC FROM "
position = len(text)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == [
Completion(text="`select`", start_position=0),
Completion(text="`réveillé`", start_position=0),
Completion(text="users", start_position=0),
Completion(text="orders", start_position=0),
Completion(text="time_zone", start_position=0),
Completion(text="time_zone_leap_second", start_position=0),
Completion(text="time_zone_name", start_position=0),
Completion(text="time_zone_transition", start_position=0),
Completion(text="time_zone_transition_type", start_position=0),
Completion(text="test", start_position=0),
Completion(text="`test 2`", start_position=0),
]


def test_sub_select_filtered_table_completion(completer, complete_event):
text = "SELECT * FROM (SELECT ordered_date FROM "
position = len(text)
result = completer.get_completions(Document(text=text, cursor_position=position), complete_event)
assert list(result) == [
Completion(text="orders", start_position=0),
Completion(text="users", start_position=0),
Completion(text="`select`", start_position=0),
Completion(text="`réveillé`", start_position=0),
Completion(text="time_zone", start_position=0),
Completion(text="time_zone_leap_second", start_position=0),
Completion(text="time_zone_name", start_position=0),
Completion(text="time_zone_transition", start_position=0),
Completion(text="time_zone_transition_type", start_position=0),
Completion(text="test", start_position=0),
Completion(text="`test 2`", start_position=0),
]


def test_enum_value_completion(completer, complete_event):
text = "SELECT * FROM orders WHERE status = "
position = len(text)
Expand Down