Make parent non-Optional in traverse_union (#9219)

## Summary

This protects callers from having to pass in `None`, and allows the
callback to operate as if it's always a union member.
This commit is contained in:
Charlie Marsh 2023-12-21 16:10:08 -05:00 committed by GitHub
parent b0ae1199e8
commit e241c1c5df
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 58 deletions

View file

@ -56,7 +56,7 @@ pub(crate) fn duplicate_union_member<'a>(checker: &mut Checker, expr: &'a Expr)
let mut diagnostics: Vec<Diagnostic> = Vec::new();
// Adds a member to `literal_exprs` if it is a `Literal` annotation
let mut check_for_duplicate_members = |expr: &'a Expr, parent: Option<&'a Expr>| {
let mut check_for_duplicate_members = |expr: &'a Expr, parent: &'a Expr| {
// If we've already seen this union member, raise a violation.
if !seen_nodes.insert(expr.into()) {
let mut diagnostic = Diagnostic::new(
@ -69,7 +69,7 @@ pub(crate) fn duplicate_union_member<'a>(checker: &mut Checker, expr: &'a Expr)
// parent without the duplicate.
// If the parent node is not a `BinOp` we will not perform a fix
if let Some(parent @ Expr::BinOp(ast::ExprBinOp { left, right, .. })) = parent {
if let Expr::BinOp(ast::ExprBinOp { left, right, .. }) = parent {
// Replace the parent with its non-duplicate child.
let child = if expr == left.as_ref() { right } else { left };
diagnostic.set_fix(Fix::safe_edit(Edit::range_replacement(
@ -82,12 +82,7 @@ pub(crate) fn duplicate_union_member<'a>(checker: &mut Checker, expr: &'a Expr)
};
// Traverse the union, collect all diagnostic members
traverse_union(
&mut check_for_duplicate_members,
checker.semantic(),
expr,
None,
);
traverse_union(&mut check_for_duplicate_members, checker.semantic(), expr);
// Add all diagnostics to the checker
checker.diagnostics.append(&mut diagnostics);

View file

@ -66,7 +66,7 @@ pub(crate) fn redundant_literal_union<'a>(checker: &mut Checker, union: &'a Expr
// Adds a member to `literal_exprs` for each value in a `Literal`, and any builtin types
// to `builtin_types_in_union`.
let mut func = |expr: &'a Expr, _| {
let mut func = |expr: &'a Expr, _parent: &'a Expr| {
if let Expr::Subscript(ast::ExprSubscript { value, slice, .. }) = expr {
if checker.semantic().match_typing_expr(value, "Literal") {
if let Expr::Tuple(ast::ExprTuple { elts, .. }) = slice.as_ref() {
@ -84,7 +84,7 @@ pub(crate) fn redundant_literal_union<'a>(checker: &mut Checker, union: &'a Expr
builtin_types_in_union.insert(builtin_type);
};
traverse_union(&mut func, checker.semantic(), union, None);
traverse_union(&mut func, checker.semantic(), union);
for typing_literal_expr in typing_literal_exprs {
let Some(literal_type) = match_literal_type(typing_literal_expr) else {

View file

@ -89,7 +89,7 @@ fn check_annotation(checker: &mut Checker, annotation: &Expr) {
let mut has_complex = false;
let mut has_int = false;
let mut func = |expr: &Expr, _parent: Option<&Expr>| {
let mut func = |expr: &Expr, _parent: &Expr| {
let Some(call_path) = checker.semantic().resolve_call_path(expr) else {
return;
};
@ -102,7 +102,7 @@ fn check_annotation(checker: &mut Checker, annotation: &Expr) {
}
};
traverse_union(&mut func, checker.semantic(), annotation, None);
traverse_union(&mut func, checker.semantic(), annotation);
if has_complex {
if has_float {

View file

@ -1,4 +1,4 @@
use ast::{ExprSubscript, Operator};
use ast::Operator;
use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::{self as ast, Expr};
@ -61,7 +61,7 @@ fn concatenate_bin_ors(exprs: Vec<&Expr>) -> Expr {
})
}
fn make_union(subscript: &ExprSubscript, exprs: Vec<&Expr>) -> Expr {
fn make_union(subscript: &ast::ExprSubscript, exprs: Vec<&Expr>) -> Expr {
Expr::Subscript(ast::ExprSubscript {
value: subscript.value.clone(),
slice: Box::new(Expr::Tuple(ast::ExprTuple {
@ -107,7 +107,7 @@ pub(crate) fn unnecessary_literal_union<'a>(checker: &mut Checker, expr: &'a Exp
let mut total_literals = 0;
// Split members into `literal_exprs` if they are a `Literal` annotation and `other_exprs` otherwise
let mut collect_literal_expr = |expr: &'a Expr, _| {
let mut collect_literal_expr = |expr: &'a Expr, _parent: &'a Expr| {
if let Expr::Subscript(ast::ExprSubscript { value, slice, .. }) = expr {
if checker.semantic().match_typing_expr(value, "Literal") {
total_literals += 1;
@ -136,7 +136,7 @@ pub(crate) fn unnecessary_literal_union<'a>(checker: &mut Checker, expr: &'a Exp
};
// Traverse the union, collect all members, split out the literals from the rest.
traverse_union(&mut collect_literal_expr, checker.semantic(), expr, None);
traverse_union(&mut collect_literal_expr, checker.semantic(), expr);
let union_subscript = expr.as_subscript_expr();
if union_subscript.is_some_and(|subscript| {

View file

@ -83,7 +83,7 @@ pub(crate) fn unnecessary_type_union<'a>(checker: &mut Checker, union: &'a Expr)
let mut type_exprs = Vec::new();
let mut other_exprs = Vec::new();
let mut collect_type_exprs = |expr: &'a Expr, _| {
let mut collect_type_exprs = |expr: &'a Expr, _parent: &'a Expr| {
let subscript = expr.as_subscript_expr();
if subscript.is_none() {
@ -102,7 +102,7 @@ pub(crate) fn unnecessary_type_union<'a>(checker: &mut Checker, union: &'a Expr)
}
};
traverse_union(&mut collect_type_exprs, checker.semantic(), union, None);
traverse_union(&mut collect_type_exprs, checker.semantic(), union);
if type_exprs.len() > 1 {
let type_members: Vec<String> = type_exprs

View file

@ -334,55 +334,66 @@ pub fn is_sys_version_block(stmt: &ast::StmtIf, semantic: &SemanticModel) -> boo
}
/// Traverse a "union" type annotation, applying `func` to each union member.
///
/// Supports traversal of `Union` and `|` union expressions.
/// The function is called with each expression in the union (excluding declarations of nested unions)
/// and the parent expression (if any).
pub fn traverse_union<'a, F>(
func: &mut F,
semantic: &SemanticModel,
expr: &'a Expr,
parent: Option<&'a Expr>,
) where
F: FnMut(&'a Expr, Option<&'a Expr>),
///
/// The function is called with each expression in the union (excluding declarations of nested
/// unions) and the parent expression.
pub fn traverse_union<'a, F>(func: &mut F, semantic: &SemanticModel, expr: &'a Expr)
where
F: FnMut(&'a Expr, &'a Expr),
{
// Ex) x | y
if let Expr::BinOp(ast::ExprBinOp {
op: Operator::BitOr,
left,
right,
range: _,
}) = expr
fn inner<'a, F>(
func: &mut F,
semantic: &SemanticModel,
expr: &'a Expr,
parent: Option<&'a Expr>,
) where
F: FnMut(&'a Expr, &'a Expr),
{
// The union data structure usually looks like this:
// a | b | c -> (a | b) | c
//
// However, parenthesized expressions can coerce it into any structure:
// a | (b | c)
//
// So we have to traverse both branches in order (left, then right), to report members
// in the order they appear in the source code.
// Ex) x | y
if let Expr::BinOp(ast::ExprBinOp {
op: Operator::BitOr,
left,
right,
range: _,
}) = expr
{
// The union data structure usually looks like this:
// a | b | c -> (a | b) | c
//
// However, parenthesized expressions can coerce it into any structure:
// a | (b | c)
//
// So we have to traverse both branches in order (left, then right), to report members
// in the order they appear in the source code.
// Traverse the left then right arms
traverse_union(func, semantic, left, Some(expr));
traverse_union(func, semantic, right, Some(expr));
return;
}
// Traverse the left then right arms
inner(func, semantic, left, Some(expr));
inner(func, semantic, right, Some(expr));
return;
}
// Ex) Union[x, y]
if let Expr::Subscript(ast::ExprSubscript { value, slice, .. }) = expr {
if semantic.match_typing_expr(value, "Union") {
if let Expr::Tuple(ast::ExprTuple { elts, .. }) = slice.as_ref() {
// Traverse each element of the tuple within the union recursively to handle cases
// such as `Union[..., Union[...]]
elts.iter()
.for_each(|elt| traverse_union(func, semantic, elt, Some(expr)));
return;
// Ex) Union[x, y]
if let Expr::Subscript(ast::ExprSubscript { value, slice, .. }) = expr {
if semantic.match_typing_expr(value, "Union") {
if let Expr::Tuple(ast::ExprTuple { elts, .. }) = slice.as_ref() {
// Traverse each element of the tuple within the union recursively to handle cases
// such as `Union[..., Union[...]]
elts.iter()
.for_each(|elt| inner(func, semantic, elt, Some(expr)));
return;
}
}
}
// Otherwise, call the function on expression, if it's not the top-level expression.
if let Some(parent) = parent {
func(expr, parent);
}
}
// Otherwise, call the function on expression
func(expr, parent);
inner(func, semantic, expr, None);
}
/// Abstraction for a type checker, conservatively checks for the intended type(s).