Replace type_id() by trait method to allow wrapping dialects (#1065)

This commit is contained in:
Joris Bayer 2023-12-19 21:54:48 +01:00 committed by GitHub
parent e027b3cad2
commit 29b4ce81c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -89,6 +89,15 @@ macro_rules! dialect_of {
///
/// [module level documentation]: crate
pub trait Dialect: Debug + Any {
/// Determine the [`TypeId`] of this dialect.
///
/// By default, return the same [`TypeId`] as [`Any::type_id`]. Can be overriden
/// by dialects that behave like other dialects
/// (for example when wrapping a dialect).
fn dialect(&self) -> TypeId {
self.type_id()
}
/// Determine if a character starts a quoted identifier. The default
/// implementation, accepting "double quoted" ids is both ANSI-compliant
/// and appropriate for most dialects (with the notable exception of
@ -164,7 +173,7 @@ impl dyn Dialect {
#[inline]
pub fn is<T: Dialect>(&self) -> bool {
// borrowed from `Any` implementation
TypeId::of::<T>() == self.type_id()
TypeId::of::<T>() == self.dialect()
}
}
@ -248,4 +257,98 @@ mod tests {
fn parse_dialect(v: &str) -> Box<dyn Dialect> {
dialect_from_str(v).unwrap()
}
#[test]
fn parse_with_wrapped_dialect() {
/// Wrapper for a dialect. In a real-world example, this wrapper
/// would tweak the behavior of the dialect. For the test case,
/// it wraps all methods unaltered.
#[derive(Debug)]
struct WrappedDialect(MySqlDialect);
impl Dialect for WrappedDialect {
fn dialect(&self) -> std::any::TypeId {
self.0.dialect()
}
fn is_identifier_start(&self, ch: char) -> bool {
self.0.is_identifier_start(ch)
}
fn is_delimited_identifier_start(&self, ch: char) -> bool {
self.0.is_delimited_identifier_start(ch)
}
fn is_proper_identifier_inside_quotes(
&self,
chars: std::iter::Peekable<std::str::Chars<'_>>,
) -> bool {
self.0.is_proper_identifier_inside_quotes(chars)
}
fn supports_filter_during_aggregation(&self) -> bool {
self.0.supports_filter_during_aggregation()
}
fn supports_within_after_array_aggregation(&self) -> bool {
self.0.supports_within_after_array_aggregation()
}
fn supports_group_by_expr(&self) -> bool {
self.0.supports_group_by_expr()
}
fn supports_substring_from_for_expr(&self) -> bool {
self.0.supports_substring_from_for_expr()
}
fn supports_in_empty_list(&self) -> bool {
self.0.supports_in_empty_list()
}
fn convert_type_before_value(&self) -> bool {
self.0.convert_type_before_value()
}
fn parse_prefix(
&self,
parser: &mut sqlparser::parser::Parser,
) -> Option<Result<Expr, sqlparser::parser::ParserError>> {
self.0.parse_prefix(parser)
}
fn parse_infix(
&self,
parser: &mut sqlparser::parser::Parser,
expr: &Expr,
precedence: u8,
) -> Option<Result<Expr, sqlparser::parser::ParserError>> {
self.0.parse_infix(parser, expr, precedence)
}
fn get_next_precedence(
&self,
parser: &sqlparser::parser::Parser,
) -> Option<Result<u8, sqlparser::parser::ParserError>> {
self.0.get_next_precedence(parser)
}
fn parse_statement(
&self,
parser: &mut sqlparser::parser::Parser,
) -> Option<Result<Statement, sqlparser::parser::ParserError>> {
self.0.parse_statement(parser)
}
fn is_identifier_part(&self, ch: char) -> bool {
self.0.is_identifier_part(ch)
}
}
let statement = r#"SELECT 'Wayne\'s World'"#;
let res1 = Parser::parse_sql(&MySqlDialect {}, statement);
let res2 = Parser::parse_sql(&WrappedDialect(MySqlDialect {}), statement);
assert!(res1.is_ok());
assert_eq!(res1, res2);
}
}