Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions crates/squawk_ide/src/binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ fn bind_stmt(b: &mut Binder, stmt: ast::Stmt) {
match stmt {
ast::Stmt::CreateTable(create_table) => bind_create_table(b, create_table),
ast::Stmt::CreateIndex(create_index) => bind_create_index(b, create_index),
ast::Stmt::CreateFunction(create_function) => bind_create_function(b, create_function),
ast::Stmt::Set(set) => bind_set(b, set),
_ => {}
}
Expand Down Expand Up @@ -129,6 +130,31 @@ fn bind_create_index(b: &mut Binder, create_index: ast::CreateIndex) {
b.scopes[root].insert(index_name, index_id);
}

fn bind_create_function(b: &mut Binder, create_function: ast::CreateFunction) {
let Some(path) = create_function.path() else {
return;
};

let Some(function_name) = item_name(&path) else {
return;
};

let name_ptr = path_to_ptr(&path);

let Some(schema) = b.current_search_path().first().cloned() else {
return;
};

let function_id = b.symbols.alloc(Symbol {
kind: SymbolKind::Function,
ptr: name_ptr,
schema,
});

let root = b.root_scope();
b.scopes[root].insert(function_name, function_id);
}

fn item_name(path: &ast::Path) -> Option<Name> {
let segment = path.segment()?;

Expand Down
90 changes: 90 additions & 0 deletions crates/squawk_ide/src/goto_definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -908,4 +908,94 @@ create index idx_email on users(email$0);
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_function() {
assert_snapshot!(goto("
create function foo() returns int as $$ select 1 $$ language sql;
drop function foo$0();
"), @r"
╭▸
2 │ create function foo() returns int as $$ select 1 $$ language sql;
│ ─── 2. destination
3 │ drop function foo();
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_function_with_schema() {
assert_snapshot!(goto("
set search_path to public;
create function foo() returns int as $$ select 1 $$ language sql;
drop function public.foo$0();
"), @r"
╭▸
3 │ create function foo() returns int as $$ select 1 $$ language sql;
│ ─── 2. destination
4 │ drop function public.foo();
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_function_defined_after() {
assert_snapshot!(goto("
drop function foo$0();
create function foo() returns int as $$ select 1 $$ language sql;
"), @r"
╭▸
2 │ drop function foo();
│ ─ 1. source
3 │ create function foo() returns int as $$ select 1 $$ language sql;
╰╴ ─── 2. destination
");
}

#[test]
fn goto_function_definition_returns_self() {
assert_snapshot!(goto("
create function foo$0() returns int as $$ select 1 $$ language sql;
"), @r"
╭▸
2 │ create function foo() returns int as $$ select 1 $$ language sql;
│ ┬─┬
│ │ │
│ │ 1. source
╰╴ 2. destination
");
}

#[test]
fn goto_drop_function_with_search_path() {
assert_snapshot!(goto("
create function foo() returns int as $$ select 1 $$ language sql;
set search_path to bar;
create function foo() returns int as $$ select 1 $$ language sql;
set search_path to default;
drop function foo$0();
"), @r"
╭▸
2 │ create function foo() returns int as $$ select 1 $$ language sql;
│ ─── 2. destination
6 │ drop function foo();
╰╴ ─ 1. source
");
}

#[test]
fn goto_drop_function_multiple() {
assert_snapshot!(goto("
create function foo() returns int as $$ select 1 $$ language sql;
create function bar() returns int as $$ select 1 $$ language sql;
drop function foo(), bar$0();
"), @r"
╭▸
3 │ create function bar() returns int as $$ select 1 $$ language sql;
│ ─── 2. destination
4 │ drop function foo(), bar();
╰╴ ─ 1. source
");
}
}
39 changes: 39 additions & 0 deletions crates/squawk_ide/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ enum NameRefContext {
DropTable,
Table,
DropIndex,
DropFunction,
CreateIndex,
CreateIndexColumn,
}
Expand All @@ -36,6 +37,13 @@ pub(crate) fn resolve_name_ref(binder: &Binder, name_ref: &ast::NameRef) -> Opti
let position = name_ref.syntax().text_range().start();
resolve_index(binder, &index_name, &schema, position)
}
NameRefContext::DropFunction => {
let path = find_containing_path(name_ref)?;
let function_name = extract_table_name(&path)?;
let schema = extract_schema_name(&path);
let position = name_ref.syntax().text_range().start();
resolve_function(binder, &function_name, &schema, position)
}
NameRefContext::CreateIndexColumn => resolve_create_index_column(binder, name_ref),
}
}
Expand All @@ -53,6 +61,9 @@ fn classify_name_ref_context(name_ref: &ast::NameRef) -> Option<NameRefContext>
if ast::DropIndex::can_cast(ancestor.kind()) {
return Some(NameRefContext::DropIndex);
}
if ast::DropFunction::can_cast(ancestor.kind()) {
return Some(NameRefContext::DropFunction);
}
if ast::PartitionItem::can_cast(ancestor.kind()) {
in_partition_item = true;
}
Expand Down Expand Up @@ -123,6 +134,34 @@ fn resolve_index(
None
}

fn resolve_function(
binder: &Binder,
function_name: &Name,
schema: &Option<Schema>,
position: TextSize,
) -> Option<SyntaxNodePtr> {
let symbols = binder.scopes[binder.root_scope()].get(function_name)?;

if let Some(schema) = schema {
let symbol_id = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Function && &symbol.schema == schema
})?;
return Some(binder.symbols[symbol_id].ptr);
} else {
let search_path = binder.search_path_at(position);
for search_schema in search_path {
if let Some(symbol_id) = symbols.iter().copied().find(|id| {
let symbol = &binder.symbols[*id];
symbol.kind == SymbolKind::Function && &symbol.schema == search_schema
}) {
return Some(binder.symbols[symbol_id].ptr);
}
}
}
None
}

fn resolve_create_index_column(binder: &Binder, name_ref: &ast::NameRef) -> Option<SyntaxNodePtr> {
let column_name = Name::new(name_ref.syntax().text().to_string());

Expand Down
1 change: 1 addition & 0 deletions crates/squawk_ide/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ fn normalize_identifier(text: &str) -> SmolStr {
pub(crate) enum SymbolKind {
Table,
Index,
Function,
}

#[derive(Clone, Debug)]
Expand Down
Loading