mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-30 22:01:18 +00:00
[red-knot] resolve base class types (#11178)
## Summary Resolve base class types, as long as they are simple names. ## Test Plan cargo test
This commit is contained in:
parent
04a922866a
commit
ce030a467f
3 changed files with 129 additions and 35 deletions
|
@ -95,7 +95,7 @@ impl Symbol {
|
||||||
// TODO storing TypedNodeKey for definitions means we have to search to find them again in the AST;
|
// TODO storing TypedNodeKey for definitions means we have to search to find them again in the AST;
|
||||||
// this is at best O(log n). If looking up definitions is a bottleneck we should look for
|
// this is at best O(log n). If looking up definitions is a bottleneck we should look for
|
||||||
// alternatives here.
|
// alternatives here.
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub(crate) enum Definition {
|
pub(crate) enum Definition {
|
||||||
// For the import cases, we don't need reference to any arbitrary AST subtrees (annotations,
|
// For the import cases, we don't need reference to any arbitrary AST subtrees (annotations,
|
||||||
// RHS), and referencing just the import statement node is imprecise (a single import statement
|
// RHS), and referencing just the import statement node is imprecise (a single import statement
|
||||||
|
@ -110,12 +110,12 @@ pub(crate) enum Definition {
|
||||||
// TODO with statements, except handlers, function args...
|
// TODO with statements, except handlers, function args...
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub(crate) struct ImportDefinition {
|
pub(crate) struct ImportDefinition {
|
||||||
pub(crate) module: ModuleName,
|
pub(crate) module: ModuleName,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub(crate) struct ImportFromDefinition {
|
pub(crate) struct ImportFromDefinition {
|
||||||
pub(crate) module: Option<ModuleName>,
|
pub(crate) module: Option<ModuleName>,
|
||||||
pub(crate) name: Name,
|
pub(crate) name: Name,
|
||||||
|
@ -160,6 +160,7 @@ impl SymbolTable {
|
||||||
let mut builder = SymbolTableBuilder {
|
let mut builder = SymbolTableBuilder {
|
||||||
table: SymbolTable::new(),
|
table: SymbolTable::new(),
|
||||||
scopes: vec![root_scope_id],
|
scopes: vec![root_scope_id],
|
||||||
|
current_definition: None,
|
||||||
};
|
};
|
||||||
builder.visit_body(&module.body);
|
builder.visit_body(&module.body);
|
||||||
builder.table
|
builder.table
|
||||||
|
@ -386,6 +387,8 @@ where
|
||||||
struct SymbolTableBuilder {
|
struct SymbolTableBuilder {
|
||||||
table: SymbolTable,
|
table: SymbolTable,
|
||||||
scopes: Vec<ScopeId>,
|
scopes: Vec<ScopeId>,
|
||||||
|
/// the definition whose target(s) we are currently walking
|
||||||
|
current_definition: Option<Definition>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SymbolTableBuilder {
|
impl SymbolTableBuilder {
|
||||||
|
@ -448,8 +451,13 @@ impl SymbolTableBuilder {
|
||||||
|
|
||||||
impl PreorderVisitor<'_> for SymbolTableBuilder {
|
impl PreorderVisitor<'_> for SymbolTableBuilder {
|
||||||
fn visit_expr(&mut self, expr: &ast::Expr) {
|
fn visit_expr(&mut self, expr: &ast::Expr) {
|
||||||
if let ast::Expr::Name(ast::ExprName { id, .. }) = expr {
|
if let ast::Expr::Name(ast::ExprName { id, ctx, .. }) = expr {
|
||||||
self.add_symbol(id);
|
self.add_symbol(id);
|
||||||
|
if matches!(ctx, ast::ExprContext::Store | ast::ExprContext::Del) {
|
||||||
|
if let Some(curdef) = self.current_definition.clone() {
|
||||||
|
self.add_symbol_with_def(id, curdef);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ast::visitor::preorder::walk_expr(self, expr);
|
ast::visitor::preorder::walk_expr(self, expr);
|
||||||
}
|
}
|
||||||
|
@ -532,6 +540,13 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
|
||||||
|
|
||||||
self.table.dependencies.push(dependency);
|
self.table.dependencies.push(dependency);
|
||||||
}
|
}
|
||||||
|
ast::Stmt::Assign(node) => {
|
||||||
|
debug_assert!(self.current_definition.is_none());
|
||||||
|
self.current_definition =
|
||||||
|
Some(Definition::Assignment(TypedNodeKey::from_node(node)));
|
||||||
|
ast::visitor::preorder::walk_stmt(self, stmt);
|
||||||
|
self.current_definition = None;
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
ast::visitor::preorder::walk_stmt(self, stmt);
|
ast::visitor::preorder::walk_stmt(self, stmt);
|
||||||
}
|
}
|
||||||
|
@ -649,6 +664,19 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn assign() {
|
||||||
|
let parsed = parse("x = foo");
|
||||||
|
let table = SymbolTable::from_ast(parsed.ast());
|
||||||
|
assert_eq!(names(table.root_symbols()), vec!["foo", "x"]);
|
||||||
|
assert_eq!(
|
||||||
|
table
|
||||||
|
.definitions(table.root_symbol_id_by_name("x").unwrap())
|
||||||
|
.len(),
|
||||||
|
1
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn class_scope() {
|
fn class_scope() {
|
||||||
let parsed = parse(
|
let parsed = parse(
|
||||||
|
|
|
@ -53,12 +53,6 @@ impl From<FunctionTypeId> for Type {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<ClassTypeId> for Type {
|
|
||||||
fn from(id: ClassTypeId) -> Self {
|
|
||||||
Type::Class(id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<UnionTypeId> for Type {
|
impl From<UnionTypeId> for Type {
|
||||||
fn from(id: UnionTypeId) -> Self {
|
fn from(id: UnionTypeId) -> Self {
|
||||||
Type::Union(id)
|
Type::Union(id)
|
||||||
|
@ -129,8 +123,8 @@ impl TypeStore {
|
||||||
self.add_or_get_module(file_id).add_function(name)
|
self.add_or_get_module(file_id).add_function(name)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_class(&self, file_id: FileId, name: &str) -> ClassTypeId {
|
fn add_class(&self, file_id: FileId, name: &str, bases: Vec<Type>) -> ClassTypeId {
|
||||||
self.add_or_get_module(file_id).add_class(name)
|
self.add_or_get_module(file_id).add_class(name, bases)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
|
fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
|
||||||
|
@ -322,9 +316,11 @@ impl ModuleTypeStore {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_class(&mut self, name: &str) -> ClassTypeId {
|
fn add_class(&mut self, name: &str, bases: Vec<Type>) -> ClassTypeId {
|
||||||
let class_id = self.classes.push(ClassType {
|
let class_id = self.classes.push(ClassType {
|
||||||
name: Name::new(name),
|
name: Name::new(name),
|
||||||
|
// TODO: if no bases are given, that should imply [object]
|
||||||
|
bases,
|
||||||
});
|
});
|
||||||
ClassTypeId {
|
ClassTypeId {
|
||||||
file_id: self.file_id,
|
file_id: self.file_id,
|
||||||
|
@ -408,12 +404,17 @@ impl std::fmt::Display for DisplayType<'_> {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub(crate) struct ClassType {
|
pub(crate) struct ClassType {
|
||||||
name: Name,
|
name: Name,
|
||||||
|
bases: Vec<Type>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ClassType {
|
impl ClassType {
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
self.name.as_str()
|
self.name.as_str()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn bases(&self) -> &[Type] {
|
||||||
|
self.bases.as_slice()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
@ -497,7 +498,7 @@ mod tests {
|
||||||
let store = TypeStore::default();
|
let store = TypeStore::default();
|
||||||
let files = Files::default();
|
let files = Files::default();
|
||||||
let file_id = files.intern(Path::new("/foo"));
|
let file_id = files.intern(Path::new("/foo"));
|
||||||
let id = store.add_class(file_id, "C");
|
let id = store.add_class(file_id, "C", Vec::new());
|
||||||
assert_eq!(store.get_class(id).name(), "C");
|
assert_eq!(store.get_class(id).name(), "C");
|
||||||
let inst = Type::Instance(id);
|
let inst = Type::Instance(id);
|
||||||
assert_eq!(format!("{}", inst.display(&store)), "C");
|
assert_eq!(format!("{}", inst.display(&store)), "C");
|
||||||
|
@ -519,8 +520,8 @@ mod tests {
|
||||||
let mut store = TypeStore::default();
|
let mut store = TypeStore::default();
|
||||||
let files = Files::default();
|
let files = Files::default();
|
||||||
let file_id = files.intern(Path::new("/foo"));
|
let file_id = files.intern(Path::new("/foo"));
|
||||||
let c1 = store.add_class(file_id, "C1");
|
let c1 = store.add_class(file_id, "C1", Vec::new());
|
||||||
let c2 = store.add_class(file_id, "C2");
|
let c2 = store.add_class(file_id, "C2", Vec::new());
|
||||||
let elems = vec![Type::Instance(c1), Type::Instance(c2)];
|
let elems = vec![Type::Instance(c1), Type::Instance(c2)];
|
||||||
let id = store.add_union(file_id, &elems);
|
let id = store.add_union(file_id, &elems);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -536,9 +537,9 @@ mod tests {
|
||||||
let mut store = TypeStore::default();
|
let mut store = TypeStore::default();
|
||||||
let files = Files::default();
|
let files = Files::default();
|
||||||
let file_id = files.intern(Path::new("/foo"));
|
let file_id = files.intern(Path::new("/foo"));
|
||||||
let c1 = store.add_class(file_id, "C1");
|
let c1 = store.add_class(file_id, "C1", Vec::new());
|
||||||
let c2 = store.add_class(file_id, "C2");
|
let c2 = store.add_class(file_id, "C2", Vec::new());
|
||||||
let c3 = store.add_class(file_id, "C3");
|
let c3 = store.add_class(file_id, "C3", Vec::new());
|
||||||
let pos = vec![Type::Instance(c1), Type::Instance(c2)];
|
let pos = vec![Type::Instance(c1), Type::Instance(c2)];
|
||||||
let neg = vec![Type::Instance(c3)];
|
let neg = vec![Type::Instance(c3)];
|
||||||
let id = store.add_intersection(file_id, &pos, &neg);
|
let id = store.add_intersection(file_id, &pos, &neg);
|
||||||
|
|
|
@ -7,6 +7,7 @@ use crate::module::ModuleName;
|
||||||
use crate::symbols::{Definition, ImportFromDefinition, SymbolId};
|
use crate::symbols::{Definition, ImportFromDefinition, SymbolId};
|
||||||
use crate::types::Type;
|
use crate::types::Type;
|
||||||
use crate::FileId;
|
use crate::FileId;
|
||||||
|
use ruff_python_ast as ast;
|
||||||
|
|
||||||
// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`.
|
// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`.
|
||||||
#[tracing::instrument(level = "trace", skip(db))]
|
#[tracing::instrument(level = "trace", skip(db))]
|
||||||
|
@ -58,8 +59,14 @@ where
|
||||||
let ast = parsed.ast();
|
let ast = parsed.ast();
|
||||||
let node = node_key.resolve_unwrap(ast.as_any_node_ref());
|
let node = node_key.resolve_unwrap(ast.as_any_node_ref());
|
||||||
|
|
||||||
|
let bases: Vec<_> = node
|
||||||
|
.bases()
|
||||||
|
.iter()
|
||||||
|
.map(|base_expr| infer_expr_type(db, file_id, base_expr))
|
||||||
|
.collect();
|
||||||
|
|
||||||
let store = &db.jar().type_store;
|
let store = &db.jar().type_store;
|
||||||
let ty: Type = store.add_class(file_id, &node.name.id).into();
|
let ty = Type::Class(store.add_class(file_id, &node.name.id, bases));
|
||||||
store.cache_node_type(file_id, *node_key.erased(), ty);
|
store.cache_node_type(file_id, *node_key.erased(), ty);
|
||||||
ty
|
ty
|
||||||
}),
|
}),
|
||||||
|
@ -79,6 +86,13 @@ where
|
||||||
store.cache_node_type(file_id, *node_key.erased(), ty);
|
store.cache_node_type(file_id, *node_key.erased(), ty);
|
||||||
ty
|
ty
|
||||||
}),
|
}),
|
||||||
|
Definition::Assignment(node_key) => {
|
||||||
|
let parsed = db.parse(file_id);
|
||||||
|
let ast = parsed.ast();
|
||||||
|
let node = node_key.resolve_unwrap(ast.as_any_node_ref());
|
||||||
|
// TODO handle unpacking assignment correctly
|
||||||
|
infer_expr_type(db, file_id, &node.value)
|
||||||
|
}
|
||||||
_ => todo!("other kinds of definitions"),
|
_ => todo!("other kinds of definitions"),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -89,6 +103,24 @@ where
|
||||||
ty
|
ty
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn infer_expr_type<Db>(db: &Db, file_id: FileId, expr: &ast::Expr) -> Type
|
||||||
|
where
|
||||||
|
Db: SemanticDb + HasJar<SemanticJar>,
|
||||||
|
{
|
||||||
|
// TODO cache the resolution of the type on the node
|
||||||
|
let symbols = db.symbol_table(file_id);
|
||||||
|
match expr {
|
||||||
|
ast::Expr::Name(name) => {
|
||||||
|
if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) {
|
||||||
|
db.infer_symbol_type(file_id, symbol_id)
|
||||||
|
} else {
|
||||||
|
Type::Unknown
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => todo!("full expression type resolution"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::db::tests::TestDb;
|
use crate::db::tests::TestDb;
|
||||||
|
@ -123,32 +155,65 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn follow_import_to_class() -> std::io::Result<()> {
|
fn follow_import_to_class() -> std::io::Result<()> {
|
||||||
let TestCase {
|
let case = create_test()?;
|
||||||
src,
|
let db = &case.db;
|
||||||
db,
|
|
||||||
temp_dir: _temp_dir,
|
|
||||||
} = create_test()?;
|
|
||||||
|
|
||||||
let a_path = src.path().join("a.py");
|
let a_path = case.src.path().join("a.py");
|
||||||
let b_path = src.path().join("b.py");
|
let b_path = case.src.path().join("b.py");
|
||||||
std::fs::write(a_path, "from b import C as D")?;
|
std::fs::write(a_path, "from b import C as D; E = D")?;
|
||||||
std::fs::write(b_path, "class C: pass")?;
|
std::fs::write(b_path, "class C: pass")?;
|
||||||
let a_file = db
|
let a_file = db
|
||||||
.resolve_module(ModuleName::new("a"))
|
.resolve_module(ModuleName::new("a"))
|
||||||
.expect("module should be found")
|
.expect("module should be found")
|
||||||
.path(&db)
|
.path(db)
|
||||||
.file();
|
.file();
|
||||||
let a_syms = db.symbol_table(a_file);
|
let a_syms = db.symbol_table(a_file);
|
||||||
let d_sym = a_syms
|
let e_sym = a_syms
|
||||||
.root_symbol_id_by_name("D")
|
.root_symbol_id_by_name("E")
|
||||||
.expect("D symbol should be found");
|
.expect("E symbol should be found");
|
||||||
|
|
||||||
let ty = db.infer_symbol_type(a_file, d_sym);
|
let ty = db.infer_symbol_type(a_file, e_sym);
|
||||||
|
|
||||||
let jar = HasJar::<SemanticJar>::jar(&db);
|
|
||||||
|
|
||||||
|
let jar = HasJar::<SemanticJar>::jar(db);
|
||||||
assert!(matches!(ty, Type::Class(_)));
|
assert!(matches!(ty, Type::Class(_)));
|
||||||
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]");
|
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]");
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_base_class_by_name() -> std::io::Result<()> {
|
||||||
|
let case = create_test()?;
|
||||||
|
let db = &case.db;
|
||||||
|
|
||||||
|
let path = case.src.path().join("mod.py");
|
||||||
|
std::fs::write(path, "class Base: pass\nclass Sub(Base): pass")?;
|
||||||
|
let file = db
|
||||||
|
.resolve_module(ModuleName::new("mod"))
|
||||||
|
.expect("module should be found")
|
||||||
|
.path(db)
|
||||||
|
.file();
|
||||||
|
let syms = db.symbol_table(file);
|
||||||
|
let sym = syms
|
||||||
|
.root_symbol_id_by_name("Sub")
|
||||||
|
.expect("Sub symbol should be found");
|
||||||
|
|
||||||
|
let ty = db.infer_symbol_type(file, sym);
|
||||||
|
|
||||||
|
let Type::Class(class_id) = ty else {
|
||||||
|
panic!("Sub is not a Class")
|
||||||
|
};
|
||||||
|
let jar = HasJar::<SemanticJar>::jar(db);
|
||||||
|
let base_names: Vec<_> = jar
|
||||||
|
.type_store
|
||||||
|
.get_class(class_id)
|
||||||
|
.bases()
|
||||||
|
.iter()
|
||||||
|
.map(|base_ty| format!("{}", base_ty.display(&jar.type_store)))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(base_names, vec!["Literal[Base]"]);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue