Don't skip over imports and other nodes containing nested statements in import collector (#13521)

This commit is contained in:
Micha Reiser 2024-09-26 13:57:05 +02:00 committed by GitHub
parent 9442cd8fae
commit ff2d214e11
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 97 additions and 33 deletions

View file

@ -367,3 +367,58 @@ fn wildcard() -> Result<()> {
Ok(())
}
#[test]
fn nested_imports() -> Result<()> {
let tempdir = TempDir::new()?;
let root = ChildPath::new(tempdir.path());
root.child("ruff").child("__init__.py").write_str("")?;
root.child("ruff")
.child("a.py")
.write_str(indoc::indoc! {r#"
match x:
case 1:
import ruff.b
"#})?;
root.child("ruff")
.child("b.py")
.write_str(indoc::indoc! {r#"
try:
import ruff.c
except ImportError as e:
import ruff.d
"#})?;
root.child("ruff")
.child("c.py")
.write_str(indoc::indoc! {r#"def c(): ..."#})?;
root.child("ruff")
.child("d.py")
.write_str(indoc::indoc! {r#"def d(): ..."#})?;
insta::with_settings!({
filters => INSTA_FILTERS.to_vec(),
}, {
assert_cmd_snapshot!(command().current_dir(&root), @r#"
success: true
exit_code: 0
----- stdout -----
{
"ruff/__init__.py": [],
"ruff/a.py": [
"ruff/b.py"
],
"ruff/b.py": [
"ruff/c.py",
"ruff/d.py"
],
"ruff/c.py": [],
"ruff/d.py": []
}
----- stderr -----
"#);
});
Ok(())
}

View file

@ -1,8 +1,8 @@
use red_knot_python_semantic::ModuleName;
use ruff_python_ast::visitor::source_order::{
walk_expr, walk_module, walk_stmt, SourceOrderVisitor, TraversalSignal,
walk_expr, walk_module, walk_stmt, SourceOrderVisitor,
};
use ruff_python_ast::{self as ast, AnyNodeRef, Expr, Mod, Stmt};
use ruff_python_ast::{self as ast, Expr, Mod, Stmt};
/// Collect all imports for a given Python file.
#[derive(Default, Debug)]
@ -32,28 +32,6 @@ impl<'a> Collector<'a> {
}
impl<'ast> SourceOrderVisitor<'ast> for Collector<'_> {
fn enter_node(&mut self, node: AnyNodeRef<'ast>) -> TraversalSignal {
// If string detection is enabled, we have to visit everything. Otherwise, we should only
// visit compounds statements, which can contain import statements.
if self.string_imports
|| matches!(
node,
AnyNodeRef::ModModule(_)
| AnyNodeRef::StmtFunctionDef(_)
| AnyNodeRef::StmtClassDef(_)
| AnyNodeRef::StmtWhile(_)
| AnyNodeRef::StmtFor(_)
| AnyNodeRef::StmtWith(_)
| AnyNodeRef::StmtIf(_)
| AnyNodeRef::StmtTry(_)
)
{
TraversalSignal::Traverse
} else {
TraversalSignal::Skip
}
}
fn visit_stmt(&mut self, stmt: &'ast Stmt) {
match stmt {
Stmt::ImportFrom(ast::StmtImportFrom {
@ -107,9 +85,38 @@ impl<'ast> SourceOrderVisitor<'ast> for Collector<'_> {
}
}
}
_ => {
Stmt::FunctionDef(_)
| Stmt::ClassDef(_)
| Stmt::While(_)
| Stmt::If(_)
| Stmt::With(_)
| Stmt::Match(_)
| Stmt::Try(_)
| Stmt::For(_) => {
// Always traverse into compound statements.
walk_stmt(self, stmt);
}
Stmt::Return(_)
| Stmt::Delete(_)
| Stmt::Assign(_)
| Stmt::AugAssign(_)
| Stmt::AnnAssign(_)
| Stmt::TypeAlias(_)
| Stmt::Raise(_)
| Stmt::Assert(_)
| Stmt::Global(_)
| Stmt::Nonlocal(_)
| Stmt::Expr(_)
| Stmt::Pass(_)
| Stmt::Break(_)
| Stmt::Continue(_)
| Stmt::IpyEscapeCommand(_) => {
// Only traverse simple statements when string imports is enabled.
if self.string_imports {
walk_stmt(self, stmt);
}
}
}
}

View file

@ -1,13 +1,15 @@
use std::collections::{BTreeMap, BTreeSet};
use anyhow::Result;
use ruff_db::system::{SystemPath, SystemPathBuf};
use ruff_python_ast::helpers::to_module_path;
use ruff_python_parser::{parse, Mode};
use crate::collector::Collector;
pub use crate::db::ModuleDb;
use crate::resolver::Resolver;
pub use crate::settings::{AnalyzeSettings, Direction};
use anyhow::Result;
use ruff_db::system::{SystemPath, SystemPathBuf};
use ruff_python_ast::helpers::to_module_path;
use ruff_python_parser::{parse, Mode};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, BTreeSet};
mod collector;
mod db;
@ -15,7 +17,7 @@ mod resolver;
mod settings;
#[derive(Debug, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ModuleImports(BTreeSet<SystemPathBuf>);
impl ModuleImports {
@ -90,7 +92,7 @@ impl ModuleImports {
}
#[derive(Debug, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ImportMap(BTreeMap<SystemPathBuf, ModuleImports>);
impl ImportMap {