feat: add fixed size list support (#1231)

This commit is contained in:
universalmind303 2024-04-23 16:53:03 -05:00 committed by GitHub
parent 39980e8976
commit ce85084deb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 36 additions and 10 deletions

View file

@ -349,7 +349,8 @@ impl fmt::Display for DataType {
DataType::Bytea => write!(f, "BYTEA"),
DataType::Array(ty) => match ty {
ArrayElemTypeDef::None => write!(f, "ARRAY"),
ArrayElemTypeDef::SquareBracket(t) => write!(f, "{t}[]"),
ArrayElemTypeDef::SquareBracket(t, None) => write!(f, "{t}[]"),
ArrayElemTypeDef::SquareBracket(t, Some(size)) => write!(f, "{t}[{size}]"),
ArrayElemTypeDef::AngleBracket(t) => write!(f, "ARRAY<{t}>"),
},
DataType::Custom(ty, modifiers) => {
@ -592,6 +593,6 @@ pub enum ArrayElemTypeDef {
None,
/// `ARRAY<INT>`
AngleBracket(Box<DataType>),
/// `[]INT`
SquareBracket(Box<DataType>),
/// `INT[]` or `INT[2]`
SquareBracket(Box<DataType>, Option<u64>),
}

View file

@ -6360,7 +6360,7 @@ impl<'a> Parser<'a> {
&mut self,
) -> Result<(DataType, MatchedTrailingBracket), ParserError> {
let next_token = self.next_token();
let mut trailing_bracket = false.into();
let mut trailing_bracket: MatchedTrailingBracket = false.into();
let mut data = match next_token.token {
Token::Word(w) => match w.keyword {
Keyword::BOOLEAN => Ok(DataType::Boolean),
@ -6580,8 +6580,13 @@ impl<'a> Parser<'a> {
// Parse array data types. Note: this is postgresql-specific and different from
// Keyword::ARRAY syntax from above
while self.consume_token(&Token::LBracket) {
let size = if dialect_of!(self is GenericDialect | DuckDbDialect | PostgreSqlDialect) {
self.maybe_parse(|p| p.parse_literal_uint())
} else {
None
};
self.expect_token(&Token::RBracket)?;
data = DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(data)))
data = DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(data), size))
}
Ok((data, trailing_bracket))
}

View file

@ -3132,7 +3132,7 @@ fn parse_create_table_hive_array() {
let expected = if angle_bracket_syntax {
ArrayElemTypeDef::AngleBracket(expected)
} else {
ArrayElemTypeDef::SquareBracket(expected)
ArrayElemTypeDef::SquareBracket(expected, None)
};
match dialects.one_statement_parses_to(sql.as_str(), sql.as_str()) {
@ -9257,3 +9257,21 @@ fn test_select_wildcard_with_replace() {
});
assert_eq!(expected, select.projection[0]);
}
#[test]
fn parse_sized_list() {
let dialects = TestedDialects {
dialects: vec![
Box::new(GenericDialect {}),
Box::new(PostgreSqlDialect {}),
Box::new(DuckDbDialect {}),
],
options: None,
};
let sql = r#"CREATE TABLE embeddings (data FLOAT[1536])"#;
dialects.verified_stmt(sql);
let sql = r#"CREATE TABLE embeddings (data FLOAT[1536][3])"#;
dialects.verified_stmt(sql);
let sql = r#"SELECT data::FLOAT[1536] FROM embeddings"#;
dialects.verified_stmt(sql);
}

View file

@ -1917,11 +1917,13 @@ fn parse_array_index_expr() {
})],
named: true,
})),
data_type: DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(
DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(DataType::Int(
data_type: DataType::Array(ArrayElemTypeDef::SquareBracket(
Box::new(DataType::Array(ArrayElemTypeDef::SquareBracket(
Box::new(DataType::Int(None)),
None
))))
))),
))),
None
)),
format: None,
}))),
indexes: vec![num[1].clone(), num[2].clone()],