mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-27 20:42:10 +00:00
Respect insertion location when importing symbols (#4258)
This commit is contained in:
parent
a95bafefb0
commit
e66fdb83d0
14 changed files with 134 additions and 40 deletions
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_7.py
vendored
Normal file
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_7.py
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
def main():
|
||||
exit(0)
|
||||
|
||||
|
||||
import functools
|
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_8.py
vendored
Normal file
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_8.py
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
from sys import argv
|
||||
|
||||
|
||||
def main():
|
||||
exit(0)
|
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_9.py
vendored
Normal file
5
crates/ruff/resources/test/fixtures/pylint/sys_exit_alias_9.py
vendored
Normal file
|
@ -0,0 +1,5 @@
|
|||
def main():
|
||||
exit(0)
|
||||
|
||||
|
||||
from sys import argv
|
|
@ -435,6 +435,7 @@ pub fn remove_argument(
|
|||
pub fn get_or_import_symbol(
|
||||
module: &str,
|
||||
member: &str,
|
||||
at: TextSize,
|
||||
context: &Context,
|
||||
importer: &Importer,
|
||||
locator: &Locator,
|
||||
|
@ -462,7 +463,7 @@ pub fn get_or_import_symbol(
|
|||
Edit::range_replacement(locator.slice(source.range()).to_string(), source.range());
|
||||
Ok((import_edit, binding))
|
||||
} else {
|
||||
if let Some(stmt) = importer.get_import_from(module) {
|
||||
if let Some(stmt) = importer.find_import_from(module, at) {
|
||||
// Case 1: `from functools import lru_cache` is in scope, and we're trying to reference
|
||||
// `functools.cache`; thus, we add `cache` to the import, and return `"cache"` as the
|
||||
// bound name.
|
||||
|
@ -485,7 +486,8 @@ pub fn get_or_import_symbol(
|
|||
.find_binding(module)
|
||||
.map_or(true, |binding| binding.kind.is_builtin())
|
||||
{
|
||||
let import_edit = importer.add_import(&AnyImport::Import(Import::module(module)));
|
||||
let import_edit =
|
||||
importer.add_import(&AnyImport::Import(Import::module(module)), at);
|
||||
Ok((import_edit, format!("{module}.{member}")))
|
||||
} else {
|
||||
bail!(
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
use anyhow::Result;
|
||||
use libcst_native::{Codegen, CodegenState, ImportAlias, Name, NameOrAttribute};
|
||||
use ruff_text_size::TextSize;
|
||||
use rustc_hash::FxHashMap;
|
||||
use rustpython_parser::ast::{Stmt, StmtKind, Suite};
|
||||
use rustpython_parser::{lexer, Mode, Tok};
|
||||
|
||||
|
@ -18,10 +17,7 @@ pub struct Importer<'a> {
|
|||
python_ast: &'a Suite,
|
||||
locator: &'a Locator<'a>,
|
||||
stylist: &'a Stylist<'a>,
|
||||
/// A map from module name to top-level `StmtKind::ImportFrom` statements.
|
||||
import_from_map: FxHashMap<&'a str, &'a Stmt>,
|
||||
/// The last top-level import statement.
|
||||
trailing_import: Option<&'a Stmt>,
|
||||
ordered_imports: Vec<&'a Stmt>,
|
||||
}
|
||||
|
||||
impl<'a> Importer<'a> {
|
||||
|
@ -30,34 +26,21 @@ impl<'a> Importer<'a> {
|
|||
python_ast,
|
||||
locator,
|
||||
stylist,
|
||||
import_from_map: FxHashMap::default(),
|
||||
trailing_import: None,
|
||||
ordered_imports: Vec::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Visit a top-level import statement.
|
||||
pub fn visit_import(&mut self, import: &'a Stmt) {
|
||||
// Store a reference to the import statement in the appropriate map.
|
||||
match &import.node {
|
||||
StmtKind::Import { .. } => {
|
||||
// Nothing to do here, we don't extend top-level `import` statements at all, so
|
||||
// no need to track them.
|
||||
}
|
||||
StmtKind::ImportFrom { module, level, .. } => {
|
||||
// Store a reverse-map from module name to `import ... from` statement.
|
||||
if level.map_or(true, |level| level == 0) {
|
||||
if let Some(module) = module {
|
||||
self.import_from_map.insert(module.as_str(), import);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
panic!("Expected StmtKind::Import | StmtKind::ImportFrom");
|
||||
}
|
||||
}
|
||||
self.ordered_imports.push(import);
|
||||
}
|
||||
|
||||
// Store a reference to the last top-level import statement.
|
||||
self.trailing_import = Some(import);
|
||||
/// Return the import statement that precedes the given position, if any.
|
||||
fn preceding_import(&self, at: TextSize) -> Option<&Stmt> {
|
||||
self.ordered_imports
|
||||
.partition_point(|stmt| stmt.start() < at)
|
||||
.checked_sub(1)
|
||||
.map(|idx| self.ordered_imports[idx])
|
||||
}
|
||||
|
||||
/// Add an import statement to import the given module.
|
||||
|
@ -65,9 +48,9 @@ impl<'a> Importer<'a> {
|
|||
/// If there are no existing imports, the new import will be added at the top
|
||||
/// of the file. Otherwise, it will be added after the most recent top-level
|
||||
/// import statement.
|
||||
pub fn add_import(&self, import: &AnyImport) -> Edit {
|
||||
pub fn add_import(&self, import: &AnyImport, at: TextSize) -> Edit {
|
||||
let required_import = import.to_string();
|
||||
if let Some(stmt) = self.trailing_import {
|
||||
if let Some(stmt) = self.preceding_import(at) {
|
||||
// Insert after the last top-level import.
|
||||
let Insertion {
|
||||
prefix,
|
||||
|
@ -88,10 +71,28 @@ impl<'a> Importer<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Return the top-level [`Stmt`] that imports the given module using `StmtKind::ImportFrom`.
|
||||
/// if it exists.
|
||||
pub fn get_import_from(&self, module: &str) -> Option<&Stmt> {
|
||||
self.import_from_map.get(module).copied()
|
||||
/// Return the top-level [`Stmt`] that imports the given module using `StmtKind::ImportFrom`
|
||||
/// preceding the given position, if any.
|
||||
pub fn find_import_from(&self, module: &str, at: TextSize) -> Option<&Stmt> {
|
||||
let mut import_from = None;
|
||||
for stmt in &self.ordered_imports {
|
||||
if stmt.start() >= at {
|
||||
break;
|
||||
}
|
||||
if let StmtKind::ImportFrom {
|
||||
module: name,
|
||||
level,
|
||||
..
|
||||
} = &stmt.node
|
||||
{
|
||||
if level.map_or(true, |level| level == 0)
|
||||
&& name.as_ref().map_or(false, |name| name == module)
|
||||
{
|
||||
import_from = Some(*stmt);
|
||||
}
|
||||
}
|
||||
}
|
||||
import_from
|
||||
}
|
||||
|
||||
/// Add the given member to an existing `StmtKind::ImportFrom` statement.
|
||||
|
@ -240,11 +241,11 @@ fn top_of_file_insertion(body: &[Stmt], locator: &Locator, stylist: &Stylist) ->
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
use ruff_python_ast::newlines::LineEnding;
|
||||
use ruff_text_size::TextSize;
|
||||
use rustpython_parser as parser;
|
||||
use rustpython_parser::lexer::LexResult;
|
||||
|
||||
use ruff_python_ast::newlines::LineEnding;
|
||||
use ruff_python_ast::source_code::{Locator, Stylist};
|
||||
|
||||
use crate::importer::{top_of_file_insertion, Insertion};
|
||||
|
|
|
@ -85,8 +85,14 @@ fn fix_abstractmethod_missing(
|
|||
stmt: &Stmt,
|
||||
) -> Result<Fix> {
|
||||
let indent = indentation(locator, stmt).ok_or(anyhow!("Unable to detect indentation"))?;
|
||||
let (import_edit, binding) =
|
||||
get_or_import_symbol("abc", "abstractmethod", context, importer, locator)?;
|
||||
let (import_edit, binding) = get_or_import_symbol(
|
||||
"abc",
|
||||
"abstractmethod",
|
||||
stmt.start(),
|
||||
context,
|
||||
importer,
|
||||
locator,
|
||||
)?;
|
||||
let reference_edit = Edit::insertion(
|
||||
format!(
|
||||
"@{binding}{line_ending}{indent}",
|
||||
|
|
|
@ -94,6 +94,7 @@ pub fn suppressible_exception(
|
|||
let (import_edit, binding) = get_or_import_symbol(
|
||||
"contextlib",
|
||||
"suppress",
|
||||
stmt.start(),
|
||||
&checker.ctx,
|
||||
&checker.importer,
|
||||
checker.locator,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use log::error;
|
||||
use ruff_text_size::TextRange;
|
||||
use ruff_text_size::{TextRange, TextSize};
|
||||
use rustpython_parser as parser;
|
||||
use rustpython_parser::ast::{StmtKind, Suite};
|
||||
|
||||
|
@ -120,7 +120,10 @@ fn add_required_import(
|
|||
TextRange::default(),
|
||||
);
|
||||
if autofix.into() && settings.rules.should_fix(Rule::MissingRequiredImport) {
|
||||
diagnostic.set_fix(Importer::new(python_ast, locator, stylist).add_import(required_import));
|
||||
diagnostic.set_fix(
|
||||
Importer::new(python_ast, locator, stylist)
|
||||
.add_import(required_import, TextSize::default()),
|
||||
);
|
||||
}
|
||||
Some(diagnostic)
|
||||
}
|
||||
|
|
|
@ -37,6 +37,9 @@ mod tests {
|
|||
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_4.py"); "PLR1722_4")]
|
||||
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_5.py"); "PLR1722_5")]
|
||||
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_6.py"); "PLR1722_6")]
|
||||
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_7.py"); "PLR1722_7")]
|
||||
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_8.py"); "PLR1722_8")]
|
||||
#[test_case(Rule::SysExitAlias, Path::new("sys_exit_alias_9.py"); "PLR1722_9")]
|
||||
#[test_case(Rule::ContinueInFinally, Path::new("continue_in_finally.py"); "PLE0116")]
|
||||
#[test_case(Rule::GlobalStatement, Path::new("global_statement.py"); "PLW0603")]
|
||||
#[test_case(Rule::GlobalVariableNotAssigned, Path::new("global_variable_not_assigned.py"); "PLW0602")]
|
||||
|
|
|
@ -49,6 +49,7 @@ pub fn sys_exit_alias(checker: &mut Checker, func: &Expr) {
|
|||
let (import_edit, binding) = get_or_import_symbol(
|
||||
"sys",
|
||||
"exit",
|
||||
func.start(),
|
||||
&checker.ctx,
|
||||
&checker.importer,
|
||||
checker.locator,
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
---
|
||||
source: crates/ruff/src/rules/pylint/mod.rs
|
||||
---
|
||||
sys_exit_alias_7.py:2:5: PLR1722 [*] Use `sys.exit()` instead of `exit`
|
||||
|
|
||||
2 | def main():
|
||||
3 | exit(0)
|
||||
| ^^^^ PLR1722
|
||||
|
|
||||
= help: Replace `exit` with `sys.exit()`
|
||||
|
||||
ℹ Suggested fix
|
||||
1 |+import sys
|
||||
1 2 | def main():
|
||||
2 |- exit(0)
|
||||
3 |+ sys.exit(0)
|
||||
3 4 |
|
||||
4 5 |
|
||||
5 6 | import functools
|
||||
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
---
|
||||
source: crates/ruff/src/rules/pylint/mod.rs
|
||||
---
|
||||
sys_exit_alias_8.py:5:5: PLR1722 [*] Use `sys.exit()` instead of `exit`
|
||||
|
|
||||
5 | def main():
|
||||
6 | exit(0)
|
||||
| ^^^^ PLR1722
|
||||
|
|
||||
= help: Replace `exit` with `sys.exit()`
|
||||
|
||||
ℹ Suggested fix
|
||||
1 |-from sys import argv
|
||||
1 |+from sys import argv, exit
|
||||
2 2 |
|
||||
3 3 |
|
||||
4 4 | def main():
|
||||
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
---
|
||||
source: crates/ruff/src/rules/pylint/mod.rs
|
||||
---
|
||||
sys_exit_alias_9.py:2:5: PLR1722 [*] Use `sys.exit()` instead of `exit`
|
||||
|
|
||||
2 | def main():
|
||||
3 | exit(0)
|
||||
| ^^^^ PLR1722
|
||||
|
|
||||
= help: Replace `exit` with `sys.exit()`
|
||||
|
||||
ℹ Suggested fix
|
||||
1 |+import sys
|
||||
1 2 | def main():
|
||||
2 |- exit(0)
|
||||
3 |+ sys.exit(0)
|
||||
3 4 |
|
||||
4 5 |
|
||||
5 6 | from sys import argv
|
||||
|
||||
|
|
@ -62,6 +62,7 @@ pub fn lru_cache_with_maxsize_none(checker: &mut Checker, decorator_list: &[Expr
|
|||
let (import_edit, binding) = get_or_import_symbol(
|
||||
"functools",
|
||||
"cache",
|
||||
expr.start(),
|
||||
&checker.ctx,
|
||||
&checker.importer,
|
||||
checker.locator,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue