diff --git a/src/ast/data_type.rs b/src/ast/data_type.rs index 6b082a1f..2cc21870 100644 --- a/src/ast/data_type.rs +++ b/src/ast/data_type.rs @@ -349,7 +349,8 @@ impl fmt::Display for DataType { DataType::Bytea => write!(f, "BYTEA"), DataType::Array(ty) => match ty { ArrayElemTypeDef::None => write!(f, "ARRAY"), - ArrayElemTypeDef::SquareBracket(t) => write!(f, "{t}[]"), + ArrayElemTypeDef::SquareBracket(t, None) => write!(f, "{t}[]"), + ArrayElemTypeDef::SquareBracket(t, Some(size)) => write!(f, "{t}[{size}]"), ArrayElemTypeDef::AngleBracket(t) => write!(f, "ARRAY<{t}>"), }, DataType::Custom(ty, modifiers) => { @@ -592,6 +593,6 @@ pub enum ArrayElemTypeDef { None, /// `ARRAY` AngleBracket(Box), - /// `[]INT` - SquareBracket(Box), + /// `INT[]` or `INT[2]` + SquareBracket(Box, Option), } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 5ace8da7..1ec530b7 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -6360,7 +6360,7 @@ impl<'a> Parser<'a> { &mut self, ) -> Result<(DataType, MatchedTrailingBracket), ParserError> { let next_token = self.next_token(); - let mut trailing_bracket = false.into(); + let mut trailing_bracket: MatchedTrailingBracket = false.into(); let mut data = match next_token.token { Token::Word(w) => match w.keyword { Keyword::BOOLEAN => Ok(DataType::Boolean), @@ -6580,8 +6580,13 @@ impl<'a> Parser<'a> { // Parse array data types. Note: this is postgresql-specific and different from // Keyword::ARRAY syntax from above while self.consume_token(&Token::LBracket) { + let size = if dialect_of!(self is GenericDialect | DuckDbDialect | PostgreSqlDialect) { + self.maybe_parse(|p| p.parse_literal_uint()) + } else { + None + }; self.expect_token(&Token::RBracket)?; - data = DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(data))) + data = DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(data), size)) } Ok((data, trailing_bracket)) } diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 63c0517b..e6081464 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -3132,7 +3132,7 @@ fn parse_create_table_hive_array() { let expected = if angle_bracket_syntax { ArrayElemTypeDef::AngleBracket(expected) } else { - ArrayElemTypeDef::SquareBracket(expected) + ArrayElemTypeDef::SquareBracket(expected, None) }; match dialects.one_statement_parses_to(sql.as_str(), sql.as_str()) { @@ -9257,3 +9257,21 @@ fn test_select_wildcard_with_replace() { }); assert_eq!(expected, select.projection[0]); } + +#[test] +fn parse_sized_list() { + let dialects = TestedDialects { + dialects: vec![ + Box::new(GenericDialect {}), + Box::new(PostgreSqlDialect {}), + Box::new(DuckDbDialect {}), + ], + options: None, + }; + let sql = r#"CREATE TABLE embeddings (data FLOAT[1536])"#; + dialects.verified_stmt(sql); + let sql = r#"CREATE TABLE embeddings (data FLOAT[1536][3])"#; + dialects.verified_stmt(sql); + let sql = r#"SELECT data::FLOAT[1536] FROM embeddings"#; + dialects.verified_stmt(sql); +} diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 356651af..622f19d0 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -1917,11 +1917,13 @@ fn parse_array_index_expr() { })], named: true, })), - data_type: DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new( - DataType::Array(ArrayElemTypeDef::SquareBracket(Box::new(DataType::Int( + data_type: DataType::Array(ArrayElemTypeDef::SquareBracket( + Box::new(DataType::Array(ArrayElemTypeDef::SquareBracket( + Box::new(DataType::Int(None)), None - )))) - ))), + ))), + None + )), format: None, }))), indexes: vec![num[1].clone(), num[2].clone()],