Enable callers to specify import-style preferences in Importer (#4717)

This commit is contained in:
Charlie Marsh 2023-05-30 12:46:19 -04:00 committed by GitHub
parent ea31229be0
commit f47a517e79
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 116 additions and 34 deletions

View file

@ -8,7 +8,7 @@ use ruff_text_size::TextSize;
use rustpython_parser::ast::{self, Ranged, Stmt, Suite}; use rustpython_parser::ast::{self, Ranged, Stmt, Suite};
use ruff_diagnostics::Edit; use ruff_diagnostics::Edit;
use ruff_python_ast::imports::{AnyImport, Import}; use ruff_python_ast::imports::{AnyImport, Import, ImportFrom};
use ruff_python_ast::source_code::{Locator, Stylist}; use ruff_python_ast::source_code::{Locator, Stylist};
use ruff_python_semantic::model::SemanticModel; use ruff_python_semantic::model::SemanticModel;
@ -79,27 +79,26 @@ impl<'a> Importer<'a> {
/// Attempts to reuse existing imports when possible. /// Attempts to reuse existing imports when possible.
pub(crate) fn get_or_import_symbol( pub(crate) fn get_or_import_symbol(
&self, &self,
module: &str, symbol: &ImportRequest,
member: &str,
at: TextSize, at: TextSize,
semantic_model: &SemanticModel, semantic_model: &SemanticModel,
) -> Result<(Edit, String), ResolutionError> { ) -> Result<(Edit, String), ResolutionError> {
match self.get_symbol(module, member, at, semantic_model) { match self.get_symbol(symbol, at, semantic_model) {
Some(result) => result, Some(result) => result,
None => self.import_symbol(module, member, at, semantic_model), None => self.import_symbol(symbol, at, semantic_model),
} }
} }
/// Return an [`Edit`] to reference an existing symbol, if it's present in the given [`SemanticModel`]. /// Return an [`Edit`] to reference an existing symbol, if it's present in the given [`SemanticModel`].
fn get_symbol( fn get_symbol(
&self, &self,
module: &str, symbol: &ImportRequest,
member: &str,
at: TextSize, at: TextSize,
semantic_model: &SemanticModel, semantic_model: &SemanticModel,
) -> Option<Result<(Edit, String), ResolutionError>> { ) -> Option<Result<(Edit, String), ResolutionError>> {
// If the symbol is already available in the current scope, use it. // If the symbol is already available in the current scope, use it.
let imported_name = semantic_model.resolve_qualified_import_name(module, member)?; let imported_name =
semantic_model.resolve_qualified_import_name(symbol.module, symbol.member)?;
// If the symbol source (i.e., the import statement) comes after the current location, // If the symbol source (i.e., the import statement) comes after the current location,
// abort. For example, we could be generating an edit within a function, and the import // abort. For example, we could be generating an edit within a function, and the import
@ -149,31 +148,58 @@ impl<'a> Importer<'a> {
/// the name on which the `lru_cache` symbol would be made available (`"functools.lru_cache"`). /// the name on which the `lru_cache` symbol would be made available (`"functools.lru_cache"`).
fn import_symbol( fn import_symbol(
&self, &self,
module: &str, symbol: &ImportRequest,
member: &str,
at: TextSize, at: TextSize,
semantic_model: &SemanticModel, semantic_model: &SemanticModel,
) -> Result<(Edit, String), ResolutionError> { ) -> Result<(Edit, String), ResolutionError> {
if let Some(stmt) = self.find_import_from(module, at) { if let Some(stmt) = self.find_import_from(symbol.module, at) {
// Case 1: `from functools import lru_cache` is in scope, and we're trying to reference // Case 1: `from functools import lru_cache` is in scope, and we're trying to reference
// `functools.cache`; thus, we add `cache` to the import, and return `"cache"` as the // `functools.cache`; thus, we add `cache` to the import, and return `"cache"` as the
// bound name. // bound name.
if semantic_model.is_unbound(member) { if semantic_model.is_unbound(symbol.member) {
let Ok(import_edit) = self.add_member(stmt, member) else { let Ok(import_edit) = self.add_member(stmt, symbol.member) else {
return Err(ResolutionError::InvalidEdit); return Err(ResolutionError::InvalidEdit);
}; };
Ok((import_edit, member.to_string())) Ok((import_edit, symbol.member.to_string()))
} else { } else {
Err(ResolutionError::ConflictingName(member.to_string())) Err(ResolutionError::ConflictingName(symbol.member.to_string()))
} }
} else { } else {
// Case 2: No `functools` import is in scope; thus, we add `import functools`, and match symbol.style {
// return `"functools.cache"` as the bound name. ImportStyle::Import => {
if semantic_model.is_unbound(module) { // Case 2a: No `functools` import is in scope; thus, we add `import functools`,
let import_edit = self.add_import(&AnyImport::Import(Import::module(module)), at); // and return `"functools.cache"` as the bound name.
Ok((import_edit, format!("{module}.{member}"))) if semantic_model.is_unbound(symbol.module) {
} else { let import_edit =
Err(ResolutionError::ConflictingName(module.to_string())) self.add_import(&AnyImport::Import(Import::module(symbol.module)), at);
Ok((
import_edit,
format!(
"{module}.{member}",
module = symbol.module,
member = symbol.member
),
))
} else {
Err(ResolutionError::ConflictingName(symbol.module.to_string()))
}
}
ImportStyle::ImportFrom => {
// Case 2b: No `functools` import is in scope; thus, we add
// `from functools import cache`, and return `"cache"` as the bound name.
if semantic_model.is_unbound(symbol.member) {
let import_edit = self.add_import(
&AnyImport::ImportFrom(ImportFrom::member(
symbol.module,
symbol.member,
)),
at,
);
Ok((import_edit, symbol.member.to_string()))
} else {
Err(ResolutionError::ConflictingName(symbol.member.to_string()))
}
}
} }
} }
} }
@ -234,6 +260,47 @@ impl<'a> Importer<'a> {
} }
} }
#[derive(Debug)]
enum ImportStyle {
/// Import the symbol using the `import` statement (e.g. `import foo; foo.bar`).
Import,
/// Import the symbol using the `from` statement (e.g. `from foo import bar; bar`).
ImportFrom,
}
#[derive(Debug)]
pub(crate) struct ImportRequest<'a> {
/// The module from which the symbol can be imported (e.g., `foo`, in `from foo import bar`).
module: &'a str,
/// The member to import (e.g., `bar`, in `from foo import bar`).
member: &'a str,
/// The preferred style to use when importing the symbol (e.g., `import foo` or
/// `from foo import bar`), if it's not already in scope.
style: ImportStyle,
}
impl<'a> ImportRequest<'a> {
/// Create a new `ImportRequest` from a module and member. If not present in the scope,
/// the symbol should be imported using the "import" statement.
pub(crate) fn import(module: &'a str, member: &'a str) -> Self {
Self {
module,
member,
style: ImportStyle::Import,
}
}
/// Create a new `ImportRequest` from a module and member. If not present in the scope,
/// the symbol should be imported using the "import from" statement.
pub(crate) fn import_from(module: &'a str, member: &'a str) -> Self {
Self {
module,
member,
style: ImportStyle::ImportFrom,
}
}
}
/// The result of an [`Importer::get_or_import_symbol`] call. /// The result of an [`Importer::get_or_import_symbol`] call.
#[derive(Debug)] #[derive(Debug)]
pub(crate) enum ResolutionError { pub(crate) enum ResolutionError {

View file

@ -8,6 +8,7 @@ use ruff_python_ast::helpers;
use ruff_python_ast::helpers::has_comments; use ruff_python_ast::helpers::has_comments;
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
use crate::importer::ImportRequest;
use crate::registry::AsRule; use crate::registry::AsRule;
#[violation] #[violation]
@ -87,8 +88,7 @@ pub(crate) fn suppressible_exception(
if !has_comments(stmt, checker.locator) { if !has_comments(stmt, checker.locator) {
diagnostic.try_set_fix(|| { diagnostic.try_set_fix(|| {
let (import_edit, binding) = checker.importer.get_or_import_symbol( let (import_edit, binding) = checker.importer.get_or_import_symbol(
"contextlib", &ImportRequest::import("contextlib", "suppress"),
"suppress",
stmt.start(), stmt.start(),
checker.semantic_model(), checker.semantic_model(),
)?; )?;

View file

@ -4,6 +4,7 @@ use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Fix, Violation};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
use crate::importer::ImportRequest;
use crate::registry::AsRule; use crate::registry::AsRule;
/// ## What it does /// ## What it does
@ -76,8 +77,7 @@ pub(crate) fn sys_exit_alias(checker: &mut Checker, func: &Expr) {
if checker.patch(diagnostic.kind.rule()) { if checker.patch(diagnostic.kind.rule()) {
diagnostic.try_set_fix(|| { diagnostic.try_set_fix(|| {
let (import_edit, binding) = checker.importer.get_or_import_symbol( let (import_edit, binding) = checker.importer.get_or_import_symbol(
"sys", &ImportRequest::import("sys", "exit"),
"exit",
func.start(), func.start(),
checker.semantic_model(), checker.semantic_model(),
)?; )?;

View file

@ -5,6 +5,7 @@ use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix};
use ruff_macros::{derive_message_formats, violation}; use ruff_macros::{derive_message_formats, violation};
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
use crate::importer::ImportRequest;
use crate::registry::AsRule; use crate::registry::AsRule;
#[violation] #[violation]
@ -65,8 +66,7 @@ pub(crate) fn lru_cache_with_maxsize_none(checker: &mut Checker, decorator_list:
if checker.patch(diagnostic.kind.rule()) { if checker.patch(diagnostic.kind.rule()) {
diagnostic.try_set_fix(|| { diagnostic.try_set_fix(|| {
let (import_edit, binding) = checker.importer.get_or_import_symbol( let (import_edit, binding) = checker.importer.get_or_import_symbol(
"functools", &ImportRequest::import("functools", "cache"),
"cache",
expr.start(), expr.start(),
checker.semantic_model(), checker.semantic_model(),
)?; )?;

View file

@ -6,6 +6,7 @@ use ruff_python_ast::call_path::compose_call_path;
use ruff_python_semantic::analyze::typing::ModuleMember; use ruff_python_semantic::analyze::typing::ModuleMember;
use crate::checkers::ast::Checker; use crate::checkers::ast::Checker;
use crate::importer::ImportRequest;
use crate::registry::AsRule; use crate::registry::AsRule;
#[violation] #[violation]
@ -61,8 +62,7 @@ pub(crate) fn use_pep585_annotation(
// Imported type, like `collections.deque`. // Imported type, like `collections.deque`.
diagnostic.try_set_fix(|| { diagnostic.try_set_fix(|| {
let (import_edit, binding) = checker.importer.get_or_import_symbol( let (import_edit, binding) = checker.importer.get_or_import_symbol(
module, &ImportRequest::import_from(module, member),
member,
expr.start(), expr.start(),
checker.semantic_model(), checker.semantic_model(),
)?; )?;

View file

@ -243,7 +243,7 @@ UP006_0.py:61:10: UP006 [*] Use `collections.deque` instead of `typing.Deque` fo
20 20 | 20 20 |
21 21 | 21 21 |
22 22 | from typing import List as IList 22 22 | from typing import List as IList
23 |+import collections 23 |+from collections import deque
23 24 | 23 24 |
24 25 | 24 25 |
25 26 | def f(x: IList[str]) -> None: 25 26 | def f(x: IList[str]) -> None:
@ -252,7 +252,7 @@ UP006_0.py:61:10: UP006 [*] Use `collections.deque` instead of `typing.Deque` fo
59 60 | 59 60 |
60 61 | 60 61 |
61 |-def f(x: typing.Deque[str]) -> None: 61 |-def f(x: typing.Deque[str]) -> None:
62 |+def f(x: collections.deque[str]) -> None: 62 |+def f(x: deque[str]) -> None:
62 63 | ... 62 63 | ...
63 64 | 63 64 |
64 65 | 64 65 |
@ -269,7 +269,7 @@ UP006_0.py:65:10: UP006 [*] Use `collections.defaultdict` instead of `typing.Def
20 20 | 20 20 |
21 21 | 21 21 |
22 22 | from typing import List as IList 22 22 | from typing import List as IList
23 |+import collections 23 |+from collections import defaultdict
23 24 | 23 24 |
24 25 | 24 25 |
25 26 | def f(x: IList[str]) -> None: 25 26 | def f(x: IList[str]) -> None:
@ -278,7 +278,7 @@ UP006_0.py:65:10: UP006 [*] Use `collections.defaultdict` instead of `typing.Def
63 64 | 63 64 |
64 65 | 64 65 |
65 |-def f(x: typing.DefaultDict[str, str]) -> None: 65 |-def f(x: typing.DefaultDict[str, str]) -> None:
66 |+def f(x: collections.defaultdict[str, str]) -> None: 66 |+def f(x: defaultdict[str, str]) -> None:
66 67 | ... 66 67 | ...

View file

@ -31,6 +31,7 @@ pub struct Alias<'a> {
} }
impl<'a> Import<'a> { impl<'a> Import<'a> {
/// Creates a new `Import` to import the specified module.
pub fn module(name: &'a str) -> Self { pub fn module(name: &'a str) -> Self {
Self { Self {
name: Alias { name: Alias {
@ -41,6 +42,20 @@ impl<'a> Import<'a> {
} }
} }
impl<'a> ImportFrom<'a> {
/// Creates a new `ImportFrom` to import a member from the specified module.
pub fn member(module: &'a str, name: &'a str) -> Self {
Self {
module: Some(module),
name: Alias {
name,
as_name: None,
},
level: None,
}
}
}
impl std::fmt::Display for AnyImport<'_> { impl std::fmt::Display for AnyImport<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self { match self {