Remove Result from SourceCodeGenerator signature (#1677)

We populate this buffer ourselves, so I believe it's fine for us to use
an unchecked UTF-8 cast here. It _dramatically_ simplifies so much
downstream code.
This commit is contained in:
Charlie Marsh 2023-01-05 21:41:26 -05:00 committed by GitHub
parent ee4cae97d5
commit 8caa73df6a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 274 additions and 473 deletions

View file

@ -28,6 +28,6 @@ pub fn main(cli: &Cli) -> Result<()> {
stylist.line_ending(), stylist.line_ending(),
); );
generator.unparse_suite(&python_ast); generator.unparse_suite(&python_ast);
println!("{}", generator.generate()?); println!("{}", generator.generate());
Ok(()) Ok(())
} }

View file

@ -1,4 +1,3 @@
use log::error;
use rustpython_ast::{Constant, Expr, ExprContext, ExprKind, Location, Stmt, StmtKind}; use rustpython_ast::{Constant, Expr, ExprContext, ExprKind, Location, Stmt, StmtKind};
use crate::ast::types::Range; use crate::ast::types::Range;
@ -54,16 +53,11 @@ pub fn assert_false(checker: &mut Checker, stmt: &Stmt, test: &Expr, msg: Option
checker.style.line_ending(), checker.style.line_ending(),
); );
generator.unparse_stmt(&assertion_error(msg)); generator.unparse_stmt(&assertion_error(msg));
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( stmt.location,
content, stmt.end_location.unwrap(),
stmt.location, ));
stmt.end_location.unwrap(),
));
}
Err(e) => error!("Failed to rewrite `assert False`: {e}"),
};
} }
checker.add_check(check); checker.add_check(check);
} }

View file

@ -1,5 +1,4 @@
use itertools::Itertools; use itertools::Itertools;
use log::error;
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use rustpython_ast::{Excepthandler, ExcepthandlerKind, Expr, ExprContext, ExprKind, Location}; use rustpython_ast::{Excepthandler, ExcepthandlerKind, Expr, ExprContext, ExprKind, Location};
@ -65,16 +64,11 @@ fn duplicate_handler_exceptions<'a>(
} else { } else {
generator.unparse_expr(&type_pattern(unique_elts), 0); generator.unparse_expr(&type_pattern(unique_elts), 0);
} }
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( expr.location,
content, expr.end_location.unwrap(),
expr.location, ));
expr.end_location.unwrap(),
));
}
Err(e) => error!("Failed to remove duplicate exceptions: {e}"),
}
} }
checker.add_check(check); checker.add_check(check);
} }

View file

@ -1,4 +1,3 @@
use log::error;
use rustpython_ast::{Constant, Expr, ExprContext, ExprKind, Location}; use rustpython_ast::{Constant, Expr, ExprContext, ExprKind, Location};
use crate::ast::types::Range; use crate::ast::types::Range;
@ -53,16 +52,11 @@ pub fn getattr_with_constant(checker: &mut Checker, expr: &Expr, func: &Expr, ar
checker.style.line_ending(), checker.style.line_ending(),
); );
generator.unparse_expr(&attribute(obj, value), 0); generator.unparse_expr(&attribute(obj, value), 0);
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( expr.location,
content, expr.end_location.unwrap(),
expr.location, ));
expr.end_location.unwrap(),
));
}
Err(e) => error!("Failed to rewrite `getattr`: {e}"),
}
} }
checker.add_check(check); checker.add_check(check);
} }

View file

@ -1,4 +1,3 @@
use log::error;
use rustpython_ast::{Excepthandler, ExcepthandlerKind, ExprKind}; use rustpython_ast::{Excepthandler, ExcepthandlerKind, ExprKind};
use crate::ast::types::Range; use crate::ast::types::Range;
@ -30,16 +29,11 @@ pub fn redundant_tuple_in_exception_handler(checker: &mut Checker, handlers: &[E
checker.style.line_ending(), checker.style.line_ending(),
); );
generator.unparse_expr(elt, 0); generator.unparse_expr(elt, 0);
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( type_.location,
content, type_.end_location.unwrap(),
type_.location, ));
type_.end_location.unwrap(),
));
}
Err(e) => error!("Failed to remove redundant tuple: {e}"),
}
} }
checker.add_check(check); checker.add_check(check);
} }

View file

@ -1,5 +1,3 @@
use anyhow::Result;
use log::error;
use rustpython_ast::{Constant, Expr, ExprContext, ExprKind, Location, Stmt, StmtKind}; use rustpython_ast::{Constant, Expr, ExprContext, ExprKind, Location, Stmt, StmtKind};
use crate::ast::types::Range; use crate::ast::types::Range;
@ -11,12 +9,7 @@ use crate::registry::{Check, CheckKind};
use crate::source_code_generator::SourceCodeGenerator; use crate::source_code_generator::SourceCodeGenerator;
use crate::source_code_style::SourceCodeStyleDetector; use crate::source_code_style::SourceCodeStyleDetector;
fn assignment( fn assignment(obj: &Expr, name: &str, value: &Expr, stylist: &SourceCodeStyleDetector) -> String {
obj: &Expr,
name: &str,
value: &Expr,
stylist: &SourceCodeStyleDetector,
) -> Result<String> {
let stmt = Stmt::new( let stmt = Stmt::new(
Location::default(), Location::default(),
Location::default(), Location::default(),
@ -40,7 +33,7 @@ fn assignment(
stylist.line_ending(), stylist.line_ending(),
); );
generator.unparse_stmt(&stmt); generator.unparse_stmt(&stmt);
generator.generate().map_err(std::convert::Into::into) generator.generate()
} }
/// B010 /// B010
@ -73,16 +66,11 @@ pub fn setattr_with_constant(checker: &mut Checker, expr: &Expr, func: &Expr, ar
if expr == child.as_ref() { if expr == child.as_ref() {
let mut check = Check::new(CheckKind::SetAttrWithConstant, Range::from_located(expr)); let mut check = Check::new(CheckKind::SetAttrWithConstant, Range::from_located(expr));
if checker.patch(check.kind.code()) { if checker.patch(check.kind.code()) {
match assignment(obj, name, value, checker.style) { check.amend(Fix::replacement(
Ok(content) => { assignment(obj, name, value, checker.style),
check.amend(Fix::replacement( expr.location,
content, expr.end_location.unwrap(),
expr.location, ));
expr.end_location.unwrap(),
));
}
Err(e) => error!("Failed to fix invalid comparison: {e}"),
};
} }
checker.add_check(check); checker.add_check(check);
} }

View file

@ -1,4 +1,3 @@
use log::error;
use rustpython_ast::{Constant, Expr, ExprContext, ExprKind}; use rustpython_ast::{Constant, Expr, ExprContext, ExprKind};
use super::helpers::is_pytest_parametrize; use super::helpers::is_pytest_parametrize;
@ -36,7 +35,6 @@ fn elts_to_csv(elts: &[Expr], checker: &Checker) -> Option<String> {
checker.style.quote(), checker.style.quote(),
checker.style.line_ending(), checker.style.line_ending(),
); );
generator.unparse_expr( generator.unparse_expr(
&create_expr(ExprKind::Constant { &create_expr(ExprKind::Constant {
value: Constant::Str(elts.iter().fold(String::new(), |mut acc, elt| { value: Constant::Str(elts.iter().fold(String::new(), |mut acc, elt| {
@ -56,17 +54,7 @@ fn elts_to_csv(elts: &[Expr], checker: &Checker) -> Option<String> {
}), }),
0, 0,
); );
Some(generator.generate())
match generator.generate() {
Ok(s) => Some(s),
Err(e) => {
error!(
"Failed to generate CSV string from sequence of names: {}",
e
);
None
}
}
} }
/// PT006 /// PT006
@ -120,19 +108,11 @@ fn check_names(checker: &mut Checker, expr: &Expr) {
}), }),
1, 1,
); );
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( expr.location,
content, expr.end_location.unwrap(),
expr.location, ));
expr.end_location.unwrap(),
));
}
Err(e) => error!(
"Failed to fix wrong name(s) type in \
`@pytest.mark.parametrize`: {e}"
),
};
} }
checker.add_check(check); checker.add_check(check);
} }
@ -162,19 +142,11 @@ fn check_names(checker: &mut Checker, expr: &Expr) {
}), }),
0, 0,
); );
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( expr.location,
content, expr.end_location.unwrap(),
expr.location, ));
expr.end_location.unwrap(),
));
}
Err(e) => error!(
"Failed to fix wrong name(s) type in \
`@pytest.mark.parametrize`: {e}"
),
};
} }
checker.add_check(check); checker.add_check(check);
} }
@ -208,19 +180,11 @@ fn check_names(checker: &mut Checker, expr: &Expr) {
}), }),
0, 0,
); );
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( expr.location,
content, expr.end_location.unwrap(),
expr.location, ));
expr.end_location.unwrap(),
));
}
Err(e) => error!(
"Failed to fix wrong name(s) type in \
`@pytest.mark.parametrize`: {e}"
),
};
} }
checker.add_check(check); checker.add_check(check);
} }
@ -269,19 +233,11 @@ fn check_names(checker: &mut Checker, expr: &Expr) {
}), }),
1, // so tuple is generated with parentheses 1, // so tuple is generated with parentheses
); );
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( expr.location,
content, expr.end_location.unwrap(),
expr.location, ));
expr.end_location.unwrap(),
));
}
Err(e) => error!(
"Failed to fix wrong name(s) type in \
`@pytest.mark.parametrize`: {e}"
),
};
} }
checker.add_check(check); checker.add_check(check);
} }
@ -353,16 +309,11 @@ fn handle_single_name(checker: &mut Checker, expr: &Expr, value: &Expr) {
checker.style.line_ending(), checker.style.line_ending(),
); );
generator.unparse_expr(&create_expr(value.node.clone()), 0); generator.unparse_expr(&create_expr(value.node.clone()), 0);
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( expr.location,
content, expr.end_location.unwrap(),
expr.location, ));
expr.end_location.unwrap(),
));
}
Err(e) => error!("Failed to fix wrong name(s) type in `@pytest.mark.parametrize`: {e}"),
};
} }
checker.add_check(check); checker.add_check(check);
} }

View file

@ -30,7 +30,7 @@ fn to_source(expr: &Expr, stylist: &SourceCodeStyleDetector) -> String {
stylist.line_ending(), stylist.line_ending(),
); );
generator.unparse_expr(expr, 0); generator.unparse_expr(expr, 0);
generator.generate().unwrap() generator.generate()
} }
/// SIM101 /// SIM101

View file

@ -1,4 +1,3 @@
use log::error;
use rustpython_ast::{ use rustpython_ast::{
Comprehension, Constant, Expr, ExprContext, ExprKind, Stmt, StmtKind, Unaryop, Comprehension, Constant, Expr, ExprContext, ExprKind, Stmt, StmtKind, Unaryop,
}; };
@ -83,7 +82,7 @@ fn return_stmt(
target: &Expr, target: &Expr,
iter: &Expr, iter: &Expr,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Option<String> { ) -> String {
let mut generator = SourceCodeGenerator::new( let mut generator = SourceCodeGenerator::new(
stylist.indentation(), stylist.indentation(),
stylist.quote(), stylist.quote(),
@ -107,13 +106,7 @@ fn return_stmt(
keywords: vec![], keywords: vec![],
}))), }))),
})); }));
match generator.generate() { generator.generate()
Ok(test) => Some(test),
Err(e) => {
error!("Failed to generate source code: {}", e);
None
}
}
} }
/// SIM110, SIM111 /// SIM110, SIM111
@ -121,26 +114,25 @@ pub fn convert_loop_to_any_all(checker: &mut Checker, stmt: &Stmt, sibling: &Stm
if let Some(loop_info) = return_values(stmt, sibling) { if let Some(loop_info) = return_values(stmt, sibling) {
if loop_info.return_value && !loop_info.next_return_value { if loop_info.return_value && !loop_info.next_return_value {
if checker.settings.enabled.contains(&CheckCode::SIM110) { if checker.settings.enabled.contains(&CheckCode::SIM110) {
if let Some(content) = return_stmt( let content = return_stmt(
"any", "any",
loop_info.test, loop_info.test,
loop_info.target, loop_info.target,
loop_info.iter, loop_info.iter,
checker.style, checker.style,
) { );
let mut check = Check::new( let mut check = Check::new(
CheckKind::ConvertLoopToAny(content.clone()), CheckKind::ConvertLoopToAny(content.clone()),
Range::from_located(stmt), Range::from_located(stmt),
); );
if checker.patch(&CheckCode::SIM110) { if checker.patch(&CheckCode::SIM110) {
check.amend(Fix::replacement( check.amend(Fix::replacement(
content, content,
stmt.location, stmt.location,
sibling.end_location.unwrap(), sibling.end_location.unwrap(),
)); ));
}
checker.add_check(check);
} }
checker.add_check(check);
} }
} }
@ -161,26 +153,25 @@ pub fn convert_loop_to_any_all(checker: &mut Checker, stmt: &Stmt, sibling: &Stm
}) })
} }
}; };
if let Some(content) = return_stmt( let content = return_stmt(
"all", "all",
&test, &test,
loop_info.target, loop_info.target,
loop_info.iter, loop_info.iter,
checker.style, checker.style,
) { );
let mut check = Check::new( let mut check = Check::new(
CheckKind::ConvertLoopToAll(content.clone()), CheckKind::ConvertLoopToAll(content.clone()),
Range::from_located(stmt), Range::from_located(stmt),
); );
if checker.patch(&CheckCode::SIM111) { if checker.patch(&CheckCode::SIM111) {
check.amend(Fix::replacement( check.amend(Fix::replacement(
content, content,
stmt.location, stmt.location,
sibling.end_location.unwrap(), sibling.end_location.unwrap(),
)); ));
}
checker.add_check(check);
} }
checker.add_check(check);
} }
} }
} }

View file

@ -1,6 +1,4 @@
use anyhow::Result;
use itertools::izip; use itertools::izip;
use log::error;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use rustpython_ast::{Arguments, Location, StmtKind}; use rustpython_ast::{Arguments, Location, StmtKind};
use rustpython_parser::ast::{Cmpop, Constant, Expr, ExprKind, Stmt, Unaryop}; use rustpython_parser::ast::{Cmpop, Constant, Expr, ExprKind, Stmt, Unaryop};
@ -20,7 +18,7 @@ fn compare(
ops: &[Cmpop], ops: &[Cmpop],
comparators: &[Expr], comparators: &[Expr],
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Option<String> { ) -> String {
let cmp = Expr::new( let cmp = Expr::new(
Location::default(), Location::default(),
Location::default(), Location::default(),
@ -36,7 +34,7 @@ fn compare(
stylist.line_ending(), stylist.line_ending(),
); );
generator.unparse_expr(&cmp, 0); generator.unparse_expr(&cmp, 0);
generator.generate().ok() generator.generate()
} }
/// E711, E712 /// E711, E712
@ -204,14 +202,13 @@ pub fn literal_comparisons(
.map(|(idx, op)| bad_ops.get(&idx).unwrap_or(op)) .map(|(idx, op)| bad_ops.get(&idx).unwrap_or(op))
.cloned() .cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if let Some(content) = compare(left, &ops, comparators, checker.style) { let content = compare(left, &ops, comparators, checker.style);
for check in &mut checks { for check in &mut checks {
check.amend(Fix::replacement( check.amend(Fix::replacement(
content.to_string(), content.to_string(),
expr.location, expr.location,
expr.end_location.unwrap(), expr.end_location.unwrap(),
)); ));
}
} }
} }
@ -243,15 +240,11 @@ pub fn not_tests(
let mut check = let mut check =
Check::new(CheckKind::NotInTest, Range::from_located(operand)); Check::new(CheckKind::NotInTest, Range::from_located(operand));
if checker.patch(check.kind.code()) && should_fix { if checker.patch(check.kind.code()) && should_fix {
if let Some(content) = check.amend(Fix::replacement(
compare(left, &[Cmpop::NotIn], comparators, checker.style) compare(left, &[Cmpop::NotIn], comparators, checker.style),
{ expr.location,
check.amend(Fix::replacement( expr.end_location.unwrap(),
content, ));
expr.location,
expr.end_location.unwrap(),
));
}
} }
checker.add_check(check); checker.add_check(check);
} }
@ -261,15 +254,11 @@ pub fn not_tests(
let mut check = let mut check =
Check::new(CheckKind::NotIsTest, Range::from_located(operand)); Check::new(CheckKind::NotIsTest, Range::from_located(operand));
if checker.patch(check.kind.code()) && should_fix { if checker.patch(check.kind.code()) && should_fix {
if let Some(content) = check.amend(Fix::replacement(
compare(left, &[Cmpop::IsNot], comparators, checker.style) compare(left, &[Cmpop::IsNot], comparators, checker.style),
{ expr.location,
check.amend(Fix::replacement( expr.end_location.unwrap(),
content, ));
expr.location,
expr.end_location.unwrap(),
));
}
} }
checker.add_check(check); checker.add_check(check);
} }
@ -286,7 +275,7 @@ fn function(
args: &Arguments, args: &Arguments,
body: &Expr, body: &Expr,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Result<String> { ) -> String {
let body = Stmt::new( let body = Stmt::new(
Location::default(), Location::default(),
Location::default(), Location::default(),
@ -312,7 +301,7 @@ fn function(
stylist.line_ending(), stylist.line_ending(),
); );
generator.unparse_stmt(&func); generator.unparse_stmt(&func);
Ok(generator.generate()?) generator.generate()
} }
/// E731 /// E731
@ -327,31 +316,26 @@ pub fn do_not_assign_lambda(checker: &mut Checker, target: &Expr, value: &Expr,
if !match_leading_content(stmt, checker.locator) if !match_leading_content(stmt, checker.locator)
&& !match_trailing_content(stmt, checker.locator) && !match_trailing_content(stmt, checker.locator)
{ {
match function(id, args, body, checker.style) { let first_line = checker.locator.slice_source_code_range(&Range::new(
Ok(content) => { Location::new(stmt.location.row(), 0),
let first_line = checker.locator.slice_source_code_range(&Range::new( Location::new(stmt.location.row() + 1, 0),
Location::new(stmt.location.row(), 0), ));
Location::new(stmt.location.row() + 1, 0), let indentation = &leading_space(&first_line);
)); let mut indented = String::new();
let indentation = &leading_space(&first_line); for (idx, line) in function(id, args, body, checker.style).lines().enumerate() {
let mut indented = String::new(); if idx == 0 {
for (idx, line) in content.lines().enumerate() { indented.push_str(line);
if idx == 0 { } else {
indented.push_str(line); indented.push('\n');
} else { indented.push_str(indentation);
indented.push('\n'); indented.push_str(line);
indented.push_str(indentation);
indented.push_str(line);
}
}
check.amend(Fix::replacement(
indented,
stmt.location,
stmt.end_location.unwrap(),
));
} }
Err(e) => error!("Failed to generate fix: {e}"),
} }
check.amend(Fix::replacement(
indented,
stmt.location,
stmt.end_location.unwrap(),
));
} }
} }
checker.add_check(check); checker.add_check(check);

View file

@ -165,19 +165,18 @@ fn convert_to_class(
body: Vec<Stmt>, body: Vec<Stmt>,
base_class: &ExprKind, base_class: &ExprKind,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Result<Fix> { ) -> Fix {
let mut generator = SourceCodeGenerator::new( let mut generator = SourceCodeGenerator::new(
stylist.indentation(), stylist.indentation(),
stylist.quote(), stylist.quote(),
stylist.line_ending(), stylist.line_ending(),
); );
generator.unparse_stmt(&create_class_def_stmt(typename, body, base_class)); generator.unparse_stmt(&create_class_def_stmt(typename, body, base_class));
let content = generator.generate()?; Fix::replacement(
Ok(Fix::replacement( generator.generate(),
content,
stmt.location, stmt.location,
stmt.end_location.unwrap(), stmt.end_location.unwrap(),
)) )
} }
/// UP014 /// UP014
@ -200,12 +199,13 @@ pub fn convert_named_tuple_functional_to_class(
Range::from_located(stmt), Range::from_located(stmt),
); );
if checker.patch(check.kind.code()) { if checker.patch(check.kind.code()) {
match convert_to_class(stmt, typename, properties, base_class, checker.style) { check.amend(convert_to_class(
Ok(fix) => { stmt,
check.amend(fix); typename,
} properties,
Err(err) => error!("Failed to convert `NamedTuple`: {err}"), base_class,
} checker.style,
));
} }
checker.add_check(check); checker.add_check(check);
} }

View file

@ -198,7 +198,7 @@ fn convert_to_class(
total_keyword: Option<KeywordData>, total_keyword: Option<KeywordData>,
base_class: &ExprKind, base_class: &ExprKind,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Result<Fix> { ) -> Fix {
let mut generator = SourceCodeGenerator::new( let mut generator = SourceCodeGenerator::new(
stylist.indentation(), stylist.indentation(),
stylist.quote(), stylist.quote(),
@ -210,12 +210,11 @@ fn convert_to_class(
total_keyword, total_keyword,
base_class, base_class,
)); ));
let content = generator.generate()?; Fix::replacement(
Ok(Fix::replacement( generator.generate(),
content,
stmt.location, stmt.location,
stmt.end_location.unwrap(), stmt.end_location.unwrap(),
)) )
} }
/// UP013 /// UP013
@ -242,19 +241,14 @@ pub fn convert_typed_dict_functional_to_class(
Range::from_located(stmt), Range::from_located(stmt),
); );
if checker.patch(check.kind.code()) { if checker.patch(check.kind.code()) {
match convert_to_class( check.amend(convert_to_class(
stmt, stmt,
class_name, class_name,
body, body,
total_keyword, total_keyword,
base_class, base_class,
checker.style, checker.style,
) { ));
Ok(fix) => {
check.amend(fix);
}
Err(err) => error!("Failed to convert TypedDict: {err}"),
};
} }
checker.add_check(check); checker.add_check(check);
} }

View file

@ -1,5 +1,3 @@
use anyhow::{bail, Result};
use log::error;
use rustpython_ast::{Constant, Expr, ExprContext, ExprKind, Keyword, StmtKind}; use rustpython_ast::{Constant, Expr, ExprContext, ExprKind, Keyword, StmtKind};
use crate::ast::helpers::{collect_call_paths, create_expr, create_stmt, dealias_call_path}; use crate::ast::helpers::{collect_call_paths, create_expr, create_stmt, dealias_call_path};
@ -89,7 +87,7 @@ fn replace_call_on_arg_by_arg_attribute(
expr: &Expr, expr: &Expr,
patch: bool, patch: bool,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Result<Check> { ) -> Check {
let attribute = ExprKind::Attribute { let attribute = ExprKind::Attribute {
value: Box::new(arg.clone()), value: Box::new(arg.clone()),
attr: attr.to_string(), attr: attr.to_string(),
@ -105,11 +103,10 @@ fn replace_call_on_arg_by_arg_method_call(
expr: &Expr, expr: &Expr,
patch: bool, patch: bool,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Result<Option<Check>> { ) -> Option<Check> {
if args.is_empty() { if args.is_empty() {
bail!("Expected at least one argument"); None
} } else if let ([arg], other_args) = args.split_at(1) {
if let ([arg], other_args) = args.split_at(1) {
let call = ExprKind::Call { let call = ExprKind::Call {
func: Box::new(create_expr(ExprKind::Attribute { func: Box::new(create_expr(ExprKind::Attribute {
value: Box::new(arg.clone()), value: Box::new(arg.clone()),
@ -122,10 +119,9 @@ fn replace_call_on_arg_by_arg_method_call(
.collect(), .collect(),
keywords: vec![], keywords: vec![],
}; };
let expr = replace_by_expr_kind(call, expr, patch, stylist)?; Some(replace_by_expr_kind(call, expr, patch, stylist))
Ok(Some(expr))
} else { } else {
Ok(None) None
} }
} }
@ -135,7 +131,7 @@ fn replace_by_expr_kind(
expr: &Expr, expr: &Expr,
patch: bool, patch: bool,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Result<Check> { ) -> Check {
let mut check = Check::new(CheckKind::RemoveSixCompat, Range::from_located(expr)); let mut check = Check::new(CheckKind::RemoveSixCompat, Range::from_located(expr));
if patch { if patch {
let mut generator = SourceCodeGenerator::new( let mut generator = SourceCodeGenerator::new(
@ -144,14 +140,13 @@ fn replace_by_expr_kind(
stylist.line_ending(), stylist.line_ending(),
); );
generator.unparse_expr(&create_expr(node), 0); generator.unparse_expr(&create_expr(node), 0);
let content = generator.generate()?;
check.amend(Fix::replacement( check.amend(Fix::replacement(
content, generator.generate(),
expr.location, expr.location,
expr.end_location.unwrap(), expr.end_location.unwrap(),
)); ));
} }
Ok(check) check
} }
fn replace_by_stmt_kind( fn replace_by_stmt_kind(
@ -159,7 +154,7 @@ fn replace_by_stmt_kind(
expr: &Expr, expr: &Expr,
patch: bool, patch: bool,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Result<Check> { ) -> Check {
let mut check = Check::new(CheckKind::RemoveSixCompat, Range::from_located(expr)); let mut check = Check::new(CheckKind::RemoveSixCompat, Range::from_located(expr));
if patch { if patch {
let mut generator = SourceCodeGenerator::new( let mut generator = SourceCodeGenerator::new(
@ -168,14 +163,13 @@ fn replace_by_stmt_kind(
stylist.line_ending(), stylist.line_ending(),
); );
generator.unparse_stmt(&create_stmt(node)); generator.unparse_stmt(&create_stmt(node));
let content = generator.generate()?;
check.amend(Fix::replacement( check.amend(Fix::replacement(
content, generator.generate(),
expr.location, expr.location,
expr.end_location.unwrap(), expr.end_location.unwrap(),
)); ));
} }
Ok(check) check
} }
// => `raise exc from cause` // => `raise exc from cause`
@ -185,7 +179,7 @@ fn replace_by_raise_from(
expr: &Expr, expr: &Expr,
patch: bool, patch: bool,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Result<Check> { ) -> Check {
let stmt_kind = StmtKind::Raise { let stmt_kind = StmtKind::Raise {
exc: exc.map(|exc| Box::new(create_expr(exc))), exc: exc.map(|exc| Box::new(create_expr(exc))),
cause: cause.map(|cause| Box::new(create_expr(cause))), cause: cause.map(|cause| Box::new(create_expr(cause))),
@ -199,7 +193,7 @@ fn replace_by_index_on_arg(
expr: &Expr, expr: &Expr,
patch: bool, patch: bool,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Result<Check> { ) -> Check {
let index = ExprKind::Subscript { let index = ExprKind::Subscript {
value: Box::new(create_expr(arg.node.clone())), value: Box::new(create_expr(arg.node.clone())),
slice: Box::new(create_expr(index.clone())), slice: Box::new(create_expr(index.clone())),
@ -213,7 +207,7 @@ fn handle_reraise(
expr: &Expr, expr: &Expr,
patch: bool, patch: bool,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
) -> Result<Option<Check>> { ) -> Option<Check> {
if let [_, exc, tb] = args { if let [_, exc, tb] = args {
let check = replace_by_raise_from( let check = replace_by_raise_from(
Some(ExprKind::Call { Some(ExprKind::Call {
@ -229,24 +223,24 @@ fn handle_reraise(
expr, expr,
patch, patch,
stylist, stylist,
)?; );
Ok(Some(check)) Some(check)
} else if let [arg] = args { } else if let [arg] = args {
if let ExprKind::Starred { value, .. } = &arg.node { if let ExprKind::Starred { value, .. } = &arg.node {
if let ExprKind::Call { func, .. } = &value.node { if let ExprKind::Call { func, .. } = &value.node {
if let ExprKind::Attribute { value, attr, .. } = &func.node { if let ExprKind::Attribute { value, attr, .. } = &func.node {
if let ExprKind::Name { id, .. } = &value.node { if let ExprKind::Name { id, .. } = &value.node {
if id == "sys" && attr == "exc_info" { if id == "sys" && attr == "exc_info" {
let check = replace_by_raise_from(None, None, expr, patch, stylist)?; let check = replace_by_raise_from(None, None, expr, patch, stylist);
return Ok(Some(check)); return Some(check);
}; };
}; };
}; };
}; };
}; };
Ok(None) None
} else { } else {
Ok(None) None
} }
} }
@ -258,11 +252,11 @@ fn handle_func(
patch: bool, patch: bool,
stylist: &SourceCodeStyleDetector, stylist: &SourceCodeStyleDetector,
locator: &SourceCodeLocator, locator: &SourceCodeLocator,
) -> Result<Option<Check>> { ) -> Option<Check> {
let func_name = match &func.node { let func_name = match &func.node {
ExprKind::Attribute { attr, .. } => attr, ExprKind::Attribute { attr, .. } => attr,
ExprKind::Name { id, .. } => id, ExprKind::Name { id, .. } => id,
_ => bail!("Unexpected func: {:?}", func), _ => return None,
}; };
let check = match (func_name.as_str(), args, keywords) { let check = match (func_name.as_str(), args, keywords) {
("b", [arg], []) => replace_by_str_literal(arg, true, expr, patch, locator), ("b", [arg], []) => replace_by_str_literal(arg, true, expr, patch, locator),
@ -271,73 +265,67 @@ fn handle_func(
("ensure_str", [arg], []) => replace_by_str_literal(arg, false, expr, patch, locator), ("ensure_str", [arg], []) => replace_by_str_literal(arg, false, expr, patch, locator),
("ensure_text", [arg], []) => replace_by_str_literal(arg, false, expr, patch, locator), ("ensure_text", [arg], []) => replace_by_str_literal(arg, false, expr, patch, locator),
("iteritems", args, []) => { ("iteritems", args, []) => {
replace_call_on_arg_by_arg_method_call("items", args, expr, patch, stylist)? replace_call_on_arg_by_arg_method_call("items", args, expr, patch, stylist)
} }
("viewitems", args, []) => { ("viewitems", args, []) => {
replace_call_on_arg_by_arg_method_call("items", args, expr, patch, stylist)? replace_call_on_arg_by_arg_method_call("items", args, expr, patch, stylist)
} }
("iterkeys", args, []) => { ("iterkeys", args, []) => {
replace_call_on_arg_by_arg_method_call("keys", args, expr, patch, stylist)? replace_call_on_arg_by_arg_method_call("keys", args, expr, patch, stylist)
} }
("viewkeys", args, []) => { ("viewkeys", args, []) => {
replace_call_on_arg_by_arg_method_call("keys", args, expr, patch, stylist)? replace_call_on_arg_by_arg_method_call("keys", args, expr, patch, stylist)
} }
("itervalues", args, []) => { ("itervalues", args, []) => {
replace_call_on_arg_by_arg_method_call("values", args, expr, patch, stylist)? replace_call_on_arg_by_arg_method_call("values", args, expr, patch, stylist)
} }
("viewvalues", args, []) => { ("viewvalues", args, []) => {
replace_call_on_arg_by_arg_method_call("values", args, expr, patch, stylist)? replace_call_on_arg_by_arg_method_call("values", args, expr, patch, stylist)
} }
("get_method_function", [arg], []) => Some(replace_call_on_arg_by_arg_attribute( ("get_method_function", [arg], []) => Some(replace_call_on_arg_by_arg_attribute(
"__func__", arg, expr, patch, stylist, "__func__", arg, expr, patch, stylist,
)?), )),
("get_method_self", [arg], []) => Some(replace_call_on_arg_by_arg_attribute( ("get_method_self", [arg], []) => Some(replace_call_on_arg_by_arg_attribute(
"__self__", arg, expr, patch, stylist, "__self__", arg, expr, patch, stylist,
)?), )),
("get_function_closure", [arg], []) => Some(replace_call_on_arg_by_arg_attribute( ("get_function_closure", [arg], []) => Some(replace_call_on_arg_by_arg_attribute(
"__closure__", "__closure__",
arg, arg,
expr, expr,
patch, patch,
stylist, stylist,
)?), )),
("get_function_code", [arg], []) => Some(replace_call_on_arg_by_arg_attribute( ("get_function_code", [arg], []) => Some(replace_call_on_arg_by_arg_attribute(
"__code__", arg, expr, patch, stylist, "__code__", arg, expr, patch, stylist,
)?), )),
("get_function_defaults", [arg], []) => Some(replace_call_on_arg_by_arg_attribute( ("get_function_defaults", [arg], []) => Some(replace_call_on_arg_by_arg_attribute(
"__defaults__", "__defaults__",
arg, arg,
expr, expr,
patch, patch,
stylist, stylist,
)?), )),
("get_function_globals", [arg], []) => Some(replace_call_on_arg_by_arg_attribute( ("get_function_globals", [arg], []) => Some(replace_call_on_arg_by_arg_attribute(
"__globals__", "__globals__",
arg, arg,
expr, expr,
patch, patch,
stylist, stylist,
)?), )),
("create_unbound_method", [arg, _], _) => Some(replace_by_expr_kind( ("create_unbound_method", [arg, _], _) => {
arg.node.clone(), Some(replace_by_expr_kind(arg.node.clone(), expr, patch, stylist))
expr, }
patch, ("get_unbound_function", [arg], []) => {
stylist, Some(replace_by_expr_kind(arg.node.clone(), expr, patch, stylist))
)?), }
("get_unbound_function", [arg], []) => Some(replace_by_expr_kind(
arg.node.clone(),
expr,
patch,
stylist,
)?),
("assertCountEqual", args, []) => { ("assertCountEqual", args, []) => {
replace_call_on_arg_by_arg_method_call("assertCountEqual", args, expr, patch, stylist)? replace_call_on_arg_by_arg_method_call("assertCountEqual", args, expr, patch, stylist)
} }
("assertRaisesRegex", args, []) => { ("assertRaisesRegex", args, []) => {
replace_call_on_arg_by_arg_method_call("assertRaisesRegex", args, expr, patch, stylist)? replace_call_on_arg_by_arg_method_call("assertRaisesRegex", args, expr, patch, stylist)
} }
("assertRegex", args, []) => { ("assertRegex", args, []) => {
replace_call_on_arg_by_arg_method_call("assertRegex", args, expr, patch, stylist)? replace_call_on_arg_by_arg_method_call("assertRegex", args, expr, patch, stylist)
} }
("raise_from", [exc, cause], []) => Some(replace_by_raise_from( ("raise_from", [exc, cause], []) => Some(replace_by_raise_from(
Some(exc.node.clone()), Some(exc.node.clone()),
@ -345,8 +333,8 @@ fn handle_func(
expr, expr,
patch, patch,
stylist, stylist,
)?), )),
("reraise", args, []) => handle_reraise(args, expr, patch, stylist)?, ("reraise", args, []) => handle_reraise(args, expr, patch, stylist),
("byte2int", [arg], []) => Some(replace_by_index_on_arg( ("byte2int", [arg], []) => Some(replace_by_index_on_arg(
arg, arg,
&ExprKind::Constant { &ExprKind::Constant {
@ -356,14 +344,14 @@ fn handle_func(
expr, expr,
patch, patch,
stylist, stylist,
)?), )),
("indexbytes", [arg, index], []) => Some(replace_by_index_on_arg( ("indexbytes", [arg, index], []) => Some(replace_by_index_on_arg(
arg, arg,
&index.node, &index.node,
expr, expr,
patch, patch,
stylist, stylist,
)?), )),
("int2byte", [arg], []) => Some(replace_by_expr_kind( ("int2byte", [arg], []) => Some(replace_by_expr_kind(
ExprKind::Call { ExprKind::Call {
func: Box::new(create_expr(ExprKind::Name { func: Box::new(create_expr(ExprKind::Name {
@ -379,37 +367,37 @@ fn handle_func(
expr, expr,
patch, patch,
stylist, stylist,
)?), )),
_ => None, _ => None,
}; };
Ok(check) check
} }
fn handle_next_on_six_dict(expr: &Expr, patch: bool, checker: &Checker) -> Result<Option<Check>> { fn handle_next_on_six_dict(expr: &Expr, patch: bool, checker: &Checker) -> Option<Check> {
let ExprKind::Call { func, args, .. } = &expr.node else { let ExprKind::Call { func, args, .. } = &expr.node else {
return Ok(None); return None;
}; };
let ExprKind::Name { id, .. } = &func.node else { let ExprKind::Name { id, .. } = &func.node else {
return Ok(None); return None;
}; };
if id != "next" { if id != "next" {
return Ok(None); return None;
} }
let [arg] = &args[..] else { return Ok(None); }; let [arg] = &args[..] else { return None; };
let call_path = dealias_call_path(collect_call_paths(arg), &checker.import_aliases); let call_path = dealias_call_path(collect_call_paths(arg), &checker.import_aliases);
if !is_module_member(&call_path, "six") { if !is_module_member(&call_path, "six") {
return Ok(None); return None;
} }
let ExprKind::Call { func, args, .. } = &arg.node else {return Ok(None);}; let ExprKind::Call { func, args, .. } = &arg.node else {return None;};
let ExprKind::Attribute { attr, .. } = &func.node else {return Ok(None);}; let ExprKind::Attribute { attr, .. } = &func.node else {return None;};
let [dict_arg] = &args[..] else {return Ok(None);}; let [dict_arg] = &args[..] else {return None;};
let method_name = match attr.as_str() { let method_name = match attr.as_str() {
"iteritems" => "items", "iteritems" => "items",
"iterkeys" => "keys", "iterkeys" => "keys",
"itervalues" => "values", "itervalues" => "values",
_ => return Ok(None), _ => return None,
}; };
match replace_by_expr_kind( Some(replace_by_expr_kind(
ExprKind::Call { ExprKind::Call {
func: Box::new(create_expr(ExprKind::Name { func: Box::new(create_expr(ExprKind::Name {
id: "iter".to_string(), id: "iter".to_string(),
@ -429,25 +417,16 @@ fn handle_next_on_six_dict(expr: &Expr, patch: bool, checker: &Checker) -> Resul
arg, arg,
patch, patch,
checker.style, checker.style,
) { ))
Ok(check) => Ok(Some(check)),
Err(err) => Err(err),
}
} }
/// UP016 /// UP016
pub fn remove_six_compat(checker: &mut Checker, expr: &Expr) { pub fn remove_six_compat(checker: &mut Checker, expr: &Expr) {
match handle_next_on_six_dict(expr, checker.patch(&CheckCode::UP016), checker) { if let Some(check) = handle_next_on_six_dict(expr, checker.patch(&CheckCode::UP016), checker) {
Ok(Some(check)) => { checker.add_check(check);
checker.add_check(check); return;
return; }
}
Ok(None) => (),
Err(err) => {
error!("Error while removing `six` reference: {}", err);
return;
}
};
let call_path = dealias_call_path(collect_call_paths(expr), &checker.import_aliases); let call_path = dealias_call_path(collect_call_paths(expr), &checker.import_aliases);
if is_module_member(&call_path, "six") { if is_module_member(&call_path, "six") {
let patch = checker.patch(&CheckCode::UP016); let patch = checker.patch(&CheckCode::UP016);
@ -456,7 +435,7 @@ pub fn remove_six_compat(checker: &mut Checker, expr: &Expr) {
func, func,
args, args,
keywords, keywords,
} => match handle_func( } => handle_func(
func, func,
args, args,
keywords, keywords,
@ -464,13 +443,7 @@ pub fn remove_six_compat(checker: &mut Checker, expr: &Expr) {
patch, patch,
checker.style, checker.style,
checker.locator, checker.locator,
) { ),
Ok(check) => check,
Err(err) => {
error!("Failed to remove `six` reference: {err}");
return;
}
},
ExprKind::Attribute { attr, .. } => map_name(attr.as_str(), expr, patch), ExprKind::Attribute { attr, .. } => map_name(attr.as_str(), expr, patch),
ExprKind::Name { id, .. } => map_name(id.as_str(), expr, patch), ExprKind::Name { id, .. } => map_name(id.as_str(), expr, patch),
_ => return, _ => return,

View file

@ -1,4 +1,3 @@
use log::error;
use rustpython_ast::{Constant, Expr, ExprKind, Location, Operator}; use rustpython_ast::{Constant, Expr, ExprKind, Location, Operator};
use crate::ast::helpers::{collect_call_paths, dealias_call_path}; use crate::ast::helpers::{collect_call_paths, dealias_call_path};
@ -72,16 +71,11 @@ pub fn use_pep604_annotation(checker: &mut Checker, expr: &Expr, value: &Expr, s
checker.style.line_ending(), checker.style.line_ending(),
); );
generator.unparse_expr(&optional(slice), 0); generator.unparse_expr(&optional(slice), 0);
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( expr.location,
content, expr.end_location.unwrap(),
expr.location, ));
expr.end_location.unwrap(),
));
}
Err(e) => error!("Failed to rewrite PEP604 annotation: {e}"),
};
} }
checker.add_check(check); checker.add_check(check);
} else if checker.match_typing_call_path(&call_path, "Union") { } else if checker.match_typing_call_path(&call_path, "Union") {
@ -98,16 +92,11 @@ pub fn use_pep604_annotation(checker: &mut Checker, expr: &Expr, value: &Expr, s
checker.style.line_ending(), checker.style.line_ending(),
); );
generator.unparse_expr(&union(elts), 0); generator.unparse_expr(&union(elts), 0);
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( expr.location,
content, expr.end_location.unwrap(),
expr.location, ));
expr.end_location.unwrap(),
));
}
Err(e) => error!("Failed to rewrite PEP604 annotation: {e}"),
}
} }
_ => { _ => {
// Single argument. // Single argument.
@ -117,16 +106,11 @@ pub fn use_pep604_annotation(checker: &mut Checker, expr: &Expr, value: &Expr, s
checker.style.line_ending(), checker.style.line_ending(),
); );
generator.unparse_expr(slice, 0); generator.unparse_expr(slice, 0);
match generator.generate() { check.amend(Fix::replacement(
Ok(content) => { generator.generate(),
check.amend(Fix::replacement( expr.location,
content, expr.end_location.unwrap(),
expr.location, ));
expr.end_location.unwrap(),
));
}
Err(e) => error!("Failed to rewrite PEP604 annotation: {e}"),
}
} }
} }
} }

View file

@ -182,64 +182,34 @@ impl Settings {
// Plugins // Plugins
flake8_annotations: config flake8_annotations: config
.flake8_annotations .flake8_annotations
.map(std::convert::Into::into) .map(Into::into)
.unwrap_or_default(),
flake8_bandit: config
.flake8_bandit
.map(std::convert::Into::into)
.unwrap_or_default(),
flake8_bugbear: config
.flake8_bugbear
.map(std::convert::Into::into)
.unwrap_or_default(),
flake8_errmsg: config
.flake8_errmsg
.map(std::convert::Into::into)
.unwrap_or_default(), .unwrap_or_default(),
flake8_bandit: config.flake8_bandit.map(Into::into).unwrap_or_default(),
flake8_bugbear: config.flake8_bugbear.map(Into::into).unwrap_or_default(),
flake8_errmsg: config.flake8_errmsg.map(Into::into).unwrap_or_default(),
flake8_import_conventions: config flake8_import_conventions: config
.flake8_import_conventions .flake8_import_conventions
.map(std::convert::Into::into) .map(Into::into)
.unwrap_or_default(), .unwrap_or_default(),
flake8_pytest_style: config flake8_pytest_style: config
.flake8_pytest_style .flake8_pytest_style
.map(std::convert::Into::into) .map(Into::into)
.unwrap_or_default(),
flake8_quotes: config
.flake8_quotes
.map(std::convert::Into::into)
.unwrap_or_default(), .unwrap_or_default(),
flake8_quotes: config.flake8_quotes.map(Into::into).unwrap_or_default(),
flake8_tidy_imports: config flake8_tidy_imports: config
.flake8_tidy_imports .flake8_tidy_imports
.map(std::convert::Into::into) .map(Into::into)
.unwrap_or_default(), .unwrap_or_default(),
flake8_unused_arguments: config flake8_unused_arguments: config
.flake8_unused_arguments .flake8_unused_arguments
.map(std::convert::Into::into) .map(Into::into)
.unwrap_or_default(),
isort: config
.isort
.map(std::convert::Into::into)
.unwrap_or_default(),
mccabe: config
.mccabe
.map(std::convert::Into::into)
.unwrap_or_default(),
pep8_naming: config
.pep8_naming
.map(std::convert::Into::into)
.unwrap_or_default(),
pycodestyle: config
.pycodestyle
.map(std::convert::Into::into)
.unwrap_or_default(),
pydocstyle: config
.pydocstyle
.map(std::convert::Into::into)
.unwrap_or_default(),
pyupgrade: config
.pyupgrade
.map(std::convert::Into::into)
.unwrap_or_default(), .unwrap_or_default(),
isort: config.isort.map(Into::into).unwrap_or_default(),
mccabe: config.mccabe.map(Into::into).unwrap_or_default(),
pep8_naming: config.pep8_naming.map(Into::into).unwrap_or_default(),
pycodestyle: config.pycodestyle.map(Into::into).unwrap_or_default(),
pydocstyle: config.pydocstyle.map(Into::into).unwrap_or_default(),
pyupgrade: config.pyupgrade.map(Into::into).unwrap_or_default(),
}) })
} }
@ -393,7 +363,7 @@ pub fn resolve_globset(patterns: Vec<FilePattern>) -> Result<GlobSet> {
for pattern in patterns { for pattern in patterns {
pattern.add_to(&mut builder)?; pattern.add_to(&mut builder)?;
} }
builder.build().map_err(std::convert::Into::into) builder.build().map_err(Into::into)
} }
/// Given a list of patterns, create a `GlobSet`. /// Given a list of patterns, create a `GlobSet`.

View file

@ -31,13 +31,13 @@ impl Pyproject {
/// Parse a `ruff.toml` file. /// Parse a `ruff.toml` file.
fn parse_ruff_toml<P: AsRef<Path>>(path: P) -> Result<Options> { fn parse_ruff_toml<P: AsRef<Path>>(path: P) -> Result<Options> {
let contents = fs::read_file(path)?; let contents = fs::read_file(path)?;
toml::from_str(&contents).map_err(std::convert::Into::into) toml::from_str(&contents).map_err(Into::into)
} }
/// Parse a `pyproject.toml` file. /// Parse a `pyproject.toml` file.
fn parse_pyproject_toml<P: AsRef<Path>>(path: P) -> Result<Pyproject> { fn parse_pyproject_toml<P: AsRef<Path>>(path: P) -> Result<Pyproject> {
let contents = fs::read_file(path)?; let contents = fs::read_file(path)?;
toml::from_str(&contents).map_err(std::convert::Into::into) toml::from_str(&contents).map_err(Into::into)
} }
/// Return `true` if a `pyproject.toml` contains a `[tool.ruff]` section. /// Return `true` if a `pyproject.toml` contains a `[tool.ruff]` section.

View file

@ -2,9 +2,7 @@
use std::fmt; use std::fmt;
use std::ops::Deref; use std::ops::Deref;
use std::string::FromUtf8Error;
use anyhow::Result;
use rustpython_ast::{Excepthandler, ExcepthandlerKind, Suite, Withitem}; use rustpython_ast::{Excepthandler, ExcepthandlerKind, Suite, Withitem};
use rustpython_parser::ast::{ use rustpython_parser::ast::{
Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Constant, ConversionFlag, Expr, ExprKind, Alias, Arg, Arguments, Boolop, Cmpop, Comprehension, Constant, ConversionFlag, Expr, ExprKind,
@ -60,8 +58,8 @@ impl<'a> SourceCodeGenerator<'a> {
} }
} }
pub fn generate(self) -> Result<String, FromUtf8Error> { pub fn generate(self) -> String {
String::from_utf8(self.buffer) String::from_utf8(self.buffer).expect("Generated source code is not valid UTF-8")
} }
fn newline(&mut self) { fn newline(&mut self) {
@ -1030,21 +1028,20 @@ impl<'a> SourceCodeGenerator<'a> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use anyhow::Result;
use rustpython_parser::parser; use rustpython_parser::parser;
use crate::source_code_generator::SourceCodeGenerator; use crate::source_code_generator::SourceCodeGenerator;
use crate::source_code_style::{Indentation, LineEnding, Quote}; use crate::source_code_style::{Indentation, LineEnding, Quote};
fn round_trip(contents: &str) -> Result<String> { fn round_trip(contents: &str) -> String {
let indentation = Indentation::default(); let indentation = Indentation::default();
let quote = Quote::default(); let quote = Quote::default();
let line_ending = LineEnding::default(); let line_ending = LineEnding::default();
let program = parser::parse_program(contents, "<filename>")?; let program = parser::parse_program(contents, "<filename>").unwrap();
let stmt = program.first().unwrap(); let stmt = program.first().unwrap();
let mut generator = SourceCodeGenerator::new(&indentation, &quote, &line_ending); let mut generator = SourceCodeGenerator::new(&indentation, &quote, &line_ending);
generator.unparse_stmt(stmt); generator.unparse_stmt(stmt);
generator.generate().map_err(std::convert::Into::into) generator.generate()
} }
fn round_trip_with( fn round_trip_with(
@ -1052,30 +1049,29 @@ mod tests {
quote: &Quote, quote: &Quote,
line_ending: &LineEnding, line_ending: &LineEnding,
contents: &str, contents: &str,
) -> Result<String> { ) -> String {
let program = parser::parse_program(contents, "<filename>")?; let program = parser::parse_program(contents, "<filename>").unwrap();
let stmt = program.first().unwrap(); let stmt = program.first().unwrap();
let mut generator = SourceCodeGenerator::new(indentation, quote, line_ending); let mut generator = SourceCodeGenerator::new(indentation, quote, line_ending);
generator.unparse_stmt(stmt); generator.unparse_stmt(stmt);
generator.generate().map_err(std::convert::Into::into) generator.generate()
} }
#[test] #[test]
fn quote() -> Result<()> { fn quote() {
assert_eq!(round_trip(r#""hello""#)?, r#""hello""#); assert_eq!(round_trip(r#""hello""#), r#""hello""#);
assert_eq!(round_trip(r#"'hello'"#)?, r#""hello""#); assert_eq!(round_trip(r#"'hello'"#), r#""hello""#);
assert_eq!(round_trip(r#"u'hello'"#)?, r#"u"hello""#); assert_eq!(round_trip(r#"u'hello'"#), r#"u"hello""#);
assert_eq!(round_trip(r#"r'hello'"#)?, r#""hello""#); assert_eq!(round_trip(r#"r'hello'"#), r#""hello""#);
assert_eq!(round_trip(r#"b'hello'"#)?, r#"b"hello""#); assert_eq!(round_trip(r#"b'hello'"#), r#"b"hello""#);
assert_eq!(round_trip(r#"("abc" "def" "ghi")"#)?, r#""abcdefghi""#); assert_eq!(round_trip(r#"("abc" "def" "ghi")"#), r#""abcdefghi""#);
assert_eq!(round_trip(r#""he\"llo""#)?, r#"'he"llo'"#); assert_eq!(round_trip(r#""he\"llo""#), r#"'he"llo'"#);
assert_eq!(round_trip(r#"f'abc{"def"}{1}'"#)?, r#"f'abc{"def"}{1}'"#); assert_eq!(round_trip(r#"f'abc{"def"}{1}'"#), r#"f'abc{"def"}{1}'"#);
assert_eq!(round_trip(r#"f"abc{'def'}{1}""#)?, r#"f'abc{"def"}{1}'"#); assert_eq!(round_trip(r#"f"abc{'def'}{1}""#), r#"f'abc{"def"}{1}'"#);
Ok(())
} }
#[test] #[test]
fn indent() -> Result<()> { fn indent() {
assert_eq!( assert_eq!(
round_trip( round_trip(
r#" r#"
@ -1083,25 +1079,24 @@ if True:
pass pass
"# "#
.trim(), .trim(),
)?, ),
r#" r#"
if True: if True:
pass pass
"# "#
.trim() .trim()
); );
Ok(())
} }
#[test] #[test]
fn set_quote() -> Result<()> { fn set_quote() {
assert_eq!( assert_eq!(
round_trip_with( round_trip_with(
&Indentation::default(), &Indentation::default(),
&Quote::Double, &Quote::Double,
&LineEnding::default(), &LineEnding::default(),
r#""hello""# r#""hello""#
)?, ),
r#""hello""# r#""hello""#
); );
assert_eq!( assert_eq!(
@ -1110,7 +1105,7 @@ if True:
&Quote::Single, &Quote::Single,
&LineEnding::default(), &LineEnding::default(),
r#""hello""# r#""hello""#
)?, ),
r#"'hello'"# r#"'hello'"#
); );
assert_eq!( assert_eq!(
@ -1119,7 +1114,7 @@ if True:
&Quote::Double, &Quote::Double,
&LineEnding::default(), &LineEnding::default(),
r#"'hello'"# r#"'hello'"#
)?, ),
r#""hello""# r#""hello""#
); );
assert_eq!( assert_eq!(
@ -1128,14 +1123,13 @@ if True:
&Quote::Single, &Quote::Single,
&LineEnding::default(), &LineEnding::default(),
r#"'hello'"# r#"'hello'"#
)?, ),
r#"'hello'"# r#"'hello'"#
); );
Ok(())
} }
#[test] #[test]
fn set_indent() -> Result<()> { fn set_indent() {
assert_eq!( assert_eq!(
round_trip_with( round_trip_with(
&Indentation::new(" ".to_string()), &Indentation::new(" ".to_string()),
@ -1146,7 +1140,7 @@ if True:
pass pass
"# "#
.trim(), .trim(),
)?, ),
r#" r#"
if True: if True:
pass pass
@ -1163,7 +1157,7 @@ if True:
pass pass
"# "#
.trim(), .trim(),
)?, ),
r#" r#"
if True: if True:
pass pass
@ -1180,26 +1174,24 @@ if True:
pass pass
"# "#
.trim(), .trim(),
)?, ),
r#" r#"
if True: if True:
pass pass
"# "#
.trim() .trim()
); );
Ok(())
} }
#[test] #[test]
fn set_line_ending() -> Result<()> { fn set_line_ending() {
assert_eq!( assert_eq!(
round_trip_with( round_trip_with(
&Indentation::default(), &Indentation::default(),
&Quote::default(), &Quote::default(),
&LineEnding::Lf, &LineEnding::Lf,
"if True:\n print(42)", "if True:\n print(42)",
)?, ),
"if True:\n print(42)", "if True:\n print(42)",
); );
@ -1209,7 +1201,7 @@ if True:
&Quote::default(), &Quote::default(),
&LineEnding::CrLf, &LineEnding::CrLf,
"if True:\n print(42)", "if True:\n print(42)",
)?, ),
"if True:\r\n print(42)", "if True:\r\n print(42)",
); );
@ -1219,10 +1211,8 @@ if True:
&Quote::default(), &Quote::default(),
&LineEnding::Cr, &LineEnding::Cr,
"if True:\n print(42)", "if True:\n print(42)",
)?, ),
"if True:\r print(42)", "if True:\r print(42)",
); );
Ok(())
} }
} }