From 77af346853d2aea52c08dd6b8092e02d169b781c Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Sat, 7 Feb 2026 22:09:06 -0800 Subject: [PATCH 1/4] Base case of single select handled --- changelog.md | 1 + mycli/packages/parseutils.py | 40 +++++++++++++++++++ mycli/sqlcompleter.py | 28 ++++++++++--- ...est_smart_completion_public_schema_only.py | 12 ++++++ 4 files changed, 76 insertions(+), 5 deletions(-) diff --git a/changelog.md b/changelog.md index 5b5e68a5..4b816d97 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ TBD Features -------- * Options to limit size of LLM prompts; cache LLM prompt data. +* Suggest only tables/views that contain the given columns when provided in a SELECT/DELETE query. Bug Fixes diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index b5d0d5b4..e0849294 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -185,6 +185,46 @@ def extract_tables(sql: str) -> list[tuple[str | None, str, str]]: return list(extract_table_identifiers(stream)) +def extract_columns_from_select(sql: str) -> list[str]: + """ + Extract the column names from a select SQL statement. + + Returns a list of columns. + """ + parsed = sqlparse.parse(sql) + if not parsed: + return [] + + statement = parsed[0] + columns = [] + + # Loops through the tokens (pieces) of the SQL statement. + # Once it finds the SELECT token (generally first), it + # will then start looking for columns from that point on. + # The get_real_name() function returns the real column name + # even if an alias is used. + found_select = False + for token in statement.tokens: + if token.ttype is DML and token.value.upper() == 'SELECT': + found_select = True + elif found_select: + if isinstance(token, IdentifierList): + # multiple columns + for identifier in token.get_identifiers(): + column = identifier.get_real_name() + columns.append(column) + elif isinstance(token, Identifier): + # single column + column = token.get_real_name() + columns.append(column) + elif token.ttype is Keyword: + break + + if columns: + break + return columns + + def extract_tables_from_complete_statements(sql: str) -> list[tuple[str | None, str, str | None]]: """Extract the table names from a complete and valid series of SQL statements. diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index fe578889..da34a3be 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -13,7 +13,7 @@ from mycli.packages.completion_engine import suggest_type from mycli.packages.filepaths import complete_path, parse_path, suggest_path -from mycli.packages.parseutils import last_word +from mycli.packages.parseutils import extract_columns_from_select, last_word from mycli.packages.special import llm from mycli.packages.special.favoritequeries import FavoriteQueries from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS @@ -1131,7 +1131,12 @@ def get_completions( completions.extend([(*x, rank) for x in procs_m]) elif suggestion["type"] == "table": - tables = self.populate_schema_objects(suggestion["schema"], "tables") + #print(f"##{document.text}##") + columns = extract_columns_from_select(document.text) + if columns: + tables = self.populate_schema_objects(suggestion["schema"], "tables", columns) + else: + tables = self.populate_schema_objects(suggestion["schema"], "tables") tables_m = self.find_matches(word_before_cursor, tables) completions.extend([(*x, rank) for x in tables_m]) @@ -1341,15 +1346,28 @@ def _matches_parent(parent: str, schema: str | None, relname: str, alias: str | def _quote_sql_string(value: str) -> str: return "'" + value.replace("'", "''") + "'" - def populate_schema_objects(self, schema: str | None, obj_type: str) -> list[str]: + def populate_schema_objects(self, schema: str | None, obj_type: str, columns: list[str] | None = None) -> list[str]: """Returns list of tables or functions for a (optional) schema""" metadata = self.dbmetadata[obj_type] schema = schema or self.dbname - try: objects = metadata[schema].keys() except KeyError: # schema doesn't exist objects = [] - return objects + filtered_objects: list[str] = [] + + # If the requested object type is tables and the user already entered + # columns, return a filtered list of tables (or views) that contain + # one or more of the given columns. + if obj_type == "tables" and columns and objects: + #print(f"##{columns}##") + for obj in objects: + for column in metadata[schema][obj]: + if column in columns: + filtered_objects.append(obj) + break + else: + filtered_objects = objects + return filtered_objects diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py index 13da35f6..3c70ae48 100644 --- a/test/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -144,6 +144,18 @@ def test_table_completion(completer, complete_event): ] +def test_filtered_table_completion(completer, complete_event): + text = "SELECT ABC FROM " + position = len(text) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text="`select`", start_position=0), + Completion(text="`réveillé`", start_position=0), + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), + ] + + def test_enum_value_completion(completer, complete_event): text = "SELECT * FROM orders WHERE status = " position = len(text) From 32d97a55d2017b0a05b10e1d7de4acf269efa5a9 Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Sun, 8 Feb 2026 19:02:35 -0800 Subject: [PATCH 2/4] Made it work with sub selects. Added extensive comments to explain it. --- changelog.md | 2 +- mycli/packages/parseutils.py | 43 ++++++++++++++++++- mycli/sqlcompleter.py | 7 ++- ...est_smart_completion_public_schema_only.py | 13 +++++- 4 files changed, 60 insertions(+), 5 deletions(-) diff --git a/changelog.md b/changelog.md index 4b816d97..8ba82aee 100644 --- a/changelog.md +++ b/changelog.md @@ -4,7 +4,7 @@ TBD Features -------- * Options to limit size of LLM prompts; cache LLM prompt data. -* Suggest only tables/views that contain the given columns when provided in a SELECT/DELETE query. +* Suggest only tables/views that contain the given columns when provided in a SELECT query. Bug Fixes diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index e0849294..de5676ad 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -91,6 +91,42 @@ def is_subselect(parsed: TokenList) -> bool: return False +def get_last_select(parsed: TokenList) -> TokenList: + """ + Takes a parsed sql statement and returns the last select query where applicable. + + The intended use case is for when giving table suggestions based on columns, where + we only want to look at the columns from the most recent select. This works for a single + select query, or one or more sub queries (the useful part). + + The custom logic is necessary because the typical sqlparse logic for things like finding + sub selects (i.e. is_subselect) only works on complete statements, such as: + + * select c1 from t1; + + However when suggesting tables based on columns, we only have partial select statements, i.e.: + + * select c1 + * select c1 from (select c2) + + So given the above, we must parse them ourselves as they are not viewed as complete statements. + + Returns a TokenList of the last select statement's tokens. + """ + select_indexes: list[int] = [] + + for token in parsed: + if token.match(DML, "select"): # match is case insensitive + select_indexes.append(parsed.token_index(token)) + + last_select = TokenList() + + if select_indexes: + last_select = TokenList(parsed[select_indexes[-1] :]) + + return last_select + + def extract_from_part(parsed: TokenList, stop_at_punctuation: bool = True) -> Generator[Any, None, None]: tbl_prefix_seen = False for item in parsed.tokens: @@ -195,7 +231,12 @@ def extract_columns_from_select(sql: str) -> list[str]: if not parsed: return [] - statement = parsed[0] + statement = get_last_select(parsed[0]) + + # if there is no select, skip checking for columns + if not statement: + return [] + columns = [] # Loops through the tokens (pieces) of the SQL statement. diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index da34a3be..f7aaced5 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -1131,7 +1131,10 @@ def get_completions( completions.extend([(*x, rank) for x in procs_m]) elif suggestion["type"] == "table": - #print(f"##{document.text}##") + # If this is a select and columns are given, parse the columns and + # then only return tables that have one or more of the given columns. + # If no columns are given (or able to be parsed), return all tables + # as usual. columns = extract_columns_from_select(document.text) if columns: tables = self.populate_schema_objects(suggestion["schema"], "tables", columns) @@ -1362,7 +1365,7 @@ def populate_schema_objects(self, schema: str | None, obj_type: str, columns: li # columns, return a filtered list of tables (or views) that contain # one or more of the given columns. if obj_type == "tables" and columns and objects: - #print(f"##{columns}##") + # print(f"##{columns}##") for obj in objects: for column in metadata[schema][obj]: if column in columns: diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py index 3c70ae48..20f1a2cb 100644 --- a/test/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -144,7 +144,7 @@ def test_table_completion(completer, complete_event): ] -def test_filtered_table_completion(completer, complete_event): +def test_select_filtered_table_completion(completer, complete_event): text = "SELECT ABC FROM " position = len(text) result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) @@ -156,6 +156,17 @@ def test_filtered_table_completion(completer, complete_event): ] +def test_sub_select_filtered_table_completion(completer, complete_event): + text = "SELECT * FROM (SELECT email FROM " + position = len(text) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text="users", start_position=0), + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), + ] + + def test_enum_value_completion(completer, complete_event): text = "SELECT * FROM orders WHERE status = " position = len(text) From 61a89905049e8cb50dfe1bc8c1b7240228eb2211 Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Sun, 8 Feb 2026 19:03:56 -0800 Subject: [PATCH 3/4] Removed debug print --- mycli/sqlcompleter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index f7aaced5..bbbae5ae 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -1365,7 +1365,6 @@ def populate_schema_objects(self, schema: str | None, obj_type: str, columns: li # columns, return a filtered list of tables (or views) that contain # one or more of the given columns. if obj_type == "tables" and columns and objects: - # print(f"##{columns}##") for obj in objects: for column in metadata[schema][obj]: if column in columns: From 9c7db5e4eba091d44971bddbe70352878d360f3c Mon Sep 17 00:00:00 2001 From: Scott Nemes Date: Sun, 8 Feb 2026 19:39:23 -0800 Subject: [PATCH 4/4] Updated to return return all tables after the matching tables --- changelog.md | 2 +- mycli/sqlcompleter.py | 13 ++++++++++--- .../test_smart_completion_public_schema_only.py | 17 ++++++++++++++++- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/changelog.md b/changelog.md index 246886b4..1817d52c 100644 --- a/changelog.md +++ b/changelog.md @@ -5,7 +5,7 @@ Features -------- * Options to limit size of LLM prompts; cache LLM prompt data. * Add startup usage tips. -* Suggest only tables/views that contain the given columns when provided in a SELECT query. +* Suggest tables/views that contain the given columns first when provided in a SELECT query. Bug Fixes diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index bbbae5ae..7401958e 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -1354,22 +1354,29 @@ def populate_schema_objects(self, schema: str | None, obj_type: str, columns: li metadata = self.dbmetadata[obj_type] schema = schema or self.dbname try: - objects = metadata[schema].keys() + objects = list(metadata[schema].keys()) except KeyError: # schema doesn't exist objects = [] filtered_objects: list[str] = [] + remaining_objects: list[str] = [] # If the requested object type is tables and the user already entered # columns, return a filtered list of tables (or views) that contain - # one or more of the given columns. + # one or more of the given columns. If a table does not contain the + # given columns, add it to a separate list to add to the end of the + # filtered suggestions. if obj_type == "tables" and columns and objects: for obj in objects: + matched = False for column in metadata[schema][obj]: if column in columns: filtered_objects.append(obj) + matched = True break + if not matched: + remaining_objects.append(obj) else: filtered_objects = objects - return filtered_objects + return filtered_objects + remaining_objects diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py index 20f1a2cb..6e6a843e 100644 --- a/test/test_smart_completion_public_schema_only.py +++ b/test/test_smart_completion_public_schema_only.py @@ -151,17 +151,32 @@ def test_select_filtered_table_completion(completer, complete_event): assert list(result) == [ Completion(text="`select`", start_position=0), Completion(text="`réveillé`", start_position=0), + Completion(text="users", start_position=0), + Completion(text="orders", start_position=0), + Completion(text="time_zone", start_position=0), + Completion(text="time_zone_leap_second", start_position=0), + Completion(text="time_zone_name", start_position=0), + Completion(text="time_zone_transition", start_position=0), + Completion(text="time_zone_transition_type", start_position=0), Completion(text="test", start_position=0), Completion(text="`test 2`", start_position=0), ] def test_sub_select_filtered_table_completion(completer, complete_event): - text = "SELECT * FROM (SELECT email FROM " + text = "SELECT * FROM (SELECT ordered_date FROM " position = len(text) result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) assert list(result) == [ + Completion(text="orders", start_position=0), Completion(text="users", start_position=0), + Completion(text="`select`", start_position=0), + Completion(text="`réveillé`", start_position=0), + Completion(text="time_zone", start_position=0), + Completion(text="time_zone_leap_second", start_position=0), + Completion(text="time_zone_name", start_position=0), + Completion(text="time_zone_transition", start_position=0), + Completion(text="time_zone_transition_type", start_position=0), Completion(text="test", start_position=0), Completion(text="`test 2`", start_position=0), ]