diff --git a/Cargo.lock b/Cargo.lock index 4977fc561b..3601b9a10b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2272,6 +2272,7 @@ dependencies = [ "clap 4.0.15", "codegen", "itertools", + "libcst", "ruff", "rustpython-ast", "rustpython-common", diff --git a/resources/test/fixtures/C413.py b/resources/test/fixtures/C413.py index 27736b9836..73ffd315ac 100644 --- a/resources/test/fixtures/C413.py +++ b/resources/test/fixtures/C413.py @@ -1,4 +1,6 @@ x = [2, 3, 1] +list(x) list(sorted(x)) reversed(sorted(x)) +reversed(sorted(x, key=lambda e: e)) reversed(sorted(x, reverse=True)) diff --git a/ruff_dev/Cargo.toml b/ruff_dev/Cargo.toml index b5e7cdd6e6..7b7365639f 100644 --- a/ruff_dev/Cargo.toml +++ b/ruff_dev/Cargo.toml @@ -8,6 +8,7 @@ anyhow = { version = "1.0.60" } clap = { version = "4.0.1", features = ["derive"] } codegen = { version = "0.2.0" } itertools = { version = "0.10.5" } +libcst = { git = "https://github.com/charliermarsh/LibCST", rev = "32a044c127668df44582f85699358e67803b0d73" } ruff = { path = ".." } rustpython-ast = { features = ["unparse"], git = "https://github.com/RustPython/RustPython.git", rev = "27bf82a2251d7e6ac6cd75e6ad51be12a53d84bb" } rustpython-common = { git = "https://github.com/RustPython/RustPython.git", rev = "27bf82a2251d7e6ac6cd75e6ad51be12a53d84bb" } diff --git a/ruff_dev/src/lib.rs b/ruff_dev/src/lib.rs index 46c90ce233..da10cfc7d0 100644 --- a/ruff_dev/src/lib.rs +++ b/ruff_dev/src/lib.rs @@ -2,4 +2,5 @@ pub mod generate_check_code_prefix; pub mod generate_rules_table; pub mod generate_source_code; pub mod print_ast; +pub mod print_cst; pub mod print_tokens; diff --git a/ruff_dev/src/main.rs b/ruff_dev/src/main.rs index 146bd00444..630b0a6dd7 100644 --- a/ruff_dev/src/main.rs +++ b/ruff_dev/src/main.rs @@ -1,7 +1,8 @@ use anyhow::Result; use clap::{Parser, Subcommand}; use ruff_dev::{ - generate_check_code_prefix, generate_rules_table, generate_source_code, print_ast, print_tokens, + generate_check_code_prefix, generate_rules_table, generate_source_code, print_ast, print_cst, + print_tokens, }; #[derive(Parser)] @@ -22,6 +23,8 @@ enum Commands { GenerateSourceCode(generate_source_code::Cli), /// Print the AST for a given Python file. PrintAST(print_ast::Cli), + /// Print the LibCST CST for a given Python file. + PrintCST(print_cst::Cli), /// Print the token stream for a given Python file. PrintTokens(print_tokens::Cli), } @@ -33,6 +36,7 @@ fn main() -> Result<()> { Commands::GenerateRulesTable(args) => generate_rules_table::main(args)?, Commands::GenerateSourceCode(args) => generate_source_code::main(args)?, Commands::PrintAST(args) => print_ast::main(args)?, + Commands::PrintCST(args) => print_cst::main(args)?, Commands::PrintTokens(args) => print_tokens::main(args)?, } Ok(()) diff --git a/ruff_dev/src/print_cst.rs b/ruff_dev/src/print_cst.rs new file mode 100644 index 0000000000..d22f1685b7 --- /dev/null +++ b/ruff_dev/src/print_cst.rs @@ -0,0 +1,25 @@ +//! Print the LibCST CST for a given Python file. + +use std::fs; +use std::path::PathBuf; + +use anyhow::Result; +use clap::Args; + +#[derive(Args)] +pub struct Cli { + /// Python file for which to generate the CST. + #[arg(required = true)] + file: PathBuf, +} + +pub fn main(cli: &Cli) -> Result<()> { + let contents = fs::read_to_string(&cli.file)?; + match libcst_native::parse_module(&contents, None) { + Ok(python_cst) => { + println!("{:#?}", python_cst); + Ok(()) + } + Err(_) => Err(anyhow::anyhow!("Failed to parse CST")), + } +} diff --git a/src/check_ast.rs b/src/check_ast.rs index ff175e3bee..6377dcb86b 100644 --- a/src/check_ast.rs +++ b/src/check_ast.rs @@ -1220,8 +1220,11 @@ where if self.settings.enabled.contains(&CheckCode::C413) { if let Some(check) = flake8_comprehensions::checks::unnecessary_call_around_sorted( + expr, func, args, + self.locator, + self.patch(), Range::from_located(expr), ) { diff --git a/src/flake8_comprehensions/checks.rs b/src/flake8_comprehensions/checks.rs index 9defdb380e..86c26bdbdf 100644 --- a/src/flake8_comprehensions/checks.rs +++ b/src/flake8_comprehensions/checks.rs @@ -355,8 +355,11 @@ pub fn unnecessary_list_call( /// C413 pub fn unnecessary_call_around_sorted( + expr: &Expr, func: &Expr, args: &[Expr], + locator: &SourceCodeLocator, + fix: bool, location: Range, ) -> Option { let outer = function_name(func)?; @@ -365,10 +368,17 @@ pub fn unnecessary_call_around_sorted( } if let ExprKind::Call { func, .. } = &args.first()?.node { if function_name(func)? == "sorted" { - return Some(Check::new( + let mut check = Check::new( CheckKind::UnnecessaryCallAroundSorted(outer.to_string()), location, - )); + ); + if fix { + match fixes::fix_unnecessary_call_around_sorted(locator, expr) { + Ok(fix) => check.amend(fix), + Err(e) => error!("Failed to generate fix: {}", e), + } + } + return Some(check); } } None diff --git a/src/flake8_comprehensions/fixes.rs b/src/flake8_comprehensions/fixes.rs index b13643c331..3dd9c7cb3d 100644 --- a/src/flake8_comprehensions/fixes.rs +++ b/src/flake8_comprehensions/fixes.rs @@ -1,8 +1,9 @@ use anyhow::Result; use libcst_native::{ - Arg, Call, Codegen, Dict, DictComp, DictElement, Element, Expr, Expression, LeftCurlyBrace, - LeftParen, LeftSquareBracket, List, ListComp, Name, ParenthesizableWhitespace, RightCurlyBrace, - RightParen, RightSquareBracket, Set, SetComp, SimpleString, SimpleWhitespace, Tuple, + Arg, AssignEqual, Call, Codegen, Dict, DictComp, DictElement, Element, Expr, Expression, + LeftCurlyBrace, LeftParen, LeftSquareBracket, List, ListComp, Name, ParenthesizableWhitespace, + RightCurlyBrace, RightParen, RightSquareBracket, Set, SetComp, SimpleString, SimpleWhitespace, + Tuple, }; use crate::ast::types::Range; @@ -649,6 +650,92 @@ pub fn fix_unnecessary_list_call( )) } +/// (C413) Convert `list(sorted([2, 3, 1]))` to `sorted([2, 3, 1])`. +/// (C413) Convert `reversed(sorted([2, 3, 1]))` to `sorted([2, 3, 1], +/// reverse=True)`. +pub fn fix_unnecessary_call_around_sorted( + locator: &SourceCodeLocator, + expr: &rustpython_ast::Expr, +) -> Result { + let module_text = locator.slice_source_code_range(&Range::from_located(expr)); + let mut tree = match_module(&module_text)?; + let mut body = match_expr(&mut tree)?; + let outer_call = match_call(body)?; + let inner_call = match &outer_call.args[..] { + [arg] => { + if let Expression::Call(call) = &arg.value { + call + } else { + return Err(anyhow::anyhow!("Expected node to be: Expression::Call ")); + } + } + _ => { + return Err(anyhow::anyhow!( + "Expected one argument in outer function call" + )) + } + }; + + if let Expression::Name(outer_name) = &*outer_call.func { + if outer_name.value == "list" { + body.value = Expression::Call(inner_call.clone()); + } else { + let args = if inner_call.args.iter().any(|arg| { + matches!( + arg.keyword, + Some(Name { + value: "reverse", + .. + }) + ) + }) { + inner_call.args.clone() + } else { + let mut args = inner_call.args.clone(); + args.push(Arg { + value: Expression::Name(Box::new(Name { + value: "True", + lpar: Default::default(), + rpar: Default::default(), + })), + keyword: Some(Name { + value: "reverse", + lpar: Default::default(), + rpar: Default::default(), + }), + equal: Some(AssignEqual { + whitespace_before: Default::default(), + whitespace_after: Default::default(), + }), + comma: Default::default(), + star: Default::default(), + whitespace_after_star: Default::default(), + whitespace_after_arg: Default::default(), + }); + args + }; + + body.value = Expression::Call(Box::new(Call { + func: inner_call.func.clone(), + args, + lpar: inner_call.lpar.clone(), + rpar: inner_call.rpar.clone(), + whitespace_after_func: inner_call.whitespace_after_func.clone(), + whitespace_before_args: inner_call.whitespace_before_args.clone(), + })) + } + } + + let mut state = Default::default(); + tree.codegen(&mut state); + + Ok(Fix::replacement( + state.to_string(), + expr.location, + expr.end_location.unwrap(), + )) +} + /// (C416) Convert `[i for i in x]` to `list(x)`. pub fn fix_unnecessary_comprehension( locator: &SourceCodeLocator, diff --git a/src/snapshots/ruff__linter__tests__C413_C413.py.snap b/src/snapshots/ruff__linter__tests__C413_C413.py.snap index 50ca6c787c..df802b53e7 100644 --- a/src/snapshots/ruff__linter__tests__C413_C413.py.snap +++ b/src/snapshots/ruff__linter__tests__C413_C413.py.snap @@ -5,28 +5,73 @@ expression: checks - kind: UnnecessaryCallAroundSorted: list location: - row: 2 + row: 3 column: 0 end_location: - row: 2 + row: 3 column: 15 - fix: ~ + fix: + patch: + content: sorted(x) + location: + row: 3 + column: 0 + end_location: + row: 3 + column: 15 + applied: false - kind: UnnecessaryCallAroundSorted: reversed location: - row: 3 + row: 4 column: 0 end_location: - row: 3 + row: 4 column: 19 - fix: ~ + fix: + patch: + content: "sorted(x, reverse=True)" + location: + row: 4 + column: 0 + end_location: + row: 4 + column: 19 + applied: false - kind: UnnecessaryCallAroundSorted: reversed location: - row: 4 + row: 5 column: 0 end_location: - row: 4 + row: 5 + column: 36 + fix: + patch: + content: "sorted(x, key=lambda e: e, reverse=True)" + location: + row: 5 + column: 0 + end_location: + row: 5 + column: 36 + applied: false +- kind: + UnnecessaryCallAroundSorted: reversed + location: + row: 6 + column: 0 + end_location: + row: 6 column: 33 - fix: ~ + fix: + patch: + content: "sorted(x, reverse=True)" + location: + row: 6 + column: 0 + end_location: + row: 6 + column: 33 + applied: false