diff --git a/crates/squawk_ide/src/binder.rs b/crates/squawk_ide/src/binder.rs index f7a4a5c5..ed973464 100644 --- a/crates/squawk_ide/src/binder.rs +++ b/crates/squawk_ide/src/binder.rs @@ -79,6 +79,7 @@ fn bind_stmt(b: &mut Binder, stmt: ast::Stmt) { match stmt { ast::Stmt::CreateTable(create_table) => bind_create_table(b, create_table), ast::Stmt::CreateIndex(create_index) => bind_create_index(b, create_index), + ast::Stmt::CreateFunction(create_function) => bind_create_function(b, create_function), ast::Stmt::Set(set) => bind_set(b, set), _ => {} } @@ -129,6 +130,31 @@ fn bind_create_index(b: &mut Binder, create_index: ast::CreateIndex) { b.scopes[root].insert(index_name, index_id); } +fn bind_create_function(b: &mut Binder, create_function: ast::CreateFunction) { + let Some(path) = create_function.path() else { + return; + }; + + let Some(function_name) = item_name(&path) else { + return; + }; + + let name_ptr = path_to_ptr(&path); + + let Some(schema) = b.current_search_path().first().cloned() else { + return; + }; + + let function_id = b.symbols.alloc(Symbol { + kind: SymbolKind::Function, + ptr: name_ptr, + schema, + }); + + let root = b.root_scope(); + b.scopes[root].insert(function_name, function_id); +} + fn item_name(path: &ast::Path) -> Option { let segment = path.segment()?; diff --git a/crates/squawk_ide/src/goto_definition.rs b/crates/squawk_ide/src/goto_definition.rs index 8f491ecd..c42e3118 100644 --- a/crates/squawk_ide/src/goto_definition.rs +++ b/crates/squawk_ide/src/goto_definition.rs @@ -908,4 +908,94 @@ create index idx_email on users(email$0); ╰╴ ─ 1. source "); } + + #[test] + fn goto_drop_function() { + assert_snapshot!(goto(" +create function foo() returns int as $$ select 1 $$ language sql; +drop function foo$0(); +"), @r" + ╭▸ + 2 │ create function foo() returns int as $$ select 1 $$ language sql; + │ ─── 2. destination + 3 │ drop function foo(); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_function_with_schema() { + assert_snapshot!(goto(" +set search_path to public; +create function foo() returns int as $$ select 1 $$ language sql; +drop function public.foo$0(); +"), @r" + ╭▸ + 3 │ create function foo() returns int as $$ select 1 $$ language sql; + │ ─── 2. destination + 4 │ drop function public.foo(); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_function_defined_after() { + assert_snapshot!(goto(" +drop function foo$0(); +create function foo() returns int as $$ select 1 $$ language sql; +"), @r" + ╭▸ + 2 │ drop function foo(); + │ ─ 1. source + 3 │ create function foo() returns int as $$ select 1 $$ language sql; + ╰╴ ─── 2. destination + "); + } + + #[test] + fn goto_function_definition_returns_self() { + assert_snapshot!(goto(" +create function foo$0() returns int as $$ select 1 $$ language sql; +"), @r" + ╭▸ + 2 │ create function foo() returns int as $$ select 1 $$ language sql; + │ ┬─┬ + │ │ │ + │ │ 1. source + ╰╴ 2. destination + "); + } + + #[test] + fn goto_drop_function_with_search_path() { + assert_snapshot!(goto(" +create function foo() returns int as $$ select 1 $$ language sql; +set search_path to bar; +create function foo() returns int as $$ select 1 $$ language sql; +set search_path to default; +drop function foo$0(); +"), @r" + ╭▸ + 2 │ create function foo() returns int as $$ select 1 $$ language sql; + │ ─── 2. destination + ‡ + 6 │ drop function foo(); + ╰╴ ─ 1. source + "); + } + + #[test] + fn goto_drop_function_multiple() { + assert_snapshot!(goto(" +create function foo() returns int as $$ select 1 $$ language sql; +create function bar() returns int as $$ select 1 $$ language sql; +drop function foo(), bar$0(); +"), @r" + ╭▸ + 3 │ create function bar() returns int as $$ select 1 $$ language sql; + │ ─── 2. destination + 4 │ drop function foo(), bar(); + ╰╴ ─ 1. source + "); + } } diff --git a/crates/squawk_ide/src/resolve.rs b/crates/squawk_ide/src/resolve.rs index bcdff05d..4dbfabcb 100644 --- a/crates/squawk_ide/src/resolve.rs +++ b/crates/squawk_ide/src/resolve.rs @@ -14,6 +14,7 @@ enum NameRefContext { DropTable, Table, DropIndex, + DropFunction, CreateIndex, CreateIndexColumn, } @@ -36,6 +37,13 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti let position = name_ref.syntax().text_range().start(); resolve_index(binder, &index_name, &schema, position) } + NameRefContext::DropFunction => { + let path = find_containing_path(name_ref)?; + let function_name = extract_table_name(&path)?; + let schema = extract_schema_name(&path); + let position = name_ref.syntax().text_range().start(); + resolve_function(binder, &function_name, &schema, position) + } NameRefContext::CreateIndexColumn => resolve_create_index_column(binder, name_ref), } } @@ -53,6 +61,9 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option if ast::DropIndex::can_cast(ancestor.kind()) { return Some(NameRefContext::DropIndex); } + if ast::DropFunction::can_cast(ancestor.kind()) { + return Some(NameRefContext::DropFunction); + } if ast::PartitionItem::can_cast(ancestor.kind()) { in_partition_item = true; } @@ -123,6 +134,34 @@ fn resolve_index( None } +fn resolve_function( + binder: &Binder, + function_name: &Name, + schema: &Option, + position: TextSize, +) -> Option { + let symbols = binder.scopes[binder.root_scope()].get(function_name)?; + + if let Some(schema) = schema { + let symbol_id = symbols.iter().copied().find(|id| { + let symbol = &binder.symbols[*id]; + symbol.kind == SymbolKind::Function && &symbol.schema == schema + })?; + return Some(binder.symbols[symbol_id].ptr); + } else { + let search_path = binder.search_path_at(position); + for search_schema in search_path { + if let Some(symbol_id) = symbols.iter().copied().find(|id| { + let symbol = &binder.symbols[*id]; + symbol.kind == SymbolKind::Function && &symbol.schema == search_schema + }) { + return Some(binder.symbols[symbol_id].ptr); + } + } + } + None +} + fn resolve_create_index_column(binder: &Binder, name_ref: &ast::NameRef) -> Option { let column_name = Name::new(name_ref.syntax().text().to_string()); diff --git a/crates/squawk_ide/src/symbols.rs b/crates/squawk_ide/src/symbols.rs index 2e515d42..0c71a9d9 100644 --- a/crates/squawk_ide/src/symbols.rs +++ b/crates/squawk_ide/src/symbols.rs @@ -41,6 +41,7 @@ fn normalize_identifier(text: &str) -> SmolStr { pub(crate) enum SymbolKind { Table, Index, + Function, } #[derive(Clone, Debug)]