Support Snowflake MATCH_RECOGNIZE syntax (#1222)

This commit is contained in:
Joey Hain 2024-04-22 13:17:50 -07:00 committed by GitHub
parent bf89b7d808
commit 39980e8976
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 847 additions and 9 deletions

View file

@ -33,8 +33,8 @@ use sqlparser::keywords::ALL_KEYWORDS;
use sqlparser::parser::{Parser, ParserError, ParserOptions};
use sqlparser::tokenizer::Tokenizer;
use test_utils::{
all_dialects, all_dialects_where, alter_table_op, assert_eq_vec, expr_from_projection, join,
number, only, table, table_alias, TestedDialects,
all_dialects, all_dialects_where, alter_table_op, assert_eq_vec, call, expr_from_projection,
join, number, only, table, table_alias, TestedDialects,
};
#[macro_use]
@ -8887,6 +8887,299 @@ fn parse_map_access_expr() {
}
}
#[test]
fn test_match_recognize() {
use MatchRecognizePattern::*;
use MatchRecognizeSymbol::*;
use RepetitionQuantifier::*;
let table = TableFactor::Table {
name: ObjectName(vec![Ident::new("my_table")]),
alias: None,
args: None,
with_hints: vec![],
version: None,
partitions: vec![],
};
fn check(options: &str, expect: TableFactor) {
let select = all_dialects_where(|d| d.supports_match_recognize()).verified_only_select(
&format!("SELECT * FROM my_table MATCH_RECOGNIZE({options})"),
);
assert_eq!(&select.from[0].relation, &expect);
}
check(
concat!(
"PARTITION BY company ",
"ORDER BY price_date ",
"MEASURES ",
"MATCH_NUMBER() AS match_number, ",
"FIRST(price_date) AS start_date, ",
"LAST(price_date) AS end_date ",
"ONE ROW PER MATCH ",
"AFTER MATCH SKIP TO LAST row_with_price_increase ",
"PATTERN (row_before_decrease row_with_price_decrease+ row_with_price_increase+) ",
"DEFINE ",
"row_with_price_decrease AS price < LAG(price), ",
"row_with_price_increase AS price > LAG(price)"
),
TableFactor::MatchRecognize {
table: Box::new(table),
partition_by: vec![Expr::Identifier(Ident::new("company"))],
order_by: vec![OrderByExpr {
expr: Expr::Identifier(Ident::new("price_date")),
asc: None,
nulls_first: None,
}],
measures: vec![
Measure {
expr: call("MATCH_NUMBER", []),
alias: Ident::new("match_number"),
},
Measure {
expr: call("FIRST", [Expr::Identifier(Ident::new("price_date"))]),
alias: Ident::new("start_date"),
},
Measure {
expr: call("LAST", [Expr::Identifier(Ident::new("price_date"))]),
alias: Ident::new("end_date"),
},
],
rows_per_match: Some(RowsPerMatch::OneRow),
after_match_skip: Some(AfterMatchSkip::ToLast(Ident::new(
"row_with_price_increase",
))),
pattern: Concat(vec![
Symbol(Named(Ident::new("row_before_decrease"))),
Repetition(
Box::new(Symbol(Named(Ident::new("row_with_price_decrease")))),
OneOrMore,
),
Repetition(
Box::new(Symbol(Named(Ident::new("row_with_price_increase")))),
OneOrMore,
),
]),
symbols: vec![
SymbolDefinition {
symbol: Ident::new("row_with_price_decrease"),
definition: Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("price"))),
op: BinaryOperator::Lt,
right: Box::new(call("LAG", [Expr::Identifier(Ident::new("price"))])),
},
},
SymbolDefinition {
symbol: Ident::new("row_with_price_increase"),
definition: Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("price"))),
op: BinaryOperator::Gt,
right: Box::new(call("LAG", [Expr::Identifier(Ident::new("price"))])),
},
},
],
alias: None,
},
);
#[rustfmt::skip]
let examples = [
concat!(
"SELECT * ",
"FROM login_attempts ",
"MATCH_RECOGNIZE(",
"PARTITION BY user_id ",
"ORDER BY timestamp ",
"PATTERN (failed_attempt{3,}) ",
"DEFINE ",
"failed_attempt AS status = 'failure'",
")",
),
concat!(
"SELECT * ",
"FROM stock_transactions ",
"MATCH_RECOGNIZE(",
"PARTITION BY symbol ",
"ORDER BY timestamp ",
"MEASURES ",
"FIRST(price) AS start_price, ",
"LAST(price) AS end_price, ",
"MATCH_NUMBER() AS match_num ",
"ALL ROWS PER MATCH ",
"PATTERN (STRT UP+) ",
"DEFINE ",
"UP AS price > PREV(price)",
")",
),
concat!(
"SELECT * ",
"FROM event_log ",
"MATCH_RECOGNIZE(",
"MEASURES ",
"FIRST(event_type) AS start_event, ",
"LAST(event_type) AS end_event, ",
"COUNT(*) AS error_count ",
"ALL ROWS PER MATCH ",
"PATTERN (STRT ERROR+ END) ",
"DEFINE ",
"STRT AS event_type = 'START', ",
"ERROR AS event_type = 'ERROR', ",
"END AS event_type = 'END'",
")",
)
];
for sql in examples {
all_dialects_where(|d| d.supports_match_recognize()).verified_query(sql);
}
}
#[test]
fn test_match_recognize_patterns() {
use MatchRecognizePattern::*;
use MatchRecognizeSymbol::*;
use RepetitionQuantifier::*;
fn check(pattern: &str, expect: MatchRecognizePattern) {
let select =
all_dialects_where(|d| d.supports_match_recognize()).verified_only_select(&format!(
"SELECT * FROM my_table MATCH_RECOGNIZE(PATTERN ({pattern}) DEFINE DUMMY AS true)" // "select * from my_table match_recognize ("
));
let TableFactor::MatchRecognize {
pattern: actual, ..
} = &select.from[0].relation
else {
panic!("expected match_recognize table factor");
};
assert_eq!(actual, &expect);
}
// just a symbol
check("FOO", Symbol(Named(Ident::new("FOO"))));
// just a symbol
check(
"^ FOO $",
Concat(vec![
Symbol(Start),
Symbol(Named(Ident::new("FOO"))),
Symbol(End),
]),
);
// exclusion
check("{- FOO -}", Exclude(Named(Ident::new("FOO"))));
check(
"PERMUTE(A, B, C)",
Permute(vec![
Named(Ident::new("A")),
Named(Ident::new("B")),
Named(Ident::new("C")),
]),
);
// various identifiers
check(
"FOO | \"BAR\" | baz42",
Alternation(vec![
Symbol(Named(Ident::new("FOO"))),
Symbol(Named(Ident::with_quote('"', "BAR"))),
Symbol(Named(Ident::new("baz42"))),
]),
);
// concatenated basic quantifiers
check(
"S1* S2+ S3?",
Concat(vec![
Repetition(Box::new(Symbol(Named(Ident::new("S1")))), ZeroOrMore),
Repetition(Box::new(Symbol(Named(Ident::new("S2")))), OneOrMore),
Repetition(Box::new(Symbol(Named(Ident::new("S3")))), AtMostOne),
]),
);
// double repetition
check(
"S2*?",
Repetition(
Box::new(Repetition(
Box::new(Symbol(Named(Ident::new("S2")))),
ZeroOrMore,
)),
AtMostOne,
),
);
// range quantifiers in an alternation
check(
"S1{1} | S2{2,3} | S3{4,} | S4{,5}",
Alternation(vec![
Repetition(Box::new(Symbol(Named(Ident::new("S1")))), Exactly(1)),
Repetition(Box::new(Symbol(Named(Ident::new("S2")))), Range(2, 3)),
Repetition(Box::new(Symbol(Named(Ident::new("S3")))), AtLeast(4)),
Repetition(Box::new(Symbol(Named(Ident::new("S4")))), AtMost(5)),
]),
);
// grouping case 1
check(
"S1 ( S2 )",
Concat(vec![
Symbol(Named(Ident::new("S1"))),
Group(Box::new(Symbol(Named(Ident::new("S2"))))),
]),
);
// grouping case 2
check(
"( {- S3 -} S4 )+",
Repetition(
Box::new(Group(Box::new(Concat(vec![
Exclude(Named(Ident::new("S3"))),
Symbol(Named(Ident::new("S4"))),
])))),
OneOrMore,
),
);
// the grand finale (example taken from snowflake docs)
check(
"^ S1 S2*? ( {- S3 -} S4 )+ | PERMUTE(S1, S2){1,2} $",
Alternation(vec![
Concat(vec![
Symbol(Start),
Symbol(Named(Ident::new("S1"))),
Repetition(
Box::new(Repetition(
Box::new(Symbol(Named(Ident::new("S2")))),
ZeroOrMore,
)),
AtMostOne,
),
Repetition(
Box::new(Group(Box::new(Concat(vec![
Exclude(Named(Ident::new("S3"))),
Symbol(Named(Ident::new("S4"))),
])))),
OneOrMore,
),
]),
Concat(vec![
Repetition(
Box::new(Permute(vec![
Named(Ident::new("S1")),
Named(Ident::new("S2")),
])),
Range(1, 2),
),
Symbol(End),
]),
]),
);
}
#[test]
fn test_select_wildcard_with_replace() {
let sql = r#"SELECT * REPLACE (lower(city) AS city) FROM addresses"#;