Add support for Hive's LOAD DATA expr (#1520)

Co-authored-by: Ifeanyi Ubah <ify1992@yahoo.com>
This commit is contained in:
wugeer 2024-11-15 22:53:31 +08:00 committed by GitHub
parent 62eaee62dc
commit 724a1d1aba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 323 additions and 11 deletions

View file

@ -3347,6 +3347,22 @@ pub enum Statement {
channel: Ident,
payload: Option<String>,
},
/// ```sql
/// LOAD DATA [LOCAL] INPATH 'filepath' [OVERWRITE] INTO TABLE tablename
/// [PARTITION (partcol1=val1, partcol2=val2 ...)]
/// [INPUTFORMAT 'inputformat' SERDE 'serde']
/// ```
/// Loading files into tables
///
/// See Hive <https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=27362036#LanguageManualDML-Loadingfilesintotables>
LoadData {
local: bool,
inpath: String,
overwrite: bool,
table_name: ObjectName,
partitioned: Option<Vec<Expr>>,
table_format: Option<HiveLoadDataFormat>,
},
}
impl fmt::Display for Statement {
@ -3949,6 +3965,36 @@ impl fmt::Display for Statement {
Ok(())
}
Statement::CreateTable(create_table) => create_table.fmt(f),
Statement::LoadData {
local,
inpath,
overwrite,
table_name,
partitioned,
table_format,
} => {
write!(
f,
"LOAD DATA {local}INPATH '{inpath}' {overwrite}INTO TABLE {table_name}",
local = if *local { "LOCAL " } else { "" },
inpath = inpath,
overwrite = if *overwrite { "OVERWRITE " } else { "" },
table_name = table_name,
)?;
if let Some(ref parts) = &partitioned {
if !parts.is_empty() {
write!(f, " PARTITION ({})", display_comma_separated(parts))?;
}
}
if let Some(HiveLoadDataFormat {
serde,
input_format,
}) = &table_format
{
write!(f, " INPUTFORMAT {input_format} SERDE {serde}")?;
}
Ok(())
}
Statement::CreateVirtualTable {
name,
if_not_exists,
@ -5855,6 +5901,14 @@ pub enum HiveRowFormat {
DELIMITED { delimiters: Vec<HiveRowDelimiter> },
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct HiveLoadDataFormat {
pub serde: Expr,
pub input_format: Expr,
}
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]

View file

@ -66,4 +66,9 @@ impl Dialect for DuckDbDialect {
fn supports_explain_with_utility_options(&self) -> bool {
true
}
/// See DuckDB <https://duckdb.org/docs/sql/statements/load_and_install.html#load>
fn supports_load_extension(&self) -> bool {
true
}
}

View file

@ -115,4 +115,8 @@ impl Dialect for GenericDialect {
fn supports_comment_on(&self) -> bool {
true
}
fn supports_load_extension(&self) -> bool {
true
}
}

View file

@ -56,4 +56,9 @@ impl Dialect for HiveDialect {
fn supports_bang_not_operator(&self) -> bool {
true
}
/// See Hive <https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=27362036#LanguageManualDML-Loadingfilesintotables>
fn supports_load_data(&self) -> bool {
true
}
}

View file

@ -620,6 +620,16 @@ pub trait Dialect: Debug + Any {
false
}
/// Returns true if the dialect supports the `LOAD DATA` statement
fn supports_load_data(&self) -> bool {
false
}
/// Returns true if the dialect supports the `LOAD extension` statement
fn supports_load_extension(&self) -> bool {
false
}
/// Returns true if this dialect expects the `TOP` option
/// before the `ALL`/`DISTINCT` options in a `SELECT` statement.
fn supports_top_before_distinct(&self) -> bool {

View file

@ -389,6 +389,7 @@ define_keywords!(
INITIALLY,
INNER,
INOUT,
INPATH,
INPUT,
INPUTFORMAT,
INSENSITIVE,

View file

@ -543,10 +543,7 @@ impl<'a> Parser<'a> {
Keyword::INSTALL if dialect_of!(self is DuckDbDialect | GenericDialect) => {
self.parse_install()
}
// `LOAD` is duckdb specific https://duckdb.org/docs/extensions/overview
Keyword::LOAD if dialect_of!(self is DuckDbDialect | GenericDialect) => {
self.parse_load()
}
Keyword::LOAD => self.parse_load(),
// `OPTIMIZE` is clickhouse specific https://clickhouse.tech/docs/en/sql-reference/statements/optimize/
Keyword::OPTIMIZE if dialect_of!(self is ClickHouseDialect | GenericDialect) => {
self.parse_optimize_table()
@ -11222,6 +11219,22 @@ impl<'a> Parser<'a> {
}
}
pub fn parse_load_data_table_format(
&mut self,
) -> Result<Option<HiveLoadDataFormat>, ParserError> {
if self.parse_keyword(Keyword::INPUTFORMAT) {
let input_format = self.parse_expr()?;
self.expect_keyword(Keyword::SERDE)?;
let serde = self.parse_expr()?;
Ok(Some(HiveLoadDataFormat {
input_format,
serde,
}))
} else {
Ok(None)
}
}
/// Parse an UPDATE statement, returning a `Box`ed SetExpr
///
/// This is used to reduce the size of the stack frames in debug builds
@ -12224,10 +12237,35 @@ impl<'a> Parser<'a> {
Ok(Statement::Install { extension_name })
}
/// `LOAD [extension_name]`
/// Parse a SQL LOAD statement
pub fn parse_load(&mut self) -> Result<Statement, ParserError> {
let extension_name = self.parse_identifier(false)?;
Ok(Statement::Load { extension_name })
if self.dialect.supports_load_extension() {
let extension_name = self.parse_identifier(false)?;
Ok(Statement::Load { extension_name })
} else if self.parse_keyword(Keyword::DATA) && self.dialect.supports_load_data() {
let local = self.parse_one_of_keywords(&[Keyword::LOCAL]).is_some();
self.expect_keyword(Keyword::INPATH)?;
let inpath = self.parse_literal_string()?;
let overwrite = self.parse_one_of_keywords(&[Keyword::OVERWRITE]).is_some();
self.expect_keyword(Keyword::INTO)?;
self.expect_keyword(Keyword::TABLE)?;
let table_name = self.parse_object_name(false)?;
let partitioned = self.parse_insert_partition()?;
let table_format = self.parse_load_data_table_format()?;
Ok(Statement::LoadData {
local,
inpath,
overwrite,
table_name,
partitioned,
table_format,
})
} else {
self.expected(
"`DATA` or an extension name after `LOAD`",
self.peek_token(),
)
}
}
/// ```sql

View file

@ -11583,13 +11583,208 @@ fn parse_notify_channel() {
dialects.parse_sql_statements(sql).unwrap_err(),
ParserError::ParserError("Expected: an SQL statement, found: NOTIFY".to_string())
);
assert_eq!(
dialects.parse_sql_statements(sql).unwrap_err(),
ParserError::ParserError("Expected: an SQL statement, found: NOTIFY".to_string())
);
}
}
#[test]
fn parse_load_data() {
let dialects = all_dialects_where(|d| d.supports_load_data());
let only_supports_load_extension_dialects =
all_dialects_where(|d| !d.supports_load_data() && d.supports_load_extension());
let not_supports_load_dialects =
all_dialects_where(|d| !d.supports_load_data() && !d.supports_load_extension());
let sql = "LOAD DATA INPATH '/local/path/to/data.txt' INTO TABLE test.my_table";
match dialects.verified_stmt(sql) {
Statement::LoadData {
local,
inpath,
overwrite,
table_name,
partitioned,
table_format,
} => {
assert_eq!(false, local);
assert_eq!("/local/path/to/data.txt", inpath);
assert_eq!(false, overwrite);
assert_eq!(
ObjectName(vec![Ident::new("test"), Ident::new("my_table")]),
table_name
);
assert_eq!(None, partitioned);
assert_eq!(None, table_format);
}
_ => unreachable!(),
};
// with OVERWRITE keyword
let sql = "LOAD DATA INPATH '/local/path/to/data.txt' OVERWRITE INTO TABLE my_table";
match dialects.verified_stmt(sql) {
Statement::LoadData {
local,
inpath,
overwrite,
table_name,
partitioned,
table_format,
} => {
assert_eq!(false, local);
assert_eq!("/local/path/to/data.txt", inpath);
assert_eq!(true, overwrite);
assert_eq!(ObjectName(vec![Ident::new("my_table")]), table_name);
assert_eq!(None, partitioned);
assert_eq!(None, table_format);
}
_ => unreachable!(),
};
assert_eq!(
only_supports_load_extension_dialects
.parse_sql_statements(sql)
.unwrap_err(),
ParserError::ParserError("Expected: end of statement, found: INPATH".to_string())
);
assert_eq!(
not_supports_load_dialects
.parse_sql_statements(sql)
.unwrap_err(),
ParserError::ParserError(
"Expected: `DATA` or an extension name after `LOAD`, found: INPATH".to_string()
)
);
// with LOCAL keyword
let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table";
match dialects.verified_stmt(sql) {
Statement::LoadData {
local,
inpath,
overwrite,
table_name,
partitioned,
table_format,
} => {
assert_eq!(true, local);
assert_eq!("/local/path/to/data.txt", inpath);
assert_eq!(false, overwrite);
assert_eq!(
ObjectName(vec![Ident::new("test"), Ident::new("my_table")]),
table_name
);
assert_eq!(None, partitioned);
assert_eq!(None, table_format);
}
_ => unreachable!(),
};
assert_eq!(
only_supports_load_extension_dialects
.parse_sql_statements(sql)
.unwrap_err(),
ParserError::ParserError("Expected: end of statement, found: LOCAL".to_string())
);
assert_eq!(
not_supports_load_dialects
.parse_sql_statements(sql)
.unwrap_err(),
ParserError::ParserError(
"Expected: `DATA` or an extension name after `LOAD`, found: LOCAL".to_string()
)
);
// with PARTITION clause
let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' INTO TABLE my_table PARTITION (year = 2024, month = 11)";
match dialects.verified_stmt(sql) {
Statement::LoadData {
local,
inpath,
overwrite,
table_name,
partitioned,
table_format,
} => {
assert_eq!(true, local);
assert_eq!("/local/path/to/data.txt", inpath);
assert_eq!(false, overwrite);
assert_eq!(ObjectName(vec![Ident::new("my_table")]), table_name);
assert_eq!(
Some(vec![
Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("year"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))),
},
Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("month"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))),
}
]),
partitioned
);
assert_eq!(None, table_format);
}
_ => unreachable!(),
};
// with PARTITION clause
let sql = "LOAD DATA LOCAL INPATH '/local/path/to/data.txt' OVERWRITE INTO TABLE good.my_table PARTITION (year = 2024, month = 11) INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat' SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde'";
match dialects.verified_stmt(sql) {
Statement::LoadData {
local,
inpath,
overwrite,
table_name,
partitioned,
table_format,
} => {
assert_eq!(true, local);
assert_eq!("/local/path/to/data.txt", inpath);
assert_eq!(true, overwrite);
assert_eq!(
ObjectName(vec![Ident::new("good"), Ident::new("my_table")]),
table_name
);
assert_eq!(
Some(vec![
Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("year"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Value(Value::Number("2024".parse().unwrap(), false))),
},
Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("month"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Value(Value::Number("11".parse().unwrap(), false))),
}
]),
partitioned
);
assert_eq!(
Some(HiveLoadDataFormat {
serde: Expr::Value(Value::SingleQuotedString(
"org.apache.hadoop.hive.serde2.OpenCSVSerde".to_string()
)),
input_format: Expr::Value(Value::SingleQuotedString(
"org.apache.hadoop.mapred.TextInputFormat".to_string()
))
}),
table_format
);
}
_ => unreachable!(),
};
// negative test case
let sql = "LOAD DATA2 LOCAL INPATH '/local/path/to/data.txt' INTO TABLE test.my_table";
assert_eq!(
dialects.parse_sql_statements(sql).unwrap_err(),
ParserError::ParserError(
"Expected: `DATA` or an extension name after `LOAD`, found: DATA2".to_string()
)
);
}
#[test]
fn test_select_top() {
let dialects = all_dialects_where(|d| d.supports_top_before_distinct());