diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 149365c47..7edf8fc60 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1818,7 +1818,14 @@ impl<'a> Parser<'a> { } else if let Some(lambda) = self.try_parse_lambda()? { return Ok(lambda); } else { - let exprs = self.parse_comma_separated(Parser::parse_expr)?; + // Parentheses create a normal expression context. + // This ensures that e.g. `NOT NULL` inside parens is parsed + // as `IS NOT NULL` (for dialects that support it), while + // `NOT NULL` outside parens in a column definition context + // remains a column constraint. + let exprs = self.with_state(ParserState::Normal, |p| { + p.parse_comma_separated(Parser::parse_expr) + })?; match exprs.len() { 0 => return Err(ParserError::ParserError( "Internal parser error: parse_comma_separated returned empty list" @@ -8786,19 +8793,15 @@ impl<'a> Parser<'a> { } else if self.parse_keyword(Keyword::NULL) { Ok(Some(ColumnOption::Null)) } else if self.parse_keyword(Keyword::DEFAULT) { - Ok(Some(ColumnOption::Default( - self.parse_column_option_expr()?, - ))) + Ok(Some(ColumnOption::Default(self.parse_expr()?))) } else if dialect_of!(self is ClickHouseDialect| GenericDialect) && self.parse_keyword(Keyword::MATERIALIZED) { - Ok(Some(ColumnOption::Materialized( - self.parse_column_option_expr()?, - ))) + Ok(Some(ColumnOption::Materialized(self.parse_expr()?))) } else if dialect_of!(self is ClickHouseDialect| GenericDialect) && self.parse_keyword(Keyword::ALIAS) { - Ok(Some(ColumnOption::Alias(self.parse_column_option_expr()?))) + Ok(Some(ColumnOption::Alias(self.parse_expr()?))) } else if dialect_of!(self is ClickHouseDialect| GenericDialect) && self.parse_keyword(Keyword::EPHEMERAL) { @@ -8807,9 +8810,7 @@ impl<'a> Parser<'a> { if matches!(self.peek_token().token, Token::Comma | Token::RParen) { Ok(Some(ColumnOption::Ephemeral(None))) } else { - Ok(Some(ColumnOption::Ephemeral(Some( - self.parse_column_option_expr()?, - )))) + Ok(Some(ColumnOption::Ephemeral(Some(self.parse_expr()?)))) } } else if self.parse_keywords(&[Keyword::PRIMARY, Keyword::KEY]) { let characteristics = self.parse_constraint_characteristics()?; @@ -8922,7 +8923,7 @@ impl<'a> Parser<'a> { } else if self.parse_keywords(&[Keyword::ON, Keyword::UPDATE]) && dialect_of!(self is MySqlDialect | GenericDialect) { - let expr = self.parse_column_option_expr()?; + let expr = self.parse_expr()?; Ok(Some(ColumnOption::OnUpdate(expr))) } else if self.parse_keyword(Keyword::GENERATED) { self.parse_optional_column_option_generated() @@ -8940,9 +8941,7 @@ impl<'a> Parser<'a> { } else if self.parse_keyword(Keyword::SRID) && dialect_of!(self is MySqlDialect | GenericDialect) { - Ok(Some(ColumnOption::Srid(Box::new( - self.parse_column_option_expr()?, - )))) + Ok(Some(ColumnOption::Srid(Box::new(self.parse_expr()?)))) } else if self.parse_keyword(Keyword::IDENTITY) && dialect_of!(self is MsSqlDialect | GenericDialect) { @@ -8984,31 +8983,6 @@ impl<'a> Parser<'a> { } } - /// When parsing some column option expressions we need to revert to [ParserState::Normal] since - /// `NOT NULL` is allowed as an alias for `IS NOT NULL`. - /// In those cases we use this helper instead of calling [Parser::parse_expr] directly. - /// - /// For example, consider these `CREATE TABLE` statements: - /// ```sql - /// CREATE TABLE foo (abc BOOL DEFAULT (42 NOT NULL) NOT NULL); - /// ``` - /// vs - /// ```sql - /// CREATE TABLE foo (abc BOOL NOT NULL); - /// ``` - /// - /// In the first we should parse the inner portion of `(42 NOT NULL)` as [Expr::IsNotNull], - /// whereas is both statements that trailing `NOT NULL` should only be parsed as a - /// [ColumnOption::NotNull]. - fn parse_column_option_expr(&mut self) -> Result { - if self.peek_token_ref().token == Token::LParen { - let expr: Expr = self.with_state(ParserState::Normal, |p| p.parse_prefix())?; - Ok(expr) - } else { - Ok(self.parse_expr()?) - } - } - pub(crate) fn parse_tag(&mut self) -> Result { let name = self.parse_object_name(false)?; self.expect_token(&Token::Eq)?; diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index bbbf0d835..9b38ac34a 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -17302,6 +17302,11 @@ fn test_parse_not_null_in_column_options() { ); } +#[test] +fn test_parse_default_expr_with_operators() { + all_dialects().verified_stmt("CREATE TABLE t (c INT DEFAULT (1 + 2) + 3)"); +} + #[test] fn test_parse_default_with_collate_column_option() { let sql = "CREATE TABLE foo (abc TEXT DEFAULT 'foo' COLLATE 'en_US')"; diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 325e3939e..cc872419f 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -512,6 +512,13 @@ fn parse_create_table_with_defaults() { } } +#[test] +fn parse_cast_in_default_expr() { + pg().verified_stmt("CREATE TABLE t (c TEXT DEFAULT (foo())::TEXT)"); + pg().verified_stmt("CREATE TABLE t (c TEXT DEFAULT (foo())::INT::TEXT)"); + pg().verified_stmt("CREATE TABLE t (c TEXT DEFAULT (foo())::TEXT NOT NULL)"); +} + #[test] fn parse_create_table_from_pg_dump() { let sql = "CREATE TABLE public.customer (