diff --git a/changelog.md b/changelog.md index ee247397..beb57a7b 100644 --- a/changelog.md +++ b/changelog.md @@ -5,6 +5,7 @@ Features -------- * Options to limit size of LLM prompts; cache LLM prompt data. * Add startup usage tips. +* Suggest tables/views that contain the given columns first when provided in a SELECT query. * Move `main.ssl_mode` config option to `connection.default_ssl_mode`. * Add "unsupported" and "deprecated" `--checkup` sections. diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index 559b5a18..17df81d8 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -91,6 +91,42 @@ def is_subselect(parsed: TokenList) -> bool: return False +def get_last_select(parsed: TokenList) -> TokenList: + """ + Takes a parsed sql statement and returns the last select query where applicable. + + The intended use case is for when giving table suggestions based on columns, where + we only want to look at the columns from the most recent select. This works for a single + select query, or one or more sub queries (the useful part). + + The custom logic is necessary because the typical sqlparse logic for things like finding + sub selects (i.e. is_subselect) only works on complete statements, such as: + + * select c1 from t1; + + However when suggesting tables based on columns, we only have partial select statements, i.e.: + + * select c1 + * select c1 from (select c2) + + So given the above, we must parse them ourselves as they are not viewed as complete statements. + + Returns a TokenList of the last select statement's tokens. + """ + select_indexes: list[int] = [] + + for token in parsed: + if token.match(DML, "select"): # match is case insensitive + select_indexes.append(parsed.token_index(token)) + + last_select = TokenList() + + if select_indexes: + last_select = TokenList(parsed[select_indexes[-1] :]) + + return last_select + + def extract_from_part(parsed: TokenList, stop_at_punctuation: bool = True) -> Generator[Any, None, None]: tbl_prefix_seen = False for item in parsed.tokens: @@ -185,6 +221,51 @@ def extract_tables(sql: str) -> list[tuple[str | None, str, str]]: return list(extract_table_identifiers(stream)) +def extract_columns_from_select(sql: str) -> list[str]: + """ + Extract the column names from a select SQL statement. + + Returns a list of columns. + """ + parsed = sqlparse.parse(sql) + if not parsed: + return [] + + statement = get_last_select(parsed[0]) + + # if there is no select, skip checking for columns + if not statement: + return [] + + columns = [] + + # Loops through the tokens (pieces) of the SQL statement. + # Once it finds the SELECT token (generally first), it + # will then start looking for columns from that point on. + # The get_real_name() function returns the real column name + # even if an alias is used. + found_select = False + for token in statement.tokens: + if token.ttype is DML and token.value.upper() == 'SELECT': + found_select = True + elif found_select: + if isinstance(token, IdentifierList): + # multiple columns + for identifier in token.get_identifiers(): + column = identifier.get_real_name() + columns.append(column) + elif isinstance(token, Identifier): + # single column + column = token.get_real_name() + columns.append(column) + elif token.ttype is Keyword: + break + + if columns: + break + return columns + + def extract_tables_from_complete_statements(sql: str) -> list[tuple[str | None, str, str | None]]: """Extract the table names from a complete and valid series of SQL statements. diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index fe578889..7401958e 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -13,7 +13,7 @@ from mycli.packages.completion_engine import suggest_type from mycli.packages.filepaths import complete_path, parse_path, suggest_path -from mycli.packages.parseutils import last_word +from mycli.packages.parseutils import extract_columns_from_select, last_word from mycli.packages.special import llm from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS @@ -1131,7 +1131,15 @@ def get_completions( completions.extend([(*x, rank) for x in procs_m]) elif suggestion["type"] == "table": - tables = self.populate_schema_objects(suggestion["schema"], "tables") + # If this is a select and columns are given, parse the columns and + # then only return tables that have one or more of the given columns. + # If no columns are given (or able to be parsed), return all tables + # as usual. + columns = extract_columns_from_select(document.text) + if columns: + tables = self.populate_schema_objects(suggestion["schema"], "tables", columns) + else: + tables = self.populate_schema_objects(suggestion["schema"], "tables") tables_m = self.find_matches(word_before_cursor, tables) completions.extend([(*x, rank) for x in tables_m]) @@ -1341,15 +1349,34 @@ def _matches_parent(parent: str, schema: str | None, relname: str, alias: str | def _quote_sql_string(value: str) -> str: return "'" + value.replace("'", "''") + "'" - def populate_schema_objects(self, schema: str | None, obj_type: str) -> list[str]: + def populate_schema_objects(self, schema: str | None, obj_type: str, columns: list[str] | None = None) -> list[str]: """Returns list of tables or functions for a (optional) schema""" metadata = self.dbmetadata[obj_type] schema = schema or self.dbname - try: - objects = metadata[schema].keys() + objects = list(metadata[schema].keys()) except KeyError: # schema doesn't exist objects = [] - return objects + filtered_objects: list[str] = [] + remaining_objects: list[str] = [] + + # If the requested object type is tables and the user already entered + # columns, return a filtered list of tables (or views) that contain + # one or more of the given columns. If a table does not contain the + # given columns, add it to a separate list to add to the end of the + # filtered suggestions. + if obj_type == "tables" and columns and objects: + for obj in objects: + matched = False + for column in metadata[schema][obj]: + if column in columns: + filtered_objects.append(obj) + matched = True + break + if not matched: + remaining_objects.append(obj) + else: + filtered_objects = objects + return filtered_objects + remaining_objects diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py index 13da35f6..6e6a843e 100644 --- a/test/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -144,6 +144,44 @@ def test_table_completion(completer, complete_event): ] +def test_select_filtered_table_completion(completer, complete_event): + text = "SELECT ABC FROM " + position = len(text) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text="`select`", start_position=0), + Completion(text="`réveillé`", start_position=0), + Completion(text="users", start_position=0), + Completion(text="orders", start_position=0), + Completion(text="time_zone", start_position=0), + Completion(text="time_zone_leap_second", start_position=0), + Completion(text="time_zone_name", start_position=0), + Completion(text="time_zone_transition", start_position=0), + Completion(text="time_zone_transition_type", start_position=0), + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), + ] + + +def test_sub_select_filtered_table_completion(completer, complete_event): + text = "SELECT * FROM (SELECT ordered_date FROM " + position = len(text) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text="orders", start_position=0), + Completion(text="users", start_position=0), + Completion(text="`select`", start_position=0), + Completion(text="`réveillé`", start_position=0), + Completion(text="time_zone", start_position=0), + Completion(text="time_zone_leap_second", start_position=0), + Completion(text="time_zone_name", start_position=0), + Completion(text="time_zone_transition", start_position=0), + Completion(text="time_zone_transition_type", start_position=0), + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), + ] + + def test_enum_value_completion(completer, complete_event): text = "SELECT * FROM orders WHERE status = " position = len(text)