diff --git a/crates/squawk_ide/src/column_name.rs b/crates/squawk_ide/src/column_name.rs new file mode 100644 index 00000000..399a5612 --- /dev/null +++ b/crates/squawk_ide/src/column_name.rs @@ -0,0 +1,409 @@ +use squawk_syntax::{ + SyntaxKind, SyntaxNode, + ast::{self, AstNode}, +}; + +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum ColumnName { + Column(String), + /// There's a fallback mechanism that we need to propagate through the + /// expressions/types. + // + /// We can see this with: + /// ```sql + /// select case when true then 'a' else now()::text end; + /// -- column named `now`, propagating the function name + /// -- vs + /// select case when true then 'a' else 'b' end; + /// -- column named `case` + /// ``` + UnknownColumn(Option), + Star, +} + +impl ColumnName { + #[allow(dead_code)] + pub(crate) fn from_target(target: ast::Target) -> Option<(ColumnName, SyntaxNode)> { + if let Some(as_name) = target.as_name() + && let Some(name_node) = as_name.name() + { + return Some(( + ColumnName::Column(name_node.text().to_string()), + name_node.syntax().clone(), + )); + } else if let Some(expr) = target.expr() + && let Some(name) = name_from_expr(expr, false) + { + return Some(name); + } else if target.star_token().is_some() { + return Some((ColumnName::Star, target.syntax().clone())); + } + None + } + + fn new(name: String, unknown_column: bool) -> ColumnName { + if unknown_column { + ColumnName::UnknownColumn(Some(name)) + } else { + ColumnName::Column(name) + } + } +} + +fn name_from_type(ty: ast::Type, unknown_column: bool) -> Option<(ColumnName, SyntaxNode)> { + match ty { + ast::Type::PathType(path_type) => { + if let Some(name_ref) = path_type + .path() + .and_then(|x| x.segment()) + .and_then(|x| x.name_ref()) + { + return name_from_name_ref(name_ref, true).map(|(column, node)| { + let column = match column { + ColumnName::Column(c) => ColumnName::new(c, unknown_column), + _ => column, + }; + (column, node) + }); + } + } + ast::Type::BitType(bit_type) => { + let name = if bit_type.varying_token().is_some() { + "varbit" + } else { + "bit" + }; + return Some(( + ColumnName::new(name.to_string(), unknown_column), + bit_type.syntax().clone(), + )); + } + ast::Type::CharType(char_type) => { + let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some() + { + "varchar" + } else { + "bpchar" + }; + return Some(( + ColumnName::new(name.to_string(), unknown_column), + char_type.syntax().clone(), + )); + } + ast::Type::DoubleType(double_type) => { + return Some(( + ColumnName::new("float8".to_string(), unknown_column), + double_type.syntax().clone(), + )); + } + ast::Type::IntervalType(interval_type) => { + return Some(( + ColumnName::new("interval".to_string(), unknown_column), + interval_type.syntax().clone(), + )); + } + ast::Type::TimeType(time_type) => { + let mut name = if time_type.timestamp_token().is_some() { + "timestamp".to_owned() + } else { + "time".to_owned() + }; + if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() { + // time -> timetz + // timestamp -> timestamptz + name.push_str("tz"); + }; + return Some(( + ColumnName::new(name.to_string(), unknown_column), + time_type.syntax().clone(), + )); + } + ast::Type::ArrayType(array_type) => { + if let Some(inner_ty) = array_type.ty() { + return name_from_type(inner_ty, unknown_column); + } + } + // we shouldn't ever hit this since the following isn't valid syntax: + // select cast('foo' as t.a%TYPE); + ast::Type::PercentType(_) => return None, + ast::Type::ExprType(expr_type) => { + if let Some(expr) = expr_type.expr() { + return name_from_expr(expr, true).map(|(column, node)| { + let column = match column { + ColumnName::Column(c) => ColumnName::new(c, unknown_column), + _ => column, + }; + (column, node) + }); + } + } + } + None +} + +fn name_from_name_ref(name_ref: ast::NameRef, in_type: bool) -> Option<(ColumnName, SyntaxNode)> { + if in_type { + for node in name_ref.syntax().children_with_tokens() { + match node.kind() { + SyntaxKind::BIGINT_KW => { + return Some(( + ColumnName::Column("int8".to_owned()), + name_ref.syntax().clone(), + )); + } + SyntaxKind::INT_KW | SyntaxKind::INTEGER_KW => { + return Some(( + ColumnName::Column("int4".to_owned()), + name_ref.syntax().clone(), + )); + } + SyntaxKind::SMALLINT_KW => { + return Some(( + ColumnName::Column("int2".to_owned()), + name_ref.syntax().clone(), + )); + } + _ => (), + } + } + } + return Some(( + ColumnName::Column(name_ref.text().to_string()), + name_ref.syntax().clone(), + )); +} + +/* +TODO: + +unnest(anyarray, anyarray [, ... ]) → setof anyelement, anyelement [, ... ] + +select * from unnest(ARRAY[1,2], ARRAY['foo','bar','baz']) → + unnset | unnset +--------+----- + 1 | foo + 2 | bar + | baz +*/ + +// NOTE: we have to have this in_type param because we parse some casts as exprs +// instead of types. +fn name_from_expr(expr: ast::Expr, in_type: bool) -> Option<(ColumnName, SyntaxNode)> { + let node = expr.syntax().clone(); + match expr { + ast::Expr::ArrayExpr(_) => { + return Some((ColumnName::Column("array".to_string()), node)); + } + ast::Expr::BetweenExpr(_) | ast::Expr::BinExpr(_) => { + return Some((ColumnName::UnknownColumn(None), node)); + } + ast::Expr::CallExpr(call_expr) => { + if let Some(func_name) = call_expr.expr() { + match func_name { + ast::Expr::ArrayExpr(_) + | ast::Expr::BetweenExpr(_) + | ast::Expr::ParenExpr(_) + | ast::Expr::BinExpr(_) + | ast::Expr::CallExpr(_) + | ast::Expr::CaseExpr(_) + | ast::Expr::CastExpr(_) + | ast::Expr::Literal(_) + | ast::Expr::PostfixExpr(_) + | ast::Expr::PrefixExpr(_) + | ast::Expr::TupleExpr(_) + | ast::Expr::IndexExpr(_) + | ast::Expr::SliceExpr(_) => unreachable!("not possible in the grammar"), + ast::Expr::FieldExpr(field_expr) => { + if let Some(name_ref) = field_expr.field() { + return name_from_name_ref(name_ref, in_type); + } + } + ast::Expr::NameRef(name_ref) => { + return name_from_name_ref(name_ref, in_type); + } + } + } + } + ast::Expr::CaseExpr(case) => { + if let Some(else_clause) = case.else_clause() + && let Some(expr) = else_clause.expr() + && let Some((column, node)) = name_from_expr(expr, in_type) + { + if !matches!(column, ColumnName::UnknownColumn(_)) { + return Some((column, node)); + } + } + return Some((ColumnName::Column("case".to_string()), node)); + } + ast::Expr::CastExpr(cast_expr) => { + let mut unknown_column = false; + if let Some(expr) = cast_expr.expr() + && let Some((column, node)) = name_from_expr(expr, in_type) + { + match column { + ColumnName::Column(_) => return Some((column, node)), + ColumnName::UnknownColumn(_) => unknown_column = true, + ColumnName::Star => (), + } + } + if let Some(ty) = cast_expr.ty() { + return name_from_type(ty, unknown_column); + } + } + ast::Expr::FieldExpr(field_expr) => { + if let Some(name_ref) = field_expr.field() { + return name_from_name_ref(name_ref, in_type); + } + } + ast::Expr::IndexExpr(index_expr) => { + if let Some(base) = index_expr.base() { + return name_from_expr(base, in_type); + } + } + ast::Expr::SliceExpr(slice_expr) => { + if let Some(base) = slice_expr.base() { + return name_from_expr(base, in_type); + } + } + ast::Expr::Literal(_) | ast::Expr::PrefixExpr(_) | ast::Expr::PostfixExpr(_) => { + return Some((ColumnName::UnknownColumn(None), node)); + } + ast::Expr::NameRef(name_ref) => { + return name_from_name_ref(name_ref, in_type); + } + ast::Expr::ParenExpr(paren_expr) => { + if let Some(expr) = paren_expr.expr() { + return name_from_expr(expr, in_type); + } else if let Some(select) = paren_expr.select() + && let Some(mut targets) = select + .select_clause() + .and_then(|x| x.target_list()) + .map(|x| x.targets()) + && let Some(target) = targets.next() + { + return ColumnName::from_target(target); + } + } + ast::Expr::TupleExpr(_) => { + return Some((ColumnName::Column("row".to_string()), node)); + } + } + None +} + +#[test] +fn examples() { + use insta::assert_snapshot; + + // array + assert_snapshot!(name("array(select 1)"), @"array"); + assert_snapshot!(name("array[1, 2, 3]"), @"array"); + + // unknown columns + assert_snapshot!(name("1 between 0 and 10"), @"?column?"); + assert_snapshot!(name("1 + 2"), @"?column?"); + assert_snapshot!(name("42"), @"?column?"); + assert_snapshot!(name("'string'"), @"?column?"); + // prefix + assert_snapshot!(name("-42"), @"?column?"); + assert_snapshot!(name("|/ 42"), @"?column?"); + // postfix + assert_snapshot!(name("x is null"), @"?column?"); + assert_snapshot!(name("x is not null"), @"?column?"); + // paren expr + assert_snapshot!(name("(1 * 2)"), @"?column?"); + assert_snapshot!(name("(select 1 as a)"), @"a"); + + // func + assert_snapshot!(name("count(*)"), @"count"); + assert_snapshot!(name("schema.func_name(1)"), @"func_name"); + + // index + assert_snapshot!(name("foo[bar]"), @"foo"); + assert_snapshot!(name("foo[1]"), @"foo"); + + // column + assert_snapshot!(name("database.schema.table.column"), @"column"); + assert_snapshot!(name("t.a"), @"a"); + assert_snapshot!(name("col_name"), @"col_name"); + assert_snapshot!(name("(c)"), @"c"); + + // case + assert_snapshot!(name("case when true then 'foo' end"), @"case"); + assert_snapshot!(name("case when true then 'foo' else now()::text end"), @"now"); + assert_snapshot!(name("case when true then 'foo' else 'bar' end"), @"case"); + assert_snapshot!(name("case when true then 'foo' else '1'::bigint::text end"), @"case"); + + // casts + assert_snapshot!(name("now()::text"), @"now"); + assert_snapshot!(name("cast(col_name as text)"), @"col_name"); + assert_snapshot!(name("col_name::text"), @"col_name"); + assert_snapshot!(name("col_name::int::text"), @"col_name"); + assert_snapshot!(name("'1'::bigint"), @"int8"); + assert_snapshot!(name("'1'::int"), @"int4"); + assert_snapshot!(name("'1'::smallint"), @"int2"); + assert_snapshot!(name("'{{1, 2}, {3, 4}}'::bigint[][]"), @"int8"); + assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[][]"), @"int4"); + assert_snapshot!(name("'{{1, 2}, {3, 4}}'::smallint[]"), @"int2"); + assert_snapshot!(name("pg_catalog.varchar(100) '{1}'"), @"varchar"); + assert_snapshot!(name("'{1}'::integer[];"), @"int4"); + assert_snapshot!(name("'{1}'::pg_catalog.varchar(1)[]::integer[];"), @"int4"); + assert_snapshot!(name("'1'::bigint::smallint"), @"int2"); + + // tuple + assert_snapshot!(name("(1, 2, 3)"), @"row"); + + // array types + assert_snapshot!(name("'{{1, 2}, {3, 4}}'::int[]"), @"int4"); + assert_snapshot!(name("cast('{foo}' as text[])"), @"text"); + + // bit types + assert_snapshot!(name("cast('1010' as bit varying(10))"), @"varbit"); + + // char types + assert_snapshot!(name("cast('hello' as character varying(10))"), @"varchar"); + assert_snapshot!(name("cast('hello' as char varying(5))"), @"varchar"); + assert_snapshot!(name("cast('hello' as char(5))"), @"bpchar"); + + // double types + assert_snapshot!(name("cast(1.5 as double precision)"), @"float8"); + + // interval types + assert_snapshot!(name("cast('1 hour' as interval hour to minute)"), @"interval"); + + // percent types + assert_snapshot!(name("cast(foo as schema.%TYPE)"), @"foo"); + + // time types + assert_snapshot!(name("cast('12:00:00' as time(6) without time zone)"), @"time"); + assert_snapshot!(name("cast('12:00:00' as time(6) with time zone)"), @"timetz"); + assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) with time zone)"), @"timestamptz"); + assert_snapshot!(name("cast('2024-01-01 12:00:00' as timestamp(6) without time zone)"), @"timestamp"); + + #[track_caller] + fn name(sql: &str) -> String { + let sql = "select ".to_string() + sql; + let parse = squawk_syntax::SourceFile::parse(&sql); + assert_eq!(parse.errors(), vec![]); + let file = parse.tree(); + + let stmt = file.stmts().next().unwrap(); + let ast::Stmt::Select(select) = stmt else { + unreachable!() + }; + + let target = select + .select_clause() + .and_then(|sc| sc.target_list()) + .and_then(|tl| tl.targets().next()) + .unwrap(); + + ColumnName::from_target(target) + .map(|x| match x.0 { + ColumnName::Column(string) => string, + ColumnName::Star => unreachable!(), + ColumnName::UnknownColumn(c) => c.unwrap_or_else(|| "?column?".to_string()), + }) + .unwrap() + } +} diff --git a/crates/squawk_ide/src/lib.rs b/crates/squawk_ide/src/lib.rs index d9a05edd..70b81988 100644 --- a/crates/squawk_ide/src/lib.rs +++ b/crates/squawk_ide/src/lib.rs @@ -1,4 +1,5 @@ pub mod code_actions; +pub mod column_name; pub mod expand_selection; pub mod goto_definition; #[cfg(test)]