From c49352f39492e5d0dc1a4435e35638665fa6ae9b Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Mon, 27 May 2019 01:55:45 -0400 Subject: [PATCH] Implement Hash on all AST nodes It is convenient for downstream libraries to be able to stash bits of ASTs into hash maps, e.g., for performing simple common subexpression elimination. The only downside to this change is that it requires that the f64 in the Value enum be wrapped in an OrderedFloat, which provides the necessary equality semantics to allow Hash to be drived. The reason f64 doesn't implement Hash by default is because NaN is typically not equal to itself, so it's not clear what it should hash to. That's less of a concern in a SQL context, because every SQL database I've looked at treats NaN as equal to itself, in violation of the IEEE standard, in order to permit indexing and sorting of float columns. --- Cargo.toml | 1 + src/sqlast/ddl.rs | 4 ++-- src/sqlast/mod.rs | 24 ++++++++++++------------ src/sqlast/query.rs | 24 ++++++++++++------------ src/sqlast/sql_operator.rs | 2 +- src/sqlast/sqltype.rs | 2 +- src/sqlast/value.rs | 6 ++++-- src/sqlparser.rs | 2 +- 8 files changed, 34 insertions(+), 31 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 15376f2c..0cc52a2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ path = "src/lib.rs" [dependencies] log = "0.4.5" +ordered-float = "1.0.2" [dev-dependencies] simple_logger = "1.0.1" diff --git a/src/sqlast/ddl.rs b/src/sqlast/ddl.rs index 86bb1567..38cd5adf 100644 --- a/src/sqlast/ddl.rs +++ b/src/sqlast/ddl.rs @@ -3,7 +3,7 @@ use super::{ASTNode, SQLIdent, SQLObjectName}; /// An `ALTER TABLE` (`SQLStatement::SQLAlterTable`) operation -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum AlterTableOperation { /// `ADD ` AddConstraint(TableConstraint), @@ -22,7 +22,7 @@ impl ToString for AlterTableOperation { /// A table-level constraint, specified in a `CREATE TABLE` or an /// `ALTER TABLE ADD ` statement. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum TableConstraint { /// `[ CONSTRAINT ] { PRIMARY KEY | UNIQUE } ()` Unique { diff --git a/src/sqlast/mod.rs b/src/sqlast/mod.rs index 93e11a51..dc0a3235 100644 --- a/src/sqlast/mod.rs +++ b/src/sqlast/mod.rs @@ -46,7 +46,7 @@ pub type SQLIdent = String; /// The parser does not distinguish between expressions of different types /// (e.g. boolean vs string), so the caller must handle expressions of /// inappropriate type, like `WHERE 1` or `SELECT 1=1`, as necessary. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum ASTNode { /// Identifier e.g. table name or column name SQLIdentifier(SQLIdent), @@ -214,7 +214,7 @@ impl ToString for ASTNode { } /// A window specification (i.e. `OVER (PARTITION BY .. ORDER BY .. etc.)`) -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct SQLWindowSpec { pub partition_by: Vec, pub order_by: Vec, @@ -258,7 +258,7 @@ impl ToString for SQLWindowSpec { /// Specifies the data processed by a window function, e.g. /// `RANGE UNBOUNDED PRECEDING` or `ROWS BETWEEN 5 PRECEDING AND CURRENT ROW`. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct SQLWindowFrame { pub units: SQLWindowFrameUnits, pub start_bound: SQLWindowFrameBound, @@ -267,7 +267,7 @@ pub struct SQLWindowFrame { // TBD: EXCLUDE } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum SQLWindowFrameUnits { Rows, Range, @@ -300,7 +300,7 @@ impl FromStr for SQLWindowFrameUnits { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum SQLWindowFrameBound { /// "CURRENT ROW" CurrentRow, @@ -325,7 +325,7 @@ impl ToString for SQLWindowFrameBound { /// A top-level statement (SELECT, INSERT, CREATE, etc.) #[allow(clippy::large_enum_variant)] -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum SQLStatement { /// SELECT SQLQuery(Box), @@ -527,7 +527,7 @@ impl ToString for SQLStatement { } /// A name of a table, view, custom type, etc., possibly multi-part, i.e. db.schema.obj -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct SQLObjectName(pub Vec); impl ToString for SQLObjectName { @@ -537,7 +537,7 @@ impl ToString for SQLObjectName { } /// SQL assignment `foo = expr` as used in SQLUpdate -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct SQLAssignment { id: SQLIdent, value: ASTNode, @@ -550,7 +550,7 @@ impl ToString for SQLAssignment { } /// SQL column definition -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct SQLColumnDef { pub name: SQLIdent, pub data_type: SQLType, @@ -580,7 +580,7 @@ impl ToString for SQLColumnDef { } /// SQL function -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct SQLFunction { pub name: SQLObjectName, pub args: Vec, @@ -605,7 +605,7 @@ impl ToString for SQLFunction { } /// External table's available file format -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum FileFormat { TEXTFILE, SEQUENCEFILE, @@ -654,7 +654,7 @@ impl FromStr for FileFormat { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum SQLObjectType { Table, View, diff --git a/src/sqlast/query.rs b/src/sqlast/query.rs index 297d6e51..3b16cd4c 100644 --- a/src/sqlast/query.rs +++ b/src/sqlast/query.rs @@ -2,7 +2,7 @@ use super::*; /// The most complete variant of a `SELECT` query expression, optionally /// including `WITH`, `UNION` / other set operations, and `ORDER BY`. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct SQLQuery { /// WITH (common table expressions, or CTEs) pub ctes: Vec, @@ -44,7 +44,7 @@ impl ToString for SQLQuery { /// A node in a tree, representing a "query body" expression, roughly: /// `SELECT ... [ {UNION|EXCEPT|INTERSECT} SELECT ...]` -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum SQLSetExpr { /// Restricted SELECT .. FROM .. HAVING (no ORDER BY or set operations) Select(Box), @@ -85,7 +85,7 @@ impl ToString for SQLSetExpr { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum SQLSetOperator { Union, Except, @@ -105,7 +105,7 @@ impl ToString for SQLSetOperator { /// A restricted variant of `SELECT` (without CTEs/`ORDER BY`), which may /// appear either as the only body item of an `SQLQuery`, or as an operand /// to a set operation like `UNION`. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct SQLSelect { pub distinct: bool, /// projection expressions @@ -152,7 +152,7 @@ impl ToString for SQLSelect { /// The names in the column list before `AS`, when specified, replace the names /// of the columns returned by the query. The parser does not validate that the /// number of columns in the query matches the number of columns in the query. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct Cte { pub alias: SQLIdent, pub query: SQLQuery, @@ -170,7 +170,7 @@ impl ToString for Cte { } /// One item of the comma-separated list following `SELECT` -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum SQLSelectItem { /// Any expression, not followed by `[ AS ] alias` UnnamedExpression(ASTNode), @@ -196,7 +196,7 @@ impl ToString for SQLSelectItem { } /// A table name or a parenthesized subquery with an optional alias -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum TableFactor { Table { name: SQLObjectName, @@ -255,7 +255,7 @@ impl ToString for TableFactor { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct Join { pub relation: TableFactor, pub join_operator: JoinOperator, @@ -307,7 +307,7 @@ impl ToString for Join { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum JoinOperator { Inner(JoinConstraint), LeftOuter(JoinConstraint), @@ -317,7 +317,7 @@ pub enum JoinOperator { Cross, } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum JoinConstraint { On(ASTNode), Using(Vec), @@ -325,7 +325,7 @@ pub enum JoinConstraint { } /// SQL ORDER BY expression -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct SQLOrderByExpr { pub expr: ASTNode, pub asc: Option, @@ -341,7 +341,7 @@ impl ToString for SQLOrderByExpr { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub struct Fetch { pub with_ties: bool, pub percent: bool, diff --git a/src/sqlast/sql_operator.rs b/src/sqlast/sql_operator.rs index 9cd14fd0..173bc80d 100644 --- a/src/sqlast/sql_operator.rs +++ b/src/sqlast/sql_operator.rs @@ -1,5 +1,5 @@ /// SQL Operator -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum SQLOperator { Plus, Minus, diff --git a/src/sqlast/sqltype.rs b/src/sqlast/sqltype.rs index 880a6aa4..ba339d07 100644 --- a/src/sqlast/sqltype.rs +++ b/src/sqlast/sqltype.rs @@ -1,7 +1,7 @@ use super::SQLObjectName; /// SQL data types -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum SQLType { /// Fixed-length character type e.g. CHAR(10) Char(Option), diff --git a/src/sqlast/value.rs b/src/sqlast/value.rs index 0cb7fab4..3c52e434 100644 --- a/src/sqlast/value.rs +++ b/src/sqlast/value.rs @@ -1,10 +1,12 @@ +use ordered_float::OrderedFloat; + /// Primitive SQL values such as number and string -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Hash)] pub enum Value { /// Unsigned integer value Long(u64), /// Unsigned floating point value - Double(f64), + Double(OrderedFloat), /// 'string value' SingleQuotedString(String), /// N'string value' diff --git a/src/sqlparser.rs b/src/sqlparser.rs index 24b832d1..eafcd14f 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -1022,7 +1022,7 @@ impl Parser { } }, Token::Number(ref n) if n.contains('.') => match n.parse::() { - Ok(n) => Ok(Value::Double(n)), + Ok(n) => Ok(Value::Double(n.into())), Err(e) => parser_err!(format!("Could not parse '{}' as f64: {}", n, e)), }, Token::Number(ref n) => match n.parse::() {