implement E307 for pylint invalid str return type (#4854)

This commit is contained in:
Ryan Yang 2023-06-05 10:54:15 -07:00 committed by GitHub
parent e6b00f0c4e
commit 72245960a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 166 additions and 11 deletions

View file

@ -0,0 +1,28 @@
class Str:
def __str__(self):
return 1
class Float:
def __str__(self):
return 3.05
class Int:
def __str__(self):
return 0
class Bool:
def __str__(self):
return False
class Str2:
def __str__(self):
x = "ruff"
return x
# TODO fixme once Ruff has better type checking
def return_int():
return 3
class ComplexReturn:
def __str__(self):
return return_int()

View file

@ -382,6 +382,10 @@ where
}
}
if self.enabled(Rule::InvalidStrReturnType) {
pylint::rules::invalid_str_return(self, name, body);
}
if self.enabled(Rule::InvalidFunctionName) {
if let Some(diagnostic) = pep8_naming::rules::invalid_function_name(
stmt,

View file

@ -167,6 +167,7 @@ pub fn code_to_rule(linter: Linter, code: &str) -> Option<(RuleGroup, Rule)> {
(Pylint, "E0118") => (RuleGroup::Unspecified, rules::pylint::rules::LoadBeforeGlobalDeclaration),
(Pylint, "E0241") => (RuleGroup::Unspecified, rules::pylint::rules::DuplicateBases),
(Pylint, "E0302") => (RuleGroup::Unspecified, rules::pylint::rules::UnexpectedSpecialMethodSignature),
(Pylint, "E0307") => (RuleGroup::Unspecified, rules::pylint::rules::InvalidStrReturnType),
(Pylint, "E0604") => (RuleGroup::Unspecified, rules::pylint::rules::InvalidAllObject),
(Pylint, "E0605") => (RuleGroup::Unspecified, rules::pylint::rules::InvalidAllFormat),
(Pylint, "E1142") => (RuleGroup::Unspecified, rules::pylint::rules::AwaitOutsideAsync),

View file

@ -418,12 +418,14 @@ impl Violation for AnyType {
fn is_none_returning(body: &[Stmt]) -> bool {
let mut visitor = ReturnStatementVisitor::default();
visitor.visit_body(body);
for expr in visitor.returns.into_iter().flatten() {
if !matches!(
expr,
Expr::Constant(ref constant) if constant.value.is_none()
) {
return false;
for stmt in visitor.returns {
if let Some(value) = stmt.value.as_deref() {
if !matches!(
value,
Expr::Constant(constant) if constant.value.is_none()
) {
return false;
}
}
}
true

View file

@ -52,6 +52,7 @@ mod tests {
#[test_case(Rule::ImportSelf, Path::new("import_self/module.py"))]
#[test_case(Rule::InvalidAllFormat, Path::new("invalid_all_format.py"))]
#[test_case(Rule::InvalidAllObject, Path::new("invalid_all_object.py"))]
#[test_case(Rule::InvalidStrReturnType, Path::new("invalid_return_type_str.py"))]
#[test_case(Rule::DuplicateBases, Path::new("duplicate_bases.py"))]
#[test_case(Rule::DuplicateValue, Path::new("duplicate_value.py"))]
#[test_case(Rule::InvalidCharacterBackspace, Path::new("invalid_characters.py"))]

View file

@ -0,0 +1,75 @@
use rustpython_parser::ast::{self, Constant, Expr, Ranged, Stmt};
use ruff_diagnostics::{Diagnostic, Violation};
use ruff_macros::{derive_message_formats, violation};
use ruff_python_ast::{helpers::ReturnStatementVisitor, statement_visitor::StatementVisitor};
use crate::checkers::ast::Checker;
/// ## What it does
/// Checks for `__str__` implementations that return a type other than `str`.
///
/// ## Why is this bad?
/// The `__str__` method should return a `str` object. Returning a different
/// type may cause unexpected behavior.
#[violation]
pub struct InvalidStrReturnType;
impl Violation for InvalidStrReturnType {
#[derive_message_formats]
fn message(&self) -> String {
format!("`__str__` does not return `str`")
}
}
/// E0307
pub(crate) fn invalid_str_return(checker: &mut Checker, name: &str, body: &[Stmt]) {
if name != "__str__" {
return;
}
if !checker.semantic_model().scope().kind.is_class() {
return;
}
let returns = {
let mut visitor = ReturnStatementVisitor::default();
visitor.visit_body(body);
visitor.returns
};
for stmt in returns {
// Disallow implicit `None`.
let Some(value) = stmt.value.as_deref() else {
checker.diagnostics.push(Diagnostic::new(InvalidStrReturnType, stmt.range()));
continue;
};
// Disallow other constants.
if matches!(
value,
Expr::List(_)
| Expr::Dict(_)
| Expr::Set(_)
| Expr::ListComp(_)
| Expr::DictComp(_)
| Expr::SetComp(_)
| Expr::GeneratorExp(_)
| Expr::Constant(ast::ExprConstant {
value: Constant::None
| Constant::Bool(_)
| Constant::Bytes(_)
| Constant::Int(_)
| Constant::Tuple(_)
| Constant::Float(_)
| Constant::Complex { .. }
| Constant::Ellipsis,
..
})
) {
checker
.diagnostics
.push(Diagnostic::new(InvalidStrReturnType, value.range()));
}
}
}

View file

@ -17,6 +17,7 @@ pub(crate) use invalid_all_format::{invalid_all_format, InvalidAllFormat};
pub(crate) use invalid_all_object::{invalid_all_object, InvalidAllObject};
pub(crate) use invalid_envvar_default::{invalid_envvar_default, InvalidEnvvarDefault};
pub(crate) use invalid_envvar_value::{invalid_envvar_value, InvalidEnvvarValue};
pub(crate) use invalid_str_return::{invalid_str_return, InvalidStrReturnType};
pub(crate) use invalid_string_characters::{
invalid_string_characters, InvalidCharacterBackspace, InvalidCharacterEsc, InvalidCharacterNul,
InvalidCharacterSub, InvalidCharacterZeroWidthSpace,
@ -73,6 +74,7 @@ mod invalid_all_format;
mod invalid_all_object;
mod invalid_envvar_default;
mod invalid_envvar_value;
mod invalid_str_return;
mod invalid_string_characters;
mod iteration_over_set;
mod load_before_global_declaration;

View file

@ -0,0 +1,44 @@
---
source: crates/ruff/src/rules/pylint/mod.rs
---
invalid_return_type_str.py:3:16: PLE0307 `__str__` does not return `str`
|
3 | class Str:
4 | def __str__(self):
5 | return 1
| ^ PLE0307
6 |
7 | class Float:
|
invalid_return_type_str.py:7:16: PLE0307 `__str__` does not return `str`
|
7 | class Float:
8 | def __str__(self):
9 | return 3.05
| ^^^^ PLE0307
10 |
11 | class Int:
|
invalid_return_type_str.py:11:16: PLE0307 `__str__` does not return `str`
|
11 | class Int:
12 | def __str__(self):
13 | return 0
| ^ PLE0307
14 |
15 | class Bool:
|
invalid_return_type_str.py:15:16: PLE0307 `__str__` does not return `str`
|
15 | class Bool:
16 | def __str__(self):
17 | return False
| ^^^^^ PLE0307
18 |
19 | class Str2:
|

View file

@ -907,7 +907,7 @@ pub fn resolve_imported_module_path<'a>(
/// A [`StatementVisitor`] that collects all `return` statements in a function or method.
#[derive(Default)]
pub struct ReturnStatementVisitor<'a> {
pub returns: Vec<Option<&'a Expr>>,
pub returns: Vec<&'a ast::StmtReturn>,
}
impl<'a, 'b> StatementVisitor<'b> for ReturnStatementVisitor<'a>
@ -919,10 +919,7 @@ where
Stmt::FunctionDef(_) | Stmt::AsyncFunctionDef(_) => {
// Don't recurse.
}
Stmt::Return(ast::StmtReturn {
value,
range: _range,
}) => self.returns.push(value.as_deref()),
Stmt::Return(stmt) => self.returns.push(stmt),
_ => walk_stmt(self, stmt),
}
}