[flake8-simplify] Only trigger SIM401 on known dictionaries (SIM401) (#15995)

## Summary

This change resolves #15814 to ensure that `SIM401` is only triggered on
known dictionary types. Before, the rule was getting triggered even on
types that _resemble_ a dictionary but are not actually a dictionary.

I did this using the `is_known_to_be_of_type_dict(...)` functionality.
The logic for this function was duplicated in a few spots, so I moved
the code to a central location, removed redundant definitions, and
updated existing calls to use the single definition of the function!

## Test Plan

Since this PR only modifies an existing rule, I made changes to the
existing test instead of adding new ones. I made sure that `SIM401` is
triggered on types that are clearly dictionaries and that it's not
triggered on a simple custom dictionary-like type (using a modified
version of [the code in the issue](#15814))

The additional changes to de-duplicate `is_known_to_be_of_type_dict`
don't break any existing tests -- I think this should be fine since the
logic remains the same (please let me know if you think otherwise, I'm
excited to get feedback and work towards a good fix 🙂).

---------

Co-authored-by: Junhson Jean-Baptiste <junhsonjb@naan.mynetworksettings.com>
Co-authored-by: Micha Reiser <micha@reiser.io>
This commit is contained in:
Junhson Jean-Baptiste 2025-02-07 03:25:20 -05:00 committed by GitHub
parent bb979e05ac
commit 349f93389e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 174 additions and 172 deletions

View file

@ -2,6 +2,8 @@
# Positive cases
###
a_dict = {}
# SIM401 (pattern-1)
if key in a_dict:
var = a_dict[key]
@ -26,6 +28,8 @@ if keys[idx] in a_dict:
else:
var = "default"
dicts = {"key": a_dict}
# SIM401 (complex expression in dict)
if key in dicts[idx]:
var = dicts[idx][key]
@ -115,6 +119,28 @@ elif key in a_dict:
else:
vars[idx] = "default"
class NotADictionary:
def __init__(self):
self._dict = {}
def __getitem__(self, key):
return self._dict[key]
def __setitem__(self, key, value):
self._dict[key] = value
def __iter__(self):
return self._dict.__iter__()
not_dict = NotADictionary()
not_dict["key"] = "value"
# OK (type `NotADictionary` is not a known dictionary type)
if "key" in not_dict:
value = not_dict["key"]
else:
value = None
###
# Positive cases (preview)
###

View file

@ -5,7 +5,9 @@ use ruff_python_ast::helpers::contains_effect;
use ruff_python_ast::{
self as ast, Arguments, CmpOp, ElifElseClause, Expr, ExprContext, Identifier, Stmt,
};
use ruff_python_semantic::analyze::typing::{is_sys_version_block, is_type_checking_block};
use ruff_python_semantic::analyze::typing::{
is_known_to_be_of_type_dict, is_sys_version_block, is_type_checking_block,
};
use ruff_text_size::{Ranged, TextRange};
use crate::checkers::ast::Checker;
@ -113,18 +115,27 @@ pub(crate) fn if_else_block_instead_of_dict_get(checker: &mut Checker, stmt_if:
let [orelse_var] = orelse_var.as_slice() else {
return;
};
let Expr::Compare(ast::ExprCompare {
left: test_key,
ops,
comparators: test_dict,
range: _,
}) = test.as_ref()
}) = &**test
else {
return;
};
let [test_dict] = &**test_dict else {
return;
};
if !test_dict
.as_name_expr()
.is_some_and(|dict_name| is_known_to_be_of_type_dict(checker.semantic(), dict_name))
{
return;
}
let (expected_var, expected_value, default_var, default_value) = match ops[..] {
[CmpOp::In] => (body_var, body_value, orelse_var, orelse_value.as_ref()),
[CmpOp::NotIn] => (orelse_var, orelse_value, body_var, body_value.as_ref()),

View file

@ -1,199 +1,173 @@
---
source: crates/ruff_linter/src/rules/flake8_simplify/mod.rs
---
SIM401.py:6:1: SIM401 [*] Use `var = a_dict.get(key, "default1")` instead of an `if` block
SIM401.py:8:1: SIM401 [*] Use `var = a_dict.get(key, "default1")` instead of an `if` block
|
5 | # SIM401 (pattern-1)
6 | / if key in a_dict:
7 | | var = a_dict[key]
8 | | else:
9 | | var = "default1"
7 | # SIM401 (pattern-1)
8 | / if key in a_dict:
9 | | var = a_dict[key]
10 | | else:
11 | | var = "default1"
| |____________________^ SIM401
10 |
11 | # SIM401 (pattern-2)
12 |
13 | # SIM401 (pattern-2)
|
= help: Replace with `var = a_dict.get(key, "default1")`
Unsafe fix
3 3 | ###
4 4 |
5 5 | # SIM401 (pattern-1)
6 |-if key in a_dict:
7 |- var = a_dict[key]
8 |-else:
9 |- var = "default1"
6 |+var = a_dict.get(key, "default1")
10 7 |
11 8 | # SIM401 (pattern-2)
12 9 | if key not in a_dict:
5 5 | a_dict = {}
6 6 |
7 7 | # SIM401 (pattern-1)
8 |-if key in a_dict:
9 |- var = a_dict[key]
10 |-else:
11 |- var = "default1"
8 |+var = a_dict.get(key, "default1")
12 9 |
13 10 | # SIM401 (pattern-2)
14 11 | if key not in a_dict:
SIM401.py:12:1: SIM401 [*] Use `var = a_dict.get(key, "default2")` instead of an `if` block
SIM401.py:14:1: SIM401 [*] Use `var = a_dict.get(key, "default2")` instead of an `if` block
|
11 | # SIM401 (pattern-2)
12 | / if key not in a_dict:
13 | | var = "default2"
14 | | else:
15 | | var = a_dict[key]
13 | # SIM401 (pattern-2)
14 | / if key not in a_dict:
15 | | var = "default2"
16 | | else:
17 | | var = a_dict[key]
| |_____________________^ SIM401
16 |
17 | # OK (default contains effect)
18 |
19 | # OK (default contains effect)
|
= help: Replace with `var = a_dict.get(key, "default2")`
Unsafe fix
9 9 | var = "default1"
10 10 |
11 11 | # SIM401 (pattern-2)
12 |-if key not in a_dict:
13 |- var = "default2"
14 |-else:
15 |- var = a_dict[key]
12 |+var = a_dict.get(key, "default2")
16 13 |
17 14 | # OK (default contains effect)
18 15 | if key in a_dict:
11 11 | var = "default1"
12 12 |
13 13 | # SIM401 (pattern-2)
14 |-if key not in a_dict:
15 |- var = "default2"
16 |-else:
17 |- var = a_dict[key]
14 |+var = a_dict.get(key, "default2")
18 15 |
19 16 | # OK (default contains effect)
20 17 | if key in a_dict:
SIM401.py:24:1: SIM401 [*] Use `var = a_dict.get(keys[idx], "default")` instead of an `if` block
SIM401.py:26:1: SIM401 [*] Use `var = a_dict.get(keys[idx], "default")` instead of an `if` block
|
23 | # SIM401 (complex expression in key)
24 | / if keys[idx] in a_dict:
25 | | var = a_dict[keys[idx]]
26 | | else:
27 | | var = "default"
25 | # SIM401 (complex expression in key)
26 | / if keys[idx] in a_dict:
27 | | var = a_dict[keys[idx]]
28 | | else:
29 | | var = "default"
| |___________________^ SIM401
28 |
29 | # SIM401 (complex expression in dict)
30 |
31 | dicts = {"key": a_dict}
|
= help: Replace with `var = a_dict.get(keys[idx], "default")`
Unsafe fix
21 21 | var = val1 + val2
22 22 |
23 23 | # SIM401 (complex expression in key)
24 |-if keys[idx] in a_dict:
25 |- var = a_dict[keys[idx]]
26 |-else:
27 |- var = "default"
24 |+var = a_dict.get(keys[idx], "default")
28 25 |
29 26 | # SIM401 (complex expression in dict)
30 27 | if key in dicts[idx]:
23 23 | var = val1 + val2
24 24 |
25 25 | # SIM401 (complex expression in key)
26 |-if keys[idx] in a_dict:
27 |- var = a_dict[keys[idx]]
28 |-else:
29 |- var = "default"
26 |+var = a_dict.get(keys[idx], "default")
30 27 |
31 28 | dicts = {"key": a_dict}
32 29 |
SIM401.py:30:1: SIM401 [*] Use `var = dicts[idx].get(key, "default")` instead of an `if` block
SIM401.py:40:1: SIM401 [*] Use `vars[idx] = a_dict.get(key, "defaultß9💣26789ß9💣26789ß9💣26789ß9💣26789ß9💣26789")` instead of an `if` block
|
29 | # SIM401 (complex expression in dict)
30 | / if key in dicts[idx]:
31 | | var = dicts[idx][key]
32 | | else:
33 | | var = "default"
| |___________________^ SIM401
34 |
35 | # SIM401 (complex expression in var)
|
= help: Replace with `var = dicts[idx].get(key, "default")`
Unsafe fix
27 27 | var = "default"
28 28 |
29 29 | # SIM401 (complex expression in dict)
30 |-if key in dicts[idx]:
31 |- var = dicts[idx][key]
32 |-else:
33 |- var = "default"
30 |+var = dicts[idx].get(key, "default")
34 31 |
35 32 | # SIM401 (complex expression in var)
36 33 | if key in a_dict:
SIM401.py:36:1: SIM401 [*] Use `vars[idx] = a_dict.get(key, "defaultß9💣26789ß9💣26789ß9💣26789ß9💣26789ß9💣26789")` instead of an `if` block
|
35 | # SIM401 (complex expression in var)
36 | / if key in a_dict:
37 | | vars[idx] = a_dict[key]
38 | | else:
39 | | vars[idx] = "defaultß9💣26789ß9💣26789ß9💣26789ß9💣26789ß9💣26789"
39 | # SIM401 (complex expression in var)
40 | / if key in a_dict:
41 | | vars[idx] = a_dict[key]
42 | | else:
43 | | vars[idx] = "defaultß9💣26789ß9💣26789ß9💣26789ß9💣26789ß9💣26789"
| |___________________________________________________________________________^ SIM401
40 |
41 | # SIM401
44 |
45 | # SIM401
|
= help: Replace with `vars[idx] = a_dict.get(key, "defaultß9💣26789ß9💣26789ß9💣26789ß9💣26789ß9💣26789")`
Unsafe fix
33 33 | var = "default"
34 34 |
35 35 | # SIM401 (complex expression in var)
36 |-if key in a_dict:
37 |- vars[idx] = a_dict[key]
38 |-else:
39 |- vars[idx] = "defaultß9💣26789ß9💣26789ß9💣26789ß9💣26789ß9💣26789"
36 |+vars[idx] = a_dict.get(key, "defaultß9💣26789ß9💣26789ß9💣26789ß9💣26789ß9💣26789")
40 37 |
41 38 | # SIM401
42 39 | if foo():
37 37 | var = "default"
38 38 |
39 39 | # SIM401 (complex expression in var)
40 |-if key in a_dict:
41 |- vars[idx] = a_dict[key]
42 |-else:
43 |- vars[idx] = "defaultß9💣26789ß9💣26789ß9💣26789ß9💣26789ß9💣26789"
40 |+vars[idx] = a_dict.get(key, "defaultß9💣26789ß9💣26789ß9💣26789ß9💣26789ß9💣26789")
44 41 |
45 42 | # SIM401
46 43 | if foo():
SIM401.py:45:5: SIM401 [*] Use `vars[idx] = a_dict.get(key, "default")` instead of an `if` block
SIM401.py:49:5: SIM401 [*] Use `vars[idx] = a_dict.get(key, "default")` instead of an `if` block
|
43 | pass
44 | else:
45 | / if key in a_dict:
46 | | vars[idx] = a_dict[key]
47 | | else:
48 | | vars[idx] = "default"
47 | pass
48 | else:
49 | / if key in a_dict:
50 | | vars[idx] = a_dict[key]
51 | | else:
52 | | vars[idx] = "default"
| |_____________________________^ SIM401
49 |
50 | ###
53 |
54 | ###
|
= help: Replace with `vars[idx] = a_dict.get(key, "default")`
Unsafe fix
42 42 | if foo():
43 43 | pass
44 44 | else:
45 |- if key in a_dict:
46 |- vars[idx] = a_dict[key]
47 |- else:
48 |- vars[idx] = "default"
45 |+ vars[idx] = a_dict.get(key, "default")
49 46 |
50 47 | ###
51 48 | # Negative cases
46 46 | if foo():
47 47 | pass
48 48 | else:
49 |- if key in a_dict:
50 |- vars[idx] = a_dict[key]
51 |- else:
52 |- vars[idx] = "default"
49 |+ vars[idx] = a_dict.get(key, "default")
53 50 |
54 51 | ###
55 52 | # Negative cases
SIM401.py:123:7: SIM401 [*] Use `a_dict.get(key, "default3")` instead of an `if` block
SIM401.py:149:7: SIM401 [*] Use `a_dict.get(key, "default3")` instead of an `if` block
|
122 | # SIM401
123 | var = a_dict[key] if key in a_dict else "default3"
148 | # SIM401
149 | var = a_dict[key] if key in a_dict else "default3"
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SIM401
124 |
125 | # SIM401
150 |
151 | # SIM401
|
= help: Replace with `a_dict.get(key, "default3")`
Unsafe fix
120 120 | ###
121 121 |
122 122 | # SIM401
123 |-var = a_dict[key] if key in a_dict else "default3"
123 |+var = a_dict.get(key, "default3")
124 124 |
125 125 | # SIM401
126 126 | var = "default-1" if key not in a_dict else a_dict[key]
146 146 | ###
147 147 |
148 148 | # SIM401
149 |-var = a_dict[key] if key in a_dict else "default3"
149 |+var = a_dict.get(key, "default3")
150 150 |
151 151 | # SIM401
152 152 | var = "default-1" if key not in a_dict else a_dict[key]
SIM401.py:126:7: SIM401 [*] Use `a_dict.get(key, "default-1")` instead of an `if` block
SIM401.py:152:7: SIM401 [*] Use `a_dict.get(key, "default-1")` instead of an `if` block
|
125 | # SIM401
126 | var = "default-1" if key not in a_dict else a_dict[key]
151 | # SIM401
152 | var = "default-1" if key not in a_dict else a_dict[key]
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ SIM401
127 |
128 | # OK (default contains effect)
153 |
154 | # OK (default contains effect)
|
= help: Replace with `a_dict.get(key, "default-1")`
Unsafe fix
123 123 | var = a_dict[key] if key in a_dict else "default3"
124 124 |
125 125 | # SIM401
126 |-var = "default-1" if key not in a_dict else a_dict[key]
126 |+var = a_dict.get(key, "default-1")
127 127 |
128 128 | # OK (default contains effect)
129 129 | var = a_dict[key] if key in a_dict else val1 + val2
149 149 | var = a_dict[key] if key in a_dict else "default3"
150 150 |
151 151 | # SIM401
152 |-var = "default-1" if key not in a_dict else a_dict[key]
152 |+var = a_dict.get(key, "default-1")
153 153 |
154 154 | # OK (default contains effect)
155 155 | var = a_dict[key] if key in a_dict else val1 + val2

View file

@ -2,9 +2,8 @@ use crate::checkers::ast::Checker;
use crate::fix::edits::{remove_argument, Parentheses};
use ruff_diagnostics::{AlwaysFixableViolation, Applicability, Diagnostic, Fix};
use ruff_macros::{derive_message_formats, ViolationMetadata};
use ruff_python_ast::{helpers::Truthiness, Expr, ExprAttribute, ExprName};
use ruff_python_ast::{helpers::Truthiness, Expr, ExprAttribute};
use ruff_python_semantic::analyze::typing;
use ruff_python_semantic::SemanticModel;
use ruff_text_size::Ranged;
/// ## What it does
@ -69,7 +68,7 @@ pub(crate) fn falsy_dict_get_fallback(checker: &mut Checker, expr: &Expr) {
// Check if the object is a dictionary using the semantic model
if !value
.as_name_expr()
.is_some_and(|name| is_known_to_be_of_type_dict(semantic, name))
.is_some_and(|name| typing::is_known_to_be_of_type_dict(semantic, name))
{
return;
}
@ -110,11 +109,3 @@ pub(crate) fn falsy_dict_get_fallback(checker: &mut Checker, expr: &Expr) {
checker.diagnostics.push(diagnostic);
}
fn is_known_to_be_of_type_dict(semantic: &SemanticModel, expr: &ExprName) -> bool {
let Some(binding) = semantic.only_binding(expr).map(|id| semantic.binding(id)) else {
return false;
};
typing::is_dict(binding, semantic)
}

View file

@ -3,7 +3,6 @@ use ruff_diagnostics::{AlwaysFixableViolation, Applicability, Diagnostic, Edit,
use ruff_macros::{derive_message_formats, ViolationMetadata};
use ruff_python_ast::{CmpOp, Expr, ExprName, ExprSubscript, Stmt, StmtIf};
use ruff_python_semantic::analyze::typing;
use ruff_python_semantic::SemanticModel;
type Key = Expr;
type Dict = ExprName;
@ -60,7 +59,7 @@ pub(crate) fn if_key_in_dict_del(checker: &mut Checker, stmt: &StmtIf) {
return;
}
if !is_known_to_be_of_type_dict(checker.semantic(), test_dict) {
if !typing::is_known_to_be_of_type_dict(checker.semantic(), test_dict) {
return;
}
@ -127,14 +126,6 @@ fn is_same_dict(test: &Dict, del: &Dict) -> bool {
test.id.as_str() == del.id.as_str()
}
fn is_known_to_be_of_type_dict(semantic: &SemanticModel, dict: &Dict) -> bool {
let Some(binding) = semantic.only_binding(dict).map(|id| semantic.binding(id)) else {
return false;
};
typing::is_dict(binding, semantic)
}
fn replace_with_dict_pop_fix(checker: &Checker, stmt: &StmtIf, dict: &Dict, key: &Key) -> Fix {
let locator = checker.locator();
let dict_expr = locator.slice(dict);

View file

@ -4,7 +4,8 @@ use ruff_python_ast::helpers::{any_over_expr, is_const_false, map_subscript};
use ruff_python_ast::identifier::Identifier;
use ruff_python_ast::name::QualifiedName;
use ruff_python_ast::{
self as ast, Expr, ExprCall, Int, Operator, ParameterWithDefault, Parameters, Stmt, StmtAssign,
self as ast, Expr, ExprCall, ExprName, Int, Operator, ParameterWithDefault, Parameters, Stmt,
StmtAssign,
};
use ruff_python_stdlib::typing::{
as_pep_585_generic, has_pep_585_generic, is_immutable_generic_type,
@ -46,6 +47,14 @@ pub enum SubscriptKind {
TypedDict,
}
pub fn is_known_to_be_of_type_dict(semantic: &SemanticModel, expr: &ExprName) -> bool {
let Some(binding) = semantic.only_binding(expr).map(|id| semantic.binding(id)) else {
return false;
};
is_dict(binding, semantic)
}
pub fn match_annotated_subscript<'a>(
expr: &Expr,
semantic: &SemanticModel,