From f47a517e79d08dca79e6956610c12cac93c745f3 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Tue, 30 May 2023 12:46:19 -0400 Subject: [PATCH] Enable callers to specify import-style preferences in `Importer` (#4717) --- crates/ruff/src/importer/mod.rs | 111 ++++++++++++++---- .../rules/suppressible_exception.rs | 4 +- .../src/rules/pylint/rules/sys_exit_alias.rs | 4 +- .../rules/lru_cache_with_maxsize_none.rs | 4 +- .../pyupgrade/rules/use_pep585_annotation.rs | 4 +- ...__rules__pyupgrade__tests__UP006_0.py.snap | 8 +- crates/ruff_python_ast/src/imports.rs | 15 +++ 7 files changed, 116 insertions(+), 34 deletions(-) diff --git a/crates/ruff/src/importer/mod.rs b/crates/ruff/src/importer/mod.rs index e8818ee063..57a5b9b9e5 100644 --- a/crates/ruff/src/importer/mod.rs +++ b/crates/ruff/src/importer/mod.rs @@ -8,7 +8,7 @@ use ruff_text_size::TextSize; use rustpython_parser::ast::{self, Ranged, Stmt, Suite}; use ruff_diagnostics::Edit; -use ruff_python_ast::imports::{AnyImport, Import}; +use ruff_python_ast::imports::{AnyImport, Import, ImportFrom}; use ruff_python_ast::source_code::{Locator, Stylist}; use ruff_python_semantic::model::SemanticModel; @@ -79,27 +79,26 @@ impl<'a> Importer<'a> { /// Attempts to reuse existing imports when possible. pub(crate) fn get_or_import_symbol( &self, - module: &str, - member: &str, + symbol: &ImportRequest, at: TextSize, semantic_model: &SemanticModel, ) -> Result<(Edit, String), ResolutionError> { - match self.get_symbol(module, member, at, semantic_model) { + match self.get_symbol(symbol, at, semantic_model) { Some(result) => result, - None => self.import_symbol(module, member, at, semantic_model), + None => self.import_symbol(symbol, at, semantic_model), } } /// Return an [`Edit`] to reference an existing symbol, if it's present in the given [`SemanticModel`]. fn get_symbol( &self, - module: &str, - member: &str, + symbol: &ImportRequest, at: TextSize, semantic_model: &SemanticModel, ) -> Option> { // If the symbol is already available in the current scope, use it. - let imported_name = semantic_model.resolve_qualified_import_name(module, member)?; + let imported_name = + semantic_model.resolve_qualified_import_name(symbol.module, symbol.member)?; // If the symbol source (i.e., the import statement) comes after the current location, // abort. For example, we could be generating an edit within a function, and the import @@ -149,31 +148,58 @@ impl<'a> Importer<'a> { /// the name on which the `lru_cache` symbol would be made available (`"functools.lru_cache"`). fn import_symbol( &self, - module: &str, - member: &str, + symbol: &ImportRequest, at: TextSize, semantic_model: &SemanticModel, ) -> Result<(Edit, String), ResolutionError> { - if let Some(stmt) = self.find_import_from(module, at) { + if let Some(stmt) = self.find_import_from(symbol.module, at) { // Case 1: `from functools import lru_cache` is in scope, and we're trying to reference // `functools.cache`; thus, we add `cache` to the import, and return `"cache"` as the // bound name. - if semantic_model.is_unbound(member) { - let Ok(import_edit) = self.add_member(stmt, member) else { + if semantic_model.is_unbound(symbol.member) { + let Ok(import_edit) = self.add_member(stmt, symbol.member) else { return Err(ResolutionError::InvalidEdit); }; - Ok((import_edit, member.to_string())) + Ok((import_edit, symbol.member.to_string())) } else { - Err(ResolutionError::ConflictingName(member.to_string())) + Err(ResolutionError::ConflictingName(symbol.member.to_string())) } } else { - // Case 2: No `functools` import is in scope; thus, we add `import functools`, and - // return `"functools.cache"` as the bound name. - if semantic_model.is_unbound(module) { - let import_edit = self.add_import(&AnyImport::Import(Import::module(module)), at); - Ok((import_edit, format!("{module}.{member}"))) - } else { - Err(ResolutionError::ConflictingName(module.to_string())) + match symbol.style { + ImportStyle::Import => { + // Case 2a: No `functools` import is in scope; thus, we add `import functools`, + // and return `"functools.cache"` as the bound name. + if semantic_model.is_unbound(symbol.module) { + let import_edit = + self.add_import(&AnyImport::Import(Import::module(symbol.module)), at); + Ok(( + import_edit, + format!( + "{module}.{member}", + module = symbol.module, + member = symbol.member + ), + )) + } else { + Err(ResolutionError::ConflictingName(symbol.module.to_string())) + } + } + ImportStyle::ImportFrom => { + // Case 2b: No `functools` import is in scope; thus, we add + // `from functools import cache`, and return `"cache"` as the bound name. + if semantic_model.is_unbound(symbol.member) { + let import_edit = self.add_import( + &AnyImport::ImportFrom(ImportFrom::member( + symbol.module, + symbol.member, + )), + at, + ); + Ok((import_edit, symbol.member.to_string())) + } else { + Err(ResolutionError::ConflictingName(symbol.member.to_string())) + } + } } } } @@ -234,6 +260,47 @@ impl<'a> Importer<'a> { } } +#[derive(Debug)] +enum ImportStyle { + /// Import the symbol using the `import` statement (e.g. `import foo; foo.bar`). + Import, + /// Import the symbol using the `from` statement (e.g. `from foo import bar; bar`). + ImportFrom, +} + +#[derive(Debug)] +pub(crate) struct ImportRequest<'a> { + /// The module from which the symbol can be imported (e.g., `foo`, in `from foo import bar`). + module: &'a str, + /// The member to import (e.g., `bar`, in `from foo import bar`). + member: &'a str, + /// The preferred style to use when importing the symbol (e.g., `import foo` or + /// `from foo import bar`), if it's not already in scope. + style: ImportStyle, +} + +impl<'a> ImportRequest<'a> { + /// Create a new `ImportRequest` from a module and member. If not present in the scope, + /// the symbol should be imported using the "import" statement. + pub(crate) fn import(module: &'a str, member: &'a str) -> Self { + Self { + module, + member, + style: ImportStyle::Import, + } + } + + /// Create a new `ImportRequest` from a module and member. If not present in the scope, + /// the symbol should be imported using the "import from" statement. + pub(crate) fn import_from(module: &'a str, member: &'a str) -> Self { + Self { + module, + member, + style: ImportStyle::ImportFrom, + } + } +} + /// The result of an [`Importer::get_or_import_symbol`] call. #[derive(Debug)] pub(crate) enum ResolutionError { diff --git a/crates/ruff/src/rules/flake8_simplify/rules/suppressible_exception.rs b/crates/ruff/src/rules/flake8_simplify/rules/suppressible_exception.rs index 515ea39dcf..78bfd983da 100644 --- a/crates/ruff/src/rules/flake8_simplify/rules/suppressible_exception.rs +++ b/crates/ruff/src/rules/flake8_simplify/rules/suppressible_exception.rs @@ -8,6 +8,7 @@ use ruff_python_ast::helpers; use ruff_python_ast::helpers::has_comments; use crate::checkers::ast::Checker; +use crate::importer::ImportRequest; use crate::registry::AsRule; #[violation] @@ -87,8 +88,7 @@ pub(crate) fn suppressible_exception( if !has_comments(stmt, checker.locator) { diagnostic.try_set_fix(|| { let (import_edit, binding) = checker.importer.get_or_import_symbol( - "contextlib", - "suppress", + &ImportRequest::import("contextlib", "suppress"), stmt.start(), checker.semantic_model(), )?; diff --git a/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs b/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs index 7cdfd8ba1b..e9dccdd1cd 100644 --- a/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs +++ b/crates/ruff/src/rules/pylint/rules/sys_exit_alias.rs @@ -4,6 +4,7 @@ use ruff_diagnostics::{AutofixKind, Diagnostic, Edit, Fix, Violation}; use ruff_macros::{derive_message_formats, violation}; use crate::checkers::ast::Checker; +use crate::importer::ImportRequest; use crate::registry::AsRule; /// ## What it does @@ -76,8 +77,7 @@ pub(crate) fn sys_exit_alias(checker: &mut Checker, func: &Expr) { if checker.patch(diagnostic.kind.rule()) { diagnostic.try_set_fix(|| { let (import_edit, binding) = checker.importer.get_or_import_symbol( - "sys", - "exit", + &ImportRequest::import("sys", "exit"), func.start(), checker.semantic_model(), )?; diff --git a/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs b/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs index d8753fd85f..9303ddfd28 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/lru_cache_with_maxsize_none.rs @@ -5,6 +5,7 @@ use ruff_diagnostics::{AlwaysAutofixableViolation, Diagnostic, Edit, Fix}; use ruff_macros::{derive_message_formats, violation}; use crate::checkers::ast::Checker; +use crate::importer::ImportRequest; use crate::registry::AsRule; #[violation] @@ -65,8 +66,7 @@ pub(crate) fn lru_cache_with_maxsize_none(checker: &mut Checker, decorator_list: if checker.patch(diagnostic.kind.rule()) { diagnostic.try_set_fix(|| { let (import_edit, binding) = checker.importer.get_or_import_symbol( - "functools", - "cache", + &ImportRequest::import("functools", "cache"), expr.start(), checker.semantic_model(), )?; diff --git a/crates/ruff/src/rules/pyupgrade/rules/use_pep585_annotation.rs b/crates/ruff/src/rules/pyupgrade/rules/use_pep585_annotation.rs index 071b3b3de1..20e09a7528 100644 --- a/crates/ruff/src/rules/pyupgrade/rules/use_pep585_annotation.rs +++ b/crates/ruff/src/rules/pyupgrade/rules/use_pep585_annotation.rs @@ -6,6 +6,7 @@ use ruff_python_ast::call_path::compose_call_path; use ruff_python_semantic::analyze::typing::ModuleMember; use crate::checkers::ast::Checker; +use crate::importer::ImportRequest; use crate::registry::AsRule; #[violation] @@ -61,8 +62,7 @@ pub(crate) fn use_pep585_annotation( // Imported type, like `collections.deque`. diagnostic.try_set_fix(|| { let (import_edit, binding) = checker.importer.get_or_import_symbol( - module, - member, + &ImportRequest::import_from(module, member), expr.start(), checker.semantic_model(), )?; diff --git a/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP006_0.py.snap b/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP006_0.py.snap index c0d954282b..2e0f6ee588 100644 --- a/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP006_0.py.snap +++ b/crates/ruff/src/rules/pyupgrade/snapshots/ruff__rules__pyupgrade__tests__UP006_0.py.snap @@ -243,7 +243,7 @@ UP006_0.py:61:10: UP006 [*] Use `collections.deque` instead of `typing.Deque` fo 20 20 | 21 21 | 22 22 | from typing import List as IList - 23 |+import collections + 23 |+from collections import deque 23 24 | 24 25 | 25 26 | def f(x: IList[str]) -> None: @@ -252,7 +252,7 @@ UP006_0.py:61:10: UP006 [*] Use `collections.deque` instead of `typing.Deque` fo 59 60 | 60 61 | 61 |-def f(x: typing.Deque[str]) -> None: - 62 |+def f(x: collections.deque[str]) -> None: + 62 |+def f(x: deque[str]) -> None: 62 63 | ... 63 64 | 64 65 | @@ -269,7 +269,7 @@ UP006_0.py:65:10: UP006 [*] Use `collections.defaultdict` instead of `typing.Def 20 20 | 21 21 | 22 22 | from typing import List as IList - 23 |+import collections + 23 |+from collections import defaultdict 23 24 | 24 25 | 25 26 | def f(x: IList[str]) -> None: @@ -278,7 +278,7 @@ UP006_0.py:65:10: UP006 [*] Use `collections.defaultdict` instead of `typing.Def 63 64 | 64 65 | 65 |-def f(x: typing.DefaultDict[str, str]) -> None: - 66 |+def f(x: collections.defaultdict[str, str]) -> None: + 66 |+def f(x: defaultdict[str, str]) -> None: 66 67 | ... diff --git a/crates/ruff_python_ast/src/imports.rs b/crates/ruff_python_ast/src/imports.rs index b9d24d74dc..098f75c0d6 100644 --- a/crates/ruff_python_ast/src/imports.rs +++ b/crates/ruff_python_ast/src/imports.rs @@ -31,6 +31,7 @@ pub struct Alias<'a> { } impl<'a> Import<'a> { + /// Creates a new `Import` to import the specified module. pub fn module(name: &'a str) -> Self { Self { name: Alias { @@ -41,6 +42,20 @@ impl<'a> Import<'a> { } } +impl<'a> ImportFrom<'a> { + /// Creates a new `ImportFrom` to import a member from the specified module. + pub fn member(module: &'a str, name: &'a str) -> Self { + Self { + module: Some(module), + name: Alias { + name, + as_name: None, + }, + level: None, + } + } +} + impl std::fmt::Display for AnyImport<'_> { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self {