Merge pull request #89 from benesch/sqlfunction-struct

Extract a SQLFunction struct
This commit is contained in:
Nikhil Benesch 2019-06-03 11:09:27 -04:00 committed by GitHub
commit a3aaa49a7e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 36 deletions

View file

@ -108,13 +108,7 @@ pub enum ASTNode {
/// SQLValue
SQLValue(Value),
/// Scalar function call e.g. `LEFT(foo, 5)`
SQLFunction {
name: SQLObjectName,
args: Vec<ASTNode>,
over: Option<SQLWindowSpec>,
// aggregate functions may specify eg `COUNT(DISTINCT x)`
distinct: bool,
},
SQLFunction(SQLFunction),
/// CASE [<operand>] WHEN <condition> THEN <result> ... [ELSE <result>] END
/// Note we only recognize a complete single expression as <condition>, not
/// `< 0` nor `1, 2, 3` as allowed in a <simple when clause> per
@ -192,23 +186,7 @@ impl ToString for ASTNode {
format!("{} {}", operator.to_string(), expr.as_ref().to_string())
}
ASTNode::SQLValue(v) => v.to_string(),
ASTNode::SQLFunction {
name,
args,
over,
distinct,
} => {
let mut s = format!(
"{}({}{})",
name.to_string(),
if *distinct { "DISTINCT " } else { "" },
comma_separated_string(args)
);
if let Some(o) = over {
s += &format!(" OVER ({})", o.to_string())
}
s
}
ASTNode::SQLFunction(f) => f.to_string(),
ASTNode::SQLCase {
operand,
conditions,
@ -346,6 +324,7 @@ impl ToString for SQLWindowFrameBound {
}
/// A top-level statement (SELECT, INSERT, CREATE, etc.)
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone, PartialEq)]
pub enum SQLStatement {
/// SELECT
@ -600,6 +579,31 @@ impl ToString for SQLColumnDef {
}
}
/// SQL function
#[derive(Debug, Clone, PartialEq)]
pub struct SQLFunction {
pub name: SQLObjectName,
pub args: Vec<ASTNode>,
pub over: Option<SQLWindowSpec>,
// aggregate functions may specify eg `COUNT(DISTINCT x)`
pub distinct: bool,
}
impl ToString for SQLFunction {
fn to_string(&self) -> String {
let mut s = format!(
"{}({}{})",
self.name.to_string(),
if self.distinct { "DISTINCT " } else { "" },
comma_separated_string(&self.args),
);
if let Some(o) = &self.over {
s += &format!(" OVER ({})", o.to_string())
}
s
}
}
/// External table's available file format
#[derive(Debug, Clone, PartialEq)]
pub enum FileFormat {

View file

@ -301,12 +301,12 @@ impl Parser {
None
};
Ok(ASTNode::SQLFunction {
Ok(ASTNode::SQLFunction(SQLFunction {
name,
args,
over,
distinct,
})
}))
}
pub fn parse_window_frame(&mut self) -> Result<Option<SQLWindowFrame>, ParserError> {

View file

@ -232,12 +232,12 @@ fn parse_select_count_wildcard() {
let sql = "SELECT COUNT(*) FROM customer";
let select = verified_only_select(sql);
assert_eq!(
&ASTNode::SQLFunction {
&ASTNode::SQLFunction(SQLFunction {
name: SQLObjectName(vec!["COUNT".to_string()]),
args: vec![ASTNode::SQLWildcard],
over: None,
distinct: false,
},
}),
expr_from_projection(only(&select.projection))
);
}
@ -247,7 +247,7 @@ fn parse_select_count_distinct() {
let sql = "SELECT COUNT(DISTINCT + x) FROM customer";
let select = verified_only_select(sql);
assert_eq!(
&ASTNode::SQLFunction {
&ASTNode::SQLFunction(SQLFunction {
name: SQLObjectName(vec!["COUNT".to_string()]),
args: vec![ASTNode::SQLUnary {
operator: SQLOperator::Plus,
@ -255,7 +255,7 @@ fn parse_select_count_distinct() {
}],
over: None,
distinct: true,
},
}),
expr_from_projection(only(&select.projection))
);
@ -886,12 +886,12 @@ fn parse_scalar_function_in_projection() {
let sql = "SELECT sqrt(id) FROM foo";
let select = verified_only_select(sql);
assert_eq!(
&ASTNode::SQLFunction {
&ASTNode::SQLFunction(SQLFunction {
name: SQLObjectName(vec!["sqrt".to_string()]),
args: vec![ASTNode::SQLIdentifier("id".to_string())],
over: None,
distinct: false,
},
}),
expr_from_projection(only(&select.projection))
);
}
@ -909,7 +909,7 @@ fn parse_window_functions() {
let select = verified_only_select(sql);
assert_eq!(4, select.projection.len());
assert_eq!(
&ASTNode::SQLFunction {
&ASTNode::SQLFunction(SQLFunction {
name: SQLObjectName(vec!["row_number".to_string()]),
args: vec![],
over: Some(SQLWindowSpec {
@ -921,7 +921,7 @@ fn parse_window_functions() {
window_frame: None,
}),
distinct: false,
},
}),
expr_from_projection(&select.projection[0])
);
}
@ -988,12 +988,12 @@ fn parse_delimited_identifiers() {
expr_from_projection(&select.projection[0]),
);
assert_eq!(
&ASTNode::SQLFunction {
&ASTNode::SQLFunction(SQLFunction {
name: SQLObjectName(vec![r#""myfun""#.to_string()]),
args: vec![],
over: None,
distinct: false,
},
}),
expr_from_projection(&select.projection[1]),
);
match &select.projection[2] {