mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-01 06:11:21 +00:00
[red-knot] resolve base class types (#11178)
## Summary Resolve base class types, as long as they are simple names. ## Test Plan cargo test
This commit is contained in:
parent
04a922866a
commit
ce030a467f
3 changed files with 129 additions and 35 deletions
|
@ -95,7 +95,7 @@ impl Symbol {
|
|||
// TODO storing TypedNodeKey for definitions means we have to search to find them again in the AST;
|
||||
// this is at best O(log n). If looking up definitions is a bottleneck we should look for
|
||||
// alternatives here.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) enum Definition {
|
||||
// For the import cases, we don't need reference to any arbitrary AST subtrees (annotations,
|
||||
// RHS), and referencing just the import statement node is imprecise (a single import statement
|
||||
|
@ -110,12 +110,12 @@ pub(crate) enum Definition {
|
|||
// TODO with statements, except handlers, function args...
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ImportDefinition {
|
||||
pub(crate) module: ModuleName,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ImportFromDefinition {
|
||||
pub(crate) module: Option<ModuleName>,
|
||||
pub(crate) name: Name,
|
||||
|
@ -160,6 +160,7 @@ impl SymbolTable {
|
|||
let mut builder = SymbolTableBuilder {
|
||||
table: SymbolTable::new(),
|
||||
scopes: vec![root_scope_id],
|
||||
current_definition: None,
|
||||
};
|
||||
builder.visit_body(&module.body);
|
||||
builder.table
|
||||
|
@ -386,6 +387,8 @@ where
|
|||
struct SymbolTableBuilder {
|
||||
table: SymbolTable,
|
||||
scopes: Vec<ScopeId>,
|
||||
/// the definition whose target(s) we are currently walking
|
||||
current_definition: Option<Definition>,
|
||||
}
|
||||
|
||||
impl SymbolTableBuilder {
|
||||
|
@ -448,8 +451,13 @@ impl SymbolTableBuilder {
|
|||
|
||||
impl PreorderVisitor<'_> for SymbolTableBuilder {
|
||||
fn visit_expr(&mut self, expr: &ast::Expr) {
|
||||
if let ast::Expr::Name(ast::ExprName { id, .. }) = expr {
|
||||
if let ast::Expr::Name(ast::ExprName { id, ctx, .. }) = expr {
|
||||
self.add_symbol(id);
|
||||
if matches!(ctx, ast::ExprContext::Store | ast::ExprContext::Del) {
|
||||
if let Some(curdef) = self.current_definition.clone() {
|
||||
self.add_symbol_with_def(id, curdef);
|
||||
}
|
||||
}
|
||||
}
|
||||
ast::visitor::preorder::walk_expr(self, expr);
|
||||
}
|
||||
|
@ -532,6 +540,13 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
|
|||
|
||||
self.table.dependencies.push(dependency);
|
||||
}
|
||||
ast::Stmt::Assign(node) => {
|
||||
debug_assert!(self.current_definition.is_none());
|
||||
self.current_definition =
|
||||
Some(Definition::Assignment(TypedNodeKey::from_node(node)));
|
||||
ast::visitor::preorder::walk_stmt(self, stmt);
|
||||
self.current_definition = None;
|
||||
}
|
||||
_ => {
|
||||
ast::visitor::preorder::walk_stmt(self, stmt);
|
||||
}
|
||||
|
@ -649,6 +664,19 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn assign() {
|
||||
let parsed = parse("x = foo");
|
||||
let table = SymbolTable::from_ast(parsed.ast());
|
||||
assert_eq!(names(table.root_symbols()), vec!["foo", "x"]);
|
||||
assert_eq!(
|
||||
table
|
||||
.definitions(table.root_symbol_id_by_name("x").unwrap())
|
||||
.len(),
|
||||
1
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn class_scope() {
|
||||
let parsed = parse(
|
||||
|
|
|
@ -53,12 +53,6 @@ impl From<FunctionTypeId> for Type {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<ClassTypeId> for Type {
|
||||
fn from(id: ClassTypeId) -> Self {
|
||||
Type::Class(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<UnionTypeId> for Type {
|
||||
fn from(id: UnionTypeId) -> Self {
|
||||
Type::Union(id)
|
||||
|
@ -129,8 +123,8 @@ impl TypeStore {
|
|||
self.add_or_get_module(file_id).add_function(name)
|
||||
}
|
||||
|
||||
fn add_class(&self, file_id: FileId, name: &str) -> ClassTypeId {
|
||||
self.add_or_get_module(file_id).add_class(name)
|
||||
fn add_class(&self, file_id: FileId, name: &str, bases: Vec<Type>) -> ClassTypeId {
|
||||
self.add_or_get_module(file_id).add_class(name, bases)
|
||||
}
|
||||
|
||||
fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
|
||||
|
@ -322,9 +316,11 @@ impl ModuleTypeStore {
|
|||
}
|
||||
}
|
||||
|
||||
fn add_class(&mut self, name: &str) -> ClassTypeId {
|
||||
fn add_class(&mut self, name: &str, bases: Vec<Type>) -> ClassTypeId {
|
||||
let class_id = self.classes.push(ClassType {
|
||||
name: Name::new(name),
|
||||
// TODO: if no bases are given, that should imply [object]
|
||||
bases,
|
||||
});
|
||||
ClassTypeId {
|
||||
file_id: self.file_id,
|
||||
|
@ -408,12 +404,17 @@ impl std::fmt::Display for DisplayType<'_> {
|
|||
#[derive(Debug)]
|
||||
pub(crate) struct ClassType {
|
||||
name: Name,
|
||||
bases: Vec<Type>,
|
||||
}
|
||||
|
||||
impl ClassType {
|
||||
fn name(&self) -> &str {
|
||||
self.name.as_str()
|
||||
}
|
||||
|
||||
fn bases(&self) -> &[Type] {
|
||||
self.bases.as_slice()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -497,7 +498,7 @@ mod tests {
|
|||
let store = TypeStore::default();
|
||||
let files = Files::default();
|
||||
let file_id = files.intern(Path::new("/foo"));
|
||||
let id = store.add_class(file_id, "C");
|
||||
let id = store.add_class(file_id, "C", Vec::new());
|
||||
assert_eq!(store.get_class(id).name(), "C");
|
||||
let inst = Type::Instance(id);
|
||||
assert_eq!(format!("{}", inst.display(&store)), "C");
|
||||
|
@ -519,8 +520,8 @@ mod tests {
|
|||
let mut store = TypeStore::default();
|
||||
let files = Files::default();
|
||||
let file_id = files.intern(Path::new("/foo"));
|
||||
let c1 = store.add_class(file_id, "C1");
|
||||
let c2 = store.add_class(file_id, "C2");
|
||||
let c1 = store.add_class(file_id, "C1", Vec::new());
|
||||
let c2 = store.add_class(file_id, "C2", Vec::new());
|
||||
let elems = vec![Type::Instance(c1), Type::Instance(c2)];
|
||||
let id = store.add_union(file_id, &elems);
|
||||
assert_eq!(
|
||||
|
@ -536,9 +537,9 @@ mod tests {
|
|||
let mut store = TypeStore::default();
|
||||
let files = Files::default();
|
||||
let file_id = files.intern(Path::new("/foo"));
|
||||
let c1 = store.add_class(file_id, "C1");
|
||||
let c2 = store.add_class(file_id, "C2");
|
||||
let c3 = store.add_class(file_id, "C3");
|
||||
let c1 = store.add_class(file_id, "C1", Vec::new());
|
||||
let c2 = store.add_class(file_id, "C2", Vec::new());
|
||||
let c3 = store.add_class(file_id, "C3", Vec::new());
|
||||
let pos = vec![Type::Instance(c1), Type::Instance(c2)];
|
||||
let neg = vec![Type::Instance(c3)];
|
||||
let id = store.add_intersection(file_id, &pos, &neg);
|
||||
|
|
|
@ -7,6 +7,7 @@ use crate::module::ModuleName;
|
|||
use crate::symbols::{Definition, ImportFromDefinition, SymbolId};
|
||||
use crate::types::Type;
|
||||
use crate::FileId;
|
||||
use ruff_python_ast as ast;
|
||||
|
||||
// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`.
|
||||
#[tracing::instrument(level = "trace", skip(db))]
|
||||
|
@ -58,8 +59,14 @@ where
|
|||
let ast = parsed.ast();
|
||||
let node = node_key.resolve_unwrap(ast.as_any_node_ref());
|
||||
|
||||
let bases: Vec<_> = node
|
||||
.bases()
|
||||
.iter()
|
||||
.map(|base_expr| infer_expr_type(db, file_id, base_expr))
|
||||
.collect();
|
||||
|
||||
let store = &db.jar().type_store;
|
||||
let ty: Type = store.add_class(file_id, &node.name.id).into();
|
||||
let ty = Type::Class(store.add_class(file_id, &node.name.id, bases));
|
||||
store.cache_node_type(file_id, *node_key.erased(), ty);
|
||||
ty
|
||||
}),
|
||||
|
@ -79,6 +86,13 @@ where
|
|||
store.cache_node_type(file_id, *node_key.erased(), ty);
|
||||
ty
|
||||
}),
|
||||
Definition::Assignment(node_key) => {
|
||||
let parsed = db.parse(file_id);
|
||||
let ast = parsed.ast();
|
||||
let node = node_key.resolve_unwrap(ast.as_any_node_ref());
|
||||
// TODO handle unpacking assignment correctly
|
||||
infer_expr_type(db, file_id, &node.value)
|
||||
}
|
||||
_ => todo!("other kinds of definitions"),
|
||||
};
|
||||
|
||||
|
@ -89,6 +103,24 @@ where
|
|||
ty
|
||||
}
|
||||
|
||||
fn infer_expr_type<Db>(db: &Db, file_id: FileId, expr: &ast::Expr) -> Type
|
||||
where
|
||||
Db: SemanticDb + HasJar<SemanticJar>,
|
||||
{
|
||||
// TODO cache the resolution of the type on the node
|
||||
let symbols = db.symbol_table(file_id);
|
||||
match expr {
|
||||
ast::Expr::Name(name) => {
|
||||
if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) {
|
||||
db.infer_symbol_type(file_id, symbol_id)
|
||||
} else {
|
||||
Type::Unknown
|
||||
}
|
||||
}
|
||||
_ => todo!("full expression type resolution"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::db::tests::TestDb;
|
||||
|
@ -123,32 +155,65 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn follow_import_to_class() -> std::io::Result<()> {
|
||||
let TestCase {
|
||||
src,
|
||||
db,
|
||||
temp_dir: _temp_dir,
|
||||
} = create_test()?;
|
||||
let case = create_test()?;
|
||||
let db = &case.db;
|
||||
|
||||
let a_path = src.path().join("a.py");
|
||||
let b_path = src.path().join("b.py");
|
||||
std::fs::write(a_path, "from b import C as D")?;
|
||||
let a_path = case.src.path().join("a.py");
|
||||
let b_path = case.src.path().join("b.py");
|
||||
std::fs::write(a_path, "from b import C as D; E = D")?;
|
||||
std::fs::write(b_path, "class C: pass")?;
|
||||
let a_file = db
|
||||
.resolve_module(ModuleName::new("a"))
|
||||
.expect("module should be found")
|
||||
.path(&db)
|
||||
.path(db)
|
||||
.file();
|
||||
let a_syms = db.symbol_table(a_file);
|
||||
let d_sym = a_syms
|
||||
.root_symbol_id_by_name("D")
|
||||
.expect("D symbol should be found");
|
||||
let e_sym = a_syms
|
||||
.root_symbol_id_by_name("E")
|
||||
.expect("E symbol should be found");
|
||||
|
||||
let ty = db.infer_symbol_type(a_file, d_sym);
|
||||
|
||||
let jar = HasJar::<SemanticJar>::jar(&db);
|
||||
let ty = db.infer_symbol_type(a_file, e_sym);
|
||||
|
||||
let jar = HasJar::<SemanticJar>::jar(db);
|
||||
assert!(matches!(ty, Type::Class(_)));
|
||||
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_base_class_by_name() -> std::io::Result<()> {
|
||||
let case = create_test()?;
|
||||
let db = &case.db;
|
||||
|
||||
let path = case.src.path().join("mod.py");
|
||||
std::fs::write(path, "class Base: pass\nclass Sub(Base): pass")?;
|
||||
let file = db
|
||||
.resolve_module(ModuleName::new("mod"))
|
||||
.expect("module should be found")
|
||||
.path(db)
|
||||
.file();
|
||||
let syms = db.symbol_table(file);
|
||||
let sym = syms
|
||||
.root_symbol_id_by_name("Sub")
|
||||
.expect("Sub symbol should be found");
|
||||
|
||||
let ty = db.infer_symbol_type(file, sym);
|
||||
|
||||
let Type::Class(class_id) = ty else {
|
||||
panic!("Sub is not a Class")
|
||||
};
|
||||
let jar = HasJar::<SemanticJar>::jar(db);
|
||||
let base_names: Vec<_> = jar
|
||||
.type_store
|
||||
.get_class(class_id)
|
||||
.bases()
|
||||
.iter()
|
||||
.map(|base_ty| format!("{}", base_ty.display(&jar.type_store)))
|
||||
.collect();
|
||||
|
||||
assert_eq!(base_names, vec!["Literal[Base]"]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue