Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 301 additions & 7 deletions crates/squawk_ide/src/code_actions.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use rowan::TextSize;
use squawk_linter::Edit;
use squawk_syntax::{
SyntaxKind,
SyntaxKind, SyntaxNode,
ast::{self, AstNode},
};

use crate::{generated::keywords::RESERVED_KEYWORDS, offsets::token_from_offset};

#[derive(Debug, Clone)]
pub enum ActionKind {
QuickFix,
Expand All @@ -25,6 +27,8 @@ pub fn code_actions(file: ast::SourceFile, offset: TextSize) -> Option<Vec<CodeA
remove_else_clause(&mut actions, &file, offset);
rewrite_table_as_select(&mut actions, &file, offset);
rewrite_select_as_table(&mut actions, &file, offset);
quote_identifier(&mut actions, &file, offset);
unquote_identifier(&mut actions, &file, offset);
Some(actions)
}

Expand Down Expand Up @@ -162,8 +166,8 @@ fn rewrite_table_as_select(
file: &ast::SourceFile,
offset: TextSize,
) -> Option<()> {
let node = file.syntax().token_at_offset(offset).left_biased()?;
let table = node.parent_ancestors().find_map(ast::Table::cast)?;
let token = token_from_offset(file, offset)?;
let table = token.parent_ancestors().find_map(ast::Table::cast)?;

let relation_name = table.relation_name()?;
let table_name = relation_name.syntax().text();
Expand All @@ -184,8 +188,8 @@ fn rewrite_select_as_table(
file: &ast::SourceFile,
offset: TextSize,
) -> Option<()> {
let node = file.syntax().token_at_offset(offset).left_biased()?;
let select = node.parent_ancestors().find_map(ast::Select::cast)?;
let token = token_from_offset(file, offset)?;
let select = token.parent_ancestors().find_map(ast::Select::cast)?;

if !can_transform_select_to_table(&select) {
return None;
Expand Down Expand Up @@ -293,6 +297,107 @@ fn can_transform_select_to_table(select: &ast::Select) -> bool {
from_item.name_ref().is_some() || from_item.field_expr().is_some()
}

fn quote_identifier(
actions: &mut Vec<CodeAction>,
file: &ast::SourceFile,
offset: TextSize,
) -> Option<()> {
let token = token_from_offset(file, offset)?;
let parent = token.parent()?;

let name_node = if let Some(name) = ast::Name::cast(parent.clone()) {
name.syntax().clone()
} else if let Some(name_ref) = ast::NameRef::cast(parent) {
name_ref.syntax().clone()
} else {
return None;
};

let text = name_node.text().to_string();

if text.starts_with('"') {
return None;
}

let quoted = format!(r#""{}""#, text.to_lowercase());

actions.push(CodeAction {
title: "Quote identifier".to_owned(),
edits: vec![Edit::replace(name_node.text_range(), quoted)],
kind: ActionKind::RefactorRewrite,
});

Some(())
}

fn unquote_identifier(
actions: &mut Vec<CodeAction>,
file: &ast::SourceFile,
offset: TextSize,
) -> Option<()> {
let token = token_from_offset(file, offset)?;
let parent = token.parent()?;

let name_node = if let Some(name) = ast::Name::cast(parent.clone()) {
name.syntax().clone()
} else if let Some(name_ref) = ast::NameRef::cast(parent) {
name_ref.syntax().clone()
} else {
return None;
};

let unquoted = unquote(&name_node)?;

actions.push(CodeAction {
title: "Unquote identifier".to_owned(),
edits: vec![Edit::replace(name_node.text_range(), unquoted)],
kind: ActionKind::RefactorRewrite,
});

Some(())
}

fn unquote(node: &SyntaxNode) -> Option<String> {
let text = node.text().to_string();

if !text.starts_with('"') || !text.ends_with('"') {
return None;
}

let text = &text[1..text.len() - 1];

if is_reserved_word(text) {
return None;
}

if text.is_empty() {
return None;
}

let mut chars = text.chars();

// see: https://www.postgresql.org/docs/18/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
match chars.next() {
Some(c) if c.is_lowercase() || c == '_' => {}
_ => return None,
}

for c in chars {
if c.is_lowercase() || c.is_ascii_digit() || c == '_' || c == '$' {
continue;
}
return None;
}

Some(text.to_string())
}

fn is_reserved_word(text: &str) -> bool {
RESERVED_KEYWORDS
.binary_search(&text.to_lowercase().as_str())
.is_ok()
}

#[cfg(test)]
mod test {
use super::*;
Expand All @@ -305,11 +410,13 @@ mod test {
f: impl Fn(&mut Vec<CodeAction>, &ast::SourceFile, TextSize) -> Option<()>,
sql: &str,
) -> String {
let (offset, sql) = fixture(sql);
let (mut offset, sql) = fixture(sql);
let parse = ast::SourceFile::parse(&sql);
assert_eq!(parse.errors(), vec![]);
let file: ast::SourceFile = parse.tree();

offset = offset.checked_sub(1.into()).unwrap_or_default();

let mut actions = vec![];
f(&mut actions, &file, offset);

Expand Down Expand Up @@ -388,7 +495,7 @@ mod test {
fn remove_else_clause_before_token() {
assert_snapshot!(apply_code_action(
remove_else_clause,
"select case x when true then 1 $0else 2 end;"),
"select case x when true then 1 e$0lse 2 end;"),
@"select case x when true then 1 end;"
);
}
Expand Down Expand Up @@ -639,4 +746,191 @@ mod test {
"table foo$0;"
));
}

#[test]
fn quote_identifier_on_name_ref() {
assert_snapshot!(apply_code_action(
quote_identifier,
"select x$0 from t;"),
@r#"select "x" from t;"#
);
}

#[test]
fn quote_identifier_on_name() {
assert_snapshot!(apply_code_action(
quote_identifier,
"create table T(X$0 int);"),
@r#"create table T("x" int);"#
);
}

#[test]
fn quote_identifier_lowercases() {
assert_snapshot!(apply_code_action(
quote_identifier,
"create table T(COL$0 int);"),
@r#"create table T("col" int);"#
);
}

#[test]
fn quote_identifier_not_applicable_when_already_quoted() {
assert!(code_action_not_applicable(
quote_identifier,
r#"select "x"$0 from t;"#
));
}

#[test]
fn quote_identifier_not_applicable_on_select_keyword() {
assert!(code_action_not_applicable(
quote_identifier,
"sel$0ect x from t;"
));
}

#[test]
fn quote_identifier_on_keyword_column_name() {
assert_snapshot!(apply_code_action(
quote_identifier,
"select te$0xt from t;"),
@r#"select "text" from t;"#
);
}

#[test]
fn quote_identifier_example_select() {
assert_snapshot!(apply_code_action(
quote_identifier,
"select x$0 from t;"),
@r#"select "x" from t;"#
);
}

#[test]
fn quote_identifier_example_create_table() {
assert_snapshot!(apply_code_action(
quote_identifier,
"create table T(X$0 int);"),
@r#"create table T("x" int);"#
);
}

#[test]
fn unquote_identifier_simple() {
assert_snapshot!(apply_code_action(
unquote_identifier,
r#"select "x"$0 from t;"#),
@"select x from t;"
);
}

#[test]
fn unquote_identifier_with_underscore() {
assert_snapshot!(apply_code_action(
unquote_identifier,
r#"select "user_id"$0 from t;"#),
@"select user_id from t;"
);
}

#[test]
fn unquote_identifier_with_digits() {
assert_snapshot!(apply_code_action(
unquote_identifier,
r#"select "x123"$0 from t;"#),
@"select x123 from t;"
);
}

#[test]
fn unquote_identifier_with_dollar() {
assert_snapshot!(apply_code_action(
unquote_identifier,
r#"select "my_table$1"$0 from t;"#),
@"select my_table$1 from t;"
);
}

#[test]
fn unquote_identifier_starts_with_underscore() {
assert_snapshot!(apply_code_action(
unquote_identifier,
r#"select "_col"$0 from t;"#),
@"select _col from t;"
);
}

#[test]
fn unquote_identifier_starts_with_unicode() {
assert_snapshot!(apply_code_action(
unquote_identifier,
r#"select "é"$0 from t;"#),
@"select é from t;"
);
}

#[test]
fn unquote_identifier_not_applicable() {
// upper case
assert!(code_action_not_applicable(
unquote_identifier,
r#"select "X"$0 from t;"#
));
// upper case
assert!(code_action_not_applicable(
unquote_identifier,
r#"select "Foo"$0 from t;"#
));
// dash
assert!(code_action_not_applicable(
unquote_identifier,
r#"select "my-col"$0 from t;"#
));
// leading digits
assert!(code_action_not_applicable(
unquote_identifier,
r#"select "123"$0 from t;"#
));
// space
assert!(code_action_not_applicable(
unquote_identifier,
r#"select "foo bar"$0 from t;"#
));
// quotes
assert!(code_action_not_applicable(
unquote_identifier,
r#"select "foo""bar"$0 from t;"#
));
// already unquoted
assert!(code_action_not_applicable(
unquote_identifier,
"select x$0 from t;"
));
// brackets
assert!(code_action_not_applicable(
unquote_identifier,
r#"select "my[col]"$0 from t;"#
));
// curly brackets
assert!(code_action_not_applicable(
unquote_identifier,
r#"select "my{}"$0 from t;"#
));
// reserved word
assert!(code_action_not_applicable(
unquote_identifier,
r#"select "select"$0 from t;"#
));
}

#[test]
fn unquote_identifier_on_name() {
assert_snapshot!(apply_code_action(
unquote_identifier,
r#"create table T("x"$0 int);"#),
@"create table T(x int);"
);
}
}
Loading
Loading