From 1086965086b1eb5c938b584bd09c777c3a7b67d6 Mon Sep 17 00:00:00 2001 From: Steve Dignam Date: Wed, 17 Dec 2025 22:15:33 -0500 Subject: [PATCH] ide: support `set search_path` --- crates/squawk_ide/src/binder.rs | 105 ++++++++++- crates/squawk_ide/src/goto_definition.rs | 163 ++++++++++++++++++ crates/squawk_ide/src/resolve.rs | 10 +- crates/squawk_parser/src/grammar.rs | 8 +- .../snapshots/tests__alter_database_ok.snap | 6 +- .../snapshots/tests__alter_procedure_ok.snap | 6 +- .../snapshots/tests__create_function_ok.snap | 6 +- .../tests/snapshots/tests__schemas_ok.snap | 15 +- .../squawk_syntax/src/ast/generated/nodes.rs | 84 ++++++++- crates/squawk_syntax/src/postgresql.ungram | 18 +- 10 files changed, 394 insertions(+), 27 deletions(-) diff --git a/crates/squawk_ide/src/binder.rs b/crates/squawk_ide/src/binder.rs index 073fb854..bceb030b 100644 --- a/crates/squawk_ide/src/binder.rs +++ b/crates/squawk_ide/src/binder.rs @@ -1,14 +1,21 @@ /// Loosely based on TypeScript's binder /// see: typescript-go/internal/binder/binder.go use la_arena::Arena; +use rowan::TextSize; use squawk_syntax::{SyntaxNodePtr, ast, ast::AstNode}; use crate::scope::{Scope, ScopeId}; use crate::symbols::{Name, Schema, Symbol, SymbolKind}; +pub(crate) struct SearchPathChange { + position: TextSize, + search_path: Vec, +} + pub(crate) struct Binder { pub(crate) scopes: Arena, pub(crate) symbols: Arena, + pub(crate) search_path_changes: Vec, } impl Binder { @@ -18,6 +25,10 @@ impl Binder { Binder { scopes, symbols: Arena::new(), + search_path_changes: vec![SearchPathChange { + position: TextSize::from(0), + search_path: vec![Schema::new("pg_temp"), Schema::new("public")], + }], } } @@ -28,6 +39,18 @@ impl Binder { .map(|(id, _)| id) .expect("root scope must exist") } + + pub(crate) fn search_path_at(&self, position: TextSize) -> &[Schema] { + // We're assuming people don't actually use `set search_path` that much, + // so linear search is fine + for change in self.search_path_changes.iter().rev() { + if change.position <= position { + return &change.search_path; + } + } + // default search path + &self.search_path_changes[0].search_path + } } pub(crate) fn bind(file: &ast::SourceFile) -> Binder { @@ -45,8 +68,10 @@ fn bind_file(b: &mut Binder, file: &ast::SourceFile) { } fn bind_stmt(b: &mut Binder, stmt: ast::Stmt) { - if let ast::Stmt::CreateTable(create_table) = stmt { - bind_create_table(b, create_table) + match stmt { + ast::Stmt::CreateTable(create_table) => bind_create_table(b, create_table), + ast::Stmt::Set(set) => bind_set(b, set), + _ => {} } } @@ -113,3 +138,79 @@ fn schema_name(path: &ast::Path, is_temp: bool) -> Schema { Schema(schema_name) } + +fn bind_set(b: &mut Binder, set: ast::Set) { + let position = set.syntax().text_range().start(); + + // `set schema` is an alternative to `set search_path` + if set.schema_token().is_some() { + if let Some(literal) = set.literal() { + if let Some(string_value) = extract_string_literal(&literal) { + b.search_path_changes.push(SearchPathChange { + position, + search_path: vec![Schema::new(string_value)], + }); + } + } + return; + } + + let Some(path) = set.path() else { return }; + + if path.qualifier().is_some() { + return; + } + + let Some(segment) = path.segment() else { + return; + }; + + let param_name = if let Some(name_ref) = segment.name_ref() { + name_ref.syntax().text().to_string() + } else { + return; + }; + + if !param_name.eq_ignore_ascii_case("search_path") { + return; + } + + // `set search_path` + if set.default_token().is_some() { + b.search_path_changes.push(SearchPathChange { + position, + search_path: vec![Schema::new("pg_temp"), Schema::new("public")], + }); + } else { + let mut search_path = vec![]; + for config_value in set.config_values() { + match config_value { + ast::ConfigValue::Literal(literal) => { + if let Some(string_value) = extract_string_literal(&literal) { + if !string_value.is_empty() { + search_path.push(Schema::new(string_value)); + } + } + } + ast::ConfigValue::NameRef(name_ref) => { + let schema_name = name_ref.syntax().text().to_string(); + search_path.push(Schema::new(schema_name)); + } + } + } + b.search_path_changes.push(SearchPathChange { + position, + search_path, + }); + } +} + +fn extract_string_literal(literal: &ast::Literal) -> Option { + let text = literal.syntax().text().to_string(); + + if text.starts_with('\'') && text.ends_with('\'') && text.len() >= 2 { + Some(text[1..text.len() - 1].to_string()) + } else { + None + } +} diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index afe83d47..0d567216 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -440,4 +440,167 @@ commit; ╰╴────── 2. destination "); } + + #[test] + fn goto_with_search_path() { + assert_snapshot!(goto(r#" +set search_path to "foo", public; +create table foo.t(); +drop table t$0; +"#), @r" + ╭▸ + 3 │ create table foo.t(); + │ ─ 2. destination + 4 │ drop table t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_with_search_path_like_variable() { + // not actually search path + goto_not_found( + " +set bar.search_path to foo, public; +create table foo.t(); +drop table t$0; +", + ) + } + + #[test] + fn goto_with_search_path_second_schema() { + assert_snapshot!(goto(" +set search_path to foo, bar, public; +create table bar.t(); +drop table t$0; +"), @r" + ╭▸ + 3 │ create table bar.t(); + │ ─ 2. destination + 4 │ drop table t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_with_search_path_skips_first() { + assert_snapshot!(goto(" +set search_path to foo, bar, public; +create table foo.t(); +create table bar.t(); +drop table t$0; +"), @r" + ╭▸ + 3 │ create table foo.t(); + │ ─ 2. destination + 4 │ create table bar.t(); + 5 │ drop table t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_without_search_path_uses_default() { + assert_snapshot!(goto(" +create table foo.t(); +create table public.t(); +drop table t$0; +"), @r" + ╭▸ + 3 │ create table public.t(); + │ ─ 2. destination + 4 │ drop table t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_with_set_schema() { + assert_snapshot!(goto(" +set schema 'myschema'; +create table myschema.t(); +drop table t$0; +"), @r" + ╭▸ + 3 │ create table myschema.t(); + │ ─ 2. destination + 4 │ drop table t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_with_set_schema_ignores_other_schemas() { + assert_snapshot!(goto(" +set schema 'myschema'; +create table public.t(); +create table myschema.t(); +drop table t$0; +"), @r" + ╭▸ + 4 │ create table myschema.t(); + │ ─ 2. destination + 5 │ drop table t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_with_search_path_changed_twice() { + assert_snapshot!(goto(" +set search_path to foo; +create table foo.t(); +set search_path to bar; +create table bar.t(); +drop table t$0; +"), @r" + ╭▸ + 5 │ create table bar.t(); + │ ─ 2. destination + 6 │ drop table t; + ╰╴ ─ 1. source + "); + + assert_snapshot!(goto(" +set search_path to foo; +create table foo.t(); +drop table t$0; +set search_path to bar; +create table bar.t(); +drop table t; +"), @r" + ╭▸ + 3 │ create table foo.t(); + │ ─ 2. destination + 4 │ drop table t; + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_with_empty_search_path() { + goto_not_found( + " +set search_path to ''; +create table public.t(); +drop table t$0; +", + ) + } + + #[test] + fn goto_with_search_path_uppercase() { + assert_snapshot!(goto(" +SET SEARCH_PATH TO foo; +create table foo.t(); +drop table t$0; +"), @r" + ╭▸ + 3 │ create table foo.t(); + │ ─ 2. destination + 4 │ drop table t; + ╰╴ ─ 1. source + "); + } } diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index ab576269..4c0c4212 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -1,3 +1,4 @@ +use rowan::TextSize; use squawk_syntax::{ SyntaxNodePtr, ast::{self, AstNode}, @@ -19,7 +20,8 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti let path = find_containing_path(name_ref)?; let table_name = extract_table_name(&path)?; let schema = extract_schema_name(&path); - resolve_table(binder, &table_name, &schema) + let position = name_ref.syntax().text_range().start(); + resolve_table(binder, &table_name, &schema, position) } } } @@ -38,6 +40,7 @@ fn resolve_table( binder: &Binder, table_name: &Name, schema: &Option, + position: TextSize, ) -> Option { let symbols = binder.scopes[binder.root_scope()].get(table_name)?; @@ -48,10 +51,11 @@ fn resolve_table( })?; return Some(binder.symbols[symbol_id].ptr); } else { - for search_schema in [Schema::new("pg_temp"), Schema::new("public")] { + let search_path = binder.search_path_at(position); + for search_schema in search_path { if let Some(symbol_id) = symbols.iter().copied().find(|id| { let symbol = &binder.symbols[*id]; - symbol.kind == SymbolKind::Table && symbol.schema == search_schema + symbol.kind == SymbolKind::Table && &symbol.schema == search_schema }) { return Some(binder.symbols[symbol_id].ptr); } diff --git a/crates/squawk_parser/src/grammar.rs b/crates/squawk_parser/src/grammar.rs index 70a01dcb..b1e22089 100644 --- a/crates/squawk_parser/src/grammar.rs +++ b/crates/squawk_parser/src/grammar.rs @@ -13535,14 +13535,10 @@ fn config_value(p: &mut Parser<'_>) -> bool { while !p.at(EOF) { if opt_string_literal(p).is_none() && opt_numeric_literal(p).is_none() - && !opt_ident(p) + && opt_name_ref(p).is_none() && !opt_bool_literal(p) { - if p.at_ts(BARE_LABEL_KEYWORDS) { - p.bump_any(); - } else { - break; - } + break; } found_value = true; if !p.eat(COMMA) { diff --git a/crates/squawk_parser/tests/snapshots/tests__alter_database_ok.snap b/crates/squawk_parser/tests/snapshots/tests__alter_database_ok.snap index eb3a24e2..cfea44cb 100644 --- a/crates/squawk_parser/tests/snapshots/tests__alter_database_ok.snap +++ b/crates/squawk_parser/tests/snapshots/tests__alter_database_ok.snap @@ -116,7 +116,8 @@ SOURCE_FILE WHITESPACE " " TO_KW "to" WHITESPACE " " - IDENT "v" + NAME_REF + IDENT "v" SEMICOLON ";" WHITESPACE "\n" ALTER_DATABASE @@ -137,7 +138,8 @@ SOURCE_FILE WHITESPACE " " EQ "=" WHITESPACE " " - IDENT "v" + NAME_REF + IDENT "v" SEMICOLON ";" WHITESPACE "\n" ALTER_DATABASE diff --git a/crates/squawk_parser/tests/snapshots/tests__alter_procedure_ok.snap b/crates/squawk_parser/tests/snapshots/tests__alter_procedure_ok.snap index 4d019e9c..e0c65765 100644 --- a/crates/squawk_parser/tests/snapshots/tests__alter_procedure_ok.snap +++ b/crates/squawk_parser/tests/snapshots/tests__alter_procedure_ok.snap @@ -260,7 +260,8 @@ SOURCE_FILE WHITESPACE " " TO_KW "to" WHITESPACE " " - IDENT "v" + NAME_REF + IDENT "v" SEMICOLON ";" WHITESPACE "\n" ALTER_PROCEDURE @@ -286,7 +287,8 @@ SOURCE_FILE WHITESPACE " " EQ "=" WHITESPACE " " - IDENT "v" + NAME_REF + IDENT "v" SEMICOLON ";" WHITESPACE "\n" ALTER_PROCEDURE diff --git a/crates/squawk_parser/tests/snapshots/tests__create_function_ok.snap b/crates/squawk_parser/tests/snapshots/tests__create_function_ok.snap index b02c9f8d..9e408655 100644 --- a/crates/squawk_parser/tests/snapshots/tests__create_function_ok.snap +++ b/crates/squawk_parser/tests/snapshots/tests__create_function_ok.snap @@ -142,10 +142,12 @@ SOURCE_FILE WHITESPACE " " EQ "=" WHITESPACE " " - ADMIN_KW "admin" + NAME_REF + ADMIN_KW "admin" COMMA "," WHITESPACE " " - IDENT "pg_temp" + NAME_REF + IDENT "pg_temp" SEMICOLON ";" WHITESPACE "\n\n" COMMENT "-- create_function_with_percent_type" diff --git a/crates/squawk_parser/tests/snapshots/tests__schemas_ok.snap b/crates/squawk_parser/tests/snapshots/tests__schemas_ok.snap index 93ed484d..59a07719 100644 --- a/crates/squawk_parser/tests/snapshots/tests__schemas_ok.snap +++ b/crates/squawk_parser/tests/snapshots/tests__schemas_ok.snap @@ -409,9 +409,11 @@ SOURCE_FILE WHITESPACE " " TO_KW "to" WHITESPACE " " - IDENT "myschema" + NAME_REF + IDENT "myschema" COMMA "," - IDENT "public" + NAME_REF + IDENT "public" SEMICOLON ";" WHITESPACE "\n\n" SET @@ -424,7 +426,8 @@ SOURCE_FILE WHITESPACE " " TO_KW "to" WHITESPACE " " - IDENT "myschema" + NAME_REF + IDENT "myschema" SEMICOLON ";" WHITESPACE "\n\n" SET @@ -437,7 +440,8 @@ SOURCE_FILE WHITESPACE " " EQ "=" WHITESPACE " " - IDENT "bar" + NAME_REF + IDENT "bar" SEMICOLON ";" WHITESPACE "\n\n" SET @@ -564,7 +568,8 @@ SOURCE_FILE WHITESPACE " " TO_KW "to" WHITESPACE " " - IDENT "a" + NAME_REF + IDENT "a" COMMA "," WHITESPACE " " LITERAL diff --git a/crates/squawk_syntax/src/ast/generated/nodes.rs b/crates/squawk_syntax/src/ast/generated/nodes.rs index 15100b1f..f0354a27 100644 --- a/crates/squawk_syntax/src/ast/generated/nodes.rs +++ b/crates/squawk_syntax/src/ast/generated/nodes.rs @@ -13768,7 +13768,15 @@ pub struct Set { } impl Set { #[inline] - pub fn expr(&self) -> Option { + pub fn config_value(&self) -> Option { + support::child(&self.syntax) + } + #[inline] + pub fn config_values(&self) -> AstChildren { + support::children(&self.syntax) + } + #[inline] + pub fn literal(&self) -> Option { support::child(&self.syntax) } #[inline] @@ -13780,14 +13788,42 @@ impl Set { support::token(&self.syntax, SyntaxKind::EQ) } #[inline] + pub fn catalog_token(&self) -> Option { + support::token(&self.syntax, SyntaxKind::CATALOG_KW) + } + #[inline] + pub fn content_token(&self) -> Option { + support::token(&self.syntax, SyntaxKind::CONTENT_KW) + } + #[inline] + pub fn current_token(&self) -> Option { + support::token(&self.syntax, SyntaxKind::CURRENT_KW) + } + #[inline] pub fn default_token(&self) -> Option { support::token(&self.syntax, SyntaxKind::DEFAULT_KW) } #[inline] + pub fn document_token(&self) -> Option { + support::token(&self.syntax, SyntaxKind::DOCUMENT_KW) + } + #[inline] + pub fn from_token(&self) -> Option { + support::token(&self.syntax, SyntaxKind::FROM_KW) + } + #[inline] pub fn local_token(&self) -> Option { support::token(&self.syntax, SyntaxKind::LOCAL_KW) } #[inline] + pub fn option_token(&self) -> Option { + support::token(&self.syntax, SyntaxKind::OPTION_KW) + } + #[inline] + pub fn schema_token(&self) -> Option { + support::token(&self.syntax, SyntaxKind::SCHEMA_KW) + } + #[inline] pub fn session_token(&self) -> Option { support::token(&self.syntax, SyntaxKind::SESSION_KW) } @@ -13804,6 +13840,10 @@ impl Set { support::token(&self.syntax, SyntaxKind::TO_KW) } #[inline] + pub fn xml_token(&self) -> Option { + support::token(&self.syntax, SyntaxKind::XML_KW) + } + #[inline] pub fn zone_token(&self) -> Option { support::token(&self.syntax, SyntaxKind::ZONE_KW) } @@ -16436,6 +16476,12 @@ pub enum ColumnConstraint { UniqueConstraint(UniqueConstraint), } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ConfigValue { + Literal(Literal), + NameRef(NameRef), +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum ConflictAction { ConflictDoNothing(ConflictDoNothing), @@ -29106,6 +29152,42 @@ impl From for ColumnConstraint { ColumnConstraint::UniqueConstraint(node) } } +impl AstNode for ConfigValue { + #[inline] + fn can_cast(kind: SyntaxKind) -> bool { + matches!(kind, SyntaxKind::LITERAL | SyntaxKind::NAME_REF) + } + #[inline] + fn cast(syntax: SyntaxNode) -> Option { + let res = match syntax.kind() { + SyntaxKind::LITERAL => ConfigValue::Literal(Literal { syntax }), + SyntaxKind::NAME_REF => ConfigValue::NameRef(NameRef { syntax }), + _ => { + return None; + } + }; + Some(res) + } + #[inline] + fn syntax(&self) -> &SyntaxNode { + match self { + ConfigValue::Literal(it) => &it.syntax, + ConfigValue::NameRef(it) => &it.syntax, + } + } +} +impl From for ConfigValue { + #[inline] + fn from(node: Literal) -> ConfigValue { + ConfigValue::Literal(node) + } +} +impl From for ConfigValue { + #[inline] + fn from(node: NameRef) -> ConfigValue { + ConfigValue::NameRef(node) + } +} impl AstNode for ConflictAction { #[inline] fn can_cast(kind: SyntaxKind) -> bool { diff --git a/crates/squawk_syntax/src/postgresql.ungram b/crates/squawk_syntax/src/postgresql.ungram index 892c3b90..40e00537 100644 --- a/crates/squawk_syntax/src/postgresql.ungram +++ b/crates/squawk_syntax/src/postgresql.ungram @@ -2742,10 +2742,20 @@ CreateExtension = 'create' 'extension' IfNotExists? Name Set = - 'set' - ('session' | 'local')? - ('time' 'zone' | (Path ('to' | '='))) - (Expr | 'local' | 'default') + 'set' + ('session' | 'local')? + ( 'xml' 'option' ('document' | 'content') + | 'time' 'zone' (ConfigValue | 'default' | 'local')? + | ('catalog' | 'schema') Literal + | Path ( + 'from' 'current' + | (('to' | '=') (ConfigValue* | 'default') ) + ) + ) + +ConfigValue = + Literal +| NameRef Show = 'show'