mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-27 20:42:10 +00:00
Respect insertion location when importing symbols (#4258)
This commit is contained in:
parent
a95bafefb0
commit
e66fdb83d0
14 changed files with 134 additions and 40 deletions
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_7.py
vendored
Normal file
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_7.py
vendored
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
def main():
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
import functools
|
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_8.py
vendored
Normal file
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_8.py
vendored
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
from sys import argv
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
exit(0)
|
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_9.py
vendored
Normal file
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_9.py
vendored
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
def main():
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
|
||||||
|
from sys import argv
|
|
@ -435,6 +435,7 @@ pub fn remove_argument(
|
||||||
pub fn get_or_import_symbol(
|
pub fn get_or_import_symbol(
|
||||||
module: &str,
|
module: &str,
|
||||||
member: &str,
|
member: &str,
|
||||||
|
at: TextSize,
|
||||||
context: &Context,
|
context: &Context,
|
||||||
importer: &Importer,
|
importer: &Importer,
|
||||||
locator: &Locator,
|
locator: &Locator,
|
||||||
|
@ -462,7 +463,7 @@ pub fn get_or_import_symbol(
|
||||||
Edit::range_replacement(locator.slice(source.range()).to_string(), source.range());
|
Edit::range_replacement(locator.slice(source.range()).to_string(), source.range());
|
||||||
Ok((import_edit, binding))
|
Ok((import_edit, binding))
|
||||||
} else {
|
} else {
|
||||||
if let Some(stmt) = importer.get_import_from(module) {
|
if let Some(stmt) = importer.find_import_from(module, at) {
|
||||||
// Case 1: `from functools import lru_cache` is in scope, and we're trying to reference
|
// Case 1: `from functools import lru_cache` is in scope, and we're trying to reference
|
||||||
// `functools.cache`; thus, we add `cache` to the import, and return `"cache"` as the
|
// `functools.cache`; thus, we add `cache` to the import, and return `"cache"` as the
|
||||||
// bound name.
|
// bound name.
|
||||||
|
@ -485,7 +486,8 @@ pub fn get_or_import_symbol(
|
||||||
.find_binding(module)
|
.find_binding(module)
|
||||||
.map_or(true, |binding| binding.kind.is_builtin())
|
.map_or(true, |binding| binding.kind.is_builtin())
|
||||||
{
|
{
|
||||||
let import_edit = importer.add_import(&AnyImport::Import(Import::module(module)));
|
let import_edit =
|
||||||
|
importer.add_import(&AnyImport::Import(Import::module(module)), at);
|
||||||
Ok((import_edit, format!("{module}.{member}")))
|
Ok((import_edit, format!("{module}.{member}")))
|
||||||
} else {
|
} else {
|
||||||
bail!(
|
bail!(
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use libcst_native::{Codegen, CodegenState, ImportAlias, Name, NameOrAttribute};
|
use libcst_native::{Codegen, CodegenState, ImportAlias, Name, NameOrAttribute};
|
||||||
use ruff_text_size::TextSize;
|
use ruff_text_size::TextSize;
|
||||||
use rustc_hash::FxHashMap;
|
|
||||||
use rustpython_parser::ast::{Stmt, StmtKind, Suite};
|
use rustpython_parser::ast::{Stmt, StmtKind, Suite};
|
||||||
use rustpython_parser::{lexer, Mode, Tok};
|
use rustpython_parser::{lexer, Mode, Tok};
|
||||||
|
|
||||||
|
@ -18,10 +17,7 @@ pub struct Importer<'a> {
|
||||||
python_ast: &'a Suite,
|
python_ast: &'a Suite,
|
||||||
locator: &'a Locator<'a>,
|
locator: &'a Locator<'a>,
|
||||||
stylist: &'a Stylist<'a>,
|
stylist: &'a Stylist<'a>,
|
||||||
/// A map from module name to top-level `StmtKind::ImportFrom` statements.
|
ordered_imports: Vec<&'a Stmt>,
|
||||||
import_from_map: FxHashMap<&'a str, &'a Stmt>,
|
|
||||||
/// The last top-level import statement.
|
|
||||||
trailing_import: Option<&'a Stmt>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Importer<'a> {
|
impl<'a> Importer<'a> {
|
||||||
|
@ -30,34 +26,21 @@ impl<'a> Importer<'a> {
|
||||||
python_ast,
|
python_ast,
|
||||||
locator,
|
locator,
|
||||||
stylist,
|
stylist,
|
||||||
import_from_map: FxHashMap::default(),
|
ordered_imports: Vec::default(),
|
||||||
trailing_import: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Visit a top-level import statement.
|
/// Visit a top-level import statement.
|
||||||
pub fn visit_import(&mut self, import: &'a Stmt) {
|
pub fn visit_import(&mut self, import: &'a Stmt) {
|
||||||
// Store a reference to the import statement in the appropriate map.
|
self.ordered_imports.push(import);
|
||||||
match &import.node {
|
}
|
||||||
StmtKind::Import { .. } => {
|
|
||||||
// Nothing to do here, we don't extend top-level `import` statements at all, so
|
|
||||||
// no need to track them.
|
|
||||||
}
|
|
||||||
StmtKind::ImportFrom { module, level, .. } => {
|
|
||||||
// Store a reverse-map from module name to `import ... from` statement.
|
|
||||||
if level.map_or(true, |level| level == 0) {
|
|
||||||
if let Some(module) = module {
|
|
||||||
self.import_from_map.insert(module.as_str(), import);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
panic!("Expected StmtKind::Import | StmtKind::ImportFrom");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store a reference to the last top-level import statement.
|
/// Return the import statement that precedes the given position, if any.
|
||||||
self.trailing_import = Some(import);
|
fn preceding_import(&self, at: TextSize) -> Option<&Stmt> {
|
||||||
|
self.ordered_imports
|
||||||
|
.partition_point(|stmt| stmt.start() < at)
|
||||||
|
.checked_sub(1)
|
||||||
|
.map(|idx| self.ordered_imports[idx])
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add an import statement to import the given module.
|
/// Add an import statement to import the given module.
|
||||||
|
@ -65,9 +48,9 @@ impl<'a> Importer<'a> {
|
||||||
/// If there are no existing imports, the new import will be added at the top
|
/// If there are no existing imports, the new import will be added at the top
|
||||||
/// of the file. Otherwise, it will be added after the most recent top-level
|
/// of the file. Otherwise, it will be added after the most recent top-level
|
||||||
/// import statement.
|
/// import statement.
|
||||||
pub fn add_import(&self, import: &AnyImport) -> Edit {
|
pub fn add_import(&self, import: &AnyImport, at: TextSize) -> Edit {
|
||||||
let required_import = import.to_string();
|
let required_import = import.to_string();
|
||||||
if let Some(stmt) = self.trailing_import {
|
if let Some(stmt) = self.preceding_import(at) {
|
||||||
// Insert after the last top-level import.
|
// Insert after the last top-level import.
|
||||||
let Insertion {
|
let Insertion {
|
||||||
prefix,
|
prefix,
|
||||||
|
@ -88,10 +71,28 @@ impl<'a> Importer<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the top-level [`Stmt`] that imports the given module using `StmtKind::ImportFrom`.
|
/// Return the top-level [`Stmt`] that imports the given module using `StmtKind::ImportFrom`
|
||||||
/// if it exists.
|
/// preceding the given position, if any.
|
||||||
pub fn get_import_from(&self, module: &str) -> Option<&Stmt> {
|
pub fn find_import_from(&self, module: &str, at: TextSize) -> Option<&Stmt> {
|
||||||
self.import_from_map.get(module).copied()
|
let mut import_from = None;
|
||||||
|
for stmt in &self.ordered_imports {
|
||||||
|
if stmt.start() >= at {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let StmtKind::ImportFrom {
|
||||||
|
module: name,
|
||||||
|
level,
|
||||||
|
..
|
||||||
|
} = &stmt.node
|
||||||
|
{
|
||||||
|
if level.map_or(true, |level| level == 0)
|
||||||
|
&& name.as_ref().map_or(false, |name| name == module)
|
||||||
|
{
|
||||||
|
import_from = Some(*stmt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
import_from
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Add the given member to an existing `StmtKind::ImportFrom` statement.
|
/// Add the given member to an existing `StmtKind::ImportFrom` statement.
|
||||||
|
@ -240,11 +241,11 @@ fn top_of_file_insertion(body: &[Stmt], locator: &Locator, stylist: &Stylist) ->
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use ruff_python_ast::newlines::LineEnding;
|
|
||||||
use ruff_text_size::TextSize;
|
use ruff_text_size::TextSize;
|
||||||
use rustpython_parser as parser;
|
use rustpython_parser as parser;
|
||||||
use rustpython_parser::lexer::LexResult;
|
use rustpython_parser::lexer::LexResult;
|
||||||
|
|
||||||
|
use ruff_python_ast::newlines::LineEnding;
|
||||||
use ruff_python_ast::source_code::{Locator, Stylist};
|
use ruff_python_ast::source_code::{Locator, Stylist};
|
||||||
|
|
||||||
use crate::importer::{top_of_file_insertion, Insertion};
|
use crate::importer::{top_of_file_insertion, Insertion};
|
||||||
|
|
|
@ -85,8 +85,14 @@ fn fix_abstractmethod_missing(
|
||||||
stmt: &Stmt,
|
stmt: &Stmt,
|
||||||
) -> Result<Fix> {
|
) -> Result<Fix> {
|
||||||
let indent = indentation(locator, stmt).ok_or(anyhow!("Unable to detect indentation"))?;
|
let indent = indentation(locator, stmt).ok_or(anyhow!("Unable to detect indentation"))?;
|
||||||
let (import_edit, binding) =
|
let (import_edit, binding) = get_or_import_symbol(
|
||||||
get_or_import_symbol("abc", "abstractmethod", context, importer, locator)?;
|
"abc",
|
||||||
|
"abstractmethod",
|
||||||
|
stmt.start(),
|
||||||
|
context,
|
||||||
|
importer,
|
||||||
|
locator,
|
||||||
|
)?;
|
||||||
let reference_edit = Edit::insertion(
|
let reference_edit = Edit::insertion(
|
||||||
format!(
|
format!(
|
||||||
"@{binding}{line_ending}{indent}",
|
"@{binding}{line_ending}{indent}",
|
||||||
|
|
|
@ -94,6 +94,7 @@ pub fn suppressible_exception(
|
||||||
let (import_edit, binding) = get_or_import_symbol(
|
let (import_edit, binding) = get_or_import_symbol(
|
||||||
"contextlib",
|
"contextlib",
|
||||||
"suppress",
|
"suppress",
|
||||||
|
stmt.start(),
|
||||||
&checker.ctx,
|
&checker.ctx,
|
||||||
&checker.importer,
|
&checker.importer,
|
||||||
checker.locator,
|
checker.locator,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use log::error;
|
use log::error;
|
||||||
use ruff_text_size::TextRange;
|
use ruff_text_size::{TextRange, TextSize};
|
||||||
use rustpython_parser as parser;
|
use rustpython_parser as parser;
|
||||||
use rustpython_parser::ast::{StmtKind, Suite};
|
use rustpython_parser::ast::{StmtKind, Suite};
|
||||||
|
|
||||||
|
@ -120,7 +120,10 @@ fn add_required_import(
|
||||||
TextRange::default(),
|
TextRange::default(),
|
||||||
);
|
);
|
||||||
if autofix.into() && settings.rules.should_fix(Rule::MissingRequiredImport) {
|
if autofix.into() && settings.rules.should_fix(Rule::MissingRequiredImport) {
|
||||||
diagnostic.set_fix(Importer::new(python_ast, locator, stylist).add_import(required_import));
|
diagnostic.set_fix(
|
||||||
|
Importer::new(python_ast, locator, stylist)
|
||||||
|
.add_import(required_import, TextSize::default()),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Some(diagnostic)
|
Some(diagnostic)
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,6 +37,9 @@ mod tests {
|
||||||
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_4.py"); "PLR1722_4")]
|
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_4.py"); "PLR1722_4")]
|
||||||
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_5.py"); "PLR1722_5")]
|
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_5.py"); "PLR1722_5")]
|
||||||
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_6.py"); "PLR1722_6")]
|
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_6.py"); "PLR1722_6")]
|
||||||
|
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_7.py"); "PLR1722_7")]
|
||||||
|
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_8.py"); "PLR1722_8")]
|
||||||
|
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_9.py"); "PLR1722_9")]
|
||||||
#[test_case(Rule::ContinueInFinally, Path::new("continue_in_finally.py"); "PLE0116")]
|
#[test_case(Rule::ContinueInFinally, Path::new("continue_in_finally.py"); "PLE0116")]
|
||||||
#[test_case(Rule::GlobalStatement, Path::new("global_statement.py"); "PLW0603")]
|
#[test_case(Rule::GlobalStatement, Path::new("global_statement.py"); "PLW0603")]
|
||||||
#[test_case(Rule::GlobalVariableNotAssigned, Path::new("global_variable_not_assigned.py"); "PLW0602")]
|
#[test_case(Rule::GlobalVariableNotAssigned, Path::new("global_variable_not_assigned.py"); "PLW0602")]
|
||||||
|
|
|
@ -49,6 +49,7 @@ pub fn sys_exit_alias(checker: &mut Checker, func: &Expr) {
|
||||||
let (import_edit, binding) = get_or_import_symbol(
|
let (import_edit, binding) = get_or_import_symbol(
|
||||||
"sys",
|
"sys",
|
||||||
"exit",
|
"exit",
|
||||||
|
func.start(),
|
||||||
&checker.ctx,
|
&checker.ctx,
|
||||||
&checker.importer,
|
&checker.importer,
|
||||||
checker.locator,
|
checker.locator,
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
---
|
||||||
|
source: crates/ruff/src/rules/pylint/mod.rs
|
||||||
|
---
|
||||||
|
sys_exit_alias_7.py:2:5: PLR1722 [*] Use `sys.exit()` instead of `exit`
|
||||||
|
|
|
||||||
|
2 | def main():
|
||||||
|
3 | exit(0)
|
||||||
|
| ^^^^ PLR1722
|
||||||
|
|
|
||||||
|
= help: Replace `exit` with `sys.exit()`
|
||||||
|
|
||||||
|
ℹ Suggested fix
|
||||||
|
1 |+import sys
|
||||||
|
1 2 | def main():
|
||||||
|
2 |- exit(0)
|
||||||
|
3 |+ sys.exit(0)
|
||||||
|
3 4 |
|
||||||
|
4 5 |
|
||||||
|
5 6 | import functools
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
---
|
||||||
|
source: crates/ruff/src/rules/pylint/mod.rs
|
||||||
|
---
|
||||||
|
sys_exit_alias_8.py:5:5: PLR1722 [*] Use `sys.exit()` instead of `exit`
|
||||||
|
|
|
||||||
|
5 | def main():
|
||||||
|
6 | exit(0)
|
||||||
|
| ^^^^ PLR1722
|
||||||
|
|
|
||||||
|
= help: Replace `exit` with `sys.exit()`
|
||||||
|
|
||||||
|
ℹ Suggested fix
|
||||||
|
1 |-from sys import argv
|
||||||
|
1 |+from sys import argv, exit
|
||||||
|
2 2 |
|
||||||
|
3 3 |
|
||||||
|
4 4 | def main():
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
---
|
||||||
|
source: crates/ruff/src/rules/pylint/mod.rs
|
||||||
|
---
|
||||||
|
sys_exit_alias_9.py:2:5: PLR1722 [*] Use `sys.exit()` instead of `exit`
|
||||||
|
|
|
||||||
|
2 | def main():
|
||||||
|
3 | exit(0)
|
||||||
|
| ^^^^ PLR1722
|
||||||
|
|
|
||||||
|
= help: Replace `exit` with `sys.exit()`
|
||||||
|
|
||||||
|
ℹ Suggested fix
|
||||||
|
1 |+import sys
|
||||||
|
1 2 | def main():
|
||||||
|
2 |- exit(0)
|
||||||
|
3 |+ sys.exit(0)
|
||||||
|
3 4 |
|
||||||
|
4 5 |
|
||||||
|
5 6 | from sys import argv
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,7 @@ pub fn lru_cache_with_maxsize_none(checker: &mut Checker, decorator_list: &[Expr
|
||||||
let (import_edit, binding) = get_or_import_symbol(
|
let (import_edit, binding) = get_or_import_symbol(
|
||||||
"functools",
|
"functools",
|
||||||
"cache",
|
"cache",
|
||||||
|
expr.start(),
|
||||||
&checker.ctx,
|
&checker.ctx,
|
||||||
&checker.importer,
|
&checker.importer,
|
||||||
checker.locator,
|
checker.locator,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue