diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 8da6bbe7..619730e2 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -344,12 +344,14 @@ impl fmt::Display for ObjectName { #[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] pub enum ObjectNamePart { Identifier(Ident), + Function(ObjectNamePartFunction), } impl ObjectNamePart { pub fn as_ident(&self) -> Option<&Ident> { match self { ObjectNamePart::Identifier(ident) => Some(ident), + ObjectNamePart::Function(_) => None, } } } @@ -358,10 +360,30 @@ impl fmt::Display for ObjectNamePart { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ObjectNamePart::Identifier(ident) => write!(f, "{ident}"), + ObjectNamePart::Function(func) => write!(f, "{func}"), } } } +/// An object name part that consists of a function that dynamically +/// constructs identifiers. +/// +/// - [Snowflake](https://docs.snowflake.com/en/sql-reference/identifier-literal) +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))] +pub struct ObjectNamePartFunction { + pub name: Ident, + pub args: Vec, +} + +impl fmt::Display for ObjectNamePartFunction { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}(", self.name)?; + write!(f, "{})", display_comma_separated(&self.args)) + } +} + /// Represents an Array Expression, either /// `ARRAY[..]`, or `[..]` #[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)] diff --git a/src/ast/spans.rs b/src/ast/spans.rs index 144de592..a1b2e4e0 100644 --- a/src/ast/spans.rs +++ b/src/ast/spans.rs @@ -1671,6 +1671,10 @@ impl Spanned for ObjectNamePart { fn span(&self) -> Span { match self { ObjectNamePart::Identifier(ident) => ident.span, + ObjectNamePart::Function(func) => func + .name + .span + .union(&union_spans(func.args.iter().map(|i| i.span()))), } } } diff --git a/src/dialect/mod.rs b/src/dialect/mod.rs index bc3c5555..88d63f27 100644 --- a/src/dialect/mod.rs +++ b/src/dialect/mod.rs @@ -49,7 +49,7 @@ pub use self::postgresql::PostgreSqlDialect; pub use self::redshift::RedshiftSqlDialect; pub use self::snowflake::SnowflakeDialect; pub use self::sqlite::SQLiteDialect; -use crate::ast::{ColumnOption, Expr, GranteesType, Statement}; +use crate::ast::{ColumnOption, Expr, GranteesType, Ident, Statement}; pub use crate::keywords; use crate::keywords::Keyword; use crate::parser::{Parser, ParserError}; @@ -1076,6 +1076,15 @@ pub trait Dialect: Debug + Any { fn supports_comma_separated_drop_column_list(&self) -> bool { false } + + /// Returns true if the dialect considers the specified ident as a function + /// that returns an identifier. Typically used to generate identifiers + /// programmatically. + /// + /// - [Snowflake](https://docs.snowflake.com/en/sql-reference/identifier-literal) + fn is_identifier_generating_function_name(&self, _ident: &Ident) -> bool { + false + } } /// This represents the operators for which precedence must be defined diff --git a/src/dialect/snowflake.rs b/src/dialect/snowflake.rs index 212cf217..12ccd647 100644 --- a/src/dialect/snowflake.rs +++ b/src/dialect/snowflake.rs @@ -367,6 +367,15 @@ impl Dialect for SnowflakeDialect { fn supports_comma_separated_drop_column_list(&self) -> bool { true } + + fn is_identifier_generating_function_name(&self, ident: &Ident) -> bool { + ident.quote_style.is_none() && ident.value.to_lowercase() == "identifier" + } + + // For example: `SELECT IDENTIFIER('alias1').* FROM tbl AS alias1` + fn supports_select_expr_star(&self) -> bool { + true + } } fn parse_file_staging_command(kw: Keyword, parser: &mut Parser) -> Result { diff --git a/src/parser/mod.rs b/src/parser/mod.rs index c4d0508d..1be6d637 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10353,49 +10353,14 @@ impl<'a> Parser<'a> { } } - /// Parse a possibly qualified, possibly quoted identifier, optionally allowing for wildcards, - /// e.g. *, *.*, `foo`.*, or "foo"."bar" - fn parse_object_name_with_wildcards( - &mut self, - in_table_clause: bool, - allow_wildcards: bool, - ) -> Result { - let mut idents = vec![]; - - if dialect_of!(self is BigQueryDialect) && in_table_clause { - loop { - let (ident, end_with_period) = self.parse_unquoted_hyphenated_identifier()?; - idents.push(ident); - if !self.consume_token(&Token::Period) && !end_with_period { - break; - } - } - } else { - loop { - let ident = if allow_wildcards && self.peek_token().token == Token::Mul { - let span = self.next_token().span; - Ident { - value: Token::Mul.to_string(), - quote_style: None, - span, - } - } else { - if self.dialect.supports_object_name_double_dot_notation() - && idents.len() == 1 - && self.consume_token(&Token::Period) - { - // Empty string here means default schema - idents.push(Ident::new("")); - } - self.parse_identifier()? - }; - idents.push(ident); - if !self.consume_token(&Token::Period) { - break; - } - } - } - Ok(ObjectName::from(idents)) + /// Parse a possibly qualified, possibly quoted identifier, e.g. + /// `foo` or `myschema."table" + /// + /// The `in_table_clause` parameter indicates whether the object name is a table in a FROM, JOIN, + /// or similar table clause. Currently, this is used only to support unquoted hyphenated identifiers + /// in this context on BigQuery. + pub fn parse_object_name(&mut self, in_table_clause: bool) -> Result { + self.parse_object_name_inner(in_table_clause, false) } /// Parse a possibly qualified, possibly quoted identifier, e.g. @@ -10404,19 +10369,68 @@ impl<'a> Parser<'a> { /// The `in_table_clause` parameter indicates whether the object name is a table in a FROM, JOIN, /// or similar table clause. Currently, this is used only to support unquoted hyphenated identifiers /// in this context on BigQuery. - pub fn parse_object_name(&mut self, in_table_clause: bool) -> Result { - let ObjectName(mut idents) = - self.parse_object_name_with_wildcards(in_table_clause, false)?; + /// + /// The `allow_wildcards` parameter indicates whether to allow for wildcards in the object name + /// e.g. *, *.*, `foo`.*, or "foo"."bar" + fn parse_object_name_inner( + &mut self, + in_table_clause: bool, + allow_wildcards: bool, + ) -> Result { + let mut parts = vec![]; + if dialect_of!(self is BigQueryDialect) && in_table_clause { + loop { + let (ident, end_with_period) = self.parse_unquoted_hyphenated_identifier()?; + parts.push(ObjectNamePart::Identifier(ident)); + if !self.consume_token(&Token::Period) && !end_with_period { + break; + } + } + } else { + loop { + if allow_wildcards && self.peek_token().token == Token::Mul { + let span = self.next_token().span; + parts.push(ObjectNamePart::Identifier(Ident { + value: Token::Mul.to_string(), + quote_style: None, + span, + })); + } else if let Some(func_part) = + self.maybe_parse(|parser| parser.parse_object_name_function_part())? + { + parts.push(ObjectNamePart::Function(func_part)); + } else if dialect_of!(self is BigQueryDialect) && in_table_clause { + let (ident, end_with_period) = self.parse_unquoted_hyphenated_identifier()?; + parts.push(ObjectNamePart::Identifier(ident)); + if !self.consume_token(&Token::Period) && !end_with_period { + break; + } + } else if self.dialect.supports_object_name_double_dot_notation() + && parts.len() == 1 + && matches!(self.peek_token().token, Token::Period) + { + // Empty string here means default schema + parts.push(ObjectNamePart::Identifier(Ident::new(""))); + } else { + let ident = self.parse_identifier()?; + parts.push(ObjectNamePart::Identifier(ident)); + } + + if !self.consume_token(&Token::Period) { + break; + } + } + } // BigQuery accepts any number of quoted identifiers of a table name. // https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#quoted_identifiers if dialect_of!(self is BigQueryDialect) - && idents.iter().any(|part| { + && parts.iter().any(|part| { part.as_ident() .is_some_and(|ident| ident.value.contains('.')) }) { - idents = idents + parts = parts .into_iter() .flat_map(|part| match part.as_ident() { Some(ident) => ident @@ -10435,7 +10449,23 @@ impl<'a> Parser<'a> { .collect() } - Ok(ObjectName(idents)) + Ok(ObjectName(parts)) + } + + fn parse_object_name_function_part(&mut self) -> Result { + let name = self.parse_identifier()?; + if self.dialect.is_identifier_generating_function_name(&name) { + self.expect_token(&Token::LParen)?; + let args: Vec = + self.parse_comma_separated0(Self::parse_function_args, Token::RParen)?; + self.expect_token(&Token::RParen)?; + Ok(ObjectNamePartFunction { name, args }) + } else { + self.expected( + "dialect specific identifier-generating function", + self.peek_token(), + ) + } } /// Parse identifiers @@ -14006,25 +14036,25 @@ impl<'a> Parser<'a> { schemas: self.parse_comma_separated(|p| p.parse_object_name(false))?, }) } else if self.parse_keywords(&[Keyword::RESOURCE, Keyword::MONITOR]) { - Some(GrantObjects::ResourceMonitors(self.parse_comma_separated( - |p| p.parse_object_name_with_wildcards(false, true), - )?)) + Some(GrantObjects::ResourceMonitors( + self.parse_comma_separated(|p| p.parse_object_name(false))?, + )) } else if self.parse_keywords(&[Keyword::COMPUTE, Keyword::POOL]) { - Some(GrantObjects::ComputePools(self.parse_comma_separated( - |p| p.parse_object_name_with_wildcards(false, true), - )?)) + Some(GrantObjects::ComputePools( + self.parse_comma_separated(|p| p.parse_object_name(false))?, + )) } else if self.parse_keywords(&[Keyword::FAILOVER, Keyword::GROUP]) { - Some(GrantObjects::FailoverGroup(self.parse_comma_separated( - |p| p.parse_object_name_with_wildcards(false, true), - )?)) + Some(GrantObjects::FailoverGroup( + self.parse_comma_separated(|p| p.parse_object_name(false))?, + )) } else if self.parse_keywords(&[Keyword::REPLICATION, Keyword::GROUP]) { - Some(GrantObjects::ReplicationGroup(self.parse_comma_separated( - |p| p.parse_object_name_with_wildcards(false, true), - )?)) + Some(GrantObjects::ReplicationGroup( + self.parse_comma_separated(|p| p.parse_object_name(false))?, + )) } else if self.parse_keywords(&[Keyword::EXTERNAL, Keyword::VOLUME]) { - Some(GrantObjects::ExternalVolumes(self.parse_comma_separated( - |p| p.parse_object_name_with_wildcards(false, true), - )?)) + Some(GrantObjects::ExternalVolumes( + self.parse_comma_separated(|p| p.parse_object_name(false))?, + )) } else { let object_type = self.parse_one_of_keywords(&[ Keyword::SEQUENCE, @@ -14041,7 +14071,7 @@ impl<'a> Parser<'a> { Keyword::CONNECTION, ]); let objects = - self.parse_comma_separated(|p| p.parse_object_name_with_wildcards(false, true)); + self.parse_comma_separated(|p| p.parse_object_name_inner(false, true)); match object_type { Some(Keyword::DATABASE) => Some(GrantObjects::Databases(objects?)), Some(Keyword::SCHEMA) => Some(GrantObjects::Schemas(objects?)), diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 27fe09c7..1de1d93f 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -1232,7 +1232,6 @@ fn parse_select_expr_star() { "SELECT 2. * 3 FROM T", ); dialects.verified_only_select("SELECT myfunc().* FROM T"); - dialects.verified_only_select("SELECT myfunc().* EXCEPT (foo) FROM T"); // Invalid let res = dialects.parse_sql_statements("SELECT foo.*.* FROM T"); @@ -1240,6 +1239,11 @@ fn parse_select_expr_star() { ParserError::ParserError("Expected: end of statement, found: .".to_string()), res.unwrap_err() ); + + let dialects = all_dialects_where(|d| { + d.supports_select_expr_star() && d.supports_select_wildcard_except() + }); + dialects.verified_only_select("SELECT myfunc().* EXCEPT (foo) FROM T"); } #[test] diff --git a/tests/sqlparser_snowflake.rs b/tests/sqlparser_snowflake.rs index e7393d3f..a164db5b 100644 --- a/tests/sqlparser_snowflake.rs +++ b/tests/sqlparser_snowflake.rs @@ -4232,3 +4232,122 @@ fn test_snowflake_create_view_with_composite_policy_name() { r#"CREATE VIEW X (COL WITH MASKING POLICY foo.bar.baz) AS SELECT * FROM Y"#; snowflake().verified_stmt(create_view_with_tag); } + +#[test] +fn test_snowflake_identifier_function() { + // Using IDENTIFIER to reference a column + match &snowflake() + .verified_only_select("SELECT identifier('email') FROM customers") + .projection[0] + { + SelectItem::UnnamedExpr(Expr::Function(Function { name, args, .. })) => { + assert_eq!(*name, ObjectName::from(vec![Ident::new("identifier")])); + assert_eq!( + *args, + FunctionArguments::List(FunctionArgumentList { + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("email".to_string()).into() + )))], + clauses: vec![], + duplicate_treatment: None + }) + ); + } + _ => unreachable!(), + } + + // Using IDENTIFIER to reference a case-sensitive column + match &snowflake() + .verified_only_select(r#"SELECT identifier('"Email"') FROM customers"#) + .projection[0] + { + SelectItem::UnnamedExpr(Expr::Function(Function { name, args, .. })) => { + assert_eq!(*name, ObjectName::from(vec![Ident::new("identifier")])); + assert_eq!( + *args, + FunctionArguments::List(FunctionArgumentList { + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("\"Email\"".to_string()).into() + )))], + clauses: vec![], + duplicate_treatment: None + }) + ); + } + _ => unreachable!(), + } + + // Using IDENTIFIER to reference an alias of a table + match &snowflake() + .verified_only_select("SELECT identifier('alias1').* FROM tbl AS alias1") + .projection[0] + { + SelectItem::QualifiedWildcard( + SelectItemQualifiedWildcardKind::Expr(Expr::Function(Function { name, args, .. })), + _, + ) => { + assert_eq!(*name, ObjectName::from(vec![Ident::new("identifier")])); + assert_eq!( + *args, + FunctionArguments::List(FunctionArgumentList { + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("alias1".to_string()).into() + )))], + clauses: vec![], + duplicate_treatment: None + }) + ); + } + _ => unreachable!(), + } + + // Using IDENTIFIER to reference a database + match snowflake().verified_stmt("CREATE DATABASE IDENTIFIER('tbl')") { + Statement::CreateDatabase { db_name, .. } => { + assert_eq!( + db_name, + ObjectName(vec![ObjectNamePart::Function(ObjectNamePartFunction { + name: Ident::new("IDENTIFIER"), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("tbl".to_string()).into() + )))] + })]) + ); + } + _ => unreachable!(), + } + + // Using IDENTIFIER to reference a schema + match snowflake().verified_stmt("CREATE SCHEMA IDENTIFIER('db1.sc1')") { + Statement::CreateSchema { schema_name, .. } => { + assert_eq!( + schema_name, + SchemaName::Simple(ObjectName(vec![ObjectNamePart::Function( + ObjectNamePartFunction { + name: Ident::new("IDENTIFIER"), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("db1.sc1".to_string()).into() + )))] + } + )])) + ); + } + _ => unreachable!(), + } + + // Using IDENTIFIER to reference a table + match snowflake().verified_stmt("CREATE TABLE IDENTIFIER('tbl') (id INT)") { + Statement::CreateTable(CreateTable { name, .. }) => { + assert_eq!( + name, + ObjectName(vec![ObjectNamePart::Function(ObjectNamePartFunction { + name: Ident::new("IDENTIFIER"), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value( + Value::SingleQuotedString("tbl".to_string()).into() + )))] + })]) + ); + } + _ => unreachable!(), + } +}