diff --git a/Cargo.lock b/Cargo.lock index 561bc7c5bc..9cb626f627 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1957,7 +1957,7 @@ dependencies = [ [[package]] name = "rustpython-ast" version = "0.1.0" -source = "git+https://github.com/charliermarsh/RustPython.git?rev=4f457893efc381ad5c432576b24bcc7e4a08c641#4f457893efc381ad5c432576b24bcc7e4a08c641" +source = "git+https://github.com/charliermarsh/RustPython.git?rev=778ae2aeb521d0438d2a91bd11238bb5c2bf9d4f#778ae2aeb521d0438d2a91bd11238bb5c2bf9d4f" dependencies = [ "num-bigint", "rustpython-common", @@ -1967,7 +1967,7 @@ dependencies = [ [[package]] name = "rustpython-common" version = "0.0.0" -source = "git+https://github.com/charliermarsh/RustPython.git?rev=4f457893efc381ad5c432576b24bcc7e4a08c641#4f457893efc381ad5c432576b24bcc7e4a08c641" +source = "git+https://github.com/charliermarsh/RustPython.git?rev=778ae2aeb521d0438d2a91bd11238bb5c2bf9d4f#778ae2aeb521d0438d2a91bd11238bb5c2bf9d4f" dependencies = [ "ascii", "cfg-if 1.0.0", @@ -1990,7 +1990,7 @@ dependencies = [ [[package]] name = "rustpython-compiler-core" version = "0.1.2" -source = "git+https://github.com/charliermarsh/RustPython.git?rev=4f457893efc381ad5c432576b24bcc7e4a08c641#4f457893efc381ad5c432576b24bcc7e4a08c641" +source = "git+https://github.com/charliermarsh/RustPython.git?rev=778ae2aeb521d0438d2a91bd11238bb5c2bf9d4f#778ae2aeb521d0438d2a91bd11238bb5c2bf9d4f" dependencies = [ "bincode", "bitflags", @@ -2007,7 +2007,7 @@ dependencies = [ [[package]] name = "rustpython-parser" version = "0.1.2" -source = "git+https://github.com/charliermarsh/RustPython.git?rev=4f457893efc381ad5c432576b24bcc7e4a08c641#4f457893efc381ad5c432576b24bcc7e4a08c641" +source = "git+https://github.com/charliermarsh/RustPython.git?rev=778ae2aeb521d0438d2a91bd11238bb5c2bf9d4f#778ae2aeb521d0438d2a91bd11238bb5c2bf9d4f" dependencies = [ "ahash", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 2d636c6855..616c869993 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,9 +27,9 @@ once_cell = { version = "1.13.1" } path-absolutize = { version = "3.0.13", features = ["once_cell_cache"] } rayon = { version = "1.5.3" } regex = { version = "1.6.0" } -rustpython-ast = { features = ["unparse"], git = "https://github.com/charliermarsh/RustPython.git", rev = "4f457893efc381ad5c432576b24bcc7e4a08c641" } -rustpython-parser = { features = ["lalrpop"], git = "https://github.com/charliermarsh/RustPython.git", rev = "4f457893efc381ad5c432576b24bcc7e4a08c641" } -rustpython-common = { git = "https://github.com/charliermarsh/RustPython.git", rev = "4f457893efc381ad5c432576b24bcc7e4a08c641" } +rustpython-ast = { features = ["unparse"], git = "https://github.com/charliermarsh/RustPython.git", rev = "778ae2aeb521d0438d2a91bd11238bb5c2bf9d4f" } +rustpython-parser = { features = ["lalrpop"], git = "https://github.com/charliermarsh/RustPython.git", rev = "778ae2aeb521d0438d2a91bd11238bb5c2bf9d4f" } +rustpython-common = { git = "https://github.com/charliermarsh/RustPython.git", rev = "778ae2aeb521d0438d2a91bd11238bb5c2bf9d4f" } serde = { version = "1.0.143", features = ["derive"] } serde_json = { version = "1.0.83" } toml = { version = "0.5.9" } diff --git a/src/plugins/use_pep604_annotation.rs b/src/plugins/use_pep604_annotation.rs index 065afaa57a..3c6b270d63 100644 --- a/src/plugins/use_pep604_annotation.rs +++ b/src/plugins/use_pep604_annotation.rs @@ -1,5 +1,4 @@ -use anyhow::{anyhow, Result}; -use rustpython_ast::{Expr, ExprKind}; +use rustpython_ast::{Constant, Expr, ExprKind, Operator}; use crate::ast::helpers::match_name_or_attr; use crate::ast::types::Range; @@ -8,15 +7,50 @@ use crate::check_ast::Checker; use crate::checks::{Check, CheckKind, Fix}; use crate::code_gen::SourceGenerator; +fn optional(expr: &Expr) -> Expr { + Expr::new( + Default::default(), + Default::default(), + ExprKind::BinOp { + left: Box::new(expr.clone()), + op: Operator::BitOr, + right: Box::new(Expr::new( + Default::default(), + Default::default(), + ExprKind::Constant { + value: Constant::None, + kind: None, + }, + )), + }, + ) +} + +fn union(elts: &[Expr]) -> Expr { + if elts.len() == 1 { + elts[0].clone() + } else { + Expr::new( + Default::default(), + Default::default(), + ExprKind::BinOp { + left: Box::new(union(&elts[..elts.len() - 1])), + op: Operator::BitOr, + right: Box::new(elts[elts.len() - 1].clone()), + }, + ) + } +} + pub fn use_pep604_annotation(checker: &mut Checker, expr: &Expr, value: &Expr, slice: &Expr) { if match_name_or_attr(value, "Optional") { let mut check = Check::new(CheckKind::UsePEP604Annotation, Range::from_located(expr)); if matches!(checker.autofix, fixer::Mode::Generate | fixer::Mode::Apply) { let mut generator = SourceGenerator::new(); - if let Ok(()) = generator.unparse_expr(slice, 0) { + if let Ok(()) = generator.unparse_expr(&optional(slice), 0) { if let Ok(content) = generator.generate() { check.amend(Fix { - content: format!("{} | None", content), + content, location: expr.location, end_location: expr.end_location, applied: false, @@ -33,27 +67,16 @@ pub fn use_pep604_annotation(checker: &mut Checker, expr: &Expr, value: &Expr, s // Invalid type annotation. } ExprKind::Tuple { elts, .. } => { - // Multiple arguments. - let parts: Result> = elts - .iter() - .map(|expr| { - let mut generator = SourceGenerator::new(); - generator - .unparse_expr(expr, 0) - .map_err(|_| anyhow!("Failed to parse."))?; - generator - .generate() - .map_err(|_| anyhow!("Failed to generate.")) - }) - .collect(); - if let Ok(parts) = parts { - let content = parts.join(" | "); - check.amend(Fix { - content, - location: expr.location, - end_location: expr.end_location, - applied: false, - }) + let mut generator = SourceGenerator::new(); + if let Ok(()) = generator.unparse_expr(&union(elts), 0) { + if let Ok(content) = generator.generate() { + check.amend(Fix { + content, + location: expr.location, + end_location: expr.end_location, + applied: false, + }) + } } } _ => {