diff --git a/crates/ruff/resources/test/fixtures/flake8_tidy_imports/TID253.py b/crates/ruff/resources/test/fixtures/flake8_tidy_imports/TID253.py new file mode 100644 index 0000000000..d0d3d14845 --- /dev/null +++ b/crates/ruff/resources/test/fixtures/flake8_tidy_imports/TID253.py @@ -0,0 +1,31 @@ +## Banned modules ## +import torch + +from torch import * + +from tensorflow import a, b, c + +import torch as torch_wearing_a_trenchcoat + +# this should count as module level +x = 1; import tensorflow + +# banning a module also bans any submodules +import torch.foo.bar + +from tensorflow.foo import bar + +from torch.foo.bar import * + +# unlike TID251, inline imports are *not* banned +def my_cool_function(): + import tensorflow.foo.bar + +def another_cool_function(): + from torch.foo import bar + +def import_alias(): + from torch.foo import bar + +if TYPE_CHECKING: + import torch diff --git a/crates/ruff/src/checkers/ast/analyze/statement.rs b/crates/ruff/src/checkers/ast/analyze/statement.rs index a5c56291ee..81a36c87d6 100644 --- a/crates/ruff/src/checkers/ast/analyze/statement.rs +++ b/crates/ruff/src/checkers/ast/analyze/statement.rs @@ -566,6 +566,15 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { alias, ); } + + if checker.enabled(Rule::BannedModuleLevelImports) { + flake8_tidy_imports::rules::name_or_parent_is_banned_at_module_level( + checker, + &alias.name, + alias.range(), + ); + } + if !checker.source_type.is_stub() { if checker.enabled(Rule::UselessImportAlias) { pylint::rules::useless_import_alias(checker, alias); @@ -734,6 +743,28 @@ pub(crate) fn statement(stmt: &Stmt, checker: &mut Checker) { } } } + if checker.enabled(Rule::BannedModuleLevelImports) { + if let Some(module) = + helpers::resolve_imported_module_path(level, module, checker.module_path) + { + flake8_tidy_imports::rules::name_or_parent_is_banned_at_module_level( + checker, + &module, + stmt.range(), + ); + + for alias in names { + if &alias.name == "*" { + continue; + } + flake8_tidy_imports::rules::name_is_banned_at_module_level( + checker, + &format!("{module}.{}", alias.name), + alias.range(), + ); + } + } + } if checker.enabled(Rule::PytestIncorrectPytestImport) { if let Some(diagnostic) = flake8_pytest_style::rules::import_from(stmt, module, level) diff --git a/crates/ruff/src/codes.rs b/crates/ruff/src/codes.rs index d7bb7cf7a5..ba86706c6a 100644 --- a/crates/ruff/src/codes.rs +++ b/crates/ruff/src/codes.rs @@ -311,6 +311,7 @@ pub fn code_to_rule(linter: Linter, code: &str) -> Option<(RuleGroup, Rule)> { // flake8-tidy-imports (Flake8TidyImports, "251") => (RuleGroup::Unspecified, rules::flake8_tidy_imports::rules::BannedApi), (Flake8TidyImports, "252") => (RuleGroup::Unspecified, rules::flake8_tidy_imports::rules::RelativeImports), + (Flake8TidyImports, "253") => (RuleGroup::Unspecified, rules::flake8_tidy_imports::rules::BannedModuleLevelImports), // flake8-return (Flake8Return, "501") => (RuleGroup::Unspecified, rules::flake8_return::rules::UnnecessaryReturnNone), diff --git a/crates/ruff/src/rules/flake8_tidy_imports/mod.rs b/crates/ruff/src/rules/flake8_tidy_imports/mod.rs index 8686f3971c..d136abbb1d 100644 --- a/crates/ruff/src/rules/flake8_tidy_imports/mod.rs +++ b/crates/ruff/src/rules/flake8_tidy_imports/mod.rs @@ -124,4 +124,23 @@ mod tests { assert_messages!(diagnostics); Ok(()) } + + #[test] + fn banned_module_level_imports() -> Result<()> { + let diagnostics = test_path( + Path::new("flake8_tidy_imports/TID253.py"), + &Settings { + flake8_tidy_imports: flake8_tidy_imports::settings::Settings { + banned_module_level_imports: vec![ + "torch".to_string(), + "tensorflow".to_string(), + ], + ..Default::default() + }, + ..Settings::for_rules(vec![Rule::BannedModuleLevelImports]) + }, + )?; + assert_messages!(diagnostics); + Ok(()) + } } diff --git a/crates/ruff/src/rules/flake8_tidy_imports/options.rs b/crates/ruff/src/rules/flake8_tidy_imports/options.rs index 6f3eb99fcb..a2bc84c300 100644 --- a/crates/ruff/src/rules/flake8_tidy_imports/options.rs +++ b/crates/ruff/src/rules/flake8_tidy_imports/options.rs @@ -41,6 +41,19 @@ pub struct Options { /// Note that this rule is only meant to flag accidental uses, /// and can be circumvented via `eval` or `importlib`. pub banned_api: Option>, + #[option( + default = r#"[]"#, + value_type = r#"list[str]"#, + example = r#" + # Ban certain modules from being imported at module level, instead requiring + # that they're imported lazily (e.g., within a function definition). + banned-module-level-imports = ["torch", "tensorflow"] + "# + )] + /// List of specific modules that may not be imported at module level, and should instead be + /// imported lazily (e.g., within a function definition, or an `if TYPE_CHECKING:` + /// block, or some other nested context). + pub banned_module_level_imports: Option>, } impl From for Settings { @@ -48,6 +61,7 @@ impl From for Settings { Self { ban_relative_imports: options.ban_relative_imports.unwrap_or(Strictness::Parents), banned_api: options.banned_api.unwrap_or_default(), + banned_module_level_imports: options.banned_module_level_imports.unwrap_or_default(), } } } @@ -57,6 +71,7 @@ impl From for Options { Self { ban_relative_imports: Some(settings.ban_relative_imports), banned_api: Some(settings.banned_api), + banned_module_level_imports: Some(settings.banned_module_level_imports), } } } diff --git a/crates/ruff/src/rules/flake8_tidy_imports/rules/banned_module_level_imports.rs b/crates/ruff/src/rules/flake8_tidy_imports/rules/banned_module_level_imports.rs new file mode 100644 index 0000000000..0efa1a3164 --- /dev/null +++ b/crates/ruff/src/rules/flake8_tidy_imports/rules/banned_module_level_imports.rs @@ -0,0 +1,111 @@ +use ruff_diagnostics::{Diagnostic, Violation}; +use ruff_macros::{derive_message_formats, violation}; +use ruff_text_size::TextRange; + +use crate::checkers::ast::Checker; + +/// ## What it does +/// Checks for module-level imports that should instead be imported lazily +/// (e.g., within a function definition, or an `if TYPE_CHECKING:` block, or +/// some other nested context). +/// +/// ## Why is this bad? +/// Some modules are expensive to import. For example, importing `torch` or +/// `tensorflow` can introduce a noticeable delay in the startup time of a +/// Python program. +/// +/// In such cases, you may want to enforce that the module is imported lazily +/// as needed, rather than at the top of the file. This could involve inlining +/// the import into the function that uses it, rather than importing it +/// unconditionally, to ensure that the module is only imported when necessary. +/// +/// ## Example +/// ```python +/// import tensorflow as tf +/// +/// +/// def show_version(): +/// print(tf.__version__) +/// ``` +/// +/// Use instead: +/// ```python +/// def show_version(): +/// import tensorflow as tf +/// +/// print(tf.__version__) +/// ``` +/// +/// ## Options +/// - `flake8-tidy-imports.banned-module-level-imports` +#[violation] +pub struct BannedModuleLevelImports { + name: String, +} + +impl Violation for BannedModuleLevelImports { + #[derive_message_formats] + fn message(&self) -> String { + let BannedModuleLevelImports { name } = self; + format!("`{name}` is banned at the module level") + } +} + +/// TID253 +pub(crate) fn name_is_banned_at_module_level( + checker: &mut Checker, + name: &str, + text_range: TextRange, +) { + banned_at_module_level_with_policy(checker, name, text_range, &NameMatchPolicy::ExactOnly); +} + +/// TID253 +pub(crate) fn name_or_parent_is_banned_at_module_level( + checker: &mut Checker, + name: &str, + text_range: TextRange, +) { + banned_at_module_level_with_policy(checker, name, text_range, &NameMatchPolicy::ExactOrParents); +} + +#[derive(Debug)] +enum NameMatchPolicy { + /// Only match an exact module name (e.g., given `import foo.bar`, only match `foo.bar`). + ExactOnly, + /// Match an exact module name or any of its parents (e.g., given `import foo.bar`, match + /// `foo.bar` or `foo`). + ExactOrParents, +} + +fn banned_at_module_level_with_policy( + checker: &mut Checker, + name: &str, + text_range: TextRange, + policy: &NameMatchPolicy, +) { + if !checker.semantic().at_top_level() { + return; + } + let banned_module_level_imports = &checker + .settings + .flake8_tidy_imports + .banned_module_level_imports; + for banned_module_name in banned_module_level_imports { + let name_is_banned = match policy { + NameMatchPolicy::ExactOnly => name == banned_module_name, + NameMatchPolicy::ExactOrParents => { + name == banned_module_name || name.starts_with(&format!("{banned_module_name}.")) + } + }; + if name_is_banned { + checker.diagnostics.push(Diagnostic::new( + BannedModuleLevelImports { + name: banned_module_name.to_string(), + }, + text_range, + )); + return; + } + } +} diff --git a/crates/ruff/src/rules/flake8_tidy_imports/rules/mod.rs b/crates/ruff/src/rules/flake8_tidy_imports/rules/mod.rs index 660116d718..a9c8e631d9 100644 --- a/crates/ruff/src/rules/flake8_tidy_imports/rules/mod.rs +++ b/crates/ruff/src/rules/flake8_tidy_imports/rules/mod.rs @@ -1,5 +1,7 @@ pub(crate) use banned_api::*; +pub(crate) use banned_module_level_imports::*; pub(crate) use relative_imports::*; mod banned_api; +mod banned_module_level_imports; mod relative_imports; diff --git a/crates/ruff/src/rules/flake8_tidy_imports/settings.rs b/crates/ruff/src/rules/flake8_tidy_imports/settings.rs index 90b2843280..a1267fbc9b 100644 --- a/crates/ruff/src/rules/flake8_tidy_imports/settings.rs +++ b/crates/ruff/src/rules/flake8_tidy_imports/settings.rs @@ -26,4 +26,5 @@ pub enum Strictness { pub struct Settings { pub ban_relative_imports: Strictness, pub banned_api: FxHashMap, + pub banned_module_level_imports: Vec, } diff --git a/crates/ruff/src/rules/flake8_tidy_imports/snapshots/ruff__rules__flake8_tidy_imports__tests__banned_module_level_imports.snap b/crates/ruff/src/rules/flake8_tidy_imports/snapshots/ruff__rules__flake8_tidy_imports__tests__banned_module_level_imports.snap new file mode 100644 index 0000000000..40bb5c9f58 --- /dev/null +++ b/crates/ruff/src/rules/flake8_tidy_imports/snapshots/ruff__rules__flake8_tidy_imports__tests__banned_module_level_imports.snap @@ -0,0 +1,81 @@ +--- +source: crates/ruff/src/rules/flake8_tidy_imports/mod.rs +--- +TID253.py:2:8: TID253 `torch` is banned at the module level + | +1 | ## Banned modules ## +2 | import torch + | ^^^^^ TID253 +3 | +4 | from torch import * + | + +TID253.py:4:1: TID253 `torch` is banned at the module level + | +2 | import torch +3 | +4 | from torch import * + | ^^^^^^^^^^^^^^^^^^^ TID253 +5 | +6 | from tensorflow import a, b, c + | + +TID253.py:6:1: TID253 `tensorflow` is banned at the module level + | +4 | from torch import * +5 | +6 | from tensorflow import a, b, c + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TID253 +7 | +8 | import torch as torch_wearing_a_trenchcoat + | + +TID253.py:8:8: TID253 `torch` is banned at the module level + | + 6 | from tensorflow import a, b, c + 7 | + 8 | import torch as torch_wearing_a_trenchcoat + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TID253 + 9 | +10 | # this should count as module level + | + +TID253.py:11:15: TID253 `tensorflow` is banned at the module level + | +10 | # this should count as module level +11 | x = 1; import tensorflow + | ^^^^^^^^^^ TID253 +12 | +13 | # banning a module also bans any submodules + | + +TID253.py:14:8: TID253 `torch` is banned at the module level + | +13 | # banning a module also bans any submodules +14 | import torch.foo.bar + | ^^^^^^^^^^^^^ TID253 +15 | +16 | from tensorflow.foo import bar + | + +TID253.py:16:1: TID253 `tensorflow` is banned at the module level + | +14 | import torch.foo.bar +15 | +16 | from tensorflow.foo import bar + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TID253 +17 | +18 | from torch.foo.bar import * + | + +TID253.py:18:1: TID253 `torch` is banned at the module level + | +16 | from tensorflow.foo import bar +17 | +18 | from torch.foo.bar import * + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ TID253 +19 | +20 | # unlike TID251, inline imports are *not* banned + | + + diff --git a/ruff.schema.json b/ruff.schema.json index 09cb719a63..95cda80152 100644 --- a/ruff.schema.json +++ b/ruff.schema.json @@ -1062,6 +1062,16 @@ "additionalProperties": { "$ref": "#/definitions/ApiBan" } + }, + "banned-module-level-imports": { + "description": "List of specific modules that may not be imported at module level, and should instead be imported lazily (e.g., within a function definition, or an `if TYPE_CHECKING:` block, or some other nested context).", + "type": [ + "array", + "null" + ], + "items": { + "type": "string" + } } }, "additionalProperties": false @@ -2610,6 +2620,7 @@ "TID25", "TID251", "TID252", + "TID253", "TRY", "TRY0", "TRY00",